Source code for pytorchart.presets.functional

import torch
import torch.nn as nn
from pytorchart.Loggers.style_utils import _spec
from inspect import signature, _empty, Parameter
import pprint
import random


_index = ['layer', 'name', 'data', 'func']


def _modules_requiring_grads():
    """
    Generate a list of types that have weights, and therefore will get Grads.
    It is a bit hacky,
    :return:
    """
    has_grad = []
    for k, cls in nn.__dict__.items():
        required = 0
        if not isinstance(cls, type):
            continue
        ds = signature(cls)
        # conver RNN and container case
        if issubclass(cls, nn.RNNBase) or issubclass(cls, nn.Container):
            has_grad.append(cls)
            continue
        for _k, param in ds.parameters.items():
            if param.default == _empty and param.kind != Parameter.VAR_POSITIONAL:
                required += 1
        args = [i+1 for i in range(required)]
        try:
            inst = cls(*args)
            # for most modules its in weight
            if hasattr(inst, 'weight'):
                has_grad.append(cls)
        except Exception as err:
            print('cannot inspect', cls.__name__, required, str(err))
    # [print(x) for x in has_grad]
    return has_grad


IGrad = _modules_requiring_grads()
###########################################################


def _identity(*x):
    return x


[docs]def summarize(model, fn=_identity): """ Given a nn.Module, recursively walk its submodules and apply a function :param model: nn.Module or superclass thereof :param fn: function to be applied to any module to create a summary :return: list of summaries generated by fn(module) """ res = [] for key, module in model._modules.items(): if type(module) in [ nn.Container, nn.Sequential, nn.ModuleList, # nn.Module, ]: res += summarize(module, fn=fn) elif any(module._modules): res += summarize(module, fn=fn) elif 'weight' in module._parameters: summary = fn(key, module) if summary is not None: res.append(summary) return res
############################# def _grad_wrt_weight(module, grad_in): m = list(module._parameters['weight'].size()) if type(module) in [nn.Linear]: m = m[::-1] if grad_in is not None: for g in grad_in: if g is not None and list(g.size()) == m: return g.data
[docs]def grad_norm_mean(module, grad_in, grad_out): m = module._parameters['weight'].data.max() g_wrt_w_data = _grad_wrt_weight(module, grad_in) n = g_wrt_w_data / m return n.mean()
[docs]def grad_std(module, grad_in, grad_out): return _grad_wrt_weight(module, grad_in).std()
[docs]def gen_grad_wrt_w(fn): def hook(m, i, o): return fn(_grad_wrt_weight(m, i)) return hook
[docs]def gen_module_wght(fn): def hook(m, i, o): return fn(m._parameters['weight'].data) return hook
[docs]def grad_mean(module, grad_in, grad_out): return _grad_wrt_weight(module, grad_in).abs().mean()
[docs]def grad_norm(module, grad_in, grad_out): return _grad_wrt_weight(module, grad_in).norm()
[docs]def has_grad(name, module_type): if module_type in IGrad: return name return None
############################# _grad_layers = lambda name, m: name if type(m) in IGrad else None _def_plot = { 'type': 'line', 'opts': {'layout': {'yaxis': {'type': 'log', 'autorange': True}}} } fspecs = \ {'grad_norms': {'doc': " common thing to track about the gradient, ", 'layer': _grad_layers, 'data': 'backward', 'name': 'grad_norm', 'func': [grad_mean, grad_norm], 'same': {'layer': 'line.color', 'func': 'line.dash'}}, 'snr': {'doc': " Tishby et al. ", 'layer': _grad_layers, 'data': 'backward', 'name': 'std_meter', 'func': [grad_norm, grad_std], 'same': {'layer': 'line.color', 'func': 'line.dash'}} }
[docs]class LayerLegend(object): """ keeps track of layers and assigned Colors """ def __init__(self): self.legend = {}
[docs] def color_for(self, layer): if layer not in self.legend: self.legend[layer] = self._gen_color() return self.legend[layer]
def _gen_color(self): r, g, b = [random.randint(0, 255) for _ in range(3)] return "#{0:02x}{1:02x}{2:02x}".format(r, g, b)
[docs] def register_model(self, model, fn=_grad_layers): layers = summarize(model, fn=fn) for layer_name in layers: self.color_for(layer_name)
def _generate_layer(model, colors=None, fn=None, target='plot', **kwargs): """ :param model: :param colors: a colors object containing layers previously indexed. :param fn: function for gathering layers :param target: :param kwargs: debug :return: """ data = [] spec = fspecs.get(target, None) if spec is None: return funcs = spec.get('func', None) lyrfn = spec.get('layer', _identity) if fn is None else fn _sumary = summarize(model, fn=lyrfn) if colors is None: colors = LayerLegend() styles = _spec.get('line.dash') if kwargs.get('debug', None) is True: print(_sumary) for i_f, f in enumerate(funcs): for module_name in _sumary: res = { 'layer': module_name, 'data': spec['data'], 'target': target, 'func': f, 'display': {'line': {'dash': styles[i_f], 'color': colors.color_for(module_name)}} } data.append(res) return data, colors
[docs]def generate_layers(model, colors=None, fn=None, targets=[]): """ :param model: :param colors: :param fn: :param targets: :return: """ meters, plots = [], {} for tgt in targets: plots[tgt] = _def_plot.copy() d, colors = _generate_layer(model, colors=colors, fn=fn, target=tgt) meters += d return meters, plots