Shortcuts

Source code for lumin.nn.callbacks.pred_handlers

import numpy as np
import torch

from .callback import Callback


[docs]class PredHandler(Callback): r""" Default callback for predictions. Collects predictions over batches and returns them as stacked array """
[docs] def on_pred_begin(self) -> None: super().on_pred_begin() self.preds = []
[docs] def on_pred_end(self) -> None: self.preds = torch.cat(self.preds)
[docs] def get_preds(self) -> np.ndarray: return self.preds
[docs] def on_forwards_end(self) -> None: if self.model.fit_params.state == "test": self.preds.append(self.model.fit_params.y_pred.detach().cpu())

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials