Source code for claf.learn.tensorboard


import os

from tensorboardX import SummaryWriter

from claf import nsml


[docs]class TensorBoard: """ TensorBoard Wrapper for Pytorch """ def __init__(self, log_dir): if not os.path.exists(log_dir): os.makedirs(log_dir) self.writer = SummaryWriter(log_dir=log_dir)
[docs] def scalar_summaries(self, step, summary): if nsml.IS_ON_NSML: if type(summary) != dict: raise ValueError(f"summary type is dict. not {type(summary)}") kwargs = {"summary": True, "scope": locals(), "step": step} kwargs.update(summary) nsml.report(**kwargs) else: for tag, value in summary.items(): self.scalar_summary(step, tag, value)
[docs] def scalar_summary(self, step, tag, value): """Log a scalar variable.""" if nsml.IS_ON_NSML: nsml.report(**{"summary": True, "scope": locals(), "step": step, tag: value}) else: self.writer.add_scalar(tag, value, step)
[docs] def image_summary(self, tag, images, step): """Log a list of images.""" raise NotImplementedError()
[docs] def embedding_summary(self, features, metadata=None, label_img=None): raise NotImplementedError()
[docs] def histogram_summary(self, tag, values, step, bins=1000): """Log a histogram of the tensor of values.""" raise NotImplementedError()
[docs] def graph_summary(self, model, input_to_model=None): raise NotImplementedError()