Source code for lumin.nn.callbacks.callback
from ...plotting.plot_settings import PlotSettings
from ..models.abs_model import AbsModel
from .abs_callback import AbsCallback
__all__ = ["Callback"]
[docs]class Callback(AbsCallback):
r"""
Base callback class from which other callbacks should inherit.
"""
def __init__(self):
self.model, self.plot_settings = None, PlotSettings()
[docs] def on_train_begin(self) -> None:
if self.model is None:
raise AttributeError(
f"The model for {type(self).__name__} callback has not been set. Please call set_model before on_train_begin."
)
[docs] def set_model(self, model: AbsModel) -> None:
r"""
Sets the callback's model in order to allow the callback to access and adjust model parameters
Arguments:
model: model to refer to during training
"""
self.model = model
[docs] def set_plot_settings(self, plot_settings: PlotSettings) -> None:
r"""
Sets the plot settings for any plots produced by the callback
Arguments:
plot_settings: PlotSettings class
"""
self.plot_settings = plot_settings
[docs] def on_pred_begin(self) -> None:
if self.model is None:
raise AttributeError(
f"The model for {type(self).__name__} callback has not been set. Please call set_wrapper before on_model_begin."
)