Source code for lumin.nn.models.layers.self_attention
import math
from typing import Any, Callable, Optional
import torch
from fastcore.all import store_attr
from torch import Tensor, nn
from ..initialisations import lookup_normal_init
from .activations import lookup_act
from .batchnorms import LCBatchNorm1d
__all__ = ["SelfAttention", "OffsetSelfAttention"]
[docs]class SelfAttention(nn.Module):
r"""
Class for applying self attention (Vaswani et al. 2017 (https://arxiv.org/abs/1706.03762)) to features per vertex.
Arguments:
n_fpv: number of features per vertex to expect
n_a: width of self attention representation
do: dropout rate to be applied to hidden layers in the NNs
bn: whether batch normalisation should be applied to hidden layers in the NNs
act: activation function to apply to hidden layers in the NNs
lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights.
lookup_act: function taking choice of activation function and returning an activation function layer
bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d`
"""
def __init__(
self,
n_fpv: int,
n_a: int,
do: float = 0,
bn: bool = False,
act: str = "relu",
lookup_init: Callable[[str, Optional[int], Optional[int]], Callable[[Tensor], None]] = lookup_normal_init,
lookup_act: Callable[[str], Any] = lookup_act,
bn_class: Callable[[int], nn.Module] = nn.BatchNorm1d,
):
super().__init__()
store_attr()
self.q = self._get_layer(self.n_fpv, self.n_a)
self.k = self._get_layer(self.n_fpv, self.n_a)
self.v = self._get_layer(self.n_fpv, self.n_fpv)
self.out = self._get_out()
def _get_out(self) -> nn.Sequential:
layers = [self._get_layer(self.n_fpv, self.n_fpv)]
if self.act != "linear":
layers.append(self.lookup_act(self.act))
if self.bn:
layers.append(LCBatchNorm1d(self.bn_class(self.n_fpv)))
if self.do:
if self.act == "selu":
layers.append(nn.AlphaDropout(self.do))
else:
layers.append(nn.Dropout(self.do))
return nn.Sequential(*layers)
def _get_layer(self, fan_in: int, fan_out: int) -> nn.Module:
l = nn.Linear(fan_in, fan_out)
self.lookup_init("linear", fan_in, fan_out)(l.weight)
nn.init.zeros_(l.bias)
return l
[docs] def forward(self, x: Tensor) -> Tensor: # B N C
r"""
Augments features per vertex
Arguemnts:
x: incoming data (batch x vertices x features)
Returns:
augmented features (batch x vertices x new features)
"""
a = (self.q(x) @ self.k(x).transpose(-1, -2)) / math.sqrt(self.n_a) # B N N
a = torch.softmax(a, dim=-1) # Softmax norm columns
sa = a @ self.v(x) # B N C
return x + self.out(sa) # B N C
[docs] def get_out_size(self) -> int:
return self.n_fpv
[docs]class OffsetSelfAttention(SelfAttention):
r"""
Class for applying offset-self attention (Guo et al. 2020 (https://arxiv.org/abs/2012.09688)) to features per vertex.
Arguments:
n_fpv: number of features per vertex to expect
n_a: width of self attention representation (paper recommends n_fpv//4)
do: dropout rate to be applied to hidden layers in the NNs
bn: whether batch normalisation should be applied to hidden layers in the NNs
act: activation function to apply to hidden layers in the NNs
lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights.
lookup_act: function taking choice of activation function and returning an activation function layer
bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d`
"""
[docs] def forward(self, x: Tensor) -> Tensor: # B N C
r"""
Augments features per vertex
Arguemnts:
x: incoming data (batch x vertices x features)
Returns:
augmented features (batch x vertices x new features)
"""
a = self.q(x) @ self.k(x).transpose(-1, -2) # B N N
a = torch.softmax(a, dim=-2) # Softmax norm rows
a = a / (a.sum(-1, keepdim=True) + 1e-17) # L1 norm columns
sa = a @ self.v(x) # B N C
return x + self.out(x - sa) # B N C