Source code for lumin.nn.models.layers.activations
from typing import Any
import torch
import torch.nn as nn
from torch import Tensor
from ....utils.misc import hard_identity
from .mish import Mish
__all__ = ["lookup_act", "Swish"]
[docs]def lookup_act(act: str) -> Any:
r"""
Map activation name to class
Arguments:
act: string representation of activation function
Returns:
Class implementing requested activation function
"""
if act == "relu":
return nn.ReLU()
if act == "prelu":
return nn.PReLU()
if act == "selu":
return nn.SELU()
if act == "sigmoid":
return nn.Sigmoid()
if act == "logsoftmax":
return nn.LogSoftmax(1)
if act == "softmax":
return nn.Softmax(1)
if act == "linear":
return hard_identity
if "swish" in act:
return Swish()
if act == "mish":
return Mish()
raise ValueError("Activation not implemented")
[docs]class Swish(nn.Module):
r"""
Non-trainable Swish activation function https://arxiv.org/abs/1710.05941
Arguments:
inplace: whether to apply activation inplace
Examples::
>>> swish = Swish()
"""
def __init__(self, inplace=False):
super().__init__()
self.inplace = False
[docs] def forward(self, x: Tensor) -> Tensor:
r"""
Pass tensor through Swish function
Arguments:
x: incoming tensor
Returns:
Resulting tensor
"""
if self.inplace:
x.mul_(torch.sigmoid(x))
return x # Do we need to return?
else:
return x * torch.sigmoid(x)