Shortcuts

Source code for lumin.nn.models.layers.batchnorms

import math

import torch
from fastcore.all import store_attr
from torch import Tensor, nn, tensor

__all__ = ["LCBatchNorm1d", "RunningBatchNorm1d", "RunningBatchNorm2d", "RunningBatchNorm3d"]


[docs]class LCBatchNorm1d(nn.Module): r""" Wrapper class for 1D batchnorm to make it run over (Batch x length x channel) data for use in NNs designed to be broadcast across matrix data. Arguments: bn: base 1D batchnorm module to call """ def __init__(self, bn: nn.BatchNorm1d): super().__init__() self.bn = bn
[docs] def forward(self, x: Tensor) -> Tensor: return self.bn(x.transpose(-1, -2)).transpose(-1, -2)
[docs]class RunningBatchNorm1d(nn.Module): r""" 1D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence. Modifcations: Adaptation to 1D & 3D, add eps in mom1 calculation, type hinting, docs Arguments: nf: number of features/channels mom: momentum (fraction to add to running averages) n_warmup: number of warmup iterations (during which variance is clamped) eps: epsilon to prevent division by zero affine: whether to apply a learnable linear transformation to incoming data """ def __init__(self, nf: int, mom: float = 0.1, n_warmup: int = 20, eps: float = 1e-5, affine: bool = True): super().__init__() store_attr() self._set_params() def _set_params(self) -> None: if self.affine: self.weight = nn.Parameter(torch.ones(self.nf, 1)) self.bias = nn.Parameter(torch.zeros(self.nf, 1)) else: self.weight = 1 self.bias = 0 self.register_buffer("sums", torch.zeros(1, self.nf, 1)) self.register_buffer("sqrs", torch.zeros(1, self.nf, 1)) self.register_buffer("batch", tensor(0.0)) self.register_buffer("count", tensor(0.0)) self.dims = (0, 2)
[docs] def update_stats(self, x: Tensor) -> None: bs, nc, *_ = x.shape self.sums.detach_() self.sqrs.detach_() s = x.sum(self.dims, keepdim=True) ss = (x * x).sum(self.dims, keepdim=True) c = s.new_tensor(x.numel() / nc) mom1 = s.new_tensor(1 - (1 - self.mom) / math.sqrt(bs - 1 + self.eps)) self.sums.lerp_(s, mom1) self.sqrs.lerp_(ss, mom1) self.count.lerp_(c, mom1) self.batch += bs
[docs] def forward(self, x: Tensor) -> Tensor: squeeze = False if len(x.shape) == 2: squeeze = True x = x.unsqueeze(-1) if self.training: self.update_stats(x) means = self.sums / self.count varns = (self.sqrs / self.count).sub_(means * means) if bool(self.batch < self.n_warmup): varns.clamp_min_(0.01) factor = self.weight / (varns + self.eps).sqrt() offset = self.bias - means * factor x = x * factor + offset if squeeze: x = x.squeeze(-1) return x
[docs]class RunningBatchNorm2d(RunningBatchNorm1d): r""" 2D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence. Modifcations: add eps in mom1 calculation, type hinting, docs Arguments: nf: number of features/channels mom: momentum (fraction to add to running averages) eps: epsilon to prevent division by zero """ def _set_params(self) -> None: if self.affine: self.weight = nn.Parameter(torch.ones(self.nf, 1, 1)) self.bias = nn.Parameter(torch.zeros(self.nf, 1, 1)) else: self.weight = 1 self.bias = 0 self.register_buffer("sums", torch.zeros(1, self.nf, 1, 1)) self.register_buffer("sqrs", torch.zeros(1, self.nf, 1, 1)) self.register_buffer("batch", tensor(0.0)) self.register_buffer("count", tensor(0.0)) self.dims = (0, 2, 3)
[docs] def forward(self, x: Tensor) -> Tensor: if self.training: self.update_stats(x) means = self.sums / self.count varns = (self.sqrs / self.count).sub_(means * means) if bool(self.batch < self.n_warmup): varns.clamp_min_(0.01) factor = self.weight / (varns + self.eps).sqrt() offset = self.bias - means * factor return x * factor + offset
[docs]class RunningBatchNorm3d(RunningBatchNorm2d): r""" 3D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence. Modifcations: Adaptation to 3D, add eps in mom1 calculation, type hinting, docs Arguments: nf: number of features/channels mom: momentum (fraction to add to running averages) eps: epsilon to prevent division by zero """ def _set_params(self) -> None: if self.affine: self.weight = nn.Parameter(torch.ones(self.nf, 1, 1, 1)) self.bias = nn.Parameter(torch.zeros(self.nf, 1, 1, 1)) else: self.weight = 1 self.bias = 0 self.register_buffer("sums", torch.zeros(1, self.nf, 1, 1, 1)) self.register_buffer("sqrs", torch.zeros(1, self.nf, 1, 1, 1)) self.register_buffer("batch", tensor(0.0)) self.register_buffer("count", tensor(0.0)) self.dims = (0, 2, 3, 4)

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials