from typing import Union
import torch
import torch.nn as nn
import numpy as np
import math
class _Cutoff(nn.Module):
    r"""Abstract cutoff class"""
    def __init__(self):
        super(_Cutoff, self).__init__()
        self.cutoff_lower = None
        self.cutoff_upper = None
    def check_cutoff(self):
        if self.cutoff_upper < self.cutoff_lower:
            raise ValueError(
                "Upper cutoff {} is less than lower cutoff {}".format(
                    self.cutoff_upper, self.cutoff_lower
                )
            )
    def forward(self):
        raise NotImplementedError
class _OneSidedCutoff(_Cutoff):
    r"""Abstract classs for cutoff functions with a fuxed lower cutoff of 0"""
    def __init__(self):
        super(_OneSidedCutoff, self).__init__()
        self.cutoff_lower = 0
        self.cutoff_upper = None
    def forward(self):
        raise NotImplementedError
[docs]
class IdentityCutoff(_Cutoff):
    r"""Cutoff function that is one everywhere, but retains
    cutoff_lower and cutoff_upper attributes
    Parameters
    ----------
    cutoff_lower:
        left bound for the radial cutoff distance
    cutoff_upper:
        right bound for the radial cutoff distance
    """
    def __init__(self, cutoff_lower: float = 0, cutoff_upper: float = np.inf):
        super(IdentityCutoff, self).__init__()
        self.cutoff_lower = cutoff_lower
        self.cutoff_upper = cutoff_upper
        self.check_cutoff()
[docs]
    def forward(self, distances: torch.Tensor) -> torch.Tensor:
        r"""Fowrad method that returns a cutoff enevlope where all values are
        one
        Parameters
        ----------
        distances:
            Input distances of shape (total_num_distances)
        Returns
        -------
            Cutoff envelope filled with ones, of shape (total_num_edges)
        """
        return torch.ones_like(distances) 
 
[docs]
class CosineCutoff(_Cutoff):
    r"""Class implementing a cutoff envelope based a cosine signal in the
    interval `[lower_cutoff, upper_cutoff]`:
    .. math::
        \cos{\left( r_{ij} \times \pi / r_{high}\right)} + 1.0
    NOTE: The behavior of the cutoff is qualitatively different for lower
    cutoff values greater than zero when compared to the zero lower cutoff
    default. We recommend visualizing your basis to see if it makes physical
    sense.
    .. math::
        0.5 \cos{ \left[ \pi \left(2 \frac{r_{ij} - r_{low}}{r_{high}
         - r_{low}} + 1.0 \right)\right]} + 0.5
    """
    def __init__(self, cutoff_lower: float = 0.0, cutoff_upper: float = 5.0):
        super(CosineCutoff, self).__init__()
        self.cutoff_lower = cutoff_lower
        self.cutoff_upper = cutoff_upper
        self.check_cutoff()
[docs]
    def forward(self, distances: torch.Tensor) -> torch.Tensor:
        """Applies cutoff envelope to distances.
        Parameters
        ----------
        distances:
            Distances of shape (total_num_edges)
        Returns
        -------
        cutoffs:
            Distances multiplied by the cutoff envelope, with shape
            (total_num_edges)
        """
        if self.cutoff_lower > 0:
            cutoffs = 0.5 * (
                torch.cos(
                    math.pi
                    * (
                        2
                        * (distances - self.cutoff_lower)
                        / (self.cutoff_upper - self.cutoff_lower)
                        + 1.0
                    )
                )
                + 1.0
            )
            # remove contributions below the cutoff radius
            cutoffs = cutoffs * (distances < self.cutoff_upper).to(
                distances.dtype
            )
            cutoffs = cutoffs * (distances > self.cutoff_lower).to(
                distances.dtype
            )
            return cutoffs
        else:
            cutoffs = 0.5 * (
                torch.cos(distances * math.pi / self.cutoff_upper) + 1.0
            )
            # remove contributions beyond the cutoff radius
            cutoffs = cutoffs * (distances < self.cutoff_upper).to(
                distances.dtype
            )
            return cutoffs 
 
[docs]
class ShiftedCosineCutoff(_OneSidedCutoff):
    r"""Class of Behler cosine cutoff with an additional smoothing parameter.
    .. math::
        0.5 + 0.5  \cos{ \left[ \pi \left( \frac{r_{ij} - r_{high} +
        \sigma}{\sigma}\right)\right]}
    where :math:`\sigma` is the smoothing width.
    Parameters
    ----------
    cutoff:
        cutoff radius
    smooth_width:
        parameter that controls the extent of smoothing in the cutoff envelope.
    """
    def __init__(
        self,
        cutoff: Union[int, float] = 5.0,
        smooth_width: Union[int, float] = 0.5,
    ):
        super(ShiftedCosineCutoff, self).__init__()
        self.cutoff_upper = cutoff
        self.smooth_width = smooth_width
        # del self.cutoff_upper
        # self.register_buffer("cutoff_upper", torch.Tensor([cutoff]))
        # self.register_buffer("smooth_width", torch.Tensor([smooth_width]))
[docs]
    def forward(self, distances):
        """Compute cutoff function.
        Args:
            distances (torch.Tensor): values of interatomic distances.
        Returns:
            torch.Tensor: values of cutoff function.
        """
        # Compute values of cutoff function
        cutoffs = torch.ones_like(distances)
        mask = distances > self.cutoff_upper - self.smooth_width
        cutoffs[mask] = 0.5 + 0.5 * torch.cos(
            math.pi
            * (distances[mask] - self.cutoff_upper + self.smooth_width)
            / self.smooth_width
        )
        cutoffs[distances > self.cutoff_upper] = 0.0
        return cutoffs.view(-1)