Source code for claf.model.token_classification.mixin

from pathlib import Path
import logging

import numpy as np
import torch
import pycm
from pycm.pycm_obj import pycmVectorError

from claf.decorator import arguments_required
import claf.utils as common_utils
from claf.model import cls_utils
from claf.metric.classification import macro_f1, macro_precision, macro_recall
from seqeval.metrics import accuracy_score as conlleval_accuracy
from seqeval.metrics import f1_score as conlleval_f1

logger = logging.getLogger(__name__)

[docs]class TokenClassification: """ Token Classification Mixin Class """
[docs] def make_predictions(self, output_dict): """ Make predictions with model's output_dict * Args: output_dict: model's output dictionary consisting of - sequence_embed: embedding vector of the sequence - tag_logits: representing unnormalized log probabilities of the tag - tag_idxs: target tag idxs - data_idx: data idx - loss: a scalar loss to be optimized * Returns: predictions: prediction dictionary consisting of - key: 'id' (sequence id) - value: dictionary consisting of - tag_idxs """ data_indices = output_dict["data_idx"] pred_tag_logits = output_dict["tag_logits"] pred_tag_idxs = [ torch.argmax(pred_tag_logit, dim=-1).tolist() for pred_tag_logit in pred_tag_logits ] predictions = { self._dataset.get_id(data_idx.item()): {"tag_idxs": pred_tag_idx} for data_idx, pred_tag_idx in zip(list(, pred_tag_idxs) } return predictions
@arguments_required(["sequence"]) def predict(self, output_dict, arguments, helper): """ Inference by raw_feature * Args: output_dict: model's output dictionary consisting of - sequence_embed: embedding vector of the sequence - tag_logits: representing unnormalized log probabilities of the tags. arguments: arguments dictionary consisting of user_input helper: dictionary to get the classification result, consisting of - tag_idx2text: dictionary converting tag_idx to tag_text * Returns: output dict (dict) consisting of - tag_logits: representing unnormalized log probabilities of the tags - tag_idxs: predicted tag idxs - tag_texts: predicted tag texts - tag_slots: predicted tag slots """ sequence = arguments["sequence"] tag_logits = output_dict["tag_logits"][0] tag_idxs = [tag_logit.argmax(dim=-1) for tag_logit in tag_logits] tag_texts = [helper["tag_idx2text"][tag_idx.item()] for tag_idx in tag_idxs] return { "tag_logits": tag_logits, "tag_idxs": tag_idxs, "tag_texts": tag_texts, "tag_dict": cls_utils.get_tag_dict(sequence, tag_texts), }
[docs] def make_metrics(self, predictions): """ Make metrics with prediction dictionary * Args: predictions: prediction dictionary consisting of - key: 'id' (sequence id) - value: dictionary consisting of - tag_idxs * Returns: metrics: metric dictionary consisting of - 'accuracy': sequence level accuracy - 'tag_accuracy': tag level accuracy - 'macro_f1': tag prediction macro(unweighted mean) f1 - 'macro_precision': tag prediction macro(unweighted mean) precision - 'macro_recall': tag prediction macro(unweighted mean) recall """ pred_tag_idxs_list = [] target_tag_idxs_list = [] accurate_sequence = [] for data_idx, pred in predictions.items(): target = self._dataset.get_ground_truth(data_idx) pred_tag_idxs_list.append(pred["tag_idxs"]) target_tag_idxs_list.append(target["tag_idxs"]) accurate_sequence.append( 1 if (np.asarray(target["tag_idxs"]) == np.asarray(pred["tag_idxs"])).all() else 0 ) pred_tags = [ [self._dataset.tag_idx2text[tag_idx] for tag_idx in tag_idxs] for tag_idxs in pred_tag_idxs_list ] target_tags = [ [self._dataset.tag_idx2text[tag_idx] for tag_idx in tag_idxs] for tag_idxs in target_tag_idxs_list ] flat_pred_tags = list(common_utils.flatten(pred_tags)) flat_target_tags = list(common_utils.flatten(target_tags)) # confusion matrix try: pycm_obj = pycm.ConfusionMatrix(actual_vector=flat_target_tags, predict_vector=flat_pred_tags) except pycmVectorError as e: if str(e) == "Number of the classes is lower than 2": logger.warning("Number of tags in the batch is 1. Sanity check is highly recommended.") return { "accuracy": 1., "tag_accuracy": 1., "macro_f1": 1., "macro_precision": 1., "macro_recall": 1., "conlleval_accuracy": 1., "conlleval_f1": 1., } raise self.write_predictions( {"target": flat_target_tags, "predict": flat_pred_tags}, pycm_obj=pycm_obj ) sequence_accuracy = sum(accurate_sequence) / len(accurate_sequence) metrics = { "accuracy": sequence_accuracy, "tag_accuracy": pycm_obj.Overall_ACC, "macro_f1": macro_f1(pycm_obj), "macro_precision": macro_precision(pycm_obj), "macro_recall": macro_recall(pycm_obj), "conlleval_accuracy": conlleval_accuracy(target_tags, pred_tags), "conlleval_f1": conlleval_f1(target_tags, pred_tags), } return metrics
[docs] def write_predictions(self, predictions, file_path=None, is_dict=True, pycm_obj=None): """ Override write_predictions() in ModelBase to log confusion matrix """ super(TokenClassification, self).write_predictions( predictions, file_path=file_path, is_dict=is_dict ) data_type = "train" if else "valid" if pycm_obj is not None: stats_file_path = f"predictions-{data_type}-{self._train_counter.get_display()}-stats" pycm_obj.save_csv(str(Path(self._log_dir) / "predictions" / stats_file_path)) confusion_matrix_file_path = ( f"predictions-{data_type}-{self._train_counter.get_display()}-confusion_matrix" ) cls_utils.write_confusion_matrix_to_csv( str(Path(self._log_dir) / "predictions" / confusion_matrix_file_path), pycm_obj )
[docs] def print_examples(self, index, inputs, predictions): """ Print evaluation examples * Args: index: data index inputs: mini-batch inputs predictions: prediction dictionary consisting of - key: 'id' (sequence id) - value: dictionary consisting of - class_idx * Returns: print(Sequence, Target Tags, Target Slots, Predicted Tags, Predicted Slots) """ data_idx = inputs["labels"]["data_idx"][index].item() data_id = self._dataset.get_id(data_idx) helper = self._dataset.helper sequence = helper["examples"][data_id]["sequence"] target_tag_texts = helper["examples"][data_id]["tag_texts"] pred_tag_idxs = predictions[data_id]["tag_idxs"] pred_tag_texts = self._dataset.get_tag_texts_with_idxs(pred_tag_idxs) print() print("- Sequence:", sequence) print("- Target:") print(" Tags:", target_tag_texts) print(" (Slots)", cls_utils.get_tag_dict(sequence, target_tag_texts)) print("- Predict:") print(" Tags:", pred_tag_texts) print(" (Slots)", cls_utils.get_tag_dict(sequence, pred_tag_texts)) print()