from overrides import overrides
from claf.config.registry import Registry
from claf.model.base import ModelWithTokenEmbedder, ModelWithoutTokenEmbedder
from claf.model.reading_comprehension.mixin import ReadingComprehension
from claf.tokens import token_embedder
from .base import Factory
[docs]class ModelFactory(Factory):
"""
Model Factory Class
Create Concrete model according to config.model_name
Get model from model registries (eg. @register("model:{model_name}"))
* Args:
config: model config from argument (config.model)
"""
def __init__(self, config):
self.registry = Registry()
self.name = config.name
self.model_config = {}
if getattr(config, config.name, None):
self.model_config = vars(getattr(config, config.name))
self.is_independent = getattr(config, "independent", False)
[docs] @overrides
def create(self, token_makers, **params):
model = self.registry.get(f"model:{self.name}")
if issubclass(model, ModelWithTokenEmbedder):
token_embedder = self.create_token_embedder(model, token_makers)
self.model_config["token_embedder"] = token_embedder
elif issubclass(model, ModelWithoutTokenEmbedder):
self.model_config["token_makers"] = token_makers
else:
raise ValueError(
"Model must have inheritance. (ModelWithTokenEmbedder or ModelWithoutTokenEmbedder)"
)
return model(**self.model_config, **params)
[docs] def create_token_embedder(self, model, token_makers):
# 1. Specific case
# ...
# 2. Base case
if issubclass(model, ReadingComprehension):
return token_embedder.RCTokenEmbedder(token_makers)
else:
return token_embedder.BasicTokenEmbedder(token_makers)