Source code for lumin.nn.models.layers.batchnorms
import math
import torch
from fastcore.all import store_attr
from torch import Tensor, nn, tensor
__all__ = ["LCBatchNorm1d", "RunningBatchNorm1d", "RunningBatchNorm2d", "RunningBatchNorm3d"]
[docs]class LCBatchNorm1d(nn.Module):
r"""
Wrapper class for 1D batchnorm to make it run over (Batch x length x channel) data for use in NNs designed to be broadcast across matrix data.
Arguments:
bn: base 1D batchnorm module to call
"""
def __init__(self, bn: nn.BatchNorm1d):
super().__init__()
self.bn = bn
[docs] def forward(self, x: Tensor) -> Tensor:
return self.bn(x.transpose(-1, -2)).transpose(-1, -2)
[docs]class RunningBatchNorm1d(nn.Module):
r"""
1D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence.
Modifcations: Adaptation to 1D & 3D, add eps in mom1 calculation, type hinting, docs
Arguments:
nf: number of features/channels
mom: momentum (fraction to add to running averages)
n_warmup: number of warmup iterations (during which variance is clamped)
eps: epsilon to prevent division by zero
affine: whether to apply a learnable linear transformation to incoming data
"""
def __init__(self, nf: int, mom: float = 0.1, n_warmup: int = 20, eps: float = 1e-5, affine: bool = True):
super().__init__()
store_attr()
self._set_params()
def _set_params(self) -> None:
if self.affine:
self.weight = nn.Parameter(torch.ones(self.nf, 1))
self.bias = nn.Parameter(torch.zeros(self.nf, 1))
else:
self.weight = 1
self.bias = 0
self.register_buffer("sums", torch.zeros(1, self.nf, 1))
self.register_buffer("sqrs", torch.zeros(1, self.nf, 1))
self.register_buffer("batch", tensor(0.0))
self.register_buffer("count", tensor(0.0))
self.dims = (0, 2)
[docs] def update_stats(self, x: Tensor) -> None:
bs, nc, *_ = x.shape
self.sums.detach_()
self.sqrs.detach_()
s = x.sum(self.dims, keepdim=True)
ss = (x * x).sum(self.dims, keepdim=True)
c = s.new_tensor(x.numel() / nc)
mom1 = s.new_tensor(1 - (1 - self.mom) / math.sqrt(bs - 1 + self.eps))
self.sums.lerp_(s, mom1)
self.sqrs.lerp_(ss, mom1)
self.count.lerp_(c, mom1)
self.batch += bs
[docs] def forward(self, x: Tensor) -> Tensor:
squeeze = False
if len(x.shape) == 2:
squeeze = True
x = x.unsqueeze(-1)
if self.training:
self.update_stats(x)
means = self.sums / self.count
varns = (self.sqrs / self.count).sub_(means * means)
if bool(self.batch < self.n_warmup):
varns.clamp_min_(0.01)
factor = self.weight / (varns + self.eps).sqrt()
offset = self.bias - means * factor
x = x * factor + offset
if squeeze:
x = x.squeeze(-1)
return x
[docs]class RunningBatchNorm2d(RunningBatchNorm1d):
r"""
2D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence.
Modifcations: add eps in mom1 calculation, type hinting, docs
Arguments:
nf: number of features/channels
mom: momentum (fraction to add to running averages)
eps: epsilon to prevent division by zero
"""
def _set_params(self) -> None:
if self.affine:
self.weight = nn.Parameter(torch.ones(self.nf, 1, 1))
self.bias = nn.Parameter(torch.zeros(self.nf, 1, 1))
else:
self.weight = 1
self.bias = 0
self.register_buffer("sums", torch.zeros(1, self.nf, 1, 1))
self.register_buffer("sqrs", torch.zeros(1, self.nf, 1, 1))
self.register_buffer("batch", tensor(0.0))
self.register_buffer("count", tensor(0.0))
self.dims = (0, 2, 3)
[docs] def forward(self, x: Tensor) -> Tensor:
if self.training:
self.update_stats(x)
means = self.sums / self.count
varns = (self.sqrs / self.count).sub_(means * means)
if bool(self.batch < self.n_warmup):
varns.clamp_min_(0.01)
factor = self.weight / (varns + self.eps).sqrt()
offset = self.bias - means * factor
return x * factor + offset
[docs]class RunningBatchNorm3d(RunningBatchNorm2d):
r"""
3D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence.
Modifcations: Adaptation to 3D, add eps in mom1 calculation, type hinting, docs
Arguments:
nf: number of features/channels
mom: momentum (fraction to add to running averages)
eps: epsilon to prevent division by zero
"""
def _set_params(self) -> None:
if self.affine:
self.weight = nn.Parameter(torch.ones(self.nf, 1, 1, 1))
self.bias = nn.Parameter(torch.zeros(self.nf, 1, 1, 1))
else:
self.weight = 1
self.bias = 0
self.register_buffer("sums", torch.zeros(1, self.nf, 1, 1, 1))
self.register_buffer("sqrs", torch.zeros(1, self.nf, 1, 1, 1))
self.register_buffer("batch", tensor(0.0))
self.register_buffer("count", tensor(0.0))
self.dims = (0, 2, 3, 4)