Shortcuts

Source code for lumin.nn.callbacks.lsuv_init

r'''
This file contains code modfied from https://github.com/ducha-aiki/LSUV-pytorch which is made available under the following BSD 2-Clause "Simplified" Licence:
Copyright (C) 2017, Dmytro Mishkin
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
   notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
   notice, this list of conditions and the following disclaimer in the
   documentation and/or other materials provided with the
   distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

The Apache Licence 2.0 underwhich the majority of the rest of LUMIN is distributed does not apply to the code within this file.
'''

import numpy as np
from typing import Optional, Union, Tuple

import torch
import torch.nn.init
import torch.nn as nn
from torch import Tensor

from lumin.nn.callbacks.callback import Callback, OldCallback
from lumin.nn.data.batch_yielder import BatchYielder
from lumin.nn.models.abs_model import AbsModel, OldAbsModel

__all__ = ['LsuvInit']


class OldLsuvInit(OldCallback):
    r'''
    .. Attention:: This class is depreciated in favour of :class:`~lumin.nn.callbacks.lsuv_init.LsuvInit`.
        It is a copy of the old `LsuvInit` class used in lumin<=0.6.
        It will be removed in V0.8
    '''

    # XXX remove in V0.8

    def __init__(self, needed_std:float=1.0, std_tol:float=0.1, max_attempts:int=10, do_orthonorm:bool=True, verbose:bool=False,
                 model:Optional[OldAbsModel]=None):
        super().__init__(model=model)
        self.needed_std,self.std_tol,self.max_attempts,self.do_orthonorm,self.verbose = needed_std,std_tol,max_attempts,do_orthonorm,verbose
    
    def on_train_begin(self, **kargs) -> None:
        r'''
        Sets the callback to initialise the model the first time that `on_epoch_begin` is called.
        '''

        self.init = False
        self.gg = {'hook_position':0, 'total_fc_conv_layers':0,'done_counter':-1,'hook':None,'act_dict':{},'counter_to_apply_correction':0,
                   'correction_needed':False,'current_coef':1.0}
        
    def on_epoch_begin(self, by:BatchYielder, **kargs) -> None:
        r'''
        If the LSUV process has yet to run, then it will run using all of the input data provided by the `BatchYielder`

        Arguments:
            by: BatchYielder providing data for the upcoming epoch
        '''

        if not self.init:
            print('Running LSUV initialisation')
            self._run_lsuv(by.get_inputs(on_device=True))
            self.init = True
    
    @staticmethod
    def _svd_orthonormal(w:np.ndarray) -> np.ndarray:
        shape = w.shape
        if len(shape) < 2: raise RuntimeError("Only shapes of length 2 or more are supported.")
        flat_shape = (shape[0], np.prod(shape[1:]))
        a = np.random.normal(0.0, 1.0, flat_shape)
        u, _, v = np.linalg.svd(a, full_matrices=False)
        q = u if u.shape == flat_shape else v
        q = q.reshape(shape)
        return q.astype(np.float32)

    def _store_activations(self, module:nn.Module, input:Tensor, output:Tensor):
        self.gg['act_dict'] = output.data.cpu().numpy()

    def _add_current_hook(self, m:nn.Module) -> None:
        if self.gg['hook'] is not None: return
        if self._check_layer(m):
            if self.gg['hook_position'] > self.gg['done_counter']: self.gg['hook'] = m.register_forward_hook(self._store_activations)
            else:                                                  self.gg['hook_position'] += 1

    def _count_conv_fc_layers(self, m:nn.Module) -> None:
        if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): self.gg['total_fc_conv_layers'] += 1
            
    @staticmethod
    def _check_layer(m:nn.Module) -> bool:
        return isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear)
    
    def _orthogonal_weights_init(self, m:nn.Module) -> None:
        if self._check_layer(m):
            if hasattr(m, 'weight_v'):
                w_ortho = self._svd_orthonormal(m.weight_v.data.cpu().numpy())
                m.weight_v.data = torch.from_numpy(w_ortho)
            else:
                w_ortho = self._svd_orthonormal(m.weight.data.cpu().numpy())
                m.weight.data = torch.from_numpy(w_ortho)
            if hasattr(m, 'bias') and m.bias is not None: nn.init.zeros_(m.bias)

    def _apply_weights_correction(self, m:nn.Module) -> None:
        if self.gg['hook'] is None: return
        if not self.gg['correction_needed']: return
        if self._check_layer(m):
            if self.gg['counter_to_apply_correction'] < self.gg['hook_position']:
                self.gg['counter_to_apply_correction'] += 1
            else:
                if hasattr(m, 'weight_g'):
                    m.weight_g.data *= float(self.gg['current_coef'])
                    self.gg['correction_needed'] = False
                else:
                    m.weight.data *= self.gg['current_coef']
                    self.gg['correction_needed'] = False

    def _run_lsuv(self, data:Union[Tensor,Tuple[Tensor,Tensor]]) -> None:
        cuda = next(self.model.model.parameters()).is_cuda
        self.model.model.eval()
        self.model.model.apply(self._count_conv_fc_layers)
        if self.verbose: print(f'Total layers to process: {self.gg["total_fc_conv_layers"]}')
        if self.do_orthonorm:
            self.model.model.apply(self._orthogonal_weights_init)
            if self.verbose: print('Orthonorm done')
            if cuda: self.model.model = self.model.model.cuda()
        for layer_idx in range(self.gg['total_fc_conv_layers']):
            if self.verbose: print(f'Checking layer {layer_idx}')
            self.model.model.apply(self._add_current_hook)
            self.model.model(data)
            current_std = self.gg['act_dict'].std()
            if self.verbose: print(f'std at layer {layer_idx} = {current_std}')
            attempts = 0
            while np.abs(current_std-self.needed_std) > self.std_tol:
                self.gg['current_coef'] = self.needed_std/(current_std+1e-8)
                self.gg['correction_needed'] = True
                self.model.model.apply(self._apply_weights_correction)
                self.model.model(data)
                current_std = self.gg['act_dict'].std()
                if self.verbose: print(f'std at layer {layer_idx} = {current_std} mean = {self.gg["act_dict"].mean()}')
                attempts += 1
                if attempts > self.max_attempts:
                    print(f'Cannot converge in {self.max_attempts} iterations')
                    break
            if self.gg['hook'] is not None: self.gg['hook'].remove()
            self.gg['done_counter'] += 1
            self.gg['counter_to_apply_correction'] = 0
            self.gg['hook_position'] = 0
            self.gg['hook'] = None
            if self.verbose: print(f'Initialised layer {layer_idx}')
        if self.verbose: print('LSUV init done!')
        if not cuda: self.model.model = self.model.model.cpu()


[docs]class LsuvInit(Callback): r''' Applies Layer-Sequential Unit-Variance (LSUV) initialisation to model, as per Mishkin & Matas 2016 https://arxiv.org/abs/1511.06422. When training begins for the first time, `Conv1D`, `Conv2D`, `Conv3D`, and `Linear` modules in the model will be LSUV initialised using the BatchYielder inputs. This involves initialising the weights with orthonormal matirces and then iteratively scaling them such that the stadndar deviation of the layer outputs is equal to a desired value, within some tolerance. Arguments: needed_std: desired standard deviation of layer outputs std_tol: tolerance for matching standard deviation with target max_attempts: number of times to attempt weight scaling per layer do_orthonorm: whether to apply orthonormal initialisation first, or rescale the exisiting values verbose: whether to print out details of the rescaling Example:: >>> lsuv = LsuvInit() >>> >>> lsuv = LsuvInit(verbose=True) >>> >>> lsuv = LsuvInit(needed_std=0.5, std_tol=0.01, max_attempts=100, do_orthonorm=True) ''' def __init__(self, needed_std:float=1.0, std_tol:float=0.1, max_attempts:int=10, do_orthonorm:bool=True, verbose:bool=False): super().__init__() self.needed_std,self.std_tol,self.max_attempts,self.do_orthonorm,self.verbose = needed_std,std_tol,max_attempts,do_orthonorm,verbose
[docs] def on_train_begin(self) -> None: r''' Sets the callback to initialise the model the first time that `on_epoch_begin` is called. ''' super().on_train_begin() self.init = False self.gg = {'hook_position':0, 'total_fc_conv_layers':0,'done_counter':-1,'hook':None,'act_dict':{},'counter_to_apply_correction':0, 'correction_needed':False,'current_coef':1.0}
[docs] def on_fold_begin(self) -> None: r''' If the LSUV process has yet to run, then it will run using all of the input data provided by the `BatchYielder` Arguments: by: BatchYielder providing data for the upcoming epoch ''' if not self.init: print('Running LSUV initialisation') self._run_lsuv(self.model.fit_params.by.get_inputs(on_device=True)) self.init = True
@staticmethod def _svd_orthonormal(w:np.ndarray) -> np.ndarray: shape = w.shape if len(shape) < 2: raise RuntimeError("Only shapes of length 2 or more are supported.") flat_shape = (shape[0], np.prod(shape[1:])) a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v q = q.reshape(shape) return q.astype(np.float32) def _store_activations(self, module:nn.Module, input:Tensor, output:Tensor): self.gg['act_dict'] = output.data.cpu().numpy() def _add_current_hook(self, m:nn.Module) -> None: if self.gg['hook'] is not None: return if self._check_layer(m): if self.gg['hook_position'] > self.gg['done_counter']: self.gg['hook'] = m.register_forward_hook(self._store_activations) else: self.gg['hook_position'] += 1 def _count_conv_fc_layers(self, m:nn.Module) -> None: if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): self.gg['total_fc_conv_layers'] += 1 @staticmethod def _check_layer(m:nn.Module) -> bool: return isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) def _orthogonal_weights_init(self, m:nn.Module) -> None: if self._check_layer(m): if hasattr(m, 'weight_v'): w_ortho = self._svd_orthonormal(m.weight_v.data.cpu().numpy()) m.weight_v.data = torch.from_numpy(w_ortho) else: w_ortho = self._svd_orthonormal(m.weight.data.cpu().numpy()) m.weight.data = torch.from_numpy(w_ortho) if hasattr(m, 'bias') and m.bias is not None: nn.init.zeros_(m.bias) def _apply_weights_correction(self, m:nn.Module) -> None: if self.gg['hook'] is None: return if not self.gg['correction_needed']: return if self._check_layer(m): if self.gg['counter_to_apply_correction'] < self.gg['hook_position']: self.gg['counter_to_apply_correction'] += 1 else: if hasattr(m, 'weight_g'): m.weight_g.data *= float(self.gg['current_coef']) self.gg['correction_needed'] = False else: m.weight.data *= self.gg['current_coef'] self.gg['correction_needed'] = False def _run_lsuv(self, data:Union[Tensor,Tuple[Tensor,Tensor]]) -> None: cuda = next(self.model.model.parameters()).is_cuda self.model.model.eval() self.model.model.apply(self._count_conv_fc_layers) if self.verbose: print(f'Total layers to process: {self.gg["total_fc_conv_layers"]}') if self.do_orthonorm: self.model.model.apply(self._orthogonal_weights_init) if self.verbose: print('Orthonorm done') if cuda: self.model.model = self.model.model.cuda() for layer_idx in range(self.gg['total_fc_conv_layers']): if self.verbose: print(f'Checking layer {layer_idx}') self.model.model.apply(self._add_current_hook) self.model.model(data) current_std = self.gg['act_dict'].std() if self.verbose: print(f'std at layer {layer_idx} = {current_std}') attempts = 0 while np.abs(current_std-self.needed_std) > self.std_tol: self.gg['current_coef'] = self.needed_std/(current_std+1e-8) self.gg['correction_needed'] = True self.model.model.apply(self._apply_weights_correction) self.model.model(data) current_std = self.gg['act_dict'].std() if self.verbose: print(f'std at layer {layer_idx} = {current_std} mean = {self.gg["act_dict"].mean()}') attempts += 1 if attempts > self.max_attempts: print(f'Cannot converge in {self.max_attempts} iterations') break if self.gg['hook'] is not None: self.gg['hook'].remove() self.gg['done_counter'] += 1 self.gg['counter_to_apply_correction'] = 0 self.gg['hook_position'] = 0 self.gg['hook'] = None if self.verbose: print(f'Initialised layer {layer_idx}') if self.verbose: print('LSUV init done!') if not cuda: self.model.model = self.model.model.cpu()
Read the Docs v: v0.7.2
Versions
latest
stable
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials