Source code for claf.modules.encoder.lstm_cell_with_projection

This code is from allenai/allennlp

import itertools

from typing import Callable, List, Tuple, Union, Optional
import torch
from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence

[docs]class LstmCellWithProjection(torch.nn.Module): # pragma: no cover """ An LSTM with Recurrent Dropout and a projected and clipped hidden state and memory. Note: this implementation is slower than the native Pytorch LSTM because it cannot make use of CUDNN optimizations for stacked RNNs due to and variational dropout and the custom nature of the cell state. Parameters ---------- input_size : ``int``, required. The dimension of the inputs to the LSTM. hidden_size : ``int``, required. The dimension of the outputs of the LSTM. cell_size : ``int``, required. The dimension of the memory cell used for the LSTM. go_forward: ``bool``, optional (default = True) The direction in which the LSTM is applied to the sequence. Forwards by default, or backwards if False. recurrent_dropout_probability: ``float``, optional (default = 0.0) The dropout probability to be used in a dropout scheme as stated in `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks <>`_ . Implementation wise, this simply applies a fixed dropout mask per sequence to the recurrent connection of the LSTM. state_projection_clip_value: ``float``, optional, (default = None) The magnitude with which to clip the hidden_state after projecting it. memory_cell_clip_value: ``float``, optional, (default = None) The magnitude with which to clip the memory cell. Returns ------- output_accumulator : ``torch.FloatTensor`` The outputs of the LSTM for each timestep. A tensor of shape (batch_size, max_timesteps, hidden_size) where for a given batch element, all outputs past the sequence length for that batch are zero tensors. final_state: ``Tuple[torch.FloatTensor, torch.FloatTensor]`` The final (state, memory) states of the LSTM, with shape (1, batch_size, hidden_size) and (1, batch_size, cell_size) respectively. The first dimension is 1 in order to match the Pytorch API for returning stacked LSTM states. """ def __init__( self, input_size: int, hidden_size: int, cell_size: int, go_forward: bool = True, recurrent_dropout_probability: float = 0.0, memory_cell_clip_value: Optional[float] = None, state_projection_clip_value: Optional[float] = None, ) -> None: super(LstmCellWithProjection, self).__init__() # Required to be wrapped with a :class:`PytorchSeq2SeqWrapper`. self.input_size = input_size self.hidden_size = hidden_size self.cell_size = cell_size self.go_forward = go_forward self.state_projection_clip_value = state_projection_clip_value self.memory_cell_clip_value = memory_cell_clip_value self.recurrent_dropout_probability = recurrent_dropout_probability # We do the projections for all the gates all at once. self.input_linearity = torch.nn.Linear(input_size, 4 * cell_size, bias=False) self.state_linearity = torch.nn.Linear(hidden_size, 4 * cell_size, bias=True) # Additional projection matrix for making the hidden state smaller. self.state_projection = torch.nn.Linear(cell_size, hidden_size, bias=False) self.reset_parameters()
[docs] def reset_parameters(self): # Use sensible default initializations for parameters. block_orthogonal(, [self.cell_size, self.input_size]) block_orthogonal(, [self.cell_size, self.hidden_size]) # Initialize forget gate biases to 1.0 as per An Empirical # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015).[self.cell_size : 2 * self.cell_size].fill_(1.0)
[docs] def forward( self, # pylint: disable=arguments-differ inputs: torch.FloatTensor, batch_lengths: List[int], initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): """ Parameters ---------- inputs : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_timesteps, input_size) to apply the LSTM over. batch_lengths : ``List[int]``, required. A list of length batch_size containing the lengths of the sequences in batch. initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None) A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). Returns ------- output_accumulator : ``torch.FloatTensor`` The outputs of the LSTM for each timestep. A tensor of shape (batch_size, max_timesteps, hidden_size) where for a given batch element, all outputs past the sequence length for that batch are zero tensors. final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]`` A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). """ batch_size = inputs.size()[0] total_timesteps = inputs.size()[1] output_accumulator = inputs.new_zeros(batch_size, total_timesteps, self.hidden_size) if initial_state is None: full_batch_previous_memory = inputs.new_zeros(batch_size, self.cell_size) full_batch_previous_state = inputs.new_zeros(batch_size, self.hidden_size) else: full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0 and dropout_mask = get_dropout_mask( self.recurrent_dropout_probability, full_batch_previous_state ) else: dropout_mask = None for timestep in range(total_timesteps): # The index depends on which end we start. index = timestep if self.go_forward else total_timesteps - timestep - 1 # What we are doing here is finding the index into the batch dimension # which we need to use for this timestep, because the sequences have # variable length, so once the index is greater than the length of this # particular batch sequence, we no longer need to do the computation for # this sequence. The key thing to recognise here is that the batch inputs # must be _ordered_ by length from longest (first in batch) to shortest # (last) so initially, we are going forwards with every sequence and as we # pass the index at which the shortest elements of the batch finish, # we stop picking them up for the computation. if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 # If we're going backwards, we are _picking up_ more indices. else: # First conditional: Are we already at the maximum number of elements in the batch? # Second conditional: Does the next shortest sequence beyond the current batch # index require computation use this timestep? while ( current_length_index < (len(batch_lengths) - 1) and batch_lengths[current_length_index + 1] > index ): current_length_index += 1 # Actually get the slices of the batch which we # need for the computation at this timestep. # shape (batch_size, cell_size) previous_memory = full_batch_previous_memory[0 : current_length_index + 1].clone() # Shape (batch_size, hidden_size) previous_state = full_batch_previous_state[0 : current_length_index + 1].clone() # Shape (batch_size, input_size) timestep_input = inputs[0 : current_length_index + 1, index] # Do the projections for all the gates all at once. # Both have shape (batch_size, 4 * cell_size) projected_input = self.input_linearity(timestep_input) projected_state = self.state_linearity(previous_state) # Main LSTM equations using relevant chunks of the big linear # projections of the hidden state and inputs. input_gate = torch.sigmoid( projected_input[:, (0 * self.cell_size) : (1 * self.cell_size)] + projected_state[:, (0 * self.cell_size) : (1 * self.cell_size)] ) forget_gate = torch.sigmoid( projected_input[:, (1 * self.cell_size) : (2 * self.cell_size)] + projected_state[:, (1 * self.cell_size) : (2 * self.cell_size)] ) memory_init = torch.tanh( projected_input[:, (2 * self.cell_size) : (3 * self.cell_size)] + projected_state[:, (2 * self.cell_size) : (3 * self.cell_size)] ) output_gate = torch.sigmoid( projected_input[:, (3 * self.cell_size) : (4 * self.cell_size)] + projected_state[:, (3 * self.cell_size) : (4 * self.cell_size)] ) memory = input_gate * memory_init + forget_gate * previous_memory # Here is the non-standard part of this LSTM cell; first, we clip the # memory cell, then we project the output of the timestep to a smaller size # and again clip it. if self.memory_cell_clip_value: # pylint: disable=invalid-unary-operand-type memory = torch.clamp( memory, -self.memory_cell_clip_value, self.memory_cell_clip_value ) # shape (current_length_index, cell_size) pre_projection_timestep_output = output_gate * torch.tanh(memory) # shape (current_length_index, hidden_size) timestep_output = self.state_projection(pre_projection_timestep_output) if self.state_projection_clip_value: # pylint: disable=invalid-unary-operand-type timestep_output = torch.clamp( timestep_output, -self.state_projection_clip_value, self.state_projection_clip_value, ) # Only do dropout if the dropout prob is > 0.0 and we are in training mode. if dropout_mask is not None: timestep_output = timestep_output * dropout_mask[0 : current_length_index + 1] # We've been doing computation with less than the full batch, so here we create a new # variable for the the whole batch at this timestep and insert the result for the # relevant elements of the batch into it. full_batch_previous_memory = full_batch_previous_memory.clone() full_batch_previous_state = full_batch_previous_state.clone() full_batch_previous_memory[0 : current_length_index + 1] = memory full_batch_previous_state[0 : current_length_index + 1] = timestep_output output_accumulator[0 : current_length_index + 1, index] = timestep_output # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, ...). As this # LSTM cell cannot be stacked, the first dimension here is just 1. final_state = ( full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0), ) return output_accumulator, final_state
[docs]def get_dropout_mask( dropout_probability: float, tensor_for_masking: torch.Tensor ): # pragma: no cover """ Computes and returns an element-wise dropout mask for a given tensor, where each element in the mask is dropped out with probability dropout_probability. Note that the mask is NOT applied to the tensor - the tensor is passed to retain the correct CUDA tensor type for the mask. Parameters ---------- dropout_probability : float, required. Probability of dropping a dimension of the input. tensor_for_masking : torch.Tensor, required. Returns ------- A torch.FloatTensor consisting of the binary mask scaled by 1/ (1 - dropout_probability). This scaling ensures expected values and variances of the output of applying this mask and the original tensor are the same. """ binary_mask = tensor_for_masking.new_tensor( torch.rand(tensor_for_masking.size()) > dropout_probability ) # Scale mask by 1/keep_prob to preserve output statistics. dropout_mask = binary_mask.float().div(1.0 - dropout_probability) return dropout_mask
[docs]def block_orthogonal( tensor: torch.Tensor, split_sizes: List[int], gain: float = 1.0 ) -> None: # pragma: no cover """ An initializer which allows initializing model parameters in "blocks". This is helpful in the case of recurrent models which use multiple gates applied to linear projections, which can be computed efficiently if they are concatenated together. However, they are separate parameters which should be initialized independently. Parameters ---------- tensor : ``torch.Tensor``, required. A tensor to initialize. split_sizes : List[int], required. A list of length ``tensor.ndim()`` specifying the size of the blocks along that particular dimension. E.g. ``[10, 20]`` would result in the tensor being split into chunks of size 10 along the first dimension and 20 along the second. gain : float, optional (default = 1.0) The gain (scaling) applied to the orthogonal initialization. """ data = sizes = list(tensor.size()) if any([a % b != 0 for a, b in zip(sizes, split_sizes)]): raise ValueError( "tensor dimensions must be divisible by their respective " "split_sizes. Found size: {} and split_sizes: {}".format(sizes, split_sizes) ) indexes = [list(range(0, max_size, split)) for max_size, split in zip(sizes, split_sizes)] # Iterate over all possible blocks within the tensor. for block_start_indices in itertools.product(*indexes): # A list of tuples containing the index to start at for this block # and the appropriate step size (i.e split_size[i] for dimension i). index_and_step_tuples = zip(block_start_indices, split_sizes) # This is a tuple of slices corresponding to: # tensor[index: index + step_size, ...]. This is # required because we could have an arbitrary number # of dimensions. The actual slices we need are the # start_index: start_index + step for each dimension in the tensor. block_slice = tuple( [slice(start_index, start_index + step) for start_index, step in index_and_step_tuples] ) data[block_slice] = torch.nn.init.orthogonal_(tensor[block_slice].contiguous(), gain=gain)
[docs]def sort_batch_by_length(tensor: torch.Tensor, sequence_lengths: torch.Tensor): # pragma: no cover """ Sort a batch first tensor by some specified lengths. Parameters ---------- tensor : torch.FloatTensor, required. A batch first Pytorch tensor. sequence_lengths : torch.LongTensor, required. A tensor representing the lengths of some dimension of the tensor which we want to sort by. Returns ------- sorted_tensor : torch.FloatTensor The original tensor sorted along the batch dimension with respect to sequence_lengths. sorted_sequence_lengths : torch.LongTensor The original sequence_lengths sorted by decreasing size. restoration_indices : torch.LongTensor Indices into the sorted_tensor such that ``sorted_tensor.index_select(0, restoration_indices) == original_tensor`` permuation_index : torch.LongTensor The indices used to sort the tensor. This is useful if you want to sort many tensors using the same ordering. """ if not isinstance(tensor, torch.Tensor) or not isinstance(sequence_lengths, torch.Tensor): raise ValueError("Both the tensor and sequence lengths must be torch.Tensors.") sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True) sorted_tensor = tensor.index_select(0, permutation_index) index_range = sequence_lengths.new_tensor(torch.arange(0, len(sequence_lengths))) # This is the equivalent of zipping with index, sorting by the original # sequence lengths and returning the now sorted indices. _, reverse_mapping = permutation_index.sort(0, descending=False) restoration_indices = index_range.index_select(0, reverse_mapping) return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
# We have two types here for the state, because storing the state in something # which is Iterable (like a tuple, below), is helpful for internal manipulation # - however, the states are consumed as either Tensors or a Tuple of Tensors, so # returning them in this format is unhelpful. RnnState = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] # pylint: disable=invalid-name RnnStateStorage = Tuple[torch.Tensor, ...] # pylint: disable=invalid-name class _EncoderBase(torch.nn.Module): # pragma: no cover # pylint: disable=abstract-method """ This abstract class serves as a base for the 3 ``Encoder`` abstractions in AllenNLP. - :class:`~allennlp.modules.seq2seq_encoders.Seq2SeqEncoders` - :class:`~allennlp.modules.seq2vec_encoders.Seq2VecEncoders` Additionally, this class provides functionality for sorting sequences by length so they can be consumed by Pytorch RNN classes, which require their inputs to be sorted by length. Finally, it also provides optional statefulness to all of it's subclasses by allowing the caching and retrieving of the hidden states of RNNs. """ def __init__(self, stateful: bool = False) -> None: super(_EncoderBase, self).__init__() self.stateful = stateful self._states: Optional[RnnStateStorage] = None def sort_and_run_forward( self, module: Callable[ [PackedSequence, Optional[RnnState]], Tuple[Union[PackedSequence, torch.Tensor], RnnState], ], inputs: torch.Tensor, mask: torch.Tensor, hidden_state: Optional[RnnState] = None, ): """ This function exists because Pytorch RNNs require that their inputs be sorted before being passed as input. As all of our Seq2xxxEncoders use this functionality, it is provided in a base class. This method can be called on any module which takes as input a ``PackedSequence`` and some ``hidden_state``, which can either be a tuple of tensors or a tensor. As all of our Seq2xxxEncoders have different return types, we return `sorted` outputs from the module, which is called directly. Additionally, we return the indices into the batch dimension required to restore the tensor to it's correct, unsorted order and the number of valid batch elements (i.e the number of elements in the batch which are not completely masked). This un-sorting and re-padding of the module outputs is left to the subclasses because their outputs have different types and handling them smoothly here is difficult. Parameters ---------- module : ``Callable[[PackedSequence, Optional[RnnState]], Tuple[Union[PackedSequence, torch.Tensor], RnnState]]``, required. A function to run on the inputs. In most cases, this is a ``torch.nn.Module``. inputs : ``torch.Tensor``, required. A tensor of shape ``(batch_size, sequence_length, embedding_size)`` representing the inputs to the Encoder. mask : ``torch.Tensor``, required. A tensor of shape ``(batch_size, sequence_length)``, representing masked and non-masked elements of the sequence for each element in the batch. hidden_state : ``Optional[RnnState]``, (default = None). A single tensor of shape (num_layers, batch_size, hidden_size) representing the state of an RNN with or a tuple of tensors of shapes (num_layers, batch_size, hidden_size) and (num_layers, batch_size, memory_size), representing the hidden state and memory state of an LSTM-like RNN. Returns ------- module_output : ``Union[torch.Tensor, PackedSequence]``. A Tensor or PackedSequence representing the output of the Pytorch Module. The batch size dimension will be equal to ``num_valid``, as sequences of zero length are clipped off before the module is called, as Pytorch cannot handle zero length sequences. final_states : ``Optional[RnnState]`` A Tensor representing the hidden state of the Pytorch Module. This can either be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in the case of a GRU, or a tuple of tensors, such as those required for an LSTM. restoration_indices : ``torch.LongTensor`` A tensor of shape ``(batch_size,)``, describing the re-indexing required to transform the outputs back to their original batch order. """ # In some circumstances you may have sequences of zero length. ``pack_padded_sequence`` # requires all sequence lengths to be > 0, so remove sequences of zero length before # calling self._module, then fill with zeros. # First count how many sequences are empty. batch_size = mask.size(0) num_valid = torch.sum(mask[:, 0]).int().item() sequence_lengths = mask.long().sum(-1) sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices = sort_batch_by_length( inputs, sequence_lengths ) # Now create a PackedSequence with only the non-empty, sorted sequences. packed_sequence_input = pack_padded_sequence( sorted_inputs[:num_valid, :, :], sorted_sequence_lengths[:num_valid].data.tolist(), batch_first=True, ) # Prepare the initial states. if not self.stateful: if hidden_state is None: initial_states = hidden_state elif isinstance(hidden_state, tuple): initial_states = [ state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous() for state in hidden_state ] else: initial_states = hidden_state.index_select(1, sorting_indices)[ :, :num_valid, : ].contiguous() else: initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices) # Actually call the module on the sorted PackedSequence. module_output, final_states = module(packed_sequence_input, initial_states) return module_output, final_states, restoration_indices def _get_initial_states( self, batch_size: int, num_valid: int, sorting_indices: torch.LongTensor ) -> Optional[RnnState]: """ Returns an initial state for use in an RNN. Additionally, this method handles the batch size changing across calls by mutating the state to append initial states for new elements in the batch. Finally, it also handles sorting the states with respect to the sequence lengths of elements in the batch and removing rows which are completely padded. Importantly, this `mutates` the state if the current batch size is larger than when it was previously called. Parameters ---------- batch_size : ``int``, required. The batch size can change size across calls to stateful RNNs, so we need to know if we need to expand or shrink the states before returning them. Expanded states will be set to zero. num_valid : ``int``, required. The batch may contain completely padded sequences which get removed before the sequence is passed through the encoder. We also need to clip these off of the state too. sorting_indices ``torch.LongTensor``, required. Pytorch RNNs take sequences sorted by length. When we return the states to be used for a given call to ``module.forward``, we need the states to match up to the sorted sequences, so before returning them, we sort the states using the same indices used to sort the sequences. Returns ------- This method has a complex return type because it has to deal with the first time it is called, when it has no state, and the fact that types of RNN have heterogeneous states. If it is the first time the module has been called, it returns ``None``, regardless of the type of the ``Module``. Otherwise, for LSTMs, it returns a tuple of ``torch.Tensors`` with shape ``(num_layers, num_valid, state_size)`` and ``(num_layers, num_valid, memory_size)`` respectively, or for GRUs, it returns a single ``torch.Tensor`` of shape ``(num_layers, num_valid, state_size)``. """ # We don't know the state sizes the first time calling forward, # so we let the module define what it's initial hidden state looks like. if self._states is None: return None # Otherwise, we have some previous states. if batch_size > self._states[0].size(1): # This batch is larger than the all previous states. # If so, resize the states. num_states_to_concat = batch_size - self._states[0].size(1) resized_states = [] # state has shape (num_layers, batch_size, hidden_size) for state in self._states: # This _must_ be inside the loop because some # RNNs have states with different last dimension sizes. zeros = state.new_zeros(state.size(0), num_states_to_concat, state.size(2)) resized_states.append([state, zeros], 1)) self._states = tuple(resized_states) correctly_shaped_states = self._states elif batch_size < self._states[0].size(1): # This batch is smaller than the previous one. correctly_shaped_states = tuple(state[:, :batch_size, :] for state in self._states) else: correctly_shaped_states = self._states # At this point, our states are of shape (num_layers, batch_size, hidden_size). # However, the encoder uses sorted sequences and additionally removes elements # of the batch which are fully padded. We need the states to match up to these # sorted and filtered sequences, so we do that in the next two blocks before # returning the state/s. if len(self._states) == 1: # GRUs only have a single state. This `unpacks` it from the # tuple and returns the tensor directly. correctly_shaped_state = correctly_shaped_states[0] sorted_state = correctly_shaped_state.index_select(1, sorting_indices) return sorted_state[:, :num_valid, :] else: # LSTMs have a state tuple of (state, memory). sorted_states = [ state.index_select(1, sorting_indices) for state in correctly_shaped_states ] return tuple(state[:, :num_valid, :] for state in sorted_states) def _update_states( self, final_states: RnnStateStorage, restoration_indices: torch.LongTensor ) -> None: """ After the RNN has run forward, the states need to be updated. This method just sets the state to the updated new state, performing several pieces of book-keeping along the way - namely, unsorting the states and ensuring that the states of completely padded sequences are not updated. Finally, it also detaches the state variable from the computational graph, such that the graph can be garbage collected after each batch iteration. Parameters ---------- final_states : ``RnnStateStorage``, required. The hidden states returned as output from the RNN. restoration_indices : ``torch.LongTensor``, required. The indices that invert the sorting used in ``sort_and_run_forward`` to order the states with respect to the lengths of the sequences in the batch. """ # TODO(Mark): seems weird to sort here, but append zeros in the subclasses. # which way around is best? new_unsorted_states = [state.index_select(1, restoration_indices) for state in final_states] if self._states is None: # We don't already have states, so just set the # ones we receive to be the current state. self._states = tuple( for state in new_unsorted_states) else: # Now we've sorted the states back so that they correspond to the original # indices, we need to figure out what states we need to update, because if we # didn't use a state for a particular row, we want to preserve its state. # Thankfully, the rows which are all zero in the state correspond exactly # to those which aren't used, so we create masks of shape (new_batch_size,), # denoting which states were used in the RNN computation. current_state_batch_size = self._states[0].size(1) new_state_batch_size = final_states[0].size(1) # Masks for the unused states of shape (1, new_batch_size, 1) used_new_rows_mask = [ (state[0, :, :].sum(-1) != 0.0).float().view(1, new_state_batch_size, 1) for state in new_unsorted_states ] new_states = [] if current_state_batch_size > new_state_batch_size: # The new state is smaller than the old one, # so just update the indices which we used. for old_state, new_state, used_mask in zip( self._states, new_unsorted_states, used_new_rows_mask ): # zero out all rows in the previous state # which _were_ used in the current state. masked_old_state = old_state[:, :new_state_batch_size, :] * (1 - used_mask) # The old state is larger, so update the relevant parts of it. old_state[:, :new_state_batch_size, :] = new_state + masked_old_state new_states.append(old_state.detach()) else: # The states are the same size, so we just have to # deal with the possibility that some rows weren't used. new_states = [] for old_state, new_state, used_mask in zip( self._states, new_unsorted_states, used_new_rows_mask ): # zero out all rows which _were_ used in the current state. masked_old_state = old_state * (1 - used_mask) # The old state is larger, so update the relevant parts of it. new_state += masked_old_state new_states.append(new_state.detach()) # It looks like there should be another case handled here - when # the current_state_batch_size < new_state_batch_size. However, # this never happens, because the states themeselves are mutated # by appending zeros when calling _get_inital_states, meaning that # the new states are either of equal size, or smaller, in the case # that there are some unused elements (zero-length) for the RNN computation. self._states = tuple(new_states) def reset_states(self): self._states = None