Source code for claf.modules.functional

"""
    some functional codes from allennlp: https://github.com/allenai/allennlp

    - add_masked_value : replace_masked_values (allennlp)
    - get_mask_from_tokens : get_mask_from_tokens (allennlp)
    - last_dim_masked_softmax : last_dim_masked_softmax (allennlp)
    - masked_softmax : masked_softmax (allennlp)
    - weighted_sum : weighted_sum (allennlp)
"""

import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


[docs]def add_masked_value(tensor, mask, value=-1e7): mask = mask.float() one_minus_mask = 1.0 - mask values_to_add = value * one_minus_mask return tensor * mask + values_to_add
[docs]def get_mask_from_tokens(tokens): tensor_dims = [(tensor.dim(), tensor) for tensor in tokens.values()] tensor_dims.sort(key=lambda x: x[0]) smallest_dim = tensor_dims[0][0] if smallest_dim == 2: token_tensor = tensor_dims[0][1] return (token_tensor != 0).long() elif smallest_dim == 3: character_tensor = tensor_dims[0][1] return ((character_tensor > 0).long().sum(dim=-1) > 0).long() else: raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim))
[docs]def last_dim_masked_softmax(x, mask): x_shape = x.size() reshaped_x = x.view(-1, x.size()[-1]) while mask.dim() < x.dim(): mask = mask.unsqueeze(1) mask = mask.expand_as(x).contiguous().float() mask = mask.view(-1, mask.size()[-1]) reshaped_result = masked_softmax(reshaped_x, mask) return reshaped_result.view(*x_shape)
[docs]def masked_softmax(x, mask): if mask is None: raise ValueError("mask can't be None.") output = F.softmax(x * mask, dim=-1) output = output * mask output = output / (output.sum(dim=1, keepdim=True) + 1e-13) return output
[docs]def weighted_sum(attention, matrix): # pragma: no cover if attention.dim() == 2 and matrix.dim() == 3: return attention.unsqueeze(1).bmm(matrix).squeeze(1) elif attention.dim() == 3 and matrix.dim() == 3: return attention.bmm(matrix) else: raise ValueError( f"attention dim {attention.dim()} and matrix dim {matrix.dim()} operation not support. (2, 3) and (3, 3) are available dimemsion." )
[docs]def masked_zero(tensor, mask): """ Tensor masking operation """ while mask.dim() < tensor.dim(): mask = mask.unsqueeze(-1) if isinstance(tensor, torch.FloatTensor): mask = mask.float() elif isinstance(tensor, torch.ByteTensor): mask = mask.byte() elif isinstance(tensor, torch.LongTensor): mask = mask.long() return tensor * mask
[docs]def masked_log_softmax(vector, mask): # pragma: no cover if mask is not None: vector = vector + mask.float().log() return torch.nn.functional.log_softmax(vector, dim=1)
[docs]def get_sorted_seq_config(features, pad_index=0): tensor_dims = [(tensor.dim(), tensor) for tensor in features.values()] tensor_dims.sort(key=lambda x: x[0]) smallest_dim = tensor_dims[0][0] if smallest_dim == 2: token_tensor = tensor_dims[0][1] else: raise ValueError("features smallest_dim must be `2` ([B, S_L]) ") seq_lengths = torch.sum(token_tensor > pad_index, dim=-1) seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) _, unperm_idx = perm_idx.sort(0) return {"seq_lengths": seq_lengths, "perm_idx": perm_idx, "unperm_idx": unperm_idx}
[docs]def forward_rnn_with_pack(rnn_module, tensor, seq_config): sorted_tensor = tensor[seq_config["perm_idx"]] packed_input = pack_padded_sequence(sorted_tensor, seq_config["seq_lengths"], batch_first=True) packed_output, _ = rnn_module(packed_input) output, _ = pad_packed_sequence(packed_output, batch_first=True) output = output[seq_config["unperm_idx"]] # restore origin order return output