Source code for claf.data.dataset.bert.squad


import json
from overrides import overrides

from claf.data import utils
from claf.data.collate import PadCollator
from claf.data.dataset.base import DatasetBase


[docs]class SQuADBertDataset(DatasetBase): """ SQuAD Dataset for BERT compatible with v1.1 and v2.0 * Args: batch: Batch DTO (claf.data.batch) * Kwargs: helper: helper from data_reader """ def __init__(self, batch, vocab, helper=None): super(SQuADBertDataset, self).__init__() self.name = "squad_bert" self.vocab = vocab self.helper = helper self.raw_dataset = helper["raw_dataset"] # Features self.bert_input_idx = [feature["bert_input"] for feature in batch.features] SEP_token = self.helper.get("sep_token", "[SEP]") self.token_type_idx = utils.make_bert_token_types(self.bert_input_idx, SEP_token=SEP_token) self.features = [self.bert_input_idx, self.token_type_idx] # for lazy_evaluation # Labels self.qids = {data_index: label["id"] for (data_index, label) in enumerate(batch.labels)} self.data_indices = list(self.qids.keys()) self.answers = { label["id"]: ( label["answerable"], (label["answer_start"], label["answer_end"]), ) for label in batch.labels } self.answer_starts = [label["answer_start"] for label in batch.labels] self.answer_ends = [label["answer_end"] for label in batch.labels] self.answerables = [label["answerable"] for label in batch.labels]
[docs] @overrides def collate_fn(self, cuda_device_id=None): """ collate: indexed features and labels -> tensor """ collator = PadCollator(cuda_device_id=cuda_device_id, pad_value=self.vocab.pad_index) def make_tensor_fn(data): bert_input_idxs, token_type_idxs, data_idxs, answer_starts, answer_ends, answerables = zip( *data ) features = { "bert_input": utils.transpose(bert_input_idxs, skip_keys=["text"]), "token_type": utils.transpose(token_type_idxs, skip_keys=["text"]), } labels = { "data_idx": data_idxs, "answer_start_idx": answer_starts, "answer_end_idx": answer_ends, "answerable": answerables, } return collator(features, labels) return make_tensor_fn
@overrides def __getitem__(self, index): self.lazy_evaluation(index) return ( self.bert_input_idx[index], self.token_type_idx[index], self.data_indices[index], self.answer_starts[index], self.answer_ends[index], self.answerables[index], ) def __len__(self): return len(self.qids) def __repr__(self): dataset_properties = { "name": self.name, "total_count": self.__len__(), "HasAns_count": len([True for k, v in self.answers.items() if v[1] == 1]), "NoAns_count": len([False for k, v in self.answers.items() if v[1] == 0]), "bert_input_maxlen": self.bert_input_maxlen, } return json.dumps(dataset_properties, indent=4) @property def bert_input_maxlen(self): return self._get_feature_maxlen(self.bert_input_idx)
[docs] def get_qid(self, data_index): qid = self.qids[data_index] if "#" in qid: qid = qid.split("#")[0] return qid
[docs] def get_id(self, data_index): return self.get_qid(data_index)
[docs] def get_qid_index(self, data_index): qid = self.qids[data_index] if "#" in qid: return qid.split("#")[1] return None
[docs] def get_context(self, data_index): qid = self.get_qid(data_index) return self.helper["examples"][qid]["context"]
[docs] @overrides def get_ground_truths(self, data_index): qid = self.get_qid(data_index) answer_texts = self.helper["examples"][qid]["answers"] answerable, answer_span = self.answers[qid] return answer_texts, answerable, answer_span
[docs] @overrides def get_predict(self, data_index, start, end): return self.get_text_with_index(data_index, start, end)
[docs] def get_text_with_index(self, data_index, start, end): if data_index is None: raise ValueError("data_id or text is required.") context_text = self.get_context(data_index) bert_token = self.get_bert_tokens(data_index) if ( start <= 0 or end >= len(bert_token) or bert_token[start].text_span is None or bert_token[end].text_span is None ): # No_Answer Case return "<noanswer>" char_start = bert_token[start].text_span[0] char_end = bert_token[end].text_span[1] if char_start > char_end or len(context_text) <= char_end: return "" return context_text[char_start:char_end]
[docs] def get_bert_tokens(self, data_index): qid = self.get_qid(data_index) index = self.get_qid_index(data_index) if index is None: raise ValueError("bert_qid must have 'bert_index' (bert_id: qid#bert_index)") bert_index = f"bert_tokens_{index}" return self.helper["examples"][qid][bert_index]