Source code for claf.model.reading_comprehension.qanet

import torch
import torch.nn as nn

from overrides import overrides

from claf.decorator import register
from claf.model.base import ModelWithTokenEmbedder
from claf.model.reading_comprehension.mixin import SQuADv1
import claf.modules.functional as f
import claf.modules.attention as attention
import claf.modules.encoder as encoder
import claf.modules.conv as conv
import claf.modules.layer as layer


[docs]@register("model:qanet") class QANet(SQuADv1, ModelWithTokenEmbedder): """ Document Reader Model. `Span Detector` Implementation of model presented in QANet:Combining Local Convolution with Global Self-Attention for Reading Comprehension (https://arxiv.org/abs/1804.09541) - Input Embedding Layer - Embedding Encoder Layer - Context-Query Attention Layer - Model Encoder Layer - Output Layer * Args: token_embedder: 'QATokenEmbedder', Used to embed the 'context' and 'question'. * Kwargs: lang_code: Dataset language code [en|ko] aligned_query_embedding: f_align(p_i) = sum(a_ij, E(qj), where the attention score a_ij captures the similarity between pi and each question words q_j. these features add soft alignments between similar but non-identical words (e.g., car and vehicle) it only apply to 'context_embed'. answer_maxlen: the most probable answer span of length less than or equal to {answer_maxlen} model_dim: the number of model dimension * Encoder Block Parameters (embedding, modeling) kernel_size: convolution kernel size in encoder block num_head: the number of multi-head attention's head num_conv_block: the number of convolution block in encoder block [Layernorm -> Conv (residual)] num_encoder_block: the number of the encoder block [position_encoding -> [n repeat conv block] -> Layernorm -> Self-attention (residual) -> Layernorm -> Feedforward (residual)] dropout: the dropout probability layer_dropout: the layer dropout probability (cf. Deep Networks with Stochastic Depth(https://arxiv.org/abs/1603.09382) ) """ def __init__( self, token_embedder, lang_code="en", aligned_query_embedding=False, answer_maxlen=None, model_dim=128, kernel_size_in_embedding=7, num_head_in_embedding=8, num_conv_block_in_embedding=4, num_embedding_encoder_block=1, kernel_size_in_modeling=5, num_head_in_modeling=8, num_conv_block_in_modeling=2, num_modeling_encoder_block=7, dropout=0.1, layer_dropout=0.9, ): super(QANet, self).__init__(token_embedder) self.lang_code = lang_code self.aligned_query_embedding = aligned_query_embedding self.answer_maxlen = answer_maxlen self.token_embedder = token_embedder context_embed_dim, query_embed_dim = token_embedder.get_embed_dim() if self.aligned_query_embedding: context_embed_dim += query_embed_dim if context_embed_dim != query_embed_dim: self.context_highway = layer.Highway(context_embed_dim) self.query_highway = layer.Highway(query_embed_dim) self.context_embed_pointwise_conv = conv.PointwiseConv(context_embed_dim, model_dim) self.query_embed_pointwise_conv = conv.PointwiseConv(query_embed_dim, model_dim) else: highway = layer.Highway(context_embed_dim) self.context_highway = highway self.query_highway = highway embed_pointwise_conv = conv.PointwiseConv(context_embed_dim, model_dim) self.context_embed_pointwise_conv = embed_pointwise_conv self.query_embed_pointwise_conv = embed_pointwise_conv self.embed_encoder_blocks = nn.ModuleList( [ EncoderBlock( model_dim=model_dim, kernel_size=kernel_size_in_embedding, num_head=num_head_in_embedding, num_conv_block=num_conv_block_in_modeling, dropout=dropout, layer_dropout=layer_dropout, ) for _ in range(num_embedding_encoder_block) ] ) self.co_attention = attention.CoAttention(model_dim) self.pointwise_conv = conv.PointwiseConv(model_dim * 4, model_dim) self.model_encoder_blocks = nn.ModuleList( [ EncoderBlock( model_dim=model_dim, kernel_size=kernel_size_in_modeling, num_head=num_head_in_modeling, num_conv_block=num_conv_block_in_modeling, dropout=dropout, layer_dropout=layer_dropout, ) for _ in range(num_modeling_encoder_block) ] ) self.span_start_linear = nn.Linear(model_dim * 2, 1, bias=False) self.span_end_linear = nn.Linear(model_dim * 2, 1, bias=False) self.dropout = nn.Dropout(p=dropout) self.criterion = nn.CrossEntropyLoss()
[docs] @overrides def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - data_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # 1. Input Embedding Layer query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding ) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() context_embed = self.context_highway(context_embed) context_embed = self.dropout(context_embed) context_embed = self.context_embed_pointwise_conv(context_embed) query_embed = self.query_highway(query_embed) query_embed = self.dropout(query_embed) query_embed = self.query_embed_pointwise_conv(query_embed) # 2. Embedding Encoder Layer for encoder_block in self.embed_encoder_blocks: context = encoder_block(context_embed) context_embed = context query = encoder_block(query_embed) query_embed = query # 3. Context-Query Attention Layer context_query_attention = self.co_attention(context, query, context_mask, query_mask) # Projection (memory issue) context_query_attention = self.pointwise_conv(context_query_attention) context_query_attention = self.dropout(context_query_attention) # 4. Model Encoder Layer model_encoder_block_inputs = context_query_attention # Stacked Model Encoder Block stacked_model_encoder_blocks = [] for i in range(3): for _, model_encoder_block in enumerate(self.model_encoder_blocks): output = model_encoder_block(model_encoder_block_inputs, context_mask) model_encoder_block_inputs = output stacked_model_encoder_blocks.append(output) # 5. Output Layer span_start_inputs = torch.cat( [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[1]], dim=-1 ) span_start_inputs = self.dropout(span_start_inputs) span_start_logits = self.span_start_linear(span_start_inputs).squeeze(-1) span_end_inputs = torch.cat( [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[2]], dim=-1 ) span_end_inputs = self.dropout(span_end_inputs) span_end_logits = self.span_end_linear(span_end_inputs).squeeze(-1) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span(span_start_logits, span_end_logits), } if labels: data_idx = labels["data_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["data_idx"] = data_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze(0) # NOTE: DataParallel concat Error return output_dict
[docs]class EncoderBlock(nn.Module): """ Encoder Block []: residual position_encoding -> [convolution-layer] x # -> [self-attention-layer] -> [feed-forward-layer] - convolution-layer: depthwise separable convolutions - self-attention-layer: multi-head attention - feed-forward-layer: pointwise convolution * Args: model_dim: the number of model dimension num_heads: the number of head in multi-head attention kernel_size: convolution kernel size num_conv_block: the number of convolution block dropout: the dropout probability layer_dropout: the layer dropout probability (cf. Deep Networks with Stochastic Depth(https://arxiv.org/abs/1603.09382) ) """ def __init__( self, model_dim=128, num_head=8, kernel_size=5, num_conv_block=4, dropout=0.1, layer_dropout=0.9, ): super(EncoderBlock, self).__init__() self.position_encoding = encoder.PositionalEncoding(model_dim) self.dropout = nn.Dropout(dropout) self.num_conv_block = num_conv_block self.conv_blocks = nn.ModuleList( [conv.DepSepConv(model_dim, model_dim, kernel_size) for _ in range(num_conv_block)] ) self.self_attention = attention.MultiHeadAttention( num_head=num_head, model_dim=model_dim, dropout=dropout ) self.feedforward_layer = layer.PositionwiseFeedForward( model_dim, model_dim * 4, dropout=dropout ) # survival probability for stochastic depth if layer_dropout < 1.0: L = (num_conv_block) + 2 - 1 layer_dropout_prob = round(1 - (1 / L) * (1 - layer_dropout), 3) self.residuals = nn.ModuleList( layer.ResidualConnection( model_dim, layer_dropout=layer_dropout_prob, layernorm=True ) for l in range(num_conv_block + 2) ) else: self.residuals = nn.ModuleList( layer.ResidualConnection(model_dim, layernorm=True) for l in range(num_conv_block + 2) )
[docs] def forward(self, x, mask=None): # Positional Encoding x = self.position_encoding(x) # Convolution Block (LayerNorm -> Conv) for i, conv_block in enumerate(self.conv_blocks): x = self.residuals[i](x, sub_layer_fn=conv_block) x = self.dropout(x) # LayerNorm -> Self-attention self_attention = lambda x: self.self_attention(q=x, k=x, v=x, mask=mask) x = self.residuals[self.num_conv_block](x, sub_layer_fn=self_attention) x = self.dropout(x) # LayerNorm -> Feedforward layer x = self.residuals[self.num_conv_block + 1](x, sub_layer_fn=self.feedforward_layer) x = self.dropout(x) return x