Shortcuts

Source code for lumin.nn.models.blocks.gnn_blocks

from typing import List, Callable, Any, Optional, Tuple, Union, Dict
from fastcore.all import store_attr, is_listy
from abc import abstractmethod, ABCMeta
import numpy as np
from functools import partial

import torch
from torch import nn, Tensor

from ..initialisations  import lookup_normal_init
from ..layers.batchnorms import LCBatchNorm1d
from ..layers.activations import lookup_act
from ..layers.self_attention import SelfAttention
from ....utils.misc import to_device

__all__ = ['GraphCollapser', 'InteractionNet', 'GravNet', 'GravNetLayer']


class AbsGraphBlock(nn.Module):
    r'''
    Abstract class for implementing graph neural-network layers and blocks
    
    Arguments:
        n_v: number of vertices per data point to expect
        n_fpv: number of features per vertex to expect
        do: dropout rate to be applied to hidden layers in the NNs
        bn: whether batch normalisation should be applied to hidden layers in the NNs
        act: activation function to apply to hidden layers in the NNs
        lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights.
        lookup_act: function taking choice of activation function and returning an activation function layer
        bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d` 
    '''
    
    def __init__(self, n_v:int, n_fpv:int, do:float=0, bn:bool=False, act:str='relu',
                 lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init,
                 lookup_act:Callable[[str],Any]=lookup_act, bn_class:Callable[[int],nn.Module]=nn.BatchNorm1d):
        super().__init__()
        store_attr()
        
    def _get_nn(self, n_in:int, n_outs:List[int]) -> nn.Sequential:
        if not is_listy(n_outs): n_outs = [n_outs]
        return nn.Sequential(*[self._get_layer(n_in if i == 0 else n_outs[i-1], n) for i,n in enumerate(n_outs)])
    
    def _get_layer(self, fan_in:int, fan_out:int) -> nn.Module:   
        layers = [nn.Linear(fan_in, fan_out)]
        self.lookup_init(self.act, fan_in, fan_out)(layers[-1].weight)
        nn.init.zeros_(layers[-1].bias)
        if self.act != 'linear': layers.append(self.lookup_act(self.act))
        if self.bn: layers.append(LCBatchNorm1d(self.bn_class(fan_out)))
        if self.do: 
            if self.act == 'selu': layers.append(nn.AlphaDropout(self.do))
            else:                  layers.append(nn.Dropout(self.do))
        return nn.Sequential(*layers)


class AbsGraphFeatExtractor(AbsGraphBlock, metaclass=ABCMeta):
    r'''
    Abstract class for implementing feature extraction for graph neural-network blocks.
    Overridden `forward` should return features per vertex.
    
    Arguments:
        n_v: number of vertices per data point to expect
        n_fpv: number of features per vertex to expect
        do: dropout rate to be applied to hidden layers in the NNs
        bn: whether batch normalisation should be applied to hidden layers in the NNs
        act: activation function to apply to hidden layers in the NNs
        lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights.
        lookup_act: function taking choice of activation function and returning an activation function layer
        bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d` 
    '''
    
    row_wise:Optional[bool] = None
    
    def __init__(self, n_v:int, n_fpv:int,
                 do:float=0, bn:bool=False, act:str='relu',
                 lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init,
                 lookup_act:Callable[[str],Any]=lookup_act, bn_class:Callable[[int],nn.Module]=nn.BatchNorm1d):
        super().__init__(n_v=n_v, n_fpv=n_fpv, do=do, bn=bn, act=act, lookup_init=lookup_init, lookup_act=lookup_act, bn_class=bn_class)
    
    @abstractmethod    
    def get_out_size(self) -> Tuple[int,int]: pass  # N-vertices, N-features per vertex


[docs]class GraphCollapser(AbsGraphBlock): r''' Class for collapsing features per vertex (batch x vertices x features) down to flat data (batch x features). Can act in two ways: 1. Compute aggregate features by taking the average and maximum of each feature across all vertices (does not assume any order to the vertices) 2. Flatten out the vertices by reshaping (does assume an ordering to the vertices) Regardless of flattening approach, features per vertex can be revised beforehand via neural networks and self-attention. Arguments: n_v: number of vertices per data point to expect n_fpv: number of features per vertex to expect flatten: if True will flatten (reshape) data into (batch x features), otherwise will compute aggregate features (average and max) f_initial_outs: list of widths for the NN layers in an NN before self-attention (None = no NN) n_sa_layers: number of self-attention layers (outputs will be fed into subsequent layers) sa_width: width of self attention representation (paper recommends n_fpv//4) f_final_outs: list of widths for the NN layers in an NN after self-attention (None = no NN) global_feat_vec: if true and f_initial_outs or f_final_outs are not None, will concatenate the mean of each feature as new features to each vertex prior to the last network. agg_methods: list of text representations of aggregation methods. Default is mean and max. do: dropout rate to be applied to hidden layers in the NNs bn: whether batch normalisation should be applied to hidden layers in the NNs act: activation function to apply to hidden layers in the NNs lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights. lookup_act: function taking choice of activation function and returning an activation function layer bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d` sa_class: class to use for self-attention layers, default is :class:`~lumin.nn.models.layers.self_attention.SelfAttention` ''' def __init__(self, n_v:int, n_fpv:int, flatten:bool, f_initial_outs:Optional[List[int]]=None, n_sa_layers:int=0, sa_width:Optional[int]=None, f_final_outs:Optional[List[int]]=None, global_feat_vec:bool=False, agg_methods:Union[List[str],str]=['mean','max'], do:float=0, bn:bool=False, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init, lookup_act:Callable[[str],Any]=lookup_act, bn_class:Callable[[int],nn.Module]=nn.BatchNorm1d, sa_class:Callable[[int],nn.Module]=SelfAttention): super().__init__(n_v=n_v, n_fpv=n_fpv, do=do, bn=bn, act=act, lookup_init=lookup_init, lookup_act=lookup_act, bn_class=bn_class) store_attr(but=['n_v', 'n_fpv', 'do', 'bn', 'act', 'lookup_init', 'lookup_act', 'bn_class', 'agg_methods']) self._check_agg_methods(agg_methods) self.gfv_pos = None if self.f_initial_outs is None: self.f_inital = lambda x: x self.f_initial_outs = [self.n_fpv] else: if not is_listy(self.f_initial_outs): self.f_initial_outs = [self.f_initial_outs] fpv = 2*self.n_fpv if self.global_feat_vec else self.n_fpv self.f_inital = self._get_nn(fpv, self.f_initial_outs) if self.global_feat_vec: self.gfv_pos = 'pre-initial' if self.n_sa_layers > 0: if self.sa_width is None: raise ValueError("Please set a value for sa_width, the width of the self-attention layers.") self.sa_layers = nn.ModuleList([self.sa_class(n_fpv=self.f_initial_outs[-1], n_a=self.sa_width, act=act, bn=bn, do=do, lookup_act=lookup_act, lookup_init=self.lookup_init) for _ in range(self.n_sa_layers)]) if self.f_final_outs is None: self.f_final = lambda x: x self.f_final_outs = [self.f_initial_outs[-1]] if self.n_sa_layers == 0 else [self.f_initial_outs[-1]*(self.n_sa_layers)] else: if not is_listy(self.f_final_outs): self.f_final_outs = [self.f_final_outs] fpv = self.f_initial_outs[-1] if self.n_sa_layers == 0 else self.f_initial_outs[-1]*(self.n_sa_layers) if self.global_feat_vec and self.gfv_pos is None: fpv *= 2 self.f_final = self._get_nn(fpv, self.f_final_outs) if self.global_feat_vec and self.gfv_pos is None: self.gfv_pos = 'pre-final' def _check_agg_methods(self, agg_methods:Union[List[str],str]) -> None: self.agg_methods = [] if not is_listy(agg_methods): agg_methods = [agg_methods] for m in agg_methods: m = m.lower() if m == 'mean': self.agg_methods.append(lambda x: torch.mean(x,dim=1)) elif m == 'max': self.agg_methods.append(lambda x: torch.max(x,dim=1)[0]) else: raise ValueError(f'{m} not in [mean, max]') def _agg(self, x:Tensor) -> Tensor: if self.flatten: return x.reshape((len(x),-1)) else: return torch.cat([agg(x) for agg in self.agg_methods],-1)
[docs] def forward(self, x:Tensor) -> Tensor: r''' Collapses features per vertex down to features Arguemnts: x: incoming data (batch x vertices x features) Returns: Flattened data (batch x flat features) ''' if self.global_feat_vec and self.gfv_pos == 'pre-initial': x = torch.cat([x,x.mean(1).unsqueeze(2).repeat_interleave(repeats=x.shape[1],dim=2).transpose(1,2)],dim=2) x = self.f_inital(x) if self.n_sa_layers > 0: outs = [] for i, sa in enumerate(self.sa_layers): outs.append(sa(outs[-1] if i > 0 else x)) x = torch.cat(outs, dim=-1) if self.global_feat_vec and self.gfv_pos == 'pre-final': x = torch.cat([x,x.mean(1).unsqueeze(2).repeat_interleave(repeats=x.shape[1],dim=2).transpose(1,2)],dim=2) x = self.f_final(x) return self._agg(x)
[docs] def get_out_size(self) -> int: r''' Get size of output Returns: Width of output representation ''' return self.f_final_outs[-1]*self.n_v if self.flatten else self.f_final_outs[-1]*len(self.agg_methods)
[docs]class InteractionNet(AbsGraphFeatExtractor): r''' Implementation of the Interaction Graph-Network (https://arxiv.org/abs/1612.00222). Shown to be applicable for embedding many 4-momenta in e.g. https://arxiv.org/abs/1908.05318 Receives column-wise data and returns column-wise Arguments: n_v: Number of vertices to expect per datapoint n_fpv: number features per vertex intfunc_outs: list of widths for the internal NN layers outfunc_outs: list of widths for the output NN layers do: dropout rate to be applied to hidden layers in the interaction-representation and post-interaction networks bn: whether batch normalisation should be applied to hidden layers in the interaction-representation and post-interaction networks act: activation function to apply to hidden layers in the interaction-representation and post-interaction networks lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights. lookup_act: function taking choice of activation function and returning an activation function layer bn_class: class to use for BatchNorm, default is `nn.BatchNorm1d` Examples:: >>> inet = InteractionNet(n_v=128, n_fpv=10, intfunc_outs=[20,10], outfunc_outs=[20,4]) ''' row_wise:Optional[bool] = False def __init__(self, n_v:int, n_fpv:int, intfunc_outs:List[int], outfunc_outs:List[int], do:float=0, bn:bool=False, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init, lookup_act:Callable[[str],Any]=lookup_act, bn_class:Callable[[int],nn.Module]=nn.BatchNorm1d): super().__init__(n_v=n_v, n_fpv=n_fpv, do=do, bn=bn, act=act, lookup_init=lookup_init, lookup_act=lookup_act, bn_class=bn_class) store_attr(names='intfunc_outs,outfunc_outs') if not is_listy(self.intfunc_outs): self.intfunc_outs = [self.intfunc_outs] if not is_listy(self.outfunc_outs): self.outfunc_outs = [self.outfunc_outs] self.n_e = self.n_v*(self.n_v-1) self.mat_rr,self.mat_rs = self._get_mat_r() self.mat_rr_t = self.mat_rr.t() self.fr = self._get_nn(2*self.n_fpv, self.intfunc_outs) self.fo = self._get_nn(self.n_fpv+self.intfunc_outs[-1], self.outfunc_outs) def _get_mat_r(self) -> Tuple[Tensor,Tensor]: mat_rr,mat_rs = torch.zeros((self.n_v,self.n_e)),torch.zeros((self.n_v,self.n_e)) for i in range(self.n_e): j = i % self.n_v mat_rr[j,(i+1) % self.n_e] = 1 mat_rs[j,i] = 1 return to_device(mat_rr),to_device(mat_rs)
[docs] def forward(self, x:Tensor) -> Tensor: r''' Learn new features per vertex Arguments: x: columnwise matrix data (batch x features x vertices) Returns: columnwise matrix data (batch x new features x vertices) ''' mat_o = torch.cat((x@self.mat_rr, x@self.mat_rs), 1) mat_o = self.fr(mat_o.transpose(1, 2)).transpose(1, 2) mat_o = mat_o@self.mat_rr_t mat_o = torch.cat((x,mat_o), 1) mat_o = self.fo(mat_o.transpose(1, 2)).transpose(1, 2) return mat_o
[docs] def get_out_size(self) -> Tuple[int,int]: return self.n_v, self.outfunc_outs[-1]
[docs]class GravNetLayer(AbsGraphBlock): r''' Single GravNet GNN layer (Qasim, Kieseler, Iiyama, & Pierini, 2019 https://link.springer.com/article/10.1140/epjc/s10052-019-7113-9). Designed to be used as a sub-layer of a head block, e.g. :class:`~lumin.nn.models.blocks.head.GravNetHead` Passes features per vertex through NN to compute new features & coordinates of vertex in latent space. Vertex then receives additional features based on aggregations of distance-weighted features for k-nearest vertices in latent space Second NN transforms features per vertex. Input (batch x vertices x features) --> Output (batch x vertices x new features) Arguments: n_fpv: number of features per vertex to expect n_s: number of latent-spatial dimensions to compute n_lr: number of features to compute per vertex for latent representation k: number of neighbours (including self) each vertex should consider when aggregating latent-representation features agg_methods: list of functions to use to aggregate distance-weighted latent-representation features n_out: number of output features to compute per vertex cat_means: if True, will extend the incoming features per vertex by including the means of all features across all vertices :class:`~lumin.nn.models.block.head.GNNHead` aslo has a `cat_means` argument, which should be set to `False` if enabled here (otherwise averaging happens twice). f_slr_depth: number of layers to use for the latent rep. NN f_out_depth: number of layers to use for the output NN potential: function to control distance weighting (default is the exp(-d^2) potential used in the paper) use_sa: if true, will apply self-attention layer to the neighbourhhood features per vertex prior to aggregation do: dropout rate to be applied to hidden layers in the NNs bn: whether batch normalisation should be applied to hidden layers in the NNs act: activation function to apply to hidden layers in the NNs lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights. lookup_act: function taking choice of activation function and returning an activation function layer freeze: whether to start with module parameters set to untrainable bn_class: class to use for BatchNorm, default is :class:`~lumin.nn.models.layers.batchnorms.LCBatchNorm1d` sa_class: class to use for self-attention layers, default is :class:`~lumin.nn.models.layers.self_attention.SelfAttention` ''' def __init__(self, n_fpv:int, n_s:int, n_lr:int, k:int, agg_methods:List[Callable[[Tensor],Tensor]], n_out:int, cat_means:bool=True, f_slr_depth:int=1, f_out_depth:int=1, potential:Callable[[Tensor],Tensor]=lambda x: torch.exp(-10*(x**2)), use_sa:bool=False, do:float=0, bn:bool=False, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init, lookup_act:Callable[[str],Any]=lookup_act, bn_class:Callable[[int],nn.Module]=nn.BatchNorm1d, sa_class:Callable[[int],nn.Module]=SelfAttention): super().__init__(n_v=None, n_fpv=n_fpv, do=do, bn=bn, act=act, lookup_init=lookup_init, lookup_act=lookup_act, bn_class=bn_class) store_attr() if self.cat_means: self.n_fpv *= 2 self.f_slr = self._get_nn(fan_in=self.n_fpv, width=2*self.n_s+self.n_lr, fan_out=self.n_s+self.n_lr, depth=self.f_slr_depth) self.f_out = self._get_nn(fan_in=self.n_fpv+(len(self.agg_methods)*self.n_lr), width=2*self.n_out, fan_out=self.n_out, depth=self.f_out_depth) if self.use_sa: self.sa_agg = self.sa_class(self.n_lr, self.n_out//4, do=self.do, bn=False, act=self.act, lookup_act=self.lookup_act, lookup_init=self.lookup_init) def _get_nn(self, fan_in:int, width:int, fan_out:int, depth:int) -> nn.Module: return nn.Sequential(*[self._get_layer(fan_in if i == 0 else width, width if i+1 < depth else fan_out) for i in range(depth)]) def _knn(self, dists:Tensor) -> Tuple[Tensor,Tensor,Tensor]: idxs = dists.argsort() i = np.arange(dists.shape[0])[:,None,None] j = np.arange(dists.shape[1])[None,:,None] return i,j,idxs[:,:,:self.k]
[docs] def forward(self, x:Tensor) -> Tensor: r''' Pass batch of vertices through GravNet layer and return new features per vertex Arguments: x: Incoming data (batch x vertices x features) Returns: Data with new features per vertex (batch x vertices x new features) ''' # Concat means if self.cat_means: x = torch.cat([x,x.mean(1).unsqueeze(2).repeat_interleave(repeats=x.shape[1],dim=-1).transpose(1,2)],dim=2) # Compute spatial and vertex features slr = self.f_slr(x) s,lr = slr[:,:,:self.n_s],slr[:,:,self.n_s:] # kNN d_jk = torch.norm(s[:,:,None]-s[:,None], dim=-1) idxs = self._knn(d_jk) d_jk = d_jk[idxs] lr = lr[:,None].expand(lr.shape[0],lr.shape[1],lr.shape[1],lr.shape[2])[idxs] # Aggregate feats v_jk = self.potential(d_jk) ft_ijk = lr*v_jk.unsqueeze(-1) if self.use_sa: ft_ijk = self.sa_agg(ft_ijk) fp = [x] for agg in self.agg_methods: fp.append(agg(ft_ijk)) fp = torch.cat(fp,dim=-1) # Output return self.f_out(fp)
[docs] def get_out_size(self) -> int: return self.n_out
[docs]class GravNet(AbsGraphFeatExtractor): r''' GravNet GNN head (Qasim, Kieseler, Iiyama, & Pierini, 2019 https://link.springer.com/article/10.1140/epjc/s10052-019-7113-9). Passes features per vertex (batch x vertices x features) through several :class:`~lumin.nn.models.blocks.head.GravNetLayer` layers. Like the paper model, this has the option of caching and concatenating the outputs of each GravNet layer prior to the final layer. The features per vertex are then flattened/aggregated across the vertices to flat data (batch x features). Arguments: n_v: Number of vertices to expect per datapoint n_fpv: number features per vertex cat_means: if True, will extend the incoming features per vertex by including the means of all features across all vertices f_slr_depth: number of layers to use for the latent rep. NN n_s: number of latent-spatial dimensions to compute n_lr: number of features to compute per vertex for latent representation k: number of neighbours (including self) each vertex should consider when aggregating latent-representation features f_out_depth: number of layers to use for the output NN n_out: number of output features to compute per vertex gn_class: class to use for GravNet layers, default is :class:`~lumin.nn.models.blocks.gnn_blocks.GravNetLayer` use_sa: if true, will apply self-attention layer to the neighbourhhood features per vertex prior to aggregation do: dropout rate to be applied to hidden layers in the NNs bn: whether batch normalisation should be applied to hidden layers in the NNs act: activation function to apply to hidden layers in the NNs lookup_init: function taking choice of activation function, number of inputs, and number of outputs an returning a function to initialise layer weights. lookup_act: function taking choice of activation function and returning an activation function layer freeze: whether to start with module parameters set to untrainable bn_class: class to use for BatchNorm, default is `nn.BatchNorm1d` sa_class: class to use for self-attention layers, default is :class:`~lumin.nn.models.layers.self_attention.SelfAttention` ''' row_wise:Optional[bool] = True def __init__(self, n_v:int, n_fpv:int, cat_means:bool, f_slr_depth:int, n_s:int, n_lr:int, k:int, f_out_depth:int, n_out:Union[List[int],int], gn_class:Callable[[Dict[str,Any]],GravNetLayer]=GravNetLayer, use_sa:bool=False, do:float=0, bn:bool=False, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init, lookup_act:Callable[[str],Any]=lookup_act, bn_class:Callable[[int],nn.Module]=nn.BatchNorm1d, sa_class:Callable[[int],nn.Module]=SelfAttention, **kargs): super().__init__(n_v=n_v, n_fpv=n_fpv, do=do, bn=bn, act=act, lookup_init=lookup_init, lookup_act=lookup_act, bn_class=bn_class) store_attr(but=['n_v', 'n_fpv', 'do', 'bn', 'act', 'lookup_init', 'lookup_act', 'bn_class']) if not is_listy(self.n_out): self.n_out = [self.n_out] self.agg_methods = [lambda x: torch.mean(x,dim=2), lambda x: torch.max(x,dim=2)[0]] self.grav_layers = self._get_grav_layers() self.out_sz = (self.n_v, np.sum([l.get_out_size() for l in self.grav_layers])) def _get_grav_layers(self) -> nn.ModuleList: ls = [] gl = partial(self.gn_class, f_slr_depth=self.f_slr_depth, n_s=self.n_s, n_lr=self.n_lr,k=self.k, agg_methods=self.agg_methods, f_out_depth=self.f_out_depth, cat_means=self.cat_means, use_sa=self.use_sa, do=self.do, bn=self.bn, act=self.act, lookup_init=self.lookup_init, lookup_act=self.lookup_act, bn_class=self.bn_class, sa_class=self.sa_class) n_fpv = self.n_fpv for n in self.n_out: ls.append(gl(n_fpv=n_fpv, n_out=n)) n_fpv = ls[-1].get_out_size() return nn.ModuleList(ls)
[docs] def forward(self, x:Tensor) -> Tensor: r''' Passes input through the GravNet head. Arguments: x: row-wise tensor (batch x vertices x features) Returns: Resulting tensor row-wise tensor (batch x vertices x new features) ''' outs = [x] for l in self.grav_layers: outs.append(l(outs[-1])) return torch.cat(outs[1:], dim=-1)
[docs] def get_out_size(self) -> Tuple[int,int]: return self.out_sz
Read the Docs v: v0.8.0
Versions
latest
stable
v0.8.0
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials