Source code for claf.tokens.cove

"""
This code is from salesforce/cove
(https://github.com/salesforce/cove/blob/master/cove/encoder.py)
"""

import torch
from torch import nn

from claf.data.data_handler import CachePath, DataHandler


[docs]class MTLSTM(nn.Module): def __init__( self, word_embedding, pretrained_path=None, requires_grad=False, residual_embeddings=False ): """Initialize an MTLSTM. Arguments: n_vocab (bool): If not None, initialize MTLSTM with an embedding matrix with n_vocab vectors vectors (Float Tensor): If not None, initiapize embedding matrix with specified vectors residual_embedding (bool): If True, concatenate the input embeddings with MTLSTM outputs during forward """ super(MTLSTM, self).__init__() self.word_embedding = word_embedding self.rnn = nn.LSTM(300, 300, num_layers=2, bidirectional=True, batch_first=True) data_handler = DataHandler(cache_path=CachePath.PRETRAINED_VECTOR) cove_weight_path = data_handler.read(pretrained_path, return_path=True) if torch.cuda.is_available(): checkpoint = torch.load(cove_weight_path) else: checkpoint = torch.load(cove_weight_path, map_location="cpu") self.rnn.load_state_dict(checkpoint) self.residual_embeddings = residual_embeddings self.requires_grad = requires_grad
[docs] def forward(self, inputs): """A pretrained MT-LSTM (McCann et. al. 2017). This LSTM was trained with 300d 840B GloVe on the WMT 2017 machine translation dataset. Arguments: inputs (Tensor): If MTLSTM handles embedding, a Long Tensor of size (batch_size, timesteps). Otherwise, a Float Tensor of size (batch_size, timesteps, features). lengths (Long Tensor): (batch_size, lengths) lenghts of each sequence for handling padding hidden (Float Tensor): initial hidden state of the LSTM """ embedded_inputs = self.word_embedding(inputs) encoded_inputs, _ = self.rnn(embedded_inputs) if not self.requires_grad: encoded_inputs.detach() outputs = encoded_inputs if self.residual_embeddings: outputs = torch.cat([embedded_inputs, encoded_inputs], 2) return outputs