import json
from overrides import overrides
from claf.data import utils
from claf.data.collate import FeatLabelPadCollator
from claf.data.dataset.base import DatasetBase
[docs]class TokClsBertDataset(DatasetBase):
"""
Dataset for Token Classification
* Args:
batch: Batch DTO (claf.data.batch)
* Kwargs:
helper: helper from data_reader
"""
def __init__(self, batch, vocab, helper=None):
super(TokClsBertDataset, self).__init__()
self.name = "tok_cls_bert"
self.vocab = vocab
self.helper = helper
self.tag_idx2text = helper["tag_idx2text"]
# Features
self.bert_input_idx = [feature["bert_input"] for feature in batch.features]
SEP_token = self.helper.get("sep_token", "[SEP]")
self.token_type_idx = utils.make_bert_token_types(self.bert_input_idx, SEP_token=SEP_token)
self.tagged_sub_token_idxs = [{"feature": feature["tagged_sub_token_idxs"]} for feature in batch.features]
self.num_tokens = [{"feature": feature["num_tokens"]} for feature in batch.features]
self.features = [self.bert_input_idx, self.token_type_idx] # for lazy evaluation
# Labels
self.data_ids = {data_index: label["id"] for (data_index, label) in enumerate(batch.labels)}
self.data_indices = list(self.data_ids.keys())
self.tags = {
label["id"]: {
"tag_idxs": label["tag_idxs"],
"tag_texts": label["tag_texts"],
}
for label in batch.labels
}
self.tag_texts = [label["tag_texts"] for label in batch.labels]
self.tag_idxs = [label["tag_idxs"] for label in batch.labels]
self.ignore_tag_idx = helper["ignore_tag_idx"]
[docs] @overrides
def collate_fn(self, cuda_device_id=None):
""" collate: indexed features and labels -> tensor """
collator = FeatLabelPadCollator(cuda_device_id=cuda_device_id, pad_value=self.vocab.pad_index)
def make_tensor_fn(data):
data_idxs, bert_input_idxs, token_type_idxs, tagged_token_idxs, num_tokens, tag_idxs_list = zip(*data)
features = {
"bert_input": utils.transpose(bert_input_idxs, skip_keys=["text"]),
"token_type": utils.transpose(token_type_idxs, skip_keys=["text"]),
"tagged_sub_token_idxs": utils.transpose(tagged_token_idxs, skip_keys=["text"]),
"num_tokens": utils.transpose(num_tokens, skip_keys=["text"]),
}
labels = {
"tag_idxs": tag_idxs_list,
"data_idx": data_idxs,
}
return collator(
features,
labels,
apply_pad_labels=["tag_idxs"],
apply_pad_values=[self.ignore_tag_idx]
)
return make_tensor_fn
@overrides
def __getitem__(self, index):
self.lazy_evaluation(index)
return (
self.data_indices[index],
self.bert_input_idx[index],
self.token_type_idx[index],
self.tagged_sub_token_idxs[index],
self.num_tokens[index],
self.tag_idxs[index],
)
def __len__(self):
return len(self.data_ids)
def __repr__(self):
dataset_properties = {
"name": self.name,
"total_count": self.__len__(),
"num_tags": self.num_tags,
"sequence_maxlen": self.sequence_maxlen,
"tags": self.tag_idx2text,
}
return json.dumps(dataset_properties, indent=4)
@property
def num_tags(self):
return len(self.tag_idx2text)
@property
def sequence_maxlen(self):
return self._get_feature_maxlen(self.bert_input_idx)
[docs] def get_id(self, data_index):
return self.data_ids[data_index]
[docs] @overrides
def get_ground_truth(self, data_id):
return self.tags[data_id]
[docs] def get_tag_texts_with_idxs(self, tag_idxs):
return [self.get_tag_text_with_idx(tag_idx)for tag_idx in tag_idxs]
[docs] def get_tag_text_with_idx(self, tag_index):
if tag_index is None:
raise ValueError("tag_index is required.")
return self.tag_idx2text[tag_index]