Source code for claf.model.reading_comprehension.mixin


from collections import OrderedDict

import numpy as np
import torch
import torch.nn.functional as F

from claf.decorator import arguments_required
from claf.metric import korquad_v1_official, squad_v1_official, squad_v2_official
from claf.model.base import ModelBase


[docs]class ReadingComprehension: """ Reading Comprehension Mixin Class * Args: token_embedder: 'RCTokenEmbedder', Used to embed the 'context' and 'question'. """
[docs] def get_best_span(self, span_start_logits, span_end_logits, answer_maxlen=None): """ Take argmax of constrained score_s * score_e. * Args: span_start_logits: independent start logits span_end_logits: independent end logits * Kwargs: answer_maxlen: max span length to consider (default is None -> All) """ B = span_start_logits.size(0) best_word_span = span_start_logits.new_zeros((B, 2), dtype=torch.long) score_starts = F.softmax(span_start_logits, dim=-1) score_ends = F.softmax(span_end_logits, dim=-1) max_len = answer_maxlen or score_starts.size(1) for i in range(score_starts.size(0)): # Outer product of scores to get full p_s * p_e matrix scores = torch.ger(score_starts[i], score_ends[i]) # Zero out negative length and over-length span scores scores.triu_().tril_(max_len - 1) # Take argmax or top n scores = scores.detach().cpu().numpy() scores_flat = scores.flatten() idx_sort = [np.argmax(scores_flat)] s_idx, e_idx = np.unravel_index(idx_sort, scores.shape) best_word_span[i, 0] = int(s_idx[0]) best_word_span[i, 1] = int(e_idx[0]) return best_word_span
def _make_span_metrics(self, predictions): """ span accuracy metrics """ start_accuracy, end_accuracy, span_accuracy = 0, 0, 0 for index, preds in predictions.items(): _, _, (answer_start, answer_end) = self._dataset.get_ground_truths(index) start_acc = 1 if preds["pred_span_start"] == answer_start else 0 end_acc = 1 if preds["pred_span_end"] == answer_end else 0 span_acc = 1 if start_acc == 1 and end_acc == 1 else 0 start_accuracy += start_acc end_accuracy += end_acc span_accuracy += span_acc start_accuracy = 100.0 * start_accuracy / len(self._dataset) end_accuracy = 100.0 * end_accuracy / len(self._dataset) span_accuracy = 100.0 * span_accuracy / len(self._dataset) return {"start_acc": start_accuracy, "end_acc": end_accuracy, "span_acc": span_accuracy}
[docs] def make_predictions(self, output_dict): """ Make predictions with model's output_dict * Args: output_dict: model's output dictionary consisting of - data_idx: question id - best_span: calculate the span_start_logits and span_end_logits to what is the best span - start_logits: span start logits - end_logits: span end logits * Returns: predictions: prediction dictionary consisting of - key: 'id' (question id) - value: consisting of dictionary predict_text, pred_span_start, pred_span_end, span_start_prob, span_end_prob """ data_indices = output_dict["data_idx"] best_word_span = output_dict["best_span"] return OrderedDict( [ ( index.item(), { "predict_text": self._dataset.get_text_with_index( index.item(), best_span[0], best_span[1] ), "pred_span_start": best_span[0], "pred_span_end": best_span[1], "start_logits": start_logits, "end_logits": end_logits, }, ) for index, best_span, start_logits, end_logits in zip( list(data_indices.data), list(best_word_span.data), list(output_dict["start_logits"].data), list(output_dict["end_logits"].data), ) ] )
@arguments_required(["context", "question"]) def predict(self, output_dict, arguments, helper): """ Inference by raw_feature * Args: output_dict: model's output dictionary consisting of - data_idx: question id - best_span: calculate the span_start_logits and span_end_logits to what is the best span arguments: arguments dictionary consisting of user_input helper: dictionary for helping get answer * Returns: span: predict best_span """ span_start, span_end = list(output_dict["best_span"][0].data) word_start = span_start.item() word_end = span_end.item() text_span = helper["text_span"] char_start = text_span[word_start][0] char_end = text_span[word_end][1] context_text = arguments["context"] answer_text = context_text[char_start:char_end] start_logit = output_dict["start_logits"][0] end_logit = output_dict["end_logits"][0] score = start_logit[span_start] + end_logit[span_end] score = score.item() return {"text": answer_text, "score": score}
[docs] def print_examples(self, index, inputs, predictions): """ Print evaluation examples * Args: index: data index inputs: mini-batch inputs predictions: prediction dictionary consisting of - key: 'id' (question id) - value: consisting of dictionary predict_text, pred_span_start, pred_span_end, span_start_prob, span_end_prob * Returns: print(Context, Question, Answers and Predict) """ data_index = inputs["labels"]["data_idx"][index].item() qid = self._dataset.get_qid(data_index) if "#" in qid: # bert case (qid#index) qid = qid.split("#")[0] helper = self._dataset.helper context = helper["examples"][qid]["context"] question = helper["examples"][qid]["question"] answers = helper["examples"][qid]["answers"] predict_text = predictions[data_index]["predict_text"] print() print("- Context:", context) print("- Question:", question) print("- Answers:", answers) print("- Predict:", predict_text) print()
[docs] def write_predictions(self, predictions, file_path=None, is_dict=True): pass
# TODO: start and end logits (TypeError: Object of type 'Tensor' is not JSON serializable) # try: # super(ReadingComprehension, self).write_predictions( # predictions, file_path=file_path, is_dict=is_dict # ) # except AttributeError: # # TODO: Need to Fix # model_base = ModelBase() # model_base._log_dir = self._log_dir # model_base._train_counter = self._train_counter # model_base.training = self.training # model_base.write_predictions(predictions, file_path=file_path, is_dict=is_dict)
[docs]class SQuADv1(ReadingComprehension): """ Reading Comprehension Mixin Class with SQuAD v1.1 evaluation * Args: token_embedder: 'QATokenEmbedder', Used to embed the 'context' and 'question'. """
[docs] def make_metrics(self, predictions): """ Make metrics with prediction dictionary * Args: predictions: prediction dictionary consisting of - key: 'id' (question id) - value: (predict_text, pred_span_start, pred_span_end) * Returns: metrics: metric dictionary consisting of - 'em': exact_match (SQuAD v1.1 official evaluation) - 'f1': f1 (SQuAD v1.1 official evaluation) - 'start_acc': span_start accuracy - 'end_acc': span_end accuracy - 'span_acc': span accuracy (start and end) """ preds = {} for index, prediction in predictions.items(): _, _, (answer_start, answer_end) = self._dataset.get_ground_truths(index) qid = self._dataset.get_qid(index) preds[qid] = prediction["predict_text"] self.write_predictions(preds) squad_offical_metrics = self._make_metrics_with_official(preds) metrics = self._make_span_metrics(predictions) metrics.update(squad_offical_metrics) return metrics
def _make_metrics_with_official(self, preds): """ SQuAD v1.1 official evaluation """ dataset = self._dataset.raw_dataset if self.lang_code.startswith("ko"): scores = korquad_v1_official.evaluate(dataset, preds) else: scores = squad_v1_official.evaluate(dataset, preds) return scores
[docs]class SQuADv1ForBert(SQuADv1): """ Reading Comprehension Mixin Class with SQuAD v1.1 evaluation * Args: token_embedder: 'QATokenEmbedder', Used to embed the 'context' and 'question'. """
[docs] def make_metrics(self, predictions): """ BERT predictions need to get nbest result """ best_predictions = {} for index, prediction in predictions.items(): qid = self._dataset.get_qid(index) predict_text = prediction["predict_text"] start_logit = prediction["start_logits"][prediction["pred_span_start"]] end_logit = prediction["end_logits"][prediction["pred_span_end"]] predict_score = start_logit.item() + end_logit.item() if qid not in best_predictions: best_predictions[qid] = [] best_predictions[qid].append((predict_text, predict_score)) for qid, predictions in best_predictions.items(): sorted_predictions = sorted(predictions, key=lambda x: x[1], reverse=True) best_predictions[qid] = sorted_predictions[0][0] self.write_predictions(best_predictions) return self._make_metrics_with_official(best_predictions)
[docs] def predict(self, output_dict, arguments, helper): """ Inference by raw_feature * Args: output_dict: model's output dictionary consisting of - data_idx: question id - best_span: calculate the span_start_logits and span_end_logits to what is the best span arguments: arguments dictionary consisting of user_input helper: dictionary for helping get answer * Returns: span: predict best_span """ context_text = arguments["context"] bert_tokens = helper["bert_token"] predictions = [ (best_span, start_logits, end_logits) for best_span, start_logits, end_logits in zip( list(output_dict["best_span"].data), list(output_dict["start_logits"].data), list(output_dict["end_logits"].data), ) ] best_predictions = [] for index, prediction in enumerate(predictions): bert_token = bert_tokens[index] best_span, start_logits, end_logits = prediction pred_start, pred_end = best_span predict_text = "" if ( pred_start < len(bert_token) and pred_end < len(bert_token) and bert_token[pred_start].text_span is not None and bert_token[pred_end].text_span is not None ): char_start = bert_token[pred_start].text_span[0] char_end = bert_token[pred_end].text_span[1] predict_text = context_text[char_start:char_end] start_logit = start_logits[pred_start] end_logit = end_logits[pred_end] predict_score = start_logit.item() + end_logit.item() best_predictions.append((predict_text, predict_score)) sorted_predictions = sorted(best_predictions, key=lambda x: x[1], reverse=True) return {"text": sorted_predictions[0][0], "score": sorted_predictions[0][1]}
[docs]class SQuADv2(ReadingComprehension): """ Reading Comprehension Mixin Class with SQuAD v2.0 evaluation * Args: token_embedder: 'RCTokenEmbedder', Used to embed the 'context' and 'question'. """
[docs] def make_metrics(self, predictions): """ Make metrics with prediction dictionary * Args: predictions: prediction dictionary consisting of - key: 'id' (question id) - value: consisting of dictionary predict_text, pred_span_start, pred_span_end, span_start_prob, span_end_prob * Returns: metrics: metric dictionary consisting of - 'start_acc': span_start accuracy - 'end_acc': span_end accuracy - 'span_acc': span accuracy (start and end) - 'em': exact_match (SQuAD v2.0 official evaluation) - 'f1': f1 (SQuAD v2.0 official evaluation) - 'HasAns_exact': has answer exact_match - 'HasAns_f1': has answer f1 - 'NoAns_exact': no answer exact_match - 'NoAns_f1': no answer f1 - 'best_exact': best exact_match score with best_exact_thresh - 'best_exact_thresh': best exact_match answerable threshold - 'best_f1': best f1 score with best_f1_thresh - 'best_f1_thresh': best f1 answerable threshold """ preds, na_probs = {}, {} for index, prediction in predictions.items(): _, _, (answer_start, answer_end) = self._dataset.get_ground_truths(index) # Metrics (SQuAD official metric) predict_text = prediction["predict_text"] if predict_text == "<noanswer>": predict_text = "" qid = self._dataset.get_qid(index) preds[qid] = predict_text span_start_probs = F.softmax(prediction["start_logits"], dim=-1) span_end_probs = F.softmax(prediction["end_logits"], dim=-1) start_no_prob = span_start_probs[-1].item() end_no_prob = span_end_probs[-1].item() no_answer_prob = start_no_prob * end_no_prob na_probs[qid] = no_answer_prob self.write_predictions(preds) model_type = "train" if self.training else "valid" self.write_predictions( na_probs, file_path=f"na_probs-{model_type}-{self._train_counter.get_display()}.json" ) squad_offical_metrics = self._make_metrics_with_official(preds, na_probs) metrics = self._make_span_metrics(predictions) metrics.update(squad_offical_metrics) return metrics
def _make_metrics_with_official(self, preds, na_probs, na_prob_thresh=1.0): """ SQuAD 2.0 official evaluation """ dataset = self._dataset.raw_dataset squad_scores = squad_v2_official.evaluate(dataset, na_probs, preds) squad_scores["em"] = squad_scores["exact"] remove_keys = ["total", "exact", "HasAns_total", "NoAns_total"] for key in remove_keys: if key in squad_scores: del squad_scores[key] return squad_scores