from typing import Optional, Callable, Any, List, Dict
import numpy as np
from functools import partial
import torch.nn as nn
import torch
from torch import Tensor
from ..layers.activations import lookup_act
from ..initialisations import lookup_normal_init
from .abs_block import AbsBlock
__all__ = ['FullyConnected', 'MultiBlock']
class AbsBody(AbsBlock):
def __init__(self, n_in:int, feat_map:Dict[str,List[int]],
lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init,
lookup_act:Callable[[str],Any]=lookup_act, freeze:bool=False):
super().__init__(lookup_init=lookup_init, freeze=freeze)
self.n_in,self.feat_map,self.lookup_act = n_in,feat_map,lookup_act
[docs]class FullyConnected(AbsBody):
r'''
Fully connected set of hidden layers. Designed to be passed as a 'body' to :class:`~lumin.nn.models.model_builder.ModelBuilder`.
Supports batch normalisation and dropout.
Order is dense->activation->BN->DO, except when res is true in which case the BN is applied after the addition.
Can optionaly have skip connections between each layer (res=true).
Alternatively can concatinate layers (dense=true)
growth_rate parameter can be used to adjust the width of layers according to width+(width*(depth-1)*growth_rate)
Arguments:
n_in: number of inputs to the block
feat_map: dictionary mapping input features to the model to outputs of head block
depth: number of hidden layers. If res==True and depth is even, depth will be increased by one.
width: base width of each hidden layer
do: if not None will add dropout layers with dropout rates do
bn: whether to use batch normalisation
act: string representation of argument to pass to lookup_act
res: whether to add an additative skip connection every two dense layers. Mutually exclusive with dense.
dense: whether to perform layer-wise concatinations after every layer. Mutually exclusion with res.
growth_rate: rate at which width of dense layers should increase with depth beyond the initial layer. Ignored if res=True. Can be negative.
lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights.
lookup_act: function taking choice of activation function and returning an activation function layer
freeze: whether to start with module parameters set to untrainable
Examples::
>>> body = FullyConnected(n_in=32, feat_map=head.feat_map, depth=4,
... width=100, act='relu')
>>>
>>> body = FullyConnected(n_in=32, feat_map=head.feat_map, depth=4,
... width=200, act='relu', growth_rate=-0.3)
>>>
>>> body = FullyConnected(n_in=32, feat_map=head.feat_map, depth=4,
... width=100, act='swish', do=0.1, res=True)
>>>
>>> body = FullyConnected(n_in=32, feat_map=head.feat_map, depth=6,
... width=32, act='selu', dense=True,
... growth_rate=0.5)
>>>
>>> body = FullyConnected(n_in=32, feat_map=head.feat_map, depth=6,
... width=50, act='prelu', bn=True,
... lookup_init=lookup_uniform_init)
'''
def __init__(self, n_in:int, feat_map:Dict[str,List[int]], depth:int, width:int, do:float=0, bn:bool=False, act:str='relu', res:bool=False,
dense:bool=False, growth_rate:int=0, lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init,
lookup_act:Callable[[str],Any]=lookup_act, freeze:bool=False):
super().__init__(n_in=n_in, feat_map=feat_map, lookup_init=lookup_init, lookup_act=lookup_act, freeze=freeze)
self.depth,self.width,self.do,self.bn,self.act,self.res,self.dense,self.growth_rate = depth,width,do,bn,act,res,dense,growth_rate
if self.res:
self.depth = 1+int(np.floor(self.depth/2)) # One upscale layer + each subsequent block will contain 2 layers
self.res_bns = nn.ModuleList([nn.BatchNorm1d(self.width) for d in range(self.depth-1)])
self.layers = nn.ModuleList([self._get_layer(idx=d, fan_in=self.width, fan_out=self.width)
if d > 0 else self._get_layer(idx=d, fan_in=self.n_in, fan_out=self.width)
for d in range(self.depth)])
elif self.dense:
self.layers = []
for d in range(self.depth):
self.layers.append(self._get_layer(idx=d, fan_in=self.n_in if d == 0 else self.n_in+np.sum([l[0].out_features for l in self.layers]),
fan_out=max(1,self.width+int(self.width*d*self.growth_rate))))
self.layers = nn.ModuleList(self.layers)
else:
self.layers = nn.Sequential(*[self._get_layer(idx=d, fan_in=self.width+int(self.width*(d-1)*self.growth_rate),
fan_out=self.width+int(self.width*d*self.growth_rate))
if d > 0 else self._get_layer(idx=d, fan_in=self.n_in, fan_out=self.width)
for d in range(self.depth)])
if self.freeze: self.freeze_layers()
def _get_layer(self, idx:int, fan_in:Optional[int]=None, fan_out:Optional[int]=None) -> nn.Module:
fan_in = self.width if fan_in is None else fan_in
fan_out = self.width if fan_out is None else fan_out
if fan_in < 1: fan_in = 1
if fan_out < 1: fan_out = 1
layers = []
for i in range(2 if self.res and idx > 0 else 1):
layers.append(nn.Linear(fan_in, fan_out))
self.lookup_init(self.act, fan_in, fan_out)(layers[-1].weight)
nn.init.zeros_(layers[-1].bias)
if self.act != 'linear': layers.append(self.lookup_act(self.act))
if self.bn and i == 0: layers.append(nn.BatchNorm1d(fan_out)) # In case of residual, BN will be added after addition
if self.do:
if self.act == 'selu': layers.append(nn.AlphaDropout(self.do))
else: layers.append(nn.Dropout(self.do))
return nn.Sequential(*layers)
[docs] def forward(self, x:Tensor) -> Tensor:
if self.dense:
for l in self.layers[:-1]: x = torch.cat((l(x), x), -1)
x = self.layers[-1](x)
elif self.res:
for i, l in enumerate(self.layers):
if i > 0:
x = l(x)+x
x = self.res_bns[i-1](x) # Renormalise after addition
else:
x = l(x)
else:
x = self.layers(x)
return x
[docs] def get_out_size(self) -> int:
r'''
Get size width of output layer
Returns:
Width of output layer
'''
return self.layers[-1][0].out_features
[docs]class MultiBlock(AbsBody):
r'''
Body block allowing outputs of head block to be split amongst a series of body blocks.
Output is the concatination of all sub-body blocks.
Optionally, single-neuron 'bottleneck' layers can be used to pass an input to each sub-block based on a learned function of the input features that block
would otherwise not receive, i.e. a highly compressed representation of the rest of teh feature space.
Arguments:
n_in: number of inputs to the block
feat_map: dictionary mapping input features to the model to outputs of head block
blocks: list of uninstantciated :class:`~lumin.nn.models.blocks.body.AbsBody` blocks to which to pass a subsection of the total inputs. Note that
partials should be used to set any relevant parameters at initialisation time
feats_per_block: list of lists of names of features to pass to each :class:`~lumin.nn.models.blocks.body.AbsBody`, not that the feat_map provided by
:class:`~lumin.nn.models.blocks.head.AbsHead` will map features to their relavant head outputs
bottleneck: if true, each block will receive the output of a single neuron which takes as input all the features which each given block does not
directly take as inputs
bottleneck_act: if set to a string representation of an activation function, the output of each bottleneck neuron will be passed throguh the defined
activation function before being passed to their associated blocks
lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights.
lookup_act: function taking choice of activation function and returning an activation function layer
freeze: whether to start with module parameters set to untrainable
Examples::
>>> body = MultiBlock(
... blocks=[partial(FullyConnected, depth=1, width=50, act='swish'),
... partial(FullyConnected, depth=6, width=55, act='swish',
... dense=True, growth_rate=-0.1)],
... feats_per_block=[[f for f in train_feats if 'DER_' in f],
... [f for f in train_feats if 'PRI_' in f]])
>>>
>>> body = MultiBlock(
... blocks=[partial(FullyConnected, depth=1, width=50, act='swish'),
... partial(FullyConnected, depth=6, width=55, act='swish',
... dense=True, growth_rate=-0.1)],
... feats_per_block=[[f for f in train_feats if 'DER_' in f],
... [f for f in train_feats if 'PRI_' in f]],
... bottleneck=True)
>>>
>>> body = MultiBlock(
... blocks=[partial(FullyConnected, depth=1, width=50, act='swish'),
... partial(FullyConnected, depth=6, width=55, act='swish',
... dense=True, growth_rate=-0.1)],
... feats_per_block=[[f for f in train_feats if 'DER_' in f],
... [f for f in train_feats if 'PRI_' in f]],
... bottleneck=True, bottleneck_act='swish')
'''
def __init__(self, n_in:int, feat_map:Dict[str,List[int]], blocks:List[partial], feats_per_block:List[List[str]],
bottleneck_sz:int=0, bottleneck_act:Optional[str]=None,
lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init,
lookup_act:Callable[[str],Any]=lookup_act, freeze:bool=False):
super().__init__(n_in=n_in, feat_map=feat_map, lookup_init=lookup_init, lookup_act=lookup_act, freeze=freeze)
self.feats_per_block,self.bottleneck_sz,self.bottleneck_act = feats_per_block,bottleneck_sz,bottleneck_act
self.blocks,self.n_out,self.masks,self.bottleneck_blocks = [],0,[],None
if self.bottleneck_sz > 0:
self.bottleneck_blocks,self.bottleneck_masks = [],[]
for fpb in self.feats_per_block:
tmp_map = {f: self.feat_map[f] for f in self.feat_map if f not in fpb}
self.bottleneck_masks.append([i for f in tmp_map for i in tmp_map[f]])
self.bottleneck_blocks.append(self._get_bottleneck(self.bottleneck_masks[-1]))
self.bottleneck_blocks = nn.ModuleList(self.bottleneck_blocks)
for i, b in enumerate(blocks):
tmp_map = {f: self.feat_map[f] for f in self.feat_map if f in self.feats_per_block[i]}
self.masks.append([i for f in tmp_map for i in tmp_map[f]])
self.blocks.append(b(n_in=len(self.masks[-1])+self.bottleneck_sz, feat_map=tmp_map, lookup_init=self.lookup_init,
lookup_act=self.lookup_act, freeze=self.freeze))
self.n_out += self.blocks[-1].get_out_size()
self.blocks = nn.ModuleList(self.blocks)
def _get_bottleneck(self, mask:List[int]) -> nn.Module:
layers = [nn.Linear(len(mask), self.bottleneck_sz)]
if self.bottleneck_act is None:
init = self.lookup_init('linear', len(mask), self.bottleneck_sz)
else:
init = self.lookup_init(self.bottleneck_act, len(mask), self.bottleneck_sz)
layers.append(self.lookup_act(self.bottleneck_act))
init(layers[0].weight)
nn.init.zeros_(layers[0].bias)
return nn.Sequential(*layers)
[docs] def get_out_size(self) -> int:
r'''
Get size width of output layer
Returns:
Total number of outputs accross all blocks
'''
return self.n_out
[docs] def forward(self, x:Tensor) -> Tensor:
y = None
for i, b in enumerate(self.blocks):
if self.bottleneck_sz:
a = self.bottleneck_blocks[i](x[:,self.bottleneck_masks[i]])
tmp_x = torch.cat((x[:,self.masks[i]], a), -1)
else:
tmp_x = x[:,self.masks[i]]
out = b(tmp_x)
if y is None: y = out
else: y = torch.cat((y, out), -1)
return y