Shortcuts

Source code for lumin.nn.models.layers.activations

from typing import Any

import torch
import torch.nn as nn
from torch import Tensor

from ....utils.misc import hard_identity
from .mish import Mish

__all__ = ["lookup_act", "Swish"]


[docs]def lookup_act(act: str) -> Any: r""" Map activation name to class Arguments: act: string representation of activation function Returns: Class implementing requested activation function """ if act == "relu": return nn.ReLU() if act == "prelu": return nn.PReLU() if act == "selu": return nn.SELU() if act == "sigmoid": return nn.Sigmoid() if act == "logsoftmax": return nn.LogSoftmax(1) if act == "softmax": return nn.Softmax(1) if act == "linear": return hard_identity if "swish" in act: return Swish() if act == "mish": return Mish() raise ValueError("Activation not implemented")
[docs]class Swish(nn.Module): r""" Non-trainable Swish activation function https://arxiv.org/abs/1710.05941 Arguments: inplace: whether to apply activation inplace Examples:: >>> swish = Swish() """ def __init__(self, inplace=False): super().__init__() self.inplace = False
[docs] def forward(self, x: Tensor) -> Tensor: r""" Pass tensor through Swish function Arguments: x: incoming tensor Returns: Resulting tensor """ if self.inplace: x.mul_(torch.sigmoid(x)) return x # Do we need to return? else: return x * torch.sigmoid(x)

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials