Source code for lumin.nn.losses.hep_losses

import torch.nn as nn
import torch
from torch import Tensor
from typing import Callable

__all__ = ['SignificanceLoss']

[docs]class SignificanceLoss(nn.Module): r''' General class for implementing significance-based loss functions, e.g. Asimov Loss ( For compatability with using basic PyTorch losses, event weights are passed during initialisation rather than when computing the loss. Arguments: weight: sample weights as PyTorch Tensor, to be used with data to be passed when computing the loss sig_wgt: total weight of signal events bkg_wgt: total weight of background events func: callable which returns a float based on signal and background weights Examples:: >>> loss = SignificanceLoss(weight, sig_weight=sig_weight, ... bkg_weight=bkg_weight, func=calc_ams_torch) >>> >>> loss = SignificanceLoss(weight, sig_weight=sig_weight, ... bkg_weight=bkg_weight, ... func=partial(calc_ams_torch, br=10)) ''' def __init__(self, weight:Tensor, sig_wgt=float, bkg_wgt=float, func=Callable[[Tensor, Tensor], Tensor]) -> Tensor: super().__init__() self.weight,self.sig_wgt,self.bkg_wgt,self.func = weight.squeeze(),sig_wgt,bkg_wgt,func
[docs] def forward(self, input:Tensor, target:Tensor) -> Tensor: r''' Evaluate loss for given predictions Arguments: input: prediction tensor target: target tensor Returns: (weighted) loss ''' input, target = input.squeeze(), target.squeeze() # Reweight accordign to batch size sig_wgt = (target*self.weight)*self.sig_wgt/, self.weight) bkg_wgt = ((1-target)*self.weight)*self.bkg_wgt/, self.weight) # Compute Signal and background weights without a hard cut s =*input, target) b =*input, (1-target)) return 1/self.func(s, b) # Return inverse of significance (would negative work better?)
