Shortcuts

Source code for lumin.nn.callbacks.loss_callbacks

from typing import Optional

import torch.nn as nn

from .callback import Callback, OldCallback
from ..models.abs_model import AbsModel, OldAbsModel

__all__ = ['GradClip']


class OldGradClip(OldCallback):
    r'''
    .. Attention:: This class is depreciated in favour of :class:`~lumin.nn.callbacks.loss_callbacks.GradClip`.
        It is a copy of the old `GradClip` class used in lumin<=0.6.
        It will be removed in V0.8
    '''

    # XXX remove in V0.8

    def __init__(self, clip:float, clip_norm:bool=True, model:Optional[OldAbsModel]=None):
        super().__init__(model=model)
        self.clip = clip
        self.func = nn.utils.clip_grad_norm_ if clip_norm else nn.utils.clip_grad_value_
        
    def on_backwards_end(self, **kargs) -> None:
        r'''
        Clips gradients prior to parameter updates
        '''

        if self.clip > 0: self.func(self.model.parameters(), self.clip)


[docs]class GradClip(Callback): r''' Callback for clipping gradients by norm or value. Arguments: clip: value to clip at clip_norm: whether to clip according to norm (`torch.nn.utils.clip_grad_norm_`) or value (`torch.nn.utils.clip_grad_value_`) Examples:: >>> grad_clip = GradClip(1e-5) ''' def __init__(self, clip:float, clip_norm:bool=True): super().__init__() self.clip = clip self.func = nn.utils.clip_grad_norm_ if clip_norm else nn.utils.clip_grad_value_
[docs] def on_backwards_end(self) -> None: r''' Clips gradients prior to parameter updates ''' if self.clip > 0: self.func(self.model.parameters(), self.clip)
Read the Docs v: v0.7.0
Versions
latest
stable
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials