Source code for claf.model.token_classification.bert


from overrides import overrides
from pytorch_transformers import BertModel
import torch.nn as nn

from claf.data.data_handler import CachePath
from claf.decorator import register
from claf.model.base import ModelWithoutTokenEmbedder
from claf.model.token_classification.mixin import TokenClassification

from claf.model import cls_utils


[docs]@register("model:bert_for_tok_cls") class BertForTokCls(TokenClassification, ModelWithoutTokenEmbedder): """ Implementation of Single Sentence Tagging model presented in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (https://arxiv.org/abs/1810.04805) * Args: token_embedder: used to embed the sequence num_tags: number of classified tags ignore_tag_idx: index of the tag to ignore when calculating loss (tag pad value) * Kwargs: pretrained_model_name: the name of a pre-trained model dropout: classification layer dropout """ def __init__( self, token_makers, num_tags, ignore_tag_idx, pretrained_model_name=None, dropout=0.2 ): super(BertForTokCls, self).__init__(token_makers) self.use_pytorch_transformers = True # for optimizer's model parameters self.ignore_tag_idx = ignore_tag_idx self.num_tags = num_tags self._model = BertModel.from_pretrained( pretrained_model_name, cache_dir=str(CachePath.ROOT) ) self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(self._model.config.hidden_size, num_tags) ) self.classifier.apply(self._model.init_weights) self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_tag_idx)
[docs] @overrides def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. { "bert_input": { "feature": [ [100, 576, 21, 45, 7, 91, 101, 0, 0, ...], ..., ] } "token_type": { "feature": [ [0, 0, 0, 0, 0, 0, 0, 0, 0, ...], ..., ] }, "tagged_sub_token_idxs": { [ [1, 3, 4, 0, 0, 0, 0, 0, 0, ...], ..., ] } } * Kwargs: label: label dictionary like below. { "class_idx": [2, 1, 0, 4, 5, ...] "data_idx": [2, 4, 5, 7, 2, 1, ...] } Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - sequence_embed: embedding vector of the sequence - tag_logits: representing unnormalized log probabilities of the tags. - tag_idxs: target class idx - data_idx: data idx - loss: a scalar loss to be optimized """ bert_inputs = features["bert_input"]["feature"] token_type_ids = features["token_type"]["feature"] tagged_sub_token_idxs = features["tagged_sub_token_idxs"]["feature"] num_tokens = features["num_tokens"]["feature"] attention_mask = (bert_inputs > 0).long() outputs = self._model( bert_inputs, token_type_ids=token_type_ids, attention_mask=attention_mask ) token_encodings = outputs[0] pooled_output = outputs[1] tag_logits = self.classifier(token_encodings) # [B, L, num_tags] # gather the logits of the tagged token positions. gather_token_pos_idxs = tagged_sub_token_idxs.unsqueeze(-1).repeat(1, 1, self.num_tags) token_tag_logits = tag_logits.gather(1, gather_token_pos_idxs) # [B, num_tokens, num_tags] sliced_token_tag_logits = [token_tag_logits[idx, :n, :] for idx, n in enumerate(num_tokens)] output_dict = {"sequence_embed": pooled_output, "tag_logits": sliced_token_tag_logits} if labels: tag_idxs = labels["tag_idxs"] data_idx = labels["data_idx"] output_dict["tag_idxs"] = tag_idxs output_dict["data_idx"] = data_idx # Loss loss = self.criterion(token_tag_logits.view(-1, self.num_tags), tag_idxs.view(-1)) output_dict["loss"] = loss.unsqueeze(0) # NOTE: DataParallel concat Error return output_dict
[docs] @overrides def print_examples(self, index, inputs, predictions): """ Print evaluation examples * Args: index: data index inputs: mini-batch inputs predictions: prediction dictionary consisting of - key: 'id' (sequence id) - value: dictionary consisting of - class_idx * Returns: print(Sequence, Sequence Tokens, Target Tags, Target Slots, Predicted Tags, Predicted Slots) """ data_idx = inputs["labels"]["data_idx"][index].item() data_id = self._dataset.get_id(data_idx) helper = self._dataset.helper sequence = helper["examples"][data_id]["sequence"] target_tag_texts = helper["examples"][data_id]["tag_texts"] pred_tag_idxs = predictions[data_id]["tag_idxs"] pred_tag_texts = self._dataset.get_tag_texts_with_idxs(pred_tag_idxs) sequence_tokens = helper["examples"][data_id]["sequence_sub_tokens"] print() print("- Sequence:", sequence) print("- Sequence Tokens:", sequence_tokens) print("- Target:") print(" Tags:", target_tag_texts) print(" (Slots)", cls_utils.get_tag_dict(sequence, target_tag_texts)) print("- Predict:") print(" Tags:", pred_tag_texts) print(" (Slots)", cls_utils.get_tag_dict(sequence, pred_tag_texts)) print()