Shortcuts

Source code for lumin.nn.models.layers.self_attention

from typing import Callable, Optional, Any
import math
from fastcore.all import store_attr

import torch
from torch import nn, Tensor

from .activations import lookup_act
from .batchnorms import LCBatchNorm1d
from ..initialisations import lookup_normal_init

__all__ = ['SelfAttention']


[docs]class SelfAttention(nn.Module): r''' Class for applying self attention (Vaswani et al. 2017 (https://arxiv.org/abs/1706.03762)) to features per vertex. Arguments: n_fpv: number of features per vertex to expect n_a: width of self attention representation (paper recommends n_fpv//4) do: dropout rate to be applied to hidden layers in the NNs bn: whether batch normalisation should be applied to hidden layers in the NNs act: activation function to apply to hidden layers in the NNs lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights. lookup_act: function taking choice of activation function and returning an activation function layer bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d` ''' def __init__(self, n_fpv:int, n_a:int, do:float=0, bn:bool=False, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init, lookup_act:Callable[[str],Any]=lookup_act, bn_class:Callable[[int],nn.Module]=nn.BatchNorm1d): super().__init__() store_attr() self.q = self._get_layer(self.n_fpv, self.n_a) self.k = self._get_layer(self.n_fpv, self.n_a) self.v = self._get_layer(self.n_fpv, self.n_fpv) self.out = self._get_out() def _get_out(self) -> nn.Sequential: layers = [self._get_layer(self.n_fpv, self.n_fpv)] if self.act != 'linear': layers.append(self.lookup_act(self.act)) if self.bn: layers.append(LCBatchNorm1d(self.bn_class(self.n_fpv))) if self.do: if self.act == 'selu': layers.append(nn.AlphaDropout(self.do)) else: layers.append(nn.Dropout(self.do)) return nn.Sequential(*layers) def _get_layer(self, fan_in:int, fan_out:int) -> nn.Module: l = nn.Linear(fan_in, fan_out) self.lookup_init('linear', fan_in, fan_out)(l.weight) nn.init.zeros_(l.bias) return l
[docs] def forward(self, x:Tensor) -> Tensor: # B N C r''' Augments features per vertex Arguemnts: x: incoming data (batch x vertices x features) Returns: augmented features (batch x vertices x new features) ''' a = (self.q(x)@self.k(x).transpose(-1,-2))/math.sqrt(self.n_a) # B N N a = torch.softmax(a, dim=-1) # Softmax norm columns sa = a@self.v(x) # B N C return x+self.out(sa) # B N C
[docs] def get_out_size(self) -> int: return self.n_fpv
class OffsetSelfAttention(SelfAttention): r''' Class for applying offset-self attention (Guo et al. 2020 (https://arxiv.org/abs/2012.09688)) to features per vertex. Arguments: n_fpv: number of features per vertex to expect n_a: width of self attention representation (paper recommends n_fpv//4) do: dropout rate to be applied to hidden layers in the NNs bn: whether batch normalisation should be applied to hidden layers in the NNs act: activation function to apply to hidden layers in the NNs lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights. lookup_act: function taking choice of activation function and returning an activation function layer bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d` ''' def forward(self, x:Tensor) -> Tensor: # B N C r''' Augments features per vertex Arguemnts: x: incoming data (batch x vertices x features) Returns: augmented features (batch x vertices x new features) ''' a = self.q(x)@self.k(x).transpose(-1,-2) # B N N a = torch.softmax(a, dim=-2) # Softmax norm rows a = a/(a.sum(-1, keepdim=True)+1e-17) # L1 norm columns sa = a@self.v(x) # B N C return x+self.out(x-sa) # B N C
Read the Docs v: stable
Versions
latest
stable
v0.8.0
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials