Source code for pytorchart.flexilogger

"""
Module for the FlexLogger Class
"""
import time, pickle, visdom
from inspect import signature
from torchnet import meter as METERS
from collections import defaultdict
from .meter_doc import meter_defs
from .Loggers import TraceLogger
import torch, pprint


_meters = list(meter_defs.keys())
_plots = ['line', 'scatter']
_loggers = ['heatmap'
            # image
            ]


[docs]class FlexLogger: """ Base object for logging. It takes some specifications for meters and for plots, indexes them, and adds a hook to send data to plots on log. :param plot_args: dictionary of definitions for plotters or None. :param meter_args: dictionary of definitions for meters :param kwargs: additional keyword arguments Examples: .. code-block:: python meters = { 'mymetric': {'type': 'AverageValueMeter', 'target': 'misc'} 'test_loss': {'type': 'AverageValueMeter', 'target': 'loss'} 'train_loss': {'type': 'AverageValueMeter', 'target': 'loss'} # target is the plot key } plots = {'loss': {'type': 'line'}, 'misc': {'type': 'line'}} TM = FlexLogger(plots, meters, env='bob') # initializes plots with meters to visdom env 'bob' # sample expirement step - for more see unittests TM(mymetric=69, test_loss=0.94) TM.step() TM.log(reset=True) # log and reset """ def __init__(self, plot_args, meter_args, **kwargs): # saving a copy of args for now self._args = {'meter': meter_args, 'plot': plot_args} self._env = kwargs.get('env', None) self._guid = kwargs.get('uid', str(time.time()).split('.')[0]) self._port = kwargs.get('port', 8097) self._step = kwargs.get('track_step', False) self._ctr = kwargs.get('step', 0) self.debug = kwargs.get('debug', False) self._viz = visdom.Visdom() # hold meters and plots self._meters = defaultdict(dict) self._plots = defaultdict(dict) # indexes self._phase_to_meter = defaultdict(str) self._plot_to_meter = defaultdict(list) self._meter_to_plot = defaultdict(str) # initialization self.update_config(plot_args, meter_args)
[docs] def update_config(self, plot_args, meter_args): """ Api for adding meters and plots. :param plot_args: :param meter_args: :return: """ self._init_links(meter_args) self._init_meters(meter_args) self._init_plots(plot_args)
def _add_link(self, k, mtr): target_plot = mtr.get('target', '') self._meter_to_plot[k] = target_plot self._plot_to_meter[target_plot].append(k) def _init_links(self, meter_args): for k, v in meter_args.items(): self._add_link(k, v) def _init_meters(self, meter_args): assert isinstance(meter_args, dict), 'meters not defined as map' for name, v in meter_args.items(): assert isinstance(v, dict), 'meter {} is not map'.format(name) self._add_meter(name, v) def _init_plots(self, plot_args): assert isinstance(plot_args, dict), 'plots not defined as map' for name, args in plot_args.items(): self._add_plot(name, args) def _add_plot(self, name, args): """ create visdom Plot and add to dictionary :param name: (str) name of plot :param args: (dict) options for plot :return: """ plot_type = args.pop('type', None) if plot_type is None: print('invalid plot type', plot_type) return port = args.pop('port', self._port) env = args.pop('env', self._env) opts = args.get('opts', {}) # set legend to be indexed by corresponding meters opts['legend'] = self._plot_to_meter.get(name, []) opts['title'] = name # args.get('title', name) traces = {} for mtr_name in opts['legend']: plot_opts = self._meters.get(mtr_name, {}).get('meta', {}).get('display', {}) plot_opts['name'] = mtr_name traces[mtr_name] = plot_opts opts['data'] = traces self._plots[name]['meta'] = opts self._plots[name]['obj'] = TraceLogger( legend=opts['legend'], port=port, opts=opts, env=env, vis=self._viz) def _add_meter(self, name, args): meter_type = args.get('type', 'AverageValueMeter') opts = args.get('opts', None) Klass = METERS.__dict__.get(meter_type, None) if Klass is None: print('cannot initialize ', Klass) return n_args = len(signature(Klass.__init__).parameters) if n_args > 1 and isinstance(opts, list): self._meters[name]['obj'] = Klass(*opts) elif n_args > 1 and isinstance(opts, dict): self._meters[name]['obj'] = Klass(**opts) else: self._meters[name]['obj'] = Klass() self._meters[name]['meta'] = args def _prep_key_args(self, keys, items): if keys is None: keys = list(items.keys()) elif isinstance(keys, str): keys = [keys] return keys
[docs] def add(self, kwargs={}): """ Add a dictionary of values to meters. :param kwargs: :return: """ for k, v in kwargs.items(): if k not in self._meters: print('Meter not found ', k) continue if type(v) in [int, float]: self._meters.get(k).get('obj').add(v) elif type(v) in [list, tuple]: self._meters.get(k).get('obj').add(*v)
[docs] def log(self, keys=None, reset=False, step=False): """ Retrieves current values of all meters, and plots at current timestep. log is used to keep familiarity with torchnet interface. if the reset keyword is set to True, calls self.reset() if the step keyword is set to True, calls self.step() :param keys: list of names of plots or None. If None, plots all keys. :param reset: reset meters after plotting :param step: increment counter after plotting :return: None :Example: """ plot_keys = self._prep_key_args(keys, self._plots) X = self._ctr # if X is None or self._step is True else X for plot_ky in plot_keys: plot = self._plots.get(plot_ky, {}).get('obj', None) if plot is None: print('Key not found ', plot_ky) continue YS = [] # get the meters kys = self._plot_to_meter.get(plot_ky, []) for meter_key in kys: meter = self._meters.get(meter_key, {}).get('obj') val = meter.value() if type(val) in [int, float]: YS.append(val) elif type(val) in [tuple, list]: YS.append(val[0]) else: YS.append(None) if reset is True: meter.reset() if YS: XS = [X] * len(YS) plot.log(XS, YS) if step is True: self.step()
[docs] def step(self, step=None): """ Increments the internal counter, or sets to value of :step arg :param step: (int) if step is specified, sets the internal counter to that step :return: (int) updated step count """ if step is None: self._ctr += 1 else: assert isinstance(step, int), 'step must be integer, got {}'.format(type(step)) self._ctr = step return self._ctr
[docs] def save(self, file_path, plots=False): """ saves this object, and the visdom state if plots is True todo implement lol :return: None """ raise NotImplementedError('not yet implemented') pickle.dump(self, file_path)
[docs] @staticmethod def load(file_path): obj = pickle.load(file_path) return obj
[docs] def reset(self, keys=None): """ Resets all of own meters. If keys kwd is specified, only resets the meters with those key names. :param keys: list[str] list of keys which will be reset. :return: None """ keys = self._prep_key_args(keys, self._meters) for k in keys: meter = self._meters.get(k, {}).get('obj') if meter is not None: meter.reset()
[docs] def value(self, keys=None): """ :param keys: :return: """ keys = self._prep_key_args(keys, self._meters) mp = [] for k in keys: meter = self._meters.get(k, {}).get('obj') if meter is not None: val = meter.value() if type(val) in [int, float, torch.Tensor]: mp.append(val) elif type(val) in [tuple, list]: mp.append(val[0]) else: mp.append(None) else: print('missing meter with key ', k) mp.append(None) return mp
[docs] def add_presets(self, *args): """ :param args: :return: """ from . import Config self.update_config(*Config.get_presets(*args))
[docs] def add_metrics_for(self, *args, plot=None, phases=None): """ Adds some metrics to hte meter, with an optional target plot. If the plot already exists, they will be added there. :param args: list of strings corresponding to metrics :param plot: string - name of plot :param phases: list of strings corresponding to phases :return: Example: .. code-block:: python # create new flexlogger TM = FlexLogger.from_presets('acc') # ... do stuff # latter on, need to add some stuff to a new plot TM.add_metrics_for('loss', 'norm_loss', 'hinge_loss', 'demorgan_loss', plot='loss', phases=['test]) """ assert plot is not None from .presets import Config self.update_config(*Config.gen_plot(*args, plot=plot, phases=phases))
[docs] @classmethod def from_presets(cls, *args, phases=None): """ Factory method to generate a logger from some preconfigured keys. see presets.preconfigured.Config for details :param args: :param phases: :return: instance of Flexilogger :Example: .. code-block:: python TM = FlexLogger.from_presets('acc', 'loss') """ from .presets import Config return cls(*Config.get_presets(*args, phases=phases))
@property def vis(self): """ property: Retrieve the Visdom() object for fun or profit :return: Visdom() object """ return self._viz
[docs] def show(self, meta=False): """ Implementation for __repr__ with additional functionality. __repr__ shows only meters, but show gives options to show metadata for charts and meters. :param meta: (boolean) :return: (string) detailed representation of self. Example: .. code-block:: python # create Logger Stat = FlexLogger('loss', 'acc') Stat.show() Output: .. code-block:: console Plots: loss train_loss - AverageValueMeter : nan test_loss - AverageValueMeter : nan acc test_acc - AverageValueMeter : nan train_acc - AverageValueMeter : nan Not plotted: """ seen, st = set(), '\n' def _show_meter(st, m): mtr = self._meters.get(m, {}).get('obj') if mtr is None: vl, cls = 0, 'na' else: vl = mtr.value() vl = vl if isinstance(vl, float) else vl[0] cls = mtr.__class__.__name__ st += ' ' * 6 + '{} - {} : {:.4f}\n'.format(m.ljust(12, ' '), cls, vl) if meta is True: _meta = self._meters.get(m, {}).get('meta', '') st += ' ' * 9 + 'meta: ' + str(_meta) + '\n' return st st += 'Step: {}\n'.format(self._ctr) st += 'Plots:\n' for plot, meters in self._plot_to_meter.items(): st += ' ' * 2 + plot if meta is True: _meta = self._plots.get(plot, {}).get('meta', '') st += ' ' * 3 + 'meta: ' + str(_meta) st += '\n' for m in meters: seen.add(m) st = _show_meter(st, m) st += ' ' * 2 + ' Not plotted:' for m in self._meters.keys(): if m not in seen: st = _show_meter(st, m) return st
def __add__(self, other): return self def __len__(self): return len(self._meter_to_plot) def __call__(self, *args, **kwargs): self.add(kwargs) def __repr__(self): return self.show(meta=False)
# def get_plot_names(self): # """ # # :return: # """ # return list(self._plot_to_meter.keys()) # # def get_plot_definitions(self): # """ # # :return: # """ # return self._plot_to_meter # # def get_meter_names(self): # """ # # :return: # """ # return list(self._meter_to_plot.keys()) # # def get_meters_for_plot(self, plot_key): # """ # # :param plot_key: # :return: # """ # return self._plot_to_meter.get(plot_key, []) # # def remove_configs(self, *keys): # """ # # :param keys: # :return: # """ # for k in keys: # self._meters.pop(k, None) # self._plots.pop(k, None) # self._plot_to_meter.pop(k, None) # self._meter_to_plot.pop(k, None)