from overrides import overrides
import torch
from torch.autograd import Variable
from claf.data import utils
[docs]class PadCollator:
"""
Collator apply pad and make tensor
Minimizes amount of padding needed while producing mini-batch.
* Kwargs:
cuda_device_id: tensor assign to cuda device id
Default is None (CPU)
skip_keys: skip to make tensor
"""
def __init__(self, cuda_device_id=None, pad_value=0, skip_keys=["text"]):
self.cuda_device_id = cuda_device_id
self.pad_value = pad_value
self.skip_keys = skip_keys
def __call__(self, features, labels):
self.collate(features, pad_value=self.pad_value)
self.collate(labels, apply_pad=False, pad_value=self.pad_value)
return utils.make_batch(features, labels)
[docs] def collate(self, datas, apply_pad=True, pad_value=0):
for data_name, data in datas.items():
if isinstance(data, dict):
for key, value in data.items():
data[key] = self._collate(
value, apply_pad=apply_pad, token_name=key, pad_value=pad_value)
else:
datas[data_name] = self._collate(data, apply_pad=apply_pad)
def _collate(self, value, apply_pad=True, token_name=None, pad_value=0):
if apply_pad:
value = self._apply_pad(value, token_name=token_name, pad_value=pad_value)
return self._make_tensor(value)
def _apply_pad(self, value, token_name=None, pad_value=0):
return utils.padding_tokens(value, token_name=token_name, pad_value=pad_value)
def _make_tensor(self, value):
if not isinstance(value, torch.Tensor):
value_type = utils.get_token_type(value)
if value_type == int:
value = torch.LongTensor(value)
else:
value = torch.FloatTensor(value)
value = Variable(value, requires_grad=False)
if self.cuda_device_id is not None:
value = value.cuda(self.cuda_device_id)
return value
[docs]class FeatLabelPadCollator(PadCollator):
"""
Collator apply pad and make tensor
Minimizes amount of padding needed while producing mini-batch.
FeatLabelPadCollator allows applying pad to not only features, but also labels.
* Kwargs:
cuda_device_id: tensor assign to cuda device id
Default is None (CPU)
skip_keys: skip to make tensor
"""
@overrides
def __call__(self, features, labels, apply_pad_labels=(), apply_pad_values=()):
self.collate(features)
self.collate(labels, apply_pad=False,
apply_pad_labels=apply_pad_labels, apply_pad_values=apply_pad_values)
return utils.make_batch(features, labels)
[docs] @overrides
def collate(self, datas, apply_pad=True, apply_pad_labels=(), apply_pad_values=()):
for data_name, data in datas.items():
if not apply_pad and data_name in apply_pad_labels:
_apply_pad = True # ignore apply_pad
pad_value = apply_pad_values[apply_pad_labels.index(data_name)]
else:
_apply_pad = apply_pad
pad_value = 0
if isinstance(data, dict):
for key, value in data.items():
data[key] = self._collate(
value, apply_pad=_apply_pad, token_name=key, pad_value=pad_value)
else:
datas[data_name] = self._collate(data, apply_pad=_apply_pad, pad_value=pad_value)