from overrides import overrides
import torch
from claf.config.namespace import NestedNamespace
from claf.learn.optimization.learning_rate_scheduler import get_lr_schedulers
from claf.learn.optimization.learning_rate_scheduler import (
LearningRateWithoutMetricsWrapper,
LearningRateWithMetricsWrapper,
)
from claf.learn.optimization.optimizer import get_optimizer_by_name
from claf.model.sequence_classification import BertForSeqCls, RobertaForSeqCls
from .base import Factory
[docs]class OptimizerFactory(Factory):
"""
Optimizer Factory Class
include optimizer, learning_rate_scheduler and exponential_moving_average
* Args:
config: optimizer config from argument (config.optimizer)
"""
def __init__(self, config):
# Optimizer
self.op_type = config.op_type
self.optimizer_params = {"lr": config.learning_rate}
op_config = getattr(config, self.op_type, None)
if op_config is not None:
op_config = vars(op_config)
self.optimizer_params.update(op_config)
# LearningRate Scheduler
self.lr_scheduler_type = getattr(config, "lr_scheduler_type", None)
if self.lr_scheduler_type is not None:
self.lr_scheduler_config = getattr(config, self.lr_scheduler_type, {})
if type(self.lr_scheduler_config) == NestedNamespace:
self.lr_scheduler_config = vars(self.lr_scheduler_config)
if "warmup" in self.lr_scheduler_type:
self.lr_scheduler_config["t_total"] = config.num_train_steps
self.set_warmup_steps(config)
# EMA
self.ema = getattr(config, "exponential_moving_average", 0)
[docs] def set_warmup_steps(self, config):
warmup_proportion = self.lr_scheduler_config.get("warmup_proportion", None)
warmup_steps = self.lr_scheduler_config.get("warmup_steps", None)
if warmup_steps and warmup_proportion:
raise ValueError("Check 'warmup_steps' and 'warmup_proportion'.")
elif not warmup_steps and warmup_proportion:
self.lr_scheduler_config["warmup_steps"] = int(config.num_train_steps * warmup_proportion) + 1
del self.lr_scheduler_config["warmup_proportion"]
elif warmup_steps and not warmup_proportion:
pass
else:
raise ValueError("Check 'warmup_steps' and 'warmup_proportion'.")
[docs] @overrides
def create(self, model):
if not issubclass(type(model), torch.nn.Module):
raise ValueError("optimizer model is must be subclass of torch.nn.Module.")
if getattr(model, "use_pytorch_transformers", False):
weight_decay = self.optimizer_params.get("weight_decay", 0)
model_parameters = self._group_parameters_for_transformers(model, weight_decay=weight_decay)
else:
model_parameters = [param for param in model.parameters() if param.requires_grad]
optimizer = get_optimizer_by_name(self.op_type)(model_parameters, **self.optimizer_params)
op_dict = {"optimizer": optimizer}
# learning_rate_scheduler
if self.lr_scheduler_type:
self.lr_scheduler_config["optimizer"] = op_dict["optimizer"]
lr_scheduler = get_lr_schedulers()[self.lr_scheduler_type](**self.lr_scheduler_config)
if self.lr_scheduler_type == "reduce_on_plateau":
lr_scheduler = LearningRateWithMetricsWrapper(lr_scheduler)
else:
lr_scheduler = LearningRateWithoutMetricsWrapper(lr_scheduler)
op_dict["learning_rate_scheduler"] = lr_scheduler
# exponential_moving_average
if self.ema and self.ema > 0:
op_dict["exponential_moving_average"] = self.ema
return op_dict
def _group_parameters_for_transformers(self, model, weight_decay=0):
# Prepare optimizer
param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used
# thus it produce None grad that break apex
if not isinstance(model, BertForSeqCls) or not isinstance(model, RobertaForSeqCls):
param_optimizer = [n for n in param_optimizer if "pooler" not in n[0]]
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters