Source code for mlcg.nn.cutoff

from typing import Union
import torch
import torch.nn as nn
import numpy as np
import math


class _Cutoff(nn.Module):
    r"""Abstract cutoff class"""

    def __init__(self):
        super(_Cutoff, self).__init__()
        self.cutoff_lower = None
        self.cutoff_upper = None

    def check_cutoff(self):
        if self.cutoff_upper < self.cutoff_lower:
            raise ValueError(
                "Upper cutoff {} is less than lower cutoff {}".format(
                    self.cutoff_upper, self.cutoff_lower
                )
            )

    def forward(self):
        raise NotImplementedError


class _OneSidedCutoff(_Cutoff):
    r"""Abstract classs for cutoff functions with a fuxed lower cutoff of 0"""

    def __init__(self):
        super(_OneSidedCutoff, self).__init__()
        self.cutoff_lower = 0
        self.cutoff_upper = None

    def forward(self):
        raise NotImplementedError


[docs] class IdentityCutoff(_Cutoff): r"""Cutoff function that is one everywhere, but retains cutoff_lower and cutoff_upper attributes Parameters ---------- cutoff_lower: left bound for the radial cutoff distance cutoff_upper: right bound for the radial cutoff distance """ def __init__(self, cutoff_lower: float = 0, cutoff_upper: float = np.inf): super(IdentityCutoff, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.check_cutoff()
[docs] def forward(self, distances: torch.Tensor) -> torch.Tensor: r"""Fowrad method that returns a cutoff enevlope where all values are one Parameters ---------- distances: Input distances of shape (total_num_distances) Returns ------- Cutoff envelope filled with ones, of shape (total_num_edges) """ return torch.ones_like(distances)
[docs] class CosineCutoff(_Cutoff): r"""Class implementing a cutoff envelope based a cosine signal in the interval `[lower_cutoff, upper_cutoff]`: .. math:: \cos{\left( r_{ij} \times \pi / r_{high}\right)} + 1.0 NOTE: The behavior of the cutoff is qualitatively different for lower cutoff values greater than zero when compared to the zero lower cutoff default. We recommend visualizing your basis to see if it makes physical sense. .. math:: 0.5 \cos{ \left[ \pi \left(2 \frac{r_{ij} - r_{low}}{r_{high} - r_{low}} + 1.0 \right)\right]} + 0.5 """ def __init__(self, cutoff_lower: float = 0.0, cutoff_upper: float = 5.0): super(CosineCutoff, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.check_cutoff()
[docs] def forward(self, distances: torch.Tensor) -> torch.Tensor: """Applies cutoff envelope to distances. Parameters ---------- distances: Distances of shape (total_num_edges) Returns ------- cutoffs: Distances multiplied by the cutoff envelope, with shape (total_num_edges) """ if self.cutoff_lower > 0: cutoffs = 0.5 * ( torch.cos( math.pi * ( 2 * (distances - self.cutoff_lower) / (self.cutoff_upper - self.cutoff_lower) + 1.0 ) ) + 1.0 ) # remove contributions below the cutoff radius cutoffs = cutoffs * (distances < self.cutoff_upper).to( distances.dtype ) cutoffs = cutoffs * (distances > self.cutoff_lower).to( distances.dtype ) return cutoffs else: cutoffs = 0.5 * ( torch.cos(distances * math.pi / self.cutoff_upper) + 1.0 ) # remove contributions beyond the cutoff radius cutoffs = cutoffs * (distances < self.cutoff_upper).to( distances.dtype ) return cutoffs
[docs] class ShiftedCosineCutoff(_OneSidedCutoff): r"""Class of Behler cosine cutoff with an additional smoothing parameter. .. math:: 0.5 + 0.5 \cos{ \left[ \pi \left( \frac{r_{ij} - r_{high} + \sigma}{\sigma}\right)\right]} where :math:`\sigma` is the smoothing width. Parameters ---------- cutoff: cutoff radius smooth_width: parameter that controls the extent of smoothing in the cutoff envelope. """ def __init__( self, cutoff: Union[int, float] = 5.0, smooth_width: Union[int, float] = 0.5, ): super(ShiftedCosineCutoff, self).__init__() self.cutoff_upper = cutoff self.smooth_width = smooth_width # del self.cutoff_upper # self.register_buffer("cutoff_upper", torch.Tensor([cutoff])) # self.register_buffer("smooth_width", torch.Tensor([smooth_width]))
[docs] def forward(self, distances): """Compute cutoff function. Args: distances (torch.Tensor): values of interatomic distances. Returns: torch.Tensor: values of cutoff function. """ # Compute values of cutoff function cutoffs = torch.ones_like(distances) mask = distances > self.cutoff_upper - self.smooth_width cutoffs[mask] = 0.5 + 0.5 * torch.cos( math.pi * (distances[mask] - self.cutoff_upper + self.smooth_width) / self.smooth_width ) cutoffs[distances > self.cutoff_upper] = 0.0 return cutoffs.view(-1)