from typing import Union
import torch
import torch.nn as nn
import numpy as np
import math
class _Cutoff(nn.Module):
r"""Abstract cutoff class"""
def __init__(self):
super(_Cutoff, self).__init__()
self.cutoff_lower = None
self.cutoff_upper = None
def check_cutoff(self):
if self.cutoff_upper < self.cutoff_lower:
raise ValueError(
"Upper cutoff {} is less than lower cutoff {}".format(
self.cutoff_upper, self.cutoff_lower
)
)
def forward(self):
raise NotImplementedError
class _OneSidedCutoff(_Cutoff):
r"""Abstract classs for cutoff functions with a fuxed lower cutoff of 0"""
def __init__(self):
super(_OneSidedCutoff, self).__init__()
self.cutoff_lower = 0
self.cutoff_upper = None
def forward(self):
raise NotImplementedError
[docs]
class IdentityCutoff(_Cutoff):
r"""Cutoff function that is one everywhere, but retains
cutoff_lower and cutoff_upper attributes
Parameters
----------
cutoff_lower:
left bound for the radial cutoff distance
cutoff_upper:
right bound for the radial cutoff distance
"""
def __init__(self, cutoff_lower: float = 0, cutoff_upper: float = np.inf):
super(IdentityCutoff, self).__init__()
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.check_cutoff()
[docs]
def forward(self, distances: torch.Tensor) -> torch.Tensor:
r"""Fowrad method that returns a cutoff enevlope where all values are
one
Parameters
----------
distances:
Input distances of shape (total_num_distances)
Returns
-------
Cutoff envelope filled with ones, of shape (total_num_edges)
"""
return torch.ones_like(distances)
[docs]
class CosineCutoff(_Cutoff):
r"""Class implementing a cutoff envelope based a cosine signal in the
interval `[lower_cutoff, upper_cutoff]`:
When the lower cutoff is set to 0, the function takes the form:
.. math::
\cos{\left( r_{ij} \times \pi / r_{high}\right)} + 1.0
For higher than zero values of the lower cutoff, the function takes
the form
.. math::
0.5 \cos{ \left[ \pi \left(2 \frac{r_{ij} - r_{low}}{r_{high}
- r_{low}} + 1.0 \right)\right]} + 0.5
.. note::
The behavior of the cutoff is qualitatively different for lower
cutoff values greater than zero when compared to the zero lower cutoff
default. We recommend visualizing your basis via the `plot` method.
"""
def __init__(self, cutoff_lower: float = 0.0, cutoff_upper: float = 5.0):
super(CosineCutoff, self).__init__()
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.check_cutoff()
[docs]
def forward(self, distances: torch.Tensor) -> torch.Tensor:
"""Applies cutoff envelope to distances.
Parameters
----------
distances:
Distances of shape (total_num_edges)
Returns
-------
cutoffs:
Distances multiplied by the cutoff envelope, with shape
(total_num_edges)
"""
if self.cutoff_lower > 0:
cutoffs = 0.5 * (
torch.cos(
math.pi
* (
2
* (distances - self.cutoff_lower)
/ (self.cutoff_upper - self.cutoff_lower)
+ 1.0
)
)
+ 1.0
)
# remove contributions below the cutoff radius
cutoffs = cutoffs * (distances < self.cutoff_upper).to(
distances.dtype
)
cutoffs = cutoffs * (distances > self.cutoff_lower).to(
distances.dtype
)
return cutoffs
else:
cutoffs = 0.5 * (
torch.cos(distances * math.pi / self.cutoff_upper) + 1.0
)
# remove contributions beyond the cutoff radius
cutoffs = cutoffs * (distances < self.cutoff_upper).to(
distances.dtype
)
return cutoffs
[docs]
class ShiftedCosineCutoff(_OneSidedCutoff):
r"""Class of Behler cosine cutoff with an additional smoothing parameter.
.. math::
0.5 + 0.5 \cos{ \left[ \pi \left( \frac{r_{ij} - r_{high} +
\sigma}{\sigma}\right)\right]}
where :math:`\sigma` is the smoothing width.
Parameters
----------
cutoff:
cutoff radius
smooth_width:
parameter that controls the extent of smoothing in the cutoff envelope.
"""
def __init__(
self,
cutoff: Union[int, float] = 5.0,
smooth_width: Union[int, float] = 0.5,
):
super(ShiftedCosineCutoff, self).__init__()
self.cutoff_upper = cutoff
self.smooth_width = smooth_width
# del self.cutoff_upper
# self.register_buffer("cutoff_upper", torch.Tensor([cutoff]))
# self.register_buffer("smooth_width", torch.Tensor([smooth_width]))
[docs]
def forward(self, distances):
"""Compute cutoff function.
Args:
distances (torch.Tensor): values of interatomic distances.
Returns:
torch.Tensor: values of cutoff function.
"""
# Compute values of cutoff function
cutoffs = torch.ones_like(distances)
mask = distances > self.cutoff_upper - self.smooth_width
cutoffs[mask] = 0.5 + 0.5 * torch.cos(
math.pi
* (distances[mask] - self.cutoff_upper + self.smooth_width)
/ self.smooth_width
)
cutoffs[distances > self.cutoff_upper] = 0.0
return cutoffs.view(-1)