Shortcuts

Source code for lumin.nn.losses.basic_weighted

from typing import Optional

import torch.nn as nn
import torch
from torch import Tensor

__all__ = ['WeightedMSE', 'WeightedMAE', 'WeightedCCE']


[docs]class WeightedMSE(nn.MSELoss): r''' Class for computing Mean Squared-Error loss with optional weights per prediction. For compatability with using basic PyTorch losses, weights are passed during initialisation rather than when computing the loss. Arguments: weight: sample weights as PyTorch Tensor, to be used with data to be passed when computing the loss Examples:: >>> loss = WeightedMSE() >>> >>> loss = WeightedMSE(weights) ''' def __init__(self, weight:Optional[Tensor]=None): super().__init__(reduction='mean' if weight is None else 'none') self.weights = weight
[docs] def forward(self, input:Tensor, target:Tensor) -> Tensor: r''' Evaluate loss for given predictions Arguments: input: prediction tensor target: target tensor Returns: (weighted) loss ''' if self.weights is not None: return torch.mean(self.weights*super().forward(input, target)) else: return super().forward(input, target)
[docs]class WeightedMAE(nn.L1Loss): r''' Class for computing Mean Absolute-Error loss with optional weights per prediction. For compatability with using basic PyTorch losses, weights are passed during initialisation rather than when computing the loss. Arguments: weight: sample weights as PyTorch Tensor, to be used with data to be passed when computing the loss Examples:: >>> loss = WeightedMAE() >>> >>> loss = WeightedMAE(weights) ''' def __init__(self, weight:Optional[Tensor]=None): super().__init__(reduction='mean' if weight is None else 'none') self.weights = weight
[docs] def forward(self, input:Tensor, target:Tensor) -> Tensor: r''' Evaluate loss for given predictions Arguments: input: prediction tensor target: target tensor Returns: (weighted) loss ''' if self.weights is not None: return torch.mean(self.weights*super().forward(input, target)) else: return super().forward(input, target)
[docs]class WeightedCCE(nn.NLLLoss): r''' Class for computing Categorical Cross-Entropy loss with optional weights per prediction. For compatability with using basic PyTorch losses, weights are passed during initialisation rather than when computing the loss. Arguments: weight: sample weights as PyTorch Tensor, to be used with data to be passed when computing the loss Examples:: >>> loss = WeightedCCE() >>> >>> loss = WeightedCCE(weights) ''' def __init__(self, weight:Optional[Tensor]=None): super().__init__(reduction='mean') self.weights = weight
[docs] def forward(self, input:Tensor, target:Tensor) -> Tensor: r''' Evaluate loss for given predictions Arguments: input: prediction tensor target: target tensor Returns: (weighted) loss ''' if self.weights is not None: return torch.mean(self.weights*super().forward(input, target)) else: return super().forward(input, target)
Read the Docs v: stable
Versions
latest
stable
v0.8.0
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials