Source code for claf.data.reader.bert.glue.stsb


import logging

from overrides import overrides

from claf.data.reader import RegressionBertReader
from claf.decorator import register

logger = logging.getLogger(__name__)


[docs]@register("reader:stsb_bert") class STSBBertReader(RegressionBertReader): """ STS-B (Semantic Textual Similarity Benchmark) DataReader for BERT * Args: file_paths: .tsv file paths (train and dev) tokenizers: defined tokenizers config """ METRIC_KEY = "pearson_spearman_corr" def __init__( self, file_paths, tokenizers, sequence_max_length=None, cls_token="[CLS]", sep_token="[SEP]", input_type="bert", is_test=False, ): super(STSBBertReader, self).__init__( file_paths, tokenizers, sequence_max_length, label_key="score", cls_token=cls_token, sep_token=sep_token, input_type=input_type, is_test=is_test, ) @overrides def _get_data(self, file_path, **kwargs): data_type = kwargs["data_type"] _file = self.data_handler.read(file_path) lines = _file.split("\n") data = [] for i, line in enumerate(lines): if i == 0: continue line_tokens = line.split("\t") if len(line_tokens) <= 1: continue data.append({ "uid": f"stsb-{file_path}-{data_type}-{i}", "sequence_a": line_tokens[7], "sequence_b": line_tokens[8], "score": float(line_tokens[-1]), }) return data