Source code for lumin.nn.callbacks.pred_handlers
import numpy as np
import torch
from .callback import Callback
[docs]class PredHandler(Callback):
r'''
Default callback for predictions. Collects predictions over batches and returns them as stacked array
'''
[docs] def on_pred_begin(self) -> None:
super().on_pred_begin()
self.preds = []
[docs] def on_pred_end(self) -> None: self.preds = torch.cat(self.preds)
[docs] def get_preds(self) -> np.ndarray: return self.preds
[docs] def on_forwards_end(self) -> None:
if self.model.fit_params.state == 'test': self.preds.append(self.model.fit_params.y_pred.detach().cpu())