Shortcuts

Source code for lumin.nn.models.initialisations

from functools import partial
from typing import Callable, Optional

import numpy as np
import torch.nn as nn
from torch import Tensor

__all__ = ["lookup_normal_init", "lookup_uniform_init"]


[docs]def lookup_normal_init( act: str, fan_in: Optional[int] = None, fan_out: Optional[int] = None ) -> Callable[[Tensor], None]: r""" Lookup for weight initialisation using Normal distributions Arguments: act: string representation of activation function fan_in: number of inputs to neuron fan_out: number of outputs from neuron Returns: Callable to initialise weight tensor """ if act == "relu": return partial(nn.init.kaiming_normal_, nonlinearity="relu", a=0) if act == "prelu": return partial(nn.init.kaiming_normal_, nonlinearity="relu", a=0) if act == "selu": return partial(nn.init.normal_, std=1 / np.sqrt(fan_in)) if act == "sigmoid": return nn.init.xavier_normal_ if act == "logsoftmax": return nn.init.xavier_normal_ if act == "softmax": return nn.init.xavier_normal_ if act == "linear": return nn.init.xavier_normal_ if "swish" in act: return partial(nn.init.kaiming_normal_, nonlinearity="relu", a=0) if act == "mish": return partial(nn.init.kaiming_normal_, nonlinearity="relu", a=0) raise ValueError("Activation not implemented")
[docs]def lookup_uniform_init( act: str, fan_in: Optional[int] = None, fan_out: Optional[int] = None ) -> Callable[[Tensor], None]: r""" Lookup weight initialisation using Uniform distributions Arguments: act: string representation of activation function fan_in: number of inputs to neuron fan_out: number of outputs from neuron Returns: Callable to initialise weight tensor """ if act == "relu": return partial(nn.init.kaiming_uniform_, nonlinearity="relu", a=0) if act == "prelu": return partial(nn.init.kaiming_uniform_, nonlinearity="relu", a=0) if act == "selu": return partial(nn.init.uniform_, a=-1 / np.sqrt(fan_in), b=1 / np.sqrt(fan_in)) if act == "sigmoid": return nn.init.xavier_uniform_ if act == "logsoftmax": return nn.init.xavier_uniform_ if act == "softmax": return nn.init.xavier_uniform_ if act == "linear": return nn.init.xavier_uniform_ if "swish" in act: return partial(nn.init.kaiming_uniform_, nonlinearity="relu", a=0) if act == "mish": return partial(nn.init.kaiming_uniform_, nonlinearity="relu", a=0) raise ValueError("Activation not implemented")

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials