Source code for claf.learn.utils


from collections import OrderedDict
import json
import logging
from pathlib import Path
import os
import re

import torch
from torch.nn import DataParallel
import requests

from claf import nsml
from claf.tokens.vocabulary import Vocab


logger = logging.getLogger(__name__)


""" Train Counter """


[docs]class TrainCounter: global_step = 0 epoch = 0 def __init__(self, display_unit="epoch"): if type(display_unit) == int: display_unit = f"every_{display_unit}_global_step" self.display_unit = display_unit
[docs] def get_display(self): if self.display_unit == "epoch": return self.epoch else: return self.global_step
""" Save and Load checkpoint """
[docs]def load_model_checkpoint(model, checkpoint): model.load_state_dict(checkpoint["weights"]) model.config = checkpoint["config"] model.metrics = checkpoint["metrics"] model.init_params = checkpoint["init_params"] model.predict_helper = checkpoint["predict_helper"] model.train_counter = checkpoint["train_counter"] model.vocabs = load_vocabs(checkpoint) logger.info(f"Load model checkpoints...!") return model
[docs]def load_optimizer_checkpoint(optimizer, checkpoint): optimizer.load_state_dict(checkpoint["optimizer"]) logger.info(f"Load optimizer checkpoints...!") return optimizer
[docs]def load_vocabs(model_checkpoint): vocabs = {} token_config = model_checkpoint["config"]["token"] for token_name in token_config["names"]: token = token_config[token_name] vocab_config = token.get("vocab", {}) texts = model_checkpoint["vocab_texts"][token_name] vocabs[token_name] = Vocab(token_name, **vocab_config).from_texts(texts) return vocabs
[docs]def save_checkpoint(path, model, optimizer, max_to_keep=10): path = Path(path) checkpoint_dir = path / "checkpoint" checkpoint_dir.mkdir(exist_ok=True) # Remove old checkpoints sorted_path = get_sorted_path(checkpoint_dir) if len(sorted_path) > max_to_keep: remove_train_counts = list(sorted_path.keys())[: -(max_to_keep - 1)] for train_count in remove_train_counts: optimizer_path = sorted_path[train_count].get("optimizer", None) if optimizer_path: os.remove(optimizer_path) model_path = sorted_path[train_count].get("model", None) if model_path: os.remove(model_path) train_counter = model.train_counter optimizer_path = checkpoint_dir / f"optimizer_{train_counter.get_display()}.pkl" torch.save({"optimizer": optimizer.state_dict()}, optimizer_path) model_path = checkpoint_dir / f"model_{train_counter.get_display()}.pkl" torch.save( { "config": model.config, "init_params": model.init_params, "predict_helper": model.predict_helper, "metrics": model.metrics, "train_counter": model.train_counter, "vocab_texts": {k: v.to_text() for k, v in model.vocabs.items()}, "weights": model.state_dict(), }, model_path, ) # Write Vocab as text file (Only once) vocab_dir = path / "vocab" vocab_dir.mkdir(exist_ok=True) for token_name, vocab in model.vocabs.items(): vocab_path = vocab_dir / f"{token_name}.txt" if not vocab_path.exists(): vocab.dump(vocab_path) logger.info(f"Save {train_counter.global_step} global_step checkpoints...!")
[docs]def get_sorted_path(checkpoint_dir, both_exist=False): paths = [] for root, dirs, files in os.walk(checkpoint_dir): for f_name in files: if "model" in f_name or "optimizer" in f_name: paths.append(Path(root) / f_name) path_with_train_count = {} for path in paths: train_count = re.findall("\d+", path.name)[0] train_count = int(train_count) if train_count not in path_with_train_count: path_with_train_count[train_count] = {} if "model" in path.name: path_with_train_count[train_count]["model"] = path if "optimizer" in path.name: path_with_train_count[train_count]["optimizer"] = path if both_exist: remove_keys = [] for key, checkpoint in path_with_train_count.items(): if not ("model" in checkpoint and "optimizer" in checkpoint): remove_keys.append(key) for key in remove_keys: del path_with_train_count[key] return OrderedDict(sorted(path_with_train_count.items()))
""" NSML """
[docs]def bind_nsml(model, **kwargs): # pragma: no cover if type(model) == DataParallel: model = model.module CHECKPOINT_FNAME = "checkpoint.bin" def infer(raw_data, **kwargs): print("raw_data:", raw_data) def load(dir_path, *args): checkpoint_path = os.path.join(dir_path, CHECKPOINT_FNAME) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint["weights"]) model.config = checkpoint["config"] model.metrics = checkpoint["metrics"] model.init_params = checkpoint["init_params"], model.predict_helper = checkpoint["predict_helper"], model.train_counter = checkpoint["train_counter"] model.vocabs = load_vocabs(checkpoint) if "optimizer" in kwargs: kwargs["optimizer"].load_state_dict(checkpoint["optimizer"]) logger.info(f"Load checkpoints...! {checkpoint_path}") def save(dir_path, *args): # save the model with 'checkpoint' dictionary. checkpoint_path = os.path.join(dir_path, CHECKPOINT_FNAME) checkpoint = { "config": model.config, "init_params": model.init_params, "predict_helper": model.predict_helper, "metrics": model.metrics, "train_counter": model.train_counter, "vocab_texts": {k: v.to_text() for k, v in model.vocabs.items()}, "weights": model.state_dict(), } if "optimizer" in kwargs: checkpoint["optimizer"] = kwargs["optimizer"].state_dict() torch.save(checkpoint, checkpoint_path) train_counter = model.train_counter logger.info(f"Save {train_counter.global_step} global_step checkpoints...! {checkpoint_path}") # function in function is just used to divide the namespace. nsml.bind(save, load, infer)
""" Notification """
[docs]def get_session_name(): session_name = "local" if nsml.IS_ON_NSML: session_name = nsml.SESSION_NAME return session_name
[docs]def send_message_to_slack(webhook_url, title=None, message=None): # pragma: no cover if message is None: data = {"text": f"{get_session_name()} session is exited."} else: data = {"attachments": [{"title": title, "text": message, "color": "#438C56"}]} try: if webhook_url == "": print(data["text"]) else: requests.post(webhook_url, data=json.dumps(data)) except Exception as e: print(str(e))