Source code for pytorchart.presets.preconfigured

from pytorchart.utils import deep_merge, deepcopy


_plot_defs = {
    'simple':
        {'plots': {'value': {'type': 'line'}},
         'meters':  {'value': {'type': 'AverageValueMeter', 'target': 'value'}}},
     # todo other stuff # class accuracy, AUC, PUCT ImageMeter
     'image':
         {'plots': {'image': {'type': 'image'}},
          'meters': {'image': {'type': 'ImageMeter', 'target': 'image'}}},
     'confusion':
         {'plots': {'confusion_view': {'type': 'image'}},
          'meters': {'confusion': {'type': 'ConfusionMeter', 'target': 'confusion_view'}}},
     'acc':
         {'plots': {'acc': {'type': 'line'}},
          'meters': {'acc': {'type': 'AverageValueMeter', 'target': 'acc'}}},
     'mse':
         {'plots': {'mse': {'type': 'line'}},
          'meters': {'mse': {'type': 'MSEMeter', 'target': 'mse'}}},
     'loss':
         {'plots': {'loss': {'type': 'line'}},
          'meters': {'loss': {'type': 'AverageValueMeter', 'target': 'loss'}}}
}

_default_phases = ['train', 'test']


[docs]class Config(object): """ Static object for generating meter and plot configurations """ @classmethod def _build_phases(cls, k, phases, target=None, c=None): "given a meter key, create copies for each phase and add target" meter_cfg = {} mk = k if c is None else c plt = _plot_defs.get(k, {}) spec = deepcopy(plt) for k, v in spec.get('meters', {}).items(): for phase in phases: meter = deepcopy(v) if target is not None: meter['target'] = target meter['phase'] = phase meter_cfg[phase + '_' + mk] = meter spec['meters'] = meter_cfg return spec
[docs] @classmethod def gen_plot(cls, *keys, phases=None, plot=None): """ Creates a plot from the specified keys. :param keys: list of strings :param phases: list of strings or None. If None, it will use default phases of ['test', ['train'] :param plot: name of plot :return: tuple of dicts for plots and meters """ if phases is None: phases = _default_phases cfg = deep_merge(*[cls._build_phases('simple', phases, target=plot, c=k) for k in keys]) if plot is not None: plt = deepcopy(cfg['plots']['value']) cfg['plots'][plot] = plt return cfg['plots'], cfg['meters']
[docs] @classmethod def get_presets(cls, *keys, phases=None): """ generate a dictionary of plots and meters given preconfigured keys. This will create a meter by key by phase. :param keys: list of strings :param phases: list of strings or None. If None, it will use default phases of ['test', ['train'] :return: tuple of dicts for plots and meters """ if phases is None: phases = _default_phases cfg = deep_merge(*[cls._build_phases(k, phases) for k in keys]) return cfg['plots'], cfg['meters']
[docs] @classmethod def default_cfg(cls): """ return a the base configuration for plot and meter :return: dict :Usage: Config.default_cfg() """ return _plot_defs.get('simple', {})
[docs]def get_meters(*keys, phases=_default_phases): _, meters = Config.get_presets(*keys, phases=phases) for k, meter in meters.items(): meter.pop('target', None) return meters
# METERS AND PLOT INFO
[docs]def preset_names(): """ Get a list of configured plots :return: list of strings """ return list(_plot_defs.keys())
# def plot_types(): # return _meters # # # def meter_types(): # return _meters # def meter_info(name): # return meter_defs.get(name, None) # PRESET CONFIGURATIONS # def get_preset_logger(key, **kwargs): # if key in _plot_defs: # cfg = _plot_defs[key] # return FlexLogger(cfg['plots'], cfg['meters'], **kwargs) # # def get_preset(key): # cfg = _plot_defs.get(key, None) # return cfg['plots'], cfg['meters'] # def get_presets(*keys, phases=_default_phases): # return Config.get_presets(*keys, phases=phases)