Source code for claf.model.sequence_classification.mixin


from pathlib import Path
import logging

import torch
import pycm
from pycm.pycm_obj import pycmVectorError

from claf.model import cls_utils
from claf.model.base import ModelBase
from claf.metric.classification import macro_f1, macro_precision, macro_recall
from claf.metric.glue import simple_accuracy, f1, matthews_corr

logger = logging.getLogger(__name__)


[docs]class SequenceClassification: """ Sequence Classification Mixin Class """
[docs] def make_predictions(self, output_dict): """ Make predictions with model's output_dict * Args: output_dict: model's output dictionary consisting of - sequence_embed: embedding vector of the sequence - logits: representing unnormalized log probabilities of the class - class_idx: target class idx - data_idx: data idx - loss: a scalar loss to be optimized * Returns: predictions: prediction dictionary consisting of - key: 'id' (sequence id) - value: dictionary consisting of - class_idx """ data_indices = output_dict["data_idx"] pred_logits = output_dict["logits"] pred_class_idxs = torch.argmax(pred_logits, dim=-1) predictions = { self._dataset.get_id(data_idx.item()): {"class_idx": pred_class_idx.item()} for data_idx, pred_class_idx in zip(list(data_indices.data), list(pred_class_idxs.data)) } return predictions
[docs] def predict(self, output_dict, arguments, helper): """ Inference by raw_feature * Args: output_dict: model's output dictionary consisting of - sequence_embed: embedding vector of the sequence - logits: representing unnormalized log probabilities of the class. arguments: arguments dictionary consisting of user_input helper: dictionary to get the classification result, consisting of - class_idx2text: dictionary converting class_idx to class_text * Returns: output dict (dict) consisting of - logits: representing unnormalized log probabilities of the class - class_idx: predicted class idx - class_text: predicted class text """ logits = output_dict["logits"] class_idx = logits.argmax(dim=-1) return { "logits": logits, "class_idx": class_idx, "class_text": helper["class_idx2text"][class_idx.item()], }
[docs] def make_metrics(self, predictions): """ Make metrics with prediction dictionary * Args: predictions: prediction dictionary consisting of - key: 'id' (sequence id) - value: dictionary consisting of - class_idx * Returns: metrics: metric dictionary consisting of - 'macro_f1': class prediction macro(unweighted mean) f1 - 'macro_precision': class prediction macro(unweighted mean) precision - 'macro_recall': class prediction macro(unweighted mean) recall - 'accuracy': class prediction accuracy """ pred_idx = [] pred_classes = [] target_idx = [] target_classes = [] target_count = len(self._dataset.class_idx2text) for data_id, pred in predictions.items(): target = self._dataset.get_ground_truth(data_id) pred_idx.append(pred["class_idx"]) pred_classes.append(self._dataset.class_idx2text[pred["class_idx"]]) target_idx.append(target["class_idx"]) target_classes.append(target["class_text"]) metrics = { "accuracy": simple_accuracy(pred_idx, target_idx), } if target_count == 2: # binary class f1_metric = f1(pred_idx, target_idx) metrics.update(f1_metric) matthews_corr_metric = matthews_corr(pred_idx, target_idx) metrics.update(matthews_corr_metric) return metrics
[docs] def write_predictions(self, predictions, file_path=None, is_dict=True, pycm_obj=None): """ Override write_predictions() in ModelBase to log confusion matrix """ try: super(SequenceClassification, self).write_predictions( predictions, file_path=file_path, is_dict=is_dict ) except AttributeError: # TODO: Need to Fix model_base = ModelBase() model_base._log_dir = self._log_dir model_base._train_counter = self._train_counter model_base.training = self.training model_base.write_predictions(predictions, file_path=file_path, is_dict=is_dict) data_type = "train" if self.training else "valid" if pycm_obj is not None: stats_file_path = f"predictions-{data_type}-{self._train_counter.get_display()}-stats" pycm_obj.save_csv(str(Path(self._log_dir) / "predictions" / stats_file_path)) confusion_matrix_file_path = ( f"predictions-{data_type}-{self._train_counter.get_display()}-confusion_matrix" ) cls_utils.write_confusion_matrix_to_csv( str(Path(self._log_dir) / "predictions" / confusion_matrix_file_path), pycm_obj )
[docs] def print_examples(self, index, inputs, predictions): """ Print evaluation examples * Args: index: data index inputs: mini-batch inputs predictions: prediction dictionary consisting of - key: 'id' (sequence id) - value: dictionary consisting of - class_idx * Returns: print(Sequence, Target Class, Predicted Class) """ data_idx = inputs["labels"]["data_idx"][index].item() data_id = self._dataset.get_id(data_idx) helper = self._dataset.helper sequence = helper["examples"][data_id]["sequence"] target_class_text = helper["examples"][data_id]["class_text"] pred_class_idx = predictions[data_id]["class_idx"] pred_class_text = self._dataset.get_class_text_with_idx(pred_class_idx) print() print("- Sequence:", sequence) print("- Target:") print(" Class:", target_class_text) print("- Predict:") print(" Class:", pred_class_text) print()