Shortcuts

Source code for lumin.nn.models.layers.batchnorms

from fastcore.all import store_attr
import math
import torch
from torch import nn, Tensor, tensor

__all__ = ['LCBatchNorm1d', 'RunningBatchNorm1d', 'RunningBatchNorm2d', 'RunningBatchNorm3d']


[docs]class LCBatchNorm1d(nn.Module): r''' Wrapper class for 1D batchnorm to make it run over (Batch x length x channel) data for use in NNs designed to be broadcast across matrix data. Arguments: bn: base 1D batchnorm module to call ''' def __init__(self, bn:nn.BatchNorm1d): super().__init__() self.bn = bn
[docs] def forward(self, x:Tensor) -> Tensor: return self.bn(x.transpose(-1,-2)).transpose(-1,-2)
[docs]class RunningBatchNorm1d(nn.Module): r''' 1D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence. Modifcations: Adaptation to 1D & 3D, add eps in mom1 calculation, type hinting, docs Arguments: nf: number of features/channels mom: momentum (fraction to add to running averages) n_warmup: number of warmup iterations (during which variance is clamped) eps: epsilon to prevent division by zero ''' def __init__(self, nf:int, mom:float=0.1, n_warmup:int=20, eps:float=1e-5): super().__init__() store_attr() self._set_params() def _set_params(self) -> None: self.weight = nn.Parameter(torch.ones(self.nf,1)) self.bias = nn.Parameter(torch.zeros(self.nf,1)) self.register_buffer('sums', torch.zeros(1,self.nf,1)) self.register_buffer('sqrs', torch.zeros(1,self.nf,1)) self.register_buffer('batch', tensor(0.)) self.register_buffer('count', tensor(0.)) self.register_buffer('step', tensor(0.)) self.dims = (0,2)
[docs] def update_stats(self, x:Tensor) -> None: bs,nc,*_ = x.shape self.sums.detach_() self.sqrs.detach_() s = x.sum(self.dims, keepdim=True) ss = (x*x).sum(self.dims, keepdim=True) c = s.new_tensor(x.numel()/nc) mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1+self.eps)) self.sums.lerp_(s, mom1) self.sqrs.lerp_(ss, mom1) self.count.lerp_(c, mom1) self.batch += bs
[docs] def forward(self, x:Tensor) -> Tensor: squeeze = False if len(x.shape) == 2: squeeze = True x = x.unsqueeze(-1) if self.training: self.update_stats(x) means = self.sums/self.count varns = (self.sqrs/self.count).sub_(means*means) if bool(self.batch < self.n_warmup): varns.clamp_min_(0.01) factor = self.weight/(varns+self.eps).sqrt() offset = self.bias-means*factor x = x*factor+offset if squeeze: x = x.squeeze(-1) return x
[docs]class RunningBatchNorm2d(RunningBatchNorm1d): r''' 2D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence. Modifcations: add eps in mom1 calculation, type hinting, docs Arguments: nf: number of features/channels mom: momentum (fraction to add to running averages) eps: epsilon to prevent division by zero ''' def _set_params(self) -> None: self.weight = nn.Parameter(torch.ones(self.nf,1,1)) self.bias = nn.Parameter(torch.zeros(self.nf,1,1)) self.register_buffer('sums', torch.zeros(1,self.nf,1,1)) self.register_buffer('sqrs', torch.zeros(1,self.nf,1,1)) self.register_buffer('batch', tensor(0.)) self.register_buffer('count', tensor(0.)) self.register_buffer('step', tensor(0.)) self.dims = (0,2,3)
[docs] def forward(self, x:Tensor) -> Tensor: if self.training: self.update_stats(x) means = self.sums/self.count varns = (self.sqrs/self.count).sub_(means*means) if bool(self.batch < self.n_warmup): varns.clamp_min_(0.01) factor = self.weight/(varns+self.eps).sqrt() offset = self.bias-means*factor return x*factor+offset
[docs]class RunningBatchNorm3d(RunningBatchNorm2d): r''' 3D Running batchnorm implementation from fastai (https://github.com/fastai/course-v3) distributed under apache2 licence. Modifcations: Adaptation to 3D, add eps in mom1 calculation, type hinting, docs Arguments: nf: number of features/channels mom: momentum (fraction to add to running averages) eps: epsilon to prevent division by zero ''' def _set_params(self) -> None: self.weight = nn.Parameter(torch.ones(self.nf,1,1,1)) self.bias = nn.Parameter(torch.zeros(self.nf,1,1,1)) self.register_buffer('sums', torch.zeros(1,self.nf,1,1,1)) self.register_buffer('sqrs', torch.zeros(1,self.nf,1,1,1)) self.register_buffer('batch', tensor(0.)) self.register_buffer('count', tensor(0.)) self.register_buffer('step', tensor(0.)) self.dims = (0,2,3,4)
Read the Docs v: v0.8.0
Versions
latest
stable
v0.8.0
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials