Source code for claf.modules.attention.docqa_attention

import torch
import torch.nn as nn
import torch.nn.functional as F

from claf.modules import initializer
import claf.modules.functional as f


[docs]class DocQAAttention(nn.Module): """ Bi-Attention Layer + (Self-Attention) in DocumentQA (https://arxiv.org/abs/1710.10723) * Args: rnn_dim: the number of GRU cell hidden size linear_dim: the number of linear hidden size * Kwargs: self_attn: (bool) self-attention weight_init: (bool) weight initialization """ def __init__(self, rnn_dim, linear_dim, self_attn=False, weight_init=True): super(DocQAAttention, self).__init__() self.self_attn = self_attn self.input_w = nn.Linear(2 * rnn_dim, 1, bias=False) self.key_w = nn.Linear(2 * rnn_dim, 1, bias=False) self.dot_w = nn.Parameter(torch.randn(1, 1, rnn_dim * 2)) torch.nn.init.xavier_uniform_(self.dot_w) self.bias = nn.Parameter(torch.FloatTensor([[1]])) self.diag_mask = nn.Parameter(torch.eye(5000)) # NOTE: (hard-code) max_sequence_length if weight_init: initializer.weight(self.input_w) initializer.weight(self.key_w)
[docs] def forward(self, x, x_mask, key, key_mask): S = self._trilinear(x, key) if self.self_attn: seq_length = x.size(1) diag_mask = self.diag_mask.narrow(0, 0, seq_length).narrow(1, 0, seq_length) joint_mask = 1 - self._compute_attention_mask(x_mask, key_mask) mask = torch.clamp(diag_mask + joint_mask, 0, 1) masked_S = S + mask * (-1e7) x2key = self._x2key(masked_S, key, key_mask) return torch.cat((x, x2key, x * x2key), dim=-1) else: joint_mask = 1 - self._compute_attention_mask(x_mask, key_mask) masked_S = S + joint_mask * (-1e7) x2key = self._x2key(masked_S, key, key_mask) masked_S = f.add_masked_value(S, key_mask.unsqueeze(1), value=-1e7) key2x = self._key2x(masked_S.max(dim=-1)[0], x, x_mask) return torch.cat((x, x2key, x * x2key, x * key2x), dim=-1)
def _compute_attention_mask(self, x_mask, key_mask): x_mask = x_mask.unsqueeze(2) key_mask = key_mask.unsqueeze(1) joint_mask = torch.mul(x_mask, key_mask) return joint_mask def _trilinear(self, x, key): B, X_L, K_L = x.size(0), x.size(1), key.size(1) matrix_shape = (B, X_L, K_L) x_logits = self.input_w(x).expand(matrix_shape) key_logits = self.key_w(key).transpose(1, 2).expand(matrix_shape) x_dots = torch.mul(x, self.dot_w) x_key = torch.matmul(x_dots, key.transpose(1, 2)) return x_logits + key_logits + x_key def _x2key(self, S, key, key_mask): if self.self_attn: bias = torch.exp(self.bias) S = torch.exp(S) attention = S / (S.sum(dim=-1, keepdim=True).expand(S.size()) + bias.expand(S.size())) else: attention = F.softmax(S, dim=-1) # (B, C_L, Q_L) x2key = f.weighted_sum(attention=attention, matrix=key) # (B, C_L, 2d) return x2key def _key2x(self, S, x, x_mask): attention = f.masked_softmax(S, x_mask) # (B, C_L) key2x = f.weighted_sum(attention=attention, matrix=x) return key2x.unsqueeze(1).expand(x.size()) # (B, C_L, 2d)