Source code for claf.data.reader.seq_cls


import json
import logging
import uuid

from overrides import overrides
from tqdm import tqdm

from claf.data.dataset.seq_cls import SeqClsDataset
from claf.data.dto import Helper
from claf.data.reader.base import DataReader
from claf.data import utils
from claf.decorator import register

logger = logging.getLogger(__name__)


[docs]@register("reader:seq_cls") class SeqClsReader(DataReader): """ DataReader for Sequence Classification * Args: file_paths: .json file paths (train and dev) tokenizers: define tokenizers config (word) * Kwargs: class_key: name of the label in .json file to use for classification """ CLASS_DATA = None def __init__(self, file_paths, tokenizers, sequence_max_length=None, class_key="class"): super(SeqClsReader, self).__init__(file_paths, SeqClsDataset) self.sequence_max_length = sequence_max_length self.text_columns = ["sequence"] if "word" not in tokenizers: raise ValueError("WordTokenizer is required. define WordTokenizer") self.word_tokenizer = tokenizers["word"] self.class_key = class_key def _get_data(self, file_path, **kwargs): data = self.data_handler.read(file_path) seq_cls_data = json.loads(data) return seq_cls_data["data"] def _get_class_dicts(self, **kwargs): seq_cls_data = kwargs["data"] if self.class_key is None: class_data = self.CLASS_DATA else: class_data = [item[self.class_key] for item in seq_cls_data] class_data = list(set(class_data)) # remove duplicate class_idx2text = { class_idx: str(class_text) for class_idx, class_text in enumerate(class_data) } class_text2idx = {class_text: class_idx for class_idx, class_text in class_idx2text.items()} return class_idx2text, class_text2idx @overrides def _read(self, file_path, data_type=None): """ .json file structure should be something like this: { "data": [ { "sequence": "what a wonderful day!", "emotion": "happy" }, ... ], "emotion": [ // class_key "angry", "happy", "sad", ... ] } """ data = self._get_data(file_path, data_type=data_type) class_idx2text, class_text2idx = self._get_class_dicts(data=data) helper = Helper(**{ "file_path": file_path, "class_idx2text": class_idx2text, "class_text2idx": class_text2idx, }) helper.set_model_parameter({ "num_classes": len(class_idx2text), }) helper.set_predict_helper({ "class_idx2text": class_idx2text, }) features, labels = [], [] for example in tqdm(data, desc=data_type): sequence = example["sequence"].strip().replace("\n", "") sequence_words = self.word_tokenizer.tokenize(sequence) if ( self.sequence_max_length is not None and data_type == "train" and len(sequence_words) > self.sequence_max_length ): continue if "uid" in example: data_uid = example["uid"] else: data_uid = str(uuid.uuid1()) feature_row = { "id": data_uid, "sequence": sequence, } features.append(feature_row) class_text = example[self.class_key] label_row = { "id": data_uid, "class_idx": class_text2idx[class_text], "class_text": class_text, } labels.append(label_row) helper.set_example(data_uid, { "sequence": sequence, "class_idx": class_text2idx[class_text], "class_text": class_text, }) return utils.make_batch(features, labels), helper.to_dict()
[docs] def read_one_example(self, inputs): """ inputs keys: sequence """ sequence = inputs["sequence"].strip().replace("\n", "") inputs["sequence"] = sequence return inputs, {}