Shortcuts

Source code for lumin.plotting.training

from collections import OrderedDict
from typing import List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from ..nn.callbacks.abs_callback import AbsCallback
from .plot_settings import PlotSettings

__all__ = ["plot_train_history", "plot_lr_finders"]


[docs]def plot_train_history( histories: List[OrderedDict], savename: Optional[str] = None, ignore_trn: bool = False, settings: PlotSettings = PlotSettings(), show: bool = True, xlow: int = 0, log_y: bool = False, ) -> None: r""" Plot histories object returned by :meth:`~lumin.nn.training.train.train_models` showing the loss evolution over time per model trained. Arguments: histories: list of dictionaries mapping loss type to values at each (sub)-epoch savename: Optional name of file to which to save the plot of feature importances ignore_trn: whether to ignore training loss settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance show: whether or not to show the plot, or just save it xlow: if set, will cut out the first given number of epochs log_y: whether to plot the y-axis with a log scale """ if not isinstance(histories, list): histories = [histories] n_folds = len(histories[0][0]["Training"]) // len(histories[0][0]["Validation"]) with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette: plt.figure(figsize=(settings.w_mid, settings.h_mid)) for i, history in enumerate(histories): for j, l in enumerate(history[0]): if j > 0 or not ignore_trn: x = ( range(1, len(history[0][l]) + 1)[xlow * n_folds :] if j == 0 else range(n_folds, (n_folds * len(history[0][l])) + 1, n_folds)[xlow:] ) plt.plot(x, history[0][l][xlow:], color=palette[j], label=l if i == 0 else None) plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.xlabel("Subepoch", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col) if log_y: plt.yscale("log") plt.grid(visible=True, which="both", axis="both") if savename is not None: plt.savefig(settings.savepath / f"{savename}_loss{settings.format}", bbox_inches="tight") if show: plt.show() else: plt.close() for metric in history[1].keys(): with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette: plt.figure(figsize=(settings.w_mid, settings.h_mid)) for i, history in enumerate(histories): plt.plot( range(n_folds, (n_folds * len(history[1][metric])) + 1, n_folds)[xlow:], history[1][metric][xlow:], color=palette[1], ) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.xlabel("Subepoch", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel(metric, fontsize=settings.lbl_sz, color=settings.lbl_col) if savename is not None: plt.savefig(settings.savepath / f"{savename}_{metric}{settings.format}", bbox_inches="tight") if show: plt.show() else: plt.close()
[docs]def plot_lr_finders( lr_finders: List[AbsCallback], lr_range: Optional[Union[float, Tuple]] = None, loss_range: Optional[Union[float, Tuple, str]] = "auto", log_y: Union[str, bool] = "auto", savename: Optional[str] = None, settings: PlotSettings = PlotSettings(), show_plot: bool = True, ) -> None: r""" Plot mean loss evolution against learning rate for several :class:`~lumin.nn.callbacks.opt_callbacks.LRFinder callbacks as returned by :meth:`~lumin.nn.optimisation.hyper_param.fold_lr_find`. Arguments: lr_finders: list of :class:`~lumin.nn.callbacks.opt_callbacks.LRFinder callbacks used during training (e.g. as returned by :meth:`~lumin.nn.optimisation.hyper_param.fold_lr_find`) lr_range: limits the range of learning rates plotted on the x-axis: if float, maximum LR; if tuple, minimum & maximum LR loss_range: limits the range of losses plotted on the x-axis: if float, maximum loss; if tuple, minimum & maximum loss; if None, no limits; if 'auto', computes an upper limit automatically log_y: whether to plot y-axis as log. If 'auto', will set to log if maximal fractional difference in loss values is greater than 50 savename: Optional name of file to which to save the plot settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance show_plot: whether to show the plot, or just save them """ df = pd.DataFrame() for lrf in lr_finders: df = pd.concat([df, lrf.get_df()], ignore_index=True) if lr_range is not None: if isinstance(lr_range, float): lr_range = (0, lr_range) df = df[(df.LR >= lr_range[0]) & (df.LR < lr_range[1])] if loss_range == "auto": # Max loss = 1.1 * max mean-loss at LR less than LR at min mean-loss agg = df.groupby(by="LR").agg(mean_loss=pd.NamedAgg(column="Loss", aggfunc="mean")) agg.reset_index(inplace=True) argmin_lr = agg.loc[agg.mean_loss.idxmin(), "LR"] loss_range = [ 0.8 * agg.loc[agg.LR < argmin_lr, "mean_loss"].min(), 1.2 * agg.loc[agg.LR < argmin_lr, "mean_loss"].max(), ] with sns.axes_style("whitegrid"), sns.color_palette(settings.cat_palette): plt.figure(figsize=(settings.w_mid, settings.h_mid)) sns.lineplot(x="LR", y="Loss", data=df, ci="sd") plt.xscale("log") if log_y == "auto": if df.Loss.max() / df.Loss.min() > 50: plt.yscale("log") elif log_y: plt.yscale("log") plt.grid(visible=True, which="both", axis="both") if loss_range is not None: plt.ylim((0, loss_range) if isinstance(loss_range, float) else loss_range) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.xlabel("Learning rate", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col) if savename is not None: plt.savefig(settings.savepath / f"{savename}.png", bbox_inches="tight") if show_plot: plt.show() else: plt.close()

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials