Source code for claf.model.sequence_classification.structured_self_attention


from overrides import overrides

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from claf.decorator import register
from claf.model.base import ModelWithTokenEmbedder
from claf.model.sequence_classification.mixin import SequenceClassification
from claf.modules import functional as f


[docs]@register("model:structured_self_attention") class StructuredSelfAttention(SequenceClassification, ModelWithTokenEmbedder): """ Implementation of model presented in A Structured Self-attentive Sentence Embedding (https://arxiv.org/abs/1703.03130) * Args: token_embedder: used to embed the sequence num_classes: number of classified classes * Kwargs: encoding_rnn_hidden_dim: hidden dimension of rnn (unidirectional) encoding_rnn_num_layer: the number of rnn layers encoding_rnn_dropout: rnn dropout probability attention_dim: attention dimension # d_a in the paper num_attention_heads: number of attention heads # r in the paper sequence_embed_dim: dimension of sequence embedding dropout: classification layer dropout penalization_coefficient: penalty coefficient for frobenius norm """ def __init__( self, token_embedder, num_classes, encoding_rnn_hidden_dim=300, encoding_rnn_num_layer=2, encoding_rnn_dropout=0., attention_dim=350, num_attention_heads=30, sequence_embed_dim=2000, dropout=0.5, penalization_coefficient=1., ): super(StructuredSelfAttention, self).__init__(token_embedder) rnn_input_dim = token_embedder.get_embed_dim() self.num_classes = num_classes self.encoding_rnn_hidden_dim = encoding_rnn_hidden_dim * 2 # bidirectional self.attention_dim = attention_dim self.num_attention_heads = num_attention_heads self.project_dim = sequence_embed_dim self.dropout = dropout self.penalization_coefficient = penalization_coefficient self.encoder = nn.LSTM( input_size=rnn_input_dim, hidden_size=encoding_rnn_hidden_dim, num_layers=encoding_rnn_num_layer, dropout=encoding_rnn_dropout, bidirectional=True, batch_first=True, ) self.A = nn.Sequential( nn.Linear(self.encoding_rnn_hidden_dim, attention_dim, bias=False), nn.Tanh(), nn.Linear(attention_dim, num_attention_heads, bias=False), ) self.fully_connected = nn.Sequential( nn.Linear(self.encoding_rnn_hidden_dim * num_attention_heads, sequence_embed_dim), nn.ReLU(), nn.Dropout(dropout), ) self.classifier = nn.Linear(sequence_embed_dim, num_classes) self.criterion = nn.CrossEntropyLoss()
[docs] @overrides def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"sequence": [0, 3, 4, 1]} * Kwargs: label: label dictionary like below. {"class_idx": 2, "data_idx": 0} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - sequence_embed: embedding vector of the sequence - logits: representing unnormalized log probabilities of the class. - class_idx: target class idx - data_idx: data idx - loss: a scalar loss to be optimized """ sequence = features["sequence"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward sequence_config = f.get_sorted_seq_config(sequence) token_embed = self.token_embedder(sequence) token_encodings = f.forward_rnn_with_pack( self.encoder, token_embed, sequence_config ) # [B, L, encoding_rnn_hidden_dim] attention = self.A(token_encodings).transpose(1, 2) # [B, num_attention_heads, L] sequence_mask = f.get_mask_from_tokens(sequence).float() # [B, L] sequence_mask = sequence_mask.unsqueeze(1).expand_as(attention) attention = F.softmax(f.add_masked_value(attention, sequence_mask) + 1e-13, dim=2) attended_encodings = torch.bmm( attention, token_encodings ) # [B, num_attention_heads, sequence_embed_dim] sequence_embed = self.fully_connected( attended_encodings.view(attended_encodings.size(0), -1) ) # [B, sequence_embed_dim] logits = self.classifier(sequence_embed) # [B, num_classes] output_dict = {"sequence_embed": sequence_embed, "logits": logits} if labels: class_idx = labels["class_idx"] data_idx = labels["data_idx"] output_dict["class_idx"] = class_idx output_dict["data_idx"] = data_idx # Loss loss = self.criterion(logits, class_idx) loss += self.penalty(attention) output_dict["loss"] = loss.unsqueeze(0) # NOTE: DataParallel concat Error return output_dict
[docs] def penalty(self, attention): aa = torch.bmm( attention, attention.transpose(1, 2) ) # [B, num_attention_heads, num_attention_heads] penalization_term = ((aa - aa.new_tensor(np.eye(aa.size(1)))) ** 2).sum() ** 0.5 return penalization_term * self.penalization_coefficient