Shortcuts

Source code for lumin.nn.callbacks.pred_handlers

import numpy as np

import torch

from .callback import Callback


[docs]class PredHandler(Callback): r''' Default callback for predictions. Collects predictions over batches and returns them as stacked array '''
[docs] def on_pred_begin(self) -> None: super().on_pred_begin() self.preds = []
[docs] def on_pred_end(self) -> None: self.preds = torch.cat(self.preds)
[docs] def get_preds(self) -> np.ndarray: return self.preds
[docs] def on_forwards_end(self) -> None: if self.model.fit_params.state == 'test': self.preds.append(self.model.fit_params.y_pred.detach().cpu())
Read the Docs v: v0.8.0
Versions
latest
stable
v0.8.0
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials