Source code for claf.modules.attention.multi_head_attention


import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import claf.modules.functional as f


[docs]class MultiHeadAttention(nn.Module): """ Transformer's Multi-Head Attention in "Attention is All You Need" (https://arxiv.org/abs/1706.03762) * Kwargs: num_head: the number of Head model_dim: the number of model dimension linear_key_dim: the number of linear key dimemsion linear_value_dim: the number of linear value dimension """ def __init__( self, num_head=8, model_dim=100, dropout=0.1, linear_key_dim=None, linear_value_dim=None ): super(MultiHeadAttention, self).__init__() if linear_key_dim is None: linear_key_dim = model_dim if linear_value_dim is None: linear_value_dim = model_dim assert linear_key_dim % num_head == 0 assert linear_value_dim % num_head == 0 self.model_dim = model_dim self.num_head = num_head self.projection = nn.ModuleList( [ nn.Linear(model_dim, linear_key_dim, bias=False), # query nn.Linear(model_dim, linear_key_dim, bias=False), # key nn.Linear(model_dim, linear_value_dim, bias=False), # value ] ) self.out_linear = nn.Linear(linear_value_dim, model_dim) if dropout > 0: self.dropout = nn.Dropout(dropout) else: self.dropout = lambda x: x
[docs] def forward(self, q, k, v, mask=None): q, k, v = self._linear_projection(q, k, v) qs, ks, vs = self._split_heads(q, k, v) outputs = self._scaled_dot_product(qs, ks, vs, mask=mask) output = self._concat_heads(outputs) return self.out_linear(output)
def _linear_projection(self, query, key, value): q = self.projection[0](query) k = self.projection[1](key) v = self.projection[2](value) return q, k, v def _split_heads(self, query, key, value): B = query.size(0) qs, ks, vs = [ x.view(B, -1, self.num_head, x.size(-1) // self.num_head).transpose(1, 2) for x in [query, key, value] ] return qs, ks, vs def _scaled_dot_product(self, query, key, value, mask=None): K_D = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(K_D) if mask is not None: mask = mask.unsqueeze(1).unsqueeze(1) # [B, #H, C_L, D] scores = f.add_masked_value(scores, mask, value=-1e7) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) return torch.matmul(attn, value) def _concat_heads(self, outputs): B = outputs.size(0) num_head, dim = outputs.size()[-2:] return outputs.transpose(1, 2).contiguous().view(B, -1, self.num_head * dim)