Source code for claf.data.reader.bert.seq_cls


import json
import logging
import uuid

from overrides import overrides
from tqdm import tqdm

from claf.data.dataset import SeqClsBertDataset
from claf.data.dto import BertFeature, Helper
from claf.data.reader.base import DataReader
from claf.data import utils
from claf.decorator import register

logger = logging.getLogger(__name__)


[docs]@register("reader:seq_cls_bert") class SeqClsBertReader(DataReader): """ DataReader for Sequence (Single and Pair) Classification using BERT * Args: file_paths: .json file paths (train and dev) tokenizers: define tokenizers config (subword) * Kwargs: class_key: name of the label in .json file to use for classification """ CLASS_DATA = None METRIC_KEY = None def __init__( self, file_paths, tokenizers, sequence_max_length=None, class_key="class", cls_token="[CLS]", sep_token="[SEP]", input_type="bert", is_test=False, ): super(SeqClsBertReader, self).__init__(file_paths, SeqClsBertDataset) self.sequence_max_length = sequence_max_length self.text_columns = ["bert_input", "sequence"] # Tokenizers # - BERT: Word + Subword or Word + Char # - RoBERTa: BPE if input_type == "bert": self.tokenizer = tokenizers.get("subword", None) if self.tokenizer is None: self.tokenizer["char"] elif input_type == "roberta": self.tokenizer = tokenizers["bpe"] else: raise ValueError("'bert' and 'roberta' are available input_type.") self.class_key = class_key self.cls_token = cls_token self.sep_token = sep_token self.input_type = input_type self.is_test = is_test def _get_data(self, file_path, **kwargs): data = self.data_handler.read(file_path) seq_cls_data = json.loads(data) return seq_cls_data["data"] def _get_class_dicts(self, **kwargs): seq_cls_data = kwargs["data"] if self.class_key is None: class_data = self.CLASS_DATA else: class_data = [item[self.class_key] for item in seq_cls_data] class_data = list(set(class_data)) # remove duplicate class_idx2text = { class_idx: str(class_text) for class_idx, class_text in enumerate(class_data) } class_text2idx = {class_text: class_idx for class_idx, class_text in class_idx2text.items()} return class_idx2text, class_text2idx @overrides def _read(self, file_path, data_type=None): """ .json file structure should be something like this: { "data": [ { "sequence": "what a wonderful day!", "emotion": "happy" }, ... ], "emotion": [ // class_key "angry", "happy", "sad", ... ] } """ data = self._get_data(file_path, data_type=data_type) class_idx2text, class_text2idx = self._get_class_dicts(data=data) helper = Helper(**{ "file_path": file_path, "class_idx2text": class_idx2text, "class_text2idx": class_text2idx, "cls_token": self.cls_token, "sep_token": self.sep_token, "dataset": SeqClsBertDataset, "metric_key": self.METRIC_KEY, }) helper.set_model_parameter({ "num_classes": len(class_idx2text), }) helper.set_predict_helper({ "class_idx2text": class_idx2text, }) features, labels = [], [] for example in tqdm(data, desc=data_type): sequence_a = utils.get_sequence_a(example) sequence_b = example.get("sequence_b", None) sequence_a_tokens = self.tokenizer.tokenize(sequence_a) sequence_b_tokens = None if sequence_b: sequence_b_tokens = self.tokenizer.tokenize(sequence_b) bert_input = utils.make_bert_input( sequence_a, sequence_b, self.tokenizer, max_seq_length=self.sequence_max_length, data_type=data_type, cls_token=self.cls_token, sep_token=self.sep_token, input_type=self.input_type, ) if bert_input is None: continue if "uid" in example: data_uid = example["uid"] else: data_uid = str(uuid.uuid1()) # token_type(segment_ids) will be added in dataset feature_row = { "id": data_uid, "bert_input": bert_input, } features.append(feature_row) class_text = example[self.class_key] label_row = { "id": data_uid, "class_idx": class_text2idx[class_text], "class_text": class_text, } labels.append(label_row) helper.set_example(data_uid, { "sequence_a": sequence_a, "sequence_a_tokens": sequence_a_tokens, "sequence_b": sequence_b, "sequence_b_tokens": sequence_b_tokens, "class_idx": class_text2idx[class_text], "class_text": class_text, }) if self.is_test and len(features) >= 10: break return utils.make_batch(features, labels), helper.to_dict()
[docs] def read_one_example(self, inputs): """ inputs keys: sequence_a and sequence_b """ sequence_a = utils.get_sequence_a(inputs) sequence_b = inputs.get("sequence_b", None) bert_feature = BertFeature() bert_feature.set_input_with_speical_token( sequence_a, sequence_b, self.tokenizer, max_seq_length=self.sequence_max_length, data_type="predict", cls_token=self.cls_token, sep_token=self.sep_token, input_type=self.input_type, ) features = [bert_feature.to_dict()] helper = {} return features, helper