import json
from pathlib import Path
import torch.nn as nn
[docs]class ModelBase(nn.Module):
"""
Model Base Class
Args:
token_embedder: (claf.tokens.token_embedder.base) TokenEmbedder
"""
def __init__(self):
super(ModelBase, self).__init__()
[docs] def forward(self, inputs):
raise NotImplementedError
[docs] def make_metrics(self, predictions):
raise NotImplementedError
[docs] def make_predictions(self, features):
"""
for Metrics
"""
raise NotImplementedError
[docs] def predict(self, features):
"""
Inference
"""
raise NotImplementedError
[docs] def print_examples(self, params):
"""
Print evaluation examples
"""
raise NotImplementedError
[docs] def write_predictions(self, predictions, file_path=None, is_dict=True):
data_type = "train" if self.training else "valid"
pred_dir = Path(self._log_dir) / "predictions"
pred_dir.mkdir(exist_ok=True)
if file_path is None:
file_path = f"predictions-{data_type}-{self._train_counter.get_display()}.json"
pred_path = pred_dir / file_path
with open(pred_path, "w") as out_file:
if is_dict:
out_file.write(json.dumps(predictions, indent=4))
else:
out_file.write(predictions)
[docs] def is_ready(self):
properties = [
self._config,
self._log_dir,
# self._dataset, It's set at _run_epoch()
# self._metrics, It's set at save()
self._train_counter,
self._vocabs
]
return all([p is not None for p in properties])
@property
def config(self):
return self._config
@config.setter
def config(self, config):
self._config = config
@property
def log_dir(self):
return self._log_dir
@log_dir.setter
def log_dir(self, log_dir):
self._log_dir = log_dir
@property
def dataset(self):
return self._dataset
@dataset.setter
def dataset(self, dataset):
self._dataset = dataset
@property
def metrics(self):
return self._metrics
@metrics.setter
def metrics(self, metrics):
self._metrics = metrics
@property
def train_counter(self):
return self._train_counter
@train_counter.setter
def train_counter(self, train_counter):
self._train_counter = train_counter
@property
def vocabs(self):
return self._vocabs
@vocabs.setter
def vocabs(self, vocabs):
self._vocabs = vocabs
[docs]class ModelWithTokenEmbedder(ModelBase):
def __init__(self, token_embedder):
super(ModelWithTokenEmbedder, self).__init__()
self.token_embedder = token_embedder
if token_embedder is not None:
self._vocabs = token_embedder.vocabs
[docs]class ModelWithoutTokenEmbedder(ModelBase):
def __init__(self, token_makers):
super(ModelWithoutTokenEmbedder, self).__init__()
self.token_makers = token_makers
self._vocabs = {
token_name: token_maker.vocab for token_name, token_maker in token_makers.items()
}