Shortcuts

Source code for lumin.plotting.training

from typing import Optional, List, Dict, Union, Tuple
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from collections import OrderedDict
from fastcore.all import is_listy

from .plot_settings import PlotSettings
from ..nn.callbacks.abs_callback import AbsCallback

__all__ = ['plot_train_history', 'plot_lr_finders']


def old_plot_train_history(histories:List[Dict[str,List[float]]], savename:Optional[str]=None, ignore_trn:bool=False, settings:PlotSettings=PlotSettings(),
                           show:bool=True, xlow:int=0, log_y:bool=False) -> None:
    r'''
    .. Attention:: This class is depreciated in favour of :class:`~lumin.plotting.training.plot_train_history`.
        It is a copy of the old `plot_train_history` class used in lumin<=0.6.
        It will be removed in V0.8

    Plot histories object returned by :meth:`~lumin.nn.training.fold_train.fold_train_ensemble` showing the loss evolution over time per model trained.

    Arguments:
        histories: list of dictionaries mapping loss type to values at each (sub)-epoch
        savename: Optional name of file to which to save the plot of feature importances
        ignore_trn: whether to ignore training loss
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
        show: whether or not to show the plot, or just save it
    '''

    # XXX remove in V0.8

    def _lookup_name(name:str) -> str:
        if name == 'trn_loss': return 'Training'
        if name == 'val_loss': return 'Validation'
        if '_trn' in name:     return name[:name.find('_trn')] + 'Training'
        if '_val' in name:     return name[:name.find('_val')] + 'Validation'

    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette:
        plt.figure(figsize=(settings.w_mid, settings.h_mid))
        for i, history in enumerate(histories):
            if i == 0:
                for j, l in enumerate(history):
                    if not('trn' in l and ignore_trn): plt.plot(range(xlow,len(history[l])), history[l][xlow:], color=palette[j], label=_lookup_name(l))
            else:
                for j, l in enumerate(history):
                    if not('trn' in l and ignore_trn): plt.plot(range(xlow,len(history[l])), history[l][xlow:], color=palette[j])

        plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.xlabel("Subepoch", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col)
        if log_y:
            plt.yscale('log')
            plt.grid(b=True, which="both", axis="both")
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight')
        if show: plt.show()


[docs]def plot_train_history(histories:List[OrderedDict], savename:Optional[str]=None, ignore_trn:bool=False, settings:PlotSettings=PlotSettings(), show:bool=True, xlow:int=0, log_y:bool=False) -> None: r''' Plot histories object returned by :meth:`~lumin.nn.training.train.train_models` showing the loss evolution over time per model trained. Arguments: histories: list of dictionaries mapping loss type to values at each (sub)-epoch savename: Optional name of file to which to save the plot of feature importances ignore_trn: whether to ignore training loss settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance show: whether or not to show the plot, or just save it xlow: if set, will cut out the first given number of epochs log_y: whether to plot the y-axis with a log scale ''' if not isinstance(histories, list): histories = [histories] n_folds = len(histories[0][0]['Training'])//len(histories[0][0]['Validation']) with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette: plt.figure(figsize=(settings.w_mid, settings.h_mid)) for i, history in enumerate(histories): for j, l in enumerate(history[0]): if j > 0 or not ignore_trn: x = range(1,len(history[0][l])+1)[xlow*n_folds:] if j == 0 else range(n_folds,(n_folds*len(history[0][l]))+1,n_folds)[xlow:] plt.plot(x, history[0][l][xlow:], color=palette[j], label=l if i == 0 else None) plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.xlabel("Subepoch", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col) if log_y: plt.yscale('log') plt.grid(b=True, which="both", axis="both") if savename is not None: plt.savefig(settings.savepath/f'{savename}_loss{settings.format}', bbox_inches='tight') if show: plt.show() for metric in history[1].keys(): with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette: plt.figure(figsize=(settings.w_mid, settings.h_mid)) for i, history in enumerate(histories): plt.plot(range(n_folds,(n_folds*len(history[1][metric]))+1,n_folds)[xlow:], history[1][metric][xlow:], color=palette[1]) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.xlabel("Subepoch", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel(metric, fontsize=settings.lbl_sz, color=settings.lbl_col) if savename is not None: plt.savefig(settings.savepath/f'{savename}_{metric}{settings.format}', bbox_inches='tight') if show: plt.show()
[docs]def plot_lr_finders(lr_finders:List[AbsCallback], lr_range:Optional[Union[float,Tuple]]=None, loss_range:Optional[Union[float,Tuple,str]]='auto', log_y:Union[str,bool]='auto', savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Plot mean loss evolution against learning rate for several :class:`~lumin.nn.callbacks.opt_callbacks.LRFinder callbacks as returned by :meth:`~lumin.nn.optimisation.hyper_param.fold_lr_find`. Arguments: lr_finders: list of :class:`~lumin.nn.callbacks.opt_callbacks.LRFinder callbacks used during training (e.g. as returned by :meth:`~lumin.nn.optimisation.hyper_param.fold_lr_find`) lr_range: limits the range of learning rates plotted on the x-axis: if float, maximum LR; if tuple, minimum & maximum LR loss_range: limits the range of losses plotted on the x-axis: if float, maximum loss; if tuple, minimum & maximum loss; if None, no limits; if 'auto', computes an upper limit automatically log_y: whether to plot y-axis as log. If 'auto', will set to log if maximal fractional difference in loss values is greater than 50 savename: Optional name of file to which to save the plot settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' df = pd.DataFrame() for lrf in lr_finders: df = df.append(lrf.get_df(), ignore_index=True) if lr_range is not None: if isinstance(lr_range, float): lr_range = (0, lr_range) df = df[(df.LR >= lr_range[0]) & (df.LR < lr_range[1])] if loss_range == 'auto': # Max loss = 1.1 * max mean-loss at LR less than LR at min mean-loss agg = df.groupby(by='LR').agg(mean_loss=pd.NamedAgg(column='Loss', aggfunc='mean')) agg.reset_index(inplace=True) argmin_lr = agg.loc[agg.mean_loss.idxmin(), 'LR'] loss_range = [0.8*agg.loc[agg.LR < argmin_lr, 'mean_loss'].min(), 1.2*agg.loc[agg.LR < argmin_lr, 'mean_loss'].max()] with sns.axes_style('whitegrid'), sns.color_palette(settings.cat_palette): plt.figure(figsize=(settings.w_mid, settings.h_mid)) sns.lineplot(x='LR', y='Loss', data=df, ci='sd') plt.xscale('log') if log_y == 'auto': if df.Loss.max()/df.Loss.min() > 50: plt.yscale('log') elif log_y: plt.yscale('log') plt.grid(b=True, which="both", axis="both") if loss_range is not None: plt.ylim((0,loss_range) if isinstance(loss_range, float) else loss_range) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.xlabel("Learning rate", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col) if savename is not None: plt.savefig(settings.savepath/f'{savename}.png', bbox_inches='tight') plt.show()
Read the Docs v: v0.7.2
Versions
latest
stable
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials