Shortcuts

Source code for lumin.nn.training.metric_logger

from typing import Tuple, List, Optional
from IPython.display import display
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from ...plotting.plot_settings import PlotSettings


__all__ = ['MetricLogger']


# TODO: Non-notebook version?


[docs]class MetricLogger(): r''' Provides live feedback during training showing a variety of metrics to help highlight problems or test hyper-parameters without completing a full training. Arguments: loss_names: List of names of losses which will be passed to the logger in the order in which they will be passed. By convention the first name will be used as the training loss when computing the ratio of training to validation losses n_folds: Number of folds present in the training data. The logger assumes that one of these folds is for validation, and so 1 training epoch = (n_fold-1) folds. autolog_scale: Whether to automatically change the scale of the y-axis for loss to logarithmic when the current loss drops below one 50th of its starting value extra_detail: Whether to include extra detail plots (loss velocity and training validation ratio), slight slower but potentially useful. plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance Examples:: >>> metric_log = MetricLogger(loss_names=['Train', 'Validation'], n_folds=train_fy.n_folds) >>> val_losses = [] >>> metric_log.reset() # Initialises plots and variables >>> for epoch in epochs: >>> for fold in train_folds: >>> # train for one fold (subepoch) >>> metric_log.update_vals([train_loss, val_loss], best=best_val_loss) >>> metric_log.update_plot() >>> plt.clf() ''' def __init__(self, loss_names:List[str], n_folds:int, autolog_scale:bool=True, extra_detail:bool=True, plot_settings:PlotSettings=PlotSettings()): self.loss_names,self.n_folds,self.autolog_scale,self.extra_detail,self.settings = loss_names,n_folds,autolog_scale,extra_detail,plot_settings
[docs] def add_loss_name(self, name:str) -> None: r''' Adds an additional loss name to the loss names displayed. The associated losses will be set to zero for any prior subepochs which have elapsed already. Arguments: name: name of loss to be added ''' self.loss_names.append(name) self.loss_vals.append(list(np.zeros_like(self.loss_vals[0]))) self.vel_vals.append(list(np.zeros_like(self.vel_vals[0]))) self.gen_vals.append(list(np.zeros_like(self.gen_vals[0]))) self.mean_losses.append(None)
[docs] def update_vals(self, vals:List[float]) -> None: r''' Appends values to the losses. This is interpreted as one subepoch having elapsed (i.e. one training fold). Arguments: vals: loss values from the last subepoch in the order of `loss_names` ''' for i, v in enumerate(vals): self.loss_vals[i].append(v) if not self.log and self.autolog_scale: if self.loss_vals[i][0]/self.loss_vals[i][-1] > 50: self.log = True self.subepochs.append(self.subepochs[-1]+1) if self.extra_detail: self.count += 1 if self.count >= self.n_folds: self.count = 1 self.epochs.append(self.epochs[-1]+1) for i, v in enumerate(self.loss_vals): vel, self.mean_losses[i] = self._get_vel(v, self.mean_losses[i]) self.vel_vals[i].append(vel) if i > 0: self.gen_vals[i-1].append(self._get_gen_err(v))
[docs] def update_plot(self, best:Optional[float]=None) -> None: r''' Updates the plot(s), Optionally showing the user-chose best loss achieved. Arguments: best: the value of the best loss achieved so far ''' # Loss self.loss_ax.clear() with sns.axes_style(**self.settings.style), sns.color_palette(self.settings.cat_palette): for v,m in zip(self.loss_vals,self.loss_names): self.loss_ax.plot(self.subepochs[1:], v, label=m) if best is not None: self.loss_ax.plot(self.subepochs[1:], np.ones_like(self.subepochs[1:])*best, label=f'Best = {best:.3E}', linestyle='--') if self.log: self.loss_ax.set_yscale('log', nonposy='clip') self.loss_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col, which='both') self.loss_ax.grid(True, which="both") self.loss_ax.legend(loc='upper right', fontsize=0.8*self.settings.leg_sz) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) if self.extra_detail: # Velocity self.vel_ax.clear() self.vel_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col, which='both') self.vel_ax.grid(True, which="both") with sns.color_palette(self.settings.cat_palette): for v,m in zip(self.vel_vals,self.loss_names): self.vel_ax.plot(self.epochs[1:], v, label=f'{m} {v[-1]:.2E}') self.vel_ax.legend(loc='lower right', fontsize=0.8*self.settings.leg_sz) self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) # Generalisation self.gen_ax.clear() self.gen_ax.grid(True, which="both") with sns.color_palette(self.settings.cat_palette) as palette: for i, (v,m) in enumerate(zip(self.gen_vals,self.loss_names[1:])): self.gen_ax.plot(self.epochs[1:], v, label=f'{m} {v[-1]:.2f}', color=palette[i+1]) self.gen_ax.legend(loc='upper left', fontsize=0.8*self.settings.leg_sz) self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) if len(self.epochs) > 5: self.epochs = self.epochs[1:] for i in range(len(self.vel_vals)): self.vel_vals[i] = self.vel_vals[i][1:] for i in range(len(self.gen_vals)): self.gen_vals[i] = self.gen_vals[i][1:] self.display.update(self.fig) else: self.display.update(self.loss_ax.figure)
def _get_vel(self, losses:List[float], old_mean:Optional[float]=None) -> Tuple[float,float]: mean = np.mean(losses[1-self.n_folds:]) if old_mean is None: old_mean = losses[0] return mean-old_mean, mean def _get_gen_err(self, losses:List[float]) -> float: trn = np.mean(self.loss_vals[0][1-self.n_folds:]) return (np.mean(losses[1-self.n_folds:]))/trn
[docs] def reset(self) -> None: r''' Resets/initialises the logger's values and plots, and produces a placeholder plot. Should be called prior to `update_vals` or `update_plot`. ''' self.loss_vals, self.vel_vals, self.gen_vals = [[] for _ in self.loss_names], [[] for _ in self.loss_names], [[] for _ in range(len(self.loss_names)-1)] self.mean_losses = [None for _ in self.loss_names] self.subepochs, self.epochs = [0], [0] self.count,self.log = 1,False with sns.axes_style(**self.settings.style): if self.extra_detail: self.fig = plt.figure(figsize=(self.settings.w_mid, self.settings.h_mid), constrained_layout=True) gs = self.fig.add_gridspec(2, 3) self.loss_ax = self.fig.add_subplot(gs[:,:-1]) self.vel_ax = self.fig.add_subplot(gs[:1,2:]) self.gen_ax = self.fig.add_subplot(gs[1:2,2:]) for ax in [self.loss_ax, self.vel_ax, self.gen_ax]: ax.tick_params(axis='x', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.display = display(self.fig, display_id=True) else: self.fig, self.loss_ax = plt.subplots(1, figsize=(self.settings.w_mid, self.settings.h_mid)) self.loss_ax.tick_params(axis='x', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) self.loss_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.display = display(self.loss_ax.figure, display_id=True)
Read the Docs v: v0.6.0
Versions
latest
stable
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials