Source code for claf.model.reading_comprehension.docqa

from overrides import overrides
import torch
import torch.nn as nn
import torch.nn.functional as F

from claf.decorator import register
from claf.model.base import ModelWithTokenEmbedder
from claf.model.reading_comprehension.mixin import SQuADv1
from claf.modules import attention, initializer
from claf.modules import functional as f


[docs]@register("model:docqa") class DocQA(SQuADv1, ModelWithTokenEmbedder): """ Document Reader Model. `Span Detector` Implementation of model presented in Simple and Effective Multi-Paragraph Reading Comprehension (https://arxiv.org/abs/1710.10723) - Embedding (Word + Char -> Contextual) - Attention - Residual self-attention - Output * Args: token_embedder: 'QATokenEmbedder', Used to embed the 'context' and 'question'. * Kwargs: lang_code: Dataset language code [en|ko] aligned_query_embedding: f_align(p_i) = sum(a_ij, E(qj), where the attention score a_ij captures the similarity between pi and each question words q_j. these features add soft alignments between similar but non-identical words (e.g., car and vehicle) it only apply to 'context_embed'. answer_maxlen: the most probable answer span of length less than or equal to {answer_maxlen} rnn_dim: the number of RNN cell dimension linear_dim: the number of attention linear dimension preprocess_rnn_num_layer: the number of recurrent layers (preprocess) modeling_rnn_num_layer: the number of recurrent layers (modeling) predict_rnn_num_layer: the number of recurrent layers (predict) dropout: the dropout probability """ def __init__( self, token_embedder, lang_code="en", aligned_query_embedding=False, answer_maxlen=17, rnn_dim=100, linear_dim=200, preprocess_rnn_num_layer=1, modeling_rnn_num_layer=2, predict_rnn_num_layer=1, dropout=0.2, weight_init=True, ): super(DocQA, self).__init__(token_embedder) self.lang_code = lang_code self.aligned_query_embedding = aligned_query_embedding self.answer_maxlen = answer_maxlen self.token_embedder = token_embedder self.dropout = nn.Dropout(p=dropout) context_embed_dim, query_embed_dim = token_embedder.get_embed_dim() if self.aligned_query_embedding: context_embed_dim += query_embed_dim if context_embed_dim != query_embed_dim: self.context_preprocess_rnn = nn.GRU( input_size=context_embed_dim, hidden_size=rnn_dim, bidirectional=True, num_layers=preprocess_rnn_num_layer, batch_first=True, ) self.query_preprocess_rnn = nn.GRU( input_size=query_embed_dim, hidden_size=rnn_dim, bidirectional=True, num_layers=preprocess_rnn_num_layer, batch_first=True, ) else: preprocess_rnn = nn.GRU( input_size=context_embed_dim, hidden_size=rnn_dim, bidirectional=True, num_layers=preprocess_rnn_num_layer, batch_first=True, ) self.context_preprocess_rnn = preprocess_rnn self.query_preprocess_rnn = preprocess_rnn self.bi_attention = attention.DocQAAttention(rnn_dim, linear_dim) self.attn_linear = nn.Linear(rnn_dim * 8, linear_dim) self.modeling_rnn = nn.GRU( input_size=linear_dim, hidden_size=rnn_dim, num_layers=modeling_rnn_num_layer, bidirectional=True, batch_first=True, ) self.self_attention = SelfAttention(rnn_dim, linear_dim, weight_init=weight_init) self.span_start_rnn = nn.GRU( input_size=linear_dim, hidden_size=rnn_dim, bidirectional=True, num_layers=predict_rnn_num_layer, batch_first=True, ) self.span_start_linear = nn.Linear(rnn_dim * 2, 1) self.span_end_rnn = nn.GRU( input_size=linear_dim + rnn_dim * 2, hidden_size=rnn_dim, bidirectional=True, num_layers=predict_rnn_num_layer, batch_first=True, ) self.span_end_linear = nn.Linear(rnn_dim * 2, 1) self.activation_fn = F.relu self.criterion = nn.CrossEntropyLoss() if weight_init: modules = [ self.context_preprocess_rnn, self.query_preprocess_rnn, self.modeling_rnn, self.attn_linear, self.span_start_rnn, self.span_start_linear, self.span_end_rnn, self.span_end_linear, ] initializer.weight(modules)
[docs] @overrides def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - data_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding ) context_mask = f.get_mask_from_tokens(context).float() # B X 1 X C_L query_mask = f.get_mask_from_tokens(question).float() # B X 1 X Q_L # Pre-process context_embed = self.dropout(context_embed) context_encoded = f.forward_rnn_with_pack( self.context_preprocess_rnn, context_embed, context_seq_config ) context_encoded = self.dropout(context_encoded) query_embed = self.dropout(query_embed) query_encoded = f.forward_rnn_with_pack( self.query_preprocess_rnn, query_embed, query_seq_config ) query_encoded = self.dropout(query_encoded) # Attention -> Projection context_attnded = self.bi_attention( context_encoded, context_mask, query_encoded, query_mask ) context_attnded = self.activation_fn(self.attn_linear(context_attnded)) # B X C_L X dim*2 # Residual Self-Attention context_attnded = self.dropout(context_attnded) context_encoded = f.forward_rnn_with_pack( self.modeling_rnn, context_attnded, context_seq_config ) context_encoded = self.dropout(context_encoded) context_self_attnded = self.self_attention(context_encoded, context_mask) # B X C_L X dim*2 context_final = self.dropout(context_attnded + context_self_attnded) # B X C_L X dim*2 # Prediction span_start_input = f.forward_rnn_with_pack( self.span_start_rnn, context_final, context_seq_config ) # B X C_L X dim*2 span_start_input = self.dropout(span_start_input) span_start_logits = self.span_start_linear(span_start_input).squeeze(-1) # B X C_L span_end_input = torch.cat([span_start_input, context_final], dim=-1) # B X C_L X dim*4 span_end_input = f.forward_rnn_with_pack( self.span_end_rnn, span_end_input, context_seq_config ) # B X C_L X dim*2 span_end_input = self.dropout(span_end_input) span_end_logits = self.span_end_linear(span_end_input).squeeze(-1) # B X C_L # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span( span_start_logits, span_end_logits, answer_maxlen=self.answer_maxlen ), } if labels: data_idx = labels["data_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] output_dict["data_idx"] = data_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze(0) # NOTE: DataParallel concat Error return output_dict
[docs]class SelfAttention(nn.Module): """ Same bi-attention mechanism, only now between the passage and itself. """ def __init__(self, rnn_dim, linear_dim, dropout=0.2, weight_init=True): super(SelfAttention, self).__init__() self.self_attention = attention.DocQAAttention( rnn_dim, linear_dim, self_attn=True, weight_init=weight_init ) self.self_attn_Linear = nn.Linear(rnn_dim * 6, linear_dim) self.dropout = nn.Dropout(p=dropout) self.activation_fn = F.relu if weight_init: initializer.weight(self.self_attn_Linear)
[docs] def forward(self, context, context_mask): context_self_attnded = self.self_attention(context, context_mask, context, context_mask) return self.activation_fn(self.self_attn_Linear(context_self_attnded))