Shortcuts

Source code for lumin.nn.callbacks.opt_callbacks

import numpy as np
import math
from typing import Tuple, Optional
import pandas as pd

from .callback import Callback
from ..models.abs_model import AbsModel
from ...plotting.plot_settings import PlotSettings

import seaborn as sns
import matplotlib.pyplot as plt

__all__ = ['LRFinder']


[docs]class LRFinder(Callback): r''' Callback class for Smith learning-rate range test (https://arxiv.org/abs/1803.09820) Arguments: nb: number of batches in a (sub-)epoch lr_bounds: tuple of initial and final LR model: :class:`~limin.nn.models.Model` to alter, alternatively call :meth:`set_model` plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' def __init__(self, nb:int, lr_bounds:Tuple[float,float]=[1e-7, 10], model:Optional[AbsModel]=None, plot_settings:PlotSettings=PlotSettings()): super().__init__(model=model, plot_settings=plot_settings) self.lr_bounds = lr_bounds self.lr_mult = (self.lr_bounds[1]/self.lr_bounds[0])**(1/nb)
[docs] def on_train_begin(self, **kargs) -> None: r''' Prepares variables and optimiser for new training ''' self.best,self.iter = math.inf,0 self.model.set_lr(self.lr_bounds[0]) self.history = {'loss': [], 'lr': []}
def _calc_lr(self): return self.lr_bounds[0]*(self.lr_mult**self.iter)
[docs] def plot(self, n_skip:int=0, n_max:Optional[int]=None, lim_y:Optional[Tuple[float,float]]=None) -> None: r''' Plot the loss as a function of the LR. Arguments: n_skip: Number of initial iterations to skip in plotting n_max: Maximum iteration number to plot lim_y: y-range for plotting ''' # TODO: Decide on whether to keep this; could just pass to plot_lr_finders with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette): plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid)) plt.plot(self.history['lr'][n_skip:n_max], self.history['loss'][n_skip:n_max], label='Training loss', color='g') if np.log10(self.lr_bounds[1])-np.log10(self.lr_bounds[0]) >= 3: plt.xscale('log') plt.ylim(lim_y) plt.grid(True, which="both") plt.legend(loc=self.plot_settings.leg_loc, fontsize=self.plot_settings.leg_sz) plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.ylabel("Loss", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.xlabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.show()
[docs] def plot_lr(self) -> None: r''' Plot the LR as a function of iterations. ''' with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette): plt.figure(figsize=(self.plot_settings.h_small, self.plot_settings.h_small)) plt.plot(range(len(self.history['lr'])), self.history['lr']) plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.ylabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.show()
[docs] def get_df(self) -> pd.DataFrame: r''' Returns a DataFrame of LRs and losses ''' return pd.DataFrame({'LR': self.history['lr'], 'Loss': self.history['loss']})
[docs] def on_batch_end(self, loss:float, **kargs) -> None: r''' Records loss and increments LR Arguments: loss: training loss for most recent batch ''' self.history['loss'].append(loss) self.history['lr'].append(self.model.opt.param_groups[0]['lr']) self.iter += 1 lr = self._calc_lr() self.model.opt.param_groups[0]['lr'] = lr if math.isnan(loss) or loss > self.best*100 or lr > self.lr_bounds[1]: self.model.stop_train = True if loss < self.best and self.iter > 10: self.best = loss
Read the Docs v: v0.4.0.1
Versions
latest
stable
v0.4.0.1
v0.3.1
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials