Shortcuts

Source code for lumin.nn.callbacks.monitors

from fastcore.all import store_attr
import math
from typing import Optional
import numpy as np
from fastprogress.fastprogress import IN_NOTEBOOK
from IPython.display import display
from collections import OrderedDict

import matplotlib.pyplot as plt
import seaborn as sns

from .callback import Callback
from ..models.abs_model import AbsModel
from ...plotting.plot_settings import PlotSettings

__all__ = ['EarlyStopping', 'SaveBest', 'MetricLogger']


[docs]class EarlyStopping(Callback): r''' Tracks validation loss during training and terminates training if loss doesn't decrease after `patience` number of epochs. Losses are assumed to be averaged and will be re-averaged over the epoch unless `loss_is_meaned` is false. Arguments: patience: number of epochs to wait without improvement before stopping training loss_is_meaned: if the batch loss value has been averaged over the number of elements in the batch, this should be true; average loss will be computed over all elements in batch. If the batch loss is not an average value, then the average will be computed over the number of batches. ''' def __init__(self, patience:int, loss_is_meaned:bool=True): super().__init__() store_attr() def _reset(self) -> None: self.epochs,self.min_loss = 0,math.inf
[docs] def on_train_begin(self) -> None: r''' Resets variables and prepares for new training ''' super().on_train_begin() self._reset() self.cyclic_cb = None if len(self.model.fit_params.cyclic_cbs) == 0 else self.model.fit_params.cyclic_cbs[-1] self.improve_in_cycle = False
[docs] def on_epoch_begin(self) -> None: r''' Prepares to track validation losses ''' if self.model.fit_params.state != 'valid': return self.cnt = 0 self.loss = [0] + [0 for _ in self.model.fit_params.loss_cbs] # Consider all losses e.g. SWA loss
[docs] def on_forwards_end(self) -> None: r''' Records losses for batch ''' if self.model.fit_params.state != 'valid': return sz = len(self.model.fit_params.x) if self.loss_is_meaned else 1 self.loss[0] += self.model.fit_params.loss_val.data.item()*sz for i,c in enumerate(self.model.fit_params.loss_cbs): self.loss[i+1] += c.get_loss()*sz self.cnt += sz
[docs] def on_epoch_end(self) -> None: r''' Computes best average validation losses and acts according to it ''' if self.model.fit_params.state != 'valid': return loss = np.min(self.loss)/self.cnt if loss <= self.min_loss: self.min_loss = loss self.epochs = 0 self.improve_in_cycle = True if self.cyclic_cb is not None and self.cyclic_cb.cycle_end: self.improve_in_cycle = False elif self.cyclic_cb is not None: if self.cyclic_cb.cycle_end: if self.improve_in_cycle: self.epochs = 0 self.improve_in_cycle = False else: self.epochs += 1 else: self.epochs += 1 if self.epochs >= self.patience: print('Early stopping') self.model.fit_params.stop = True
[docs]class SaveBest(Callback): r''' Tracks validation loss during training and automatically saves a copy of the weights to indicated file whenever validation loss decreases. Losses are assumed to be averaged and will be re-averaged over the epoch unless `loss_is_meaned` is false. Arguments: auto_reload: if true, will automatically reload the best model at the end of training loss_is_meaned: if the batch loss value has been averaged over the number of elements in the batch, this should be true; average loss will be computed over all elements in batch. If the batch loss is not an average value, then the average will be computed over the number of batches. ''' def __init__(self, auto_reload:bool=True, loss_is_meaned:bool=True): super().__init__() store_attr() def _reset(self) -> None: self.min_loss = math.inf
[docs] def on_train_begin(self) -> None: r''' Resets variables and prepares for new training ''' super().on_train_begin() self._reset()
[docs] def on_epoch_begin(self) -> None: r''' Prepares to track validation losses ''' if self.model.fit_params.state != 'valid': return self.cnt = 0 self.loss = [0] + [0 for _ in self.model.fit_params.loss_cbs] # Consider all losses e.g. SWA loss
[docs] def on_forwards_end(self) -> None: r''' Records losses for batch ''' if self.model.fit_params.state != 'valid': return sz = len(self.model.fit_params.x) if self.loss_is_meaned else 1 self.loss[0] += self.model.fit_params.loss_val.data.item()*sz for i,c in enumerate(self.model.fit_params.loss_cbs): self.loss[i+1] += c.get_loss()*sz self.cnt += sz
[docs] def on_epoch_end(self) -> None: r''' Computes best average validation losses and if it is better than the current best, saves a copy of the model which produced it ''' if self.model.fit_params.state != 'valid': return loss = np.array(self.loss)/self.cnt lm = np.min(loss) if lm < self.min_loss: self.min_loss = lm lam = np.argmin(loss) m = self.model if lam > 0: m = self.model.fit_params.loss_cbs[lam-1].test_model m.save(self.model.fit_params.cb_savepath/'best.h5')
[docs] def on_train_end(self) -> None: r''' Optionally reload best performing model ''' if self.auto_reload: print(f'Loading best model with loss {self.min_loss:.3E}') self.model.load(self.model.fit_params.cb_savepath/'best.h5')
[docs]class MetricLogger(Callback): r''' Provides live feedback during training showing a variety of metrics to help highlight problems or test hyper-parameters without completing a full training. If `show_plots` is false, will instead print training and validation losses at the end of each epoch. The full history is available as a dictionary by calling :meth:`~lumin.nn.callbacks.monitors.MetricLogger.get_loss_history`. Arguments: loss_names: List of names of losses which will be passed to the logger in the order in which they will be passed. By convention the first name will be used as the training loss when computing the ratio of training to validation losses n_folds: Number of folds present in the training data. The logger assumes that one of these folds is for validation, and so 1 training epoch = (n_fold-1) folds. autolog_scale: Whether to automatically change the scale of the y-axis for loss to logarithmic when the current loss drops below one 50th of its starting value extra_detail: Whether to include extra detail plots (loss velocity and training validation ratio), slight slower but potentially useful. ''' def __init__(self, show_plots:bool=IN_NOTEBOOK, extra_detail:bool=True, loss_is_meaned:bool=True): super().__init__() store_attr()
[docs] def on_train_begin(self) -> None: r''' Prepare for new training ''' super().on_train_begin() self._reset() for c in self.model.fit_params.loss_cbs: self._add_loss_name(type(c).__name__)
[docs] def on_epoch_begin(self) -> None: r''' Prepare to track new loss ''' self.loss,self.cnt = 0,0
[docs] def on_fold_begin(self) -> None: r''' Prepare to track new loss ''' self.on_epoch_begin()
[docs] def on_fold_end(self) -> None: r''' Record training loss for fold ''' if self.model.fit_params.state != 'train': return self.loss_vals[0].append(self.loss/self.cnt)
[docs] def on_epoch_end(self) -> None: r''' If validation epoch finished, record validation losses, compute info and update plots ''' if self.model.fit_params.state != 'valid': return self.epochs.append(self.epochs[-1]+1) self.loss_vals[1].append(self.loss/self.cnt) for i,c in enumerate(self.model.fit_params.loss_cbs): self.loss_vals[i+2].append(c.get_loss()) if self.show_plots: for i, v in enumerate(self.loss_vals[1:]): if len(self.loss_vals[1]) > 1 and self.extra_detail: self.vel_vals[i].append(v[-1]-v[-2]) self.gen_vals[i].append(v[-1]/self.loss_vals[0][-1]) if self.loss_vals[i+1][-1] <= self.best_loss: self.best_loss = self.loss_vals[i+1][-1] self.update_plot() else: self.print_losses()
[docs] def on_forwards_end(self) -> None: r''' Record batch loss ''' sz = len(self.model.fit_params.x) if self.loss_is_meaned else 1 self.loss += self.model.fit_params.loss_val.data.item()*sz self.cnt += sz
def _add_loss_name(self, name:str) -> None: self.loss_names.append(name) self.loss_vals.append(list(np.zeros_like(self.loss_vals[1]))) self.vel_vals.append(list(np.zeros_like(self.vel_vals[0]))) self.gen_vals.append(list(np.zeros_like(self.gen_vals[0])))
[docs] def print_losses(self) -> None: r''' Print training and validation losses for the last epoch ''' p = f'Epoch {len(self.loss_vals[1])}: Training = {np.mean(self.loss_vals[0][-self.n_trn_flds:]):.2E}' for v,m in zip(self.loss_vals[1:],self.loss_names[1:]): p += f' {m} = {v[-1]:.2E}' print(p)
[docs] def update_plot(self) -> None: r''' Updates the plot(s). # TODO: make this faster ''' # Loss self.loss_ax.clear() with sns.axes_style(**self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette): self.loss_ax.plot(range(1,len(self.loss_vals[0])+1), self.loss_vals[0], label=self.loss_names[0]) x = range(self.n_trn_flds, self.n_trn_flds*len(self.loss_vals[1])+1, self.n_trn_flds) for v,m in zip(self.loss_vals[1:],self.loss_names[1:]): self.loss_ax.plot(x, v, label=m) self.loss_ax.plot([1,x[-1]], [self.best_loss,self.best_loss], label=f'Best = {self.best_loss:.3E}', linestyle='--') if self.log: self.loss_ax.set_yscale('log', nonposy='clip') self.loss_ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col, which='both') self.loss_ax.grid(True, which="both") self.loss_ax.legend(loc='upper right', fontsize=0.8*self.plot_settings.leg_sz) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) if self.extra_detail and len(self.loss_vals[1]) > 1: # Velocity self.vel_ax.clear() self.vel_ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col, which='both') self.vel_ax.grid(True, which="both") with sns.color_palette(self.plot_settings.cat_palette) as palette: for i, (v,m) in enumerate(zip(self.vel_vals,self.loss_names[1:])): self.vel_ax.plot(self.epochs[2:], v, label=f'{m} {v[-1]:.2E}', color=palette[i+1]) self.vel_ax.legend(loc='lower right', fontsize=0.8*self.plot_settings.leg_sz) self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) # Generalisation self.gen_ax.clear() self.gen_ax.grid(True, which="both") with sns.color_palette(self.plot_settings.cat_palette) as palette: for i, (v,m) in enumerate(zip(self.gen_vals,self.loss_names[1:])): self.gen_ax.plot(self.epochs[2:], v, label=f'{m} {v[-1]:.2f}', color=palette[i+1]) self.gen_ax.legend(loc='upper left', fontsize=0.8*self.plot_settings.leg_sz) self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) if len(self.epochs) > 5: self.epochs = self.epochs[1:] for i in range(len(self.vel_vals)): self.vel_vals[i],self.gen_vals[i] = self.vel_vals[i][1:],self.gen_vals[i][1:] self.display.update(self.fig) else: self.display.update(self.loss_ax.figure)
def _reset(self) -> None: self.loss_names = ['Training', 'Validation'] self.loss_vals = [[] for _ in self.loss_names] self.vel_vals, self.gen_vals = [[] for _ in range(len(self.loss_names)-1)], [[] for _ in range(len(self.loss_names)-1)] self.n_trn_flds = len(self.model.fit_params.trn_idxs) self.log = 'regress' in self.model.objective.lower() self.best_loss,self.epochs = math.inf,[0] if self.show_plots: with sns.axes_style(**self.plot_settings.style): if self.extra_detail: self.fig = plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid), constrained_layout=True) gs = self.fig.add_gridspec(2, 3) self.loss_ax = self.fig.add_subplot(gs[:,:-1]) self.vel_ax = self.fig.add_subplot(gs[:1,2:]) self.gen_ax = self.fig.add_subplot(gs[1:2,2:]) for ax in [self.loss_ax, self.vel_ax, self.gen_ax]: ax.tick_params(axis='x', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col) ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.display = display(self.fig, display_id=True) else: self.fig, self.loss_ax = plt.subplots(1, figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid)) self.loss_ax.tick_params(axis='x', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col) self.loss_ax.tick_params(axis='y', labelsize=0.8*self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) self.display = display(self.loss_ax.figure, display_id=True)
[docs] def get_loss_history(self) -> OrderedDict: r''' Get the current history of losses Returns: history: ordered dictionary (training first, validations subsequent) mapping loss names to lists of loss values ''' history = OrderedDict() for v,m in zip(self.loss_vals,self.loss_names): history[m] = v return history
Read the Docs v: v0.7.0
Versions
latest
stable
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials