Source code for claf.modules.attention.bi_attention


import torch
import torch.nn as nn

import claf.modules.functional as f


[docs]class BiAttention(nn.Module): """ Attention Flow Layer in BiDAF (https://arxiv.org/pdf/1611.01603.pdf) The Similarity matrix Context-to-query Attention (C2Q) Query-to-context Attention (Q2C) * Args: model_dim: The number of module dimension """ def __init__(self, model_dim): super(BiAttention, self).__init__() self.model_dim = model_dim self.W = nn.Linear(6 * model_dim, 1, bias=False)
[docs] def forward(self, context, context_mask, query, query_mask): c, c_mask, q, q_mask = context, context_mask, query, query_mask S = self._make_similiarity_matrix(c, q) # (B, C_L, Q_L) masked_S = f.add_masked_value(S, query_mask.unsqueeze(1), value=-1e7) c2q = self._context2query(S, q, q_mask) q2c = self._query2context(masked_S.max(dim=-1)[0], c, c_mask) # [h; u˜; h◦u˜; h◦h˜] ~ (B, C_L, 8d) G = torch.cat((c, c2q, c * c2q, c * q2c), dim=-1) return G
def _make_similiarity_matrix(self, c, q): # B: batch_size, C_L: context_maxlen, Q_L: query_maxlen B, C_L, Q_L = c.size(0), c.size(1), q.size(1) matrix_shape = (B, C_L, Q_L, self.model_dim * 2) c_aug = c.unsqueeze(2).expand(matrix_shape) # (B, C_L, Q_L, 2d) q_aug = q.unsqueeze(1).expand(matrix_shape) # (B, C_L, Q_L, 2d) c_q = torch.mul(c_aug, q_aug) # element-wise multiplication concated_vector = torch.cat((c_aug, q_aug, c_q), dim=3) # [h; u; h◦u] return self.W(concated_vector).view(c.size(0), C_L, Q_L) def _context2query(self, S, q, q_mask): attention = f.last_dim_masked_softmax(S, q_mask) # (B, C_L, Q_L) c2q = f.weighted_sum(attention=attention, matrix=q) # (B, C_L, 2d) return c2q def _query2context(self, S, c, c_mask): attention = f.masked_softmax(S, c_mask) # (B, C_L) q2c = f.weighted_sum(attention=attention, matrix=c) return q2c.unsqueeze(1).expand(c.size()) # (B, C_L, 2d)