Source code for mlcg.nn.mace

from typing import Any, Callable, Dict, List, Optional, Union, Final

import torch
from e3nn import o3

try:
    from mace.modules.radial import ZBLBasis
    from mace.tools.scatter import scatter_sum
    from mace.tools import to_one_hot

    from mace.modules.blocks import (
        EquivariantProductBasisBlock,
        LinearNodeEmbeddingBlock,
        LinearReadoutBlock,
        NonLinearReadoutBlock,
        RadialEmbeddingBlock,
    )
    from mace.modules.utils import get_edge_vectors_and_lengths

except ImportError as e:
    print(e)
    print(
        "Please install or set mace to your path before using this interface. "
        + "To install you can either run 'pip install git+https://github.com/ACEsuit/mace.git@v0.3.13', "
        + "or clone the repository and add it to your PYTHONPATH."
        ""
    )

# from ..pl.model import get_class_from_str
from ..data.atomic_data import AtomicData, ENERGY_KEY
from ..neighbor_list.neighbor_list import (
    atomic_data2neighbor_list,
    validate_neighborlist,
)

from e3nn.util.jit import compile_mode


[docs] @compile_mode("script") class MACE(torch.nn.Module): """ Implementation of MACE neural network model from https://github.com/ACEsuit/mace Args: atomic_numbers (torch.Tensor): Tensor of atomic numbers present in the system. node_embedding (torch.nn.Module): Module for embedding node (atom) attributes. radial_embedding (torch.nn.Module): Module for embedding radial (distance) features. spherical_harmonics (torch.nn.Module): Module for computing spherical harmonics of edge vectors. interactions (List[torch.nn.Module]): List of interaction blocks. products (List[torch.nn.Module]): List of product basis blocks. readouts (List[torch.nn.Module]): List of readout blocks. r_max (float): Cutoff radius for neighbor list. max_num_neighbors (int): Maximum number of neighbors per atom. pair_repulsion_fn (torch.nn.Module, optional): Optional pairwise repulsion energy function. """ name: Final[str] = "mace" def __init__( self, atomic_numbers: torch.Tensor, node_embedding: torch.nn.Module, radial_embedding: torch.nn.Module, spherical_harmonics: torch.nn.Module, interactions: List[torch.nn.Module], products: List[torch.nn.Module], readouts: List[torch.nn.Module], r_max: float, max_num_neighbors: int, pair_repulsion_fn: torch.nn.Module = None, ): super().__init__() self.register_buffer("atomic_numbers", atomic_numbers) self.node_embedding = node_embedding self.radial_embedding = radial_embedding self.spherical_harmonics = spherical_harmonics self.interactions = torch.nn.ModuleList(interactions) self.products = torch.nn.ModuleList(products) self.readouts = torch.nn.ModuleList(readouts) self.r_max = r_max self.max_num_neighbors = max_num_neighbors self.pair_repulsion_fn = pair_repulsion_fn self.register_buffer( "types_mapping", -1 * torch.ones(atomic_numbers.max() + 1, dtype=torch.long), ) self.types_mapping[atomic_numbers] = torch.arange( atomic_numbers.shape[0] )
[docs] def forward(self, data: AtomicData) -> AtomicData: """ Forward pass of the MACE model. Args: data (AtomicData): Input atomic data object. Returns: AtomicData: Output data with predicted energies in `data.out`. """ # Setup num_atoms_arange = torch.arange(data.pos.shape[0]) num_graphs = data.ptr.numel() - 1 # data.batch.max() node_heads = torch.zeros_like(data.batch) types_ids = self.types_mapping[data.atom_types].view(-1, 1) node_attrs = to_one_hot(types_ids, self.atomic_numbers.shape[0]) # Embeddings node_feats = self.node_embedding(node_attrs) neighbor_list = data.neighbor_list.get(self.name) if not self.is_nl_compatible(neighbor_list): neighbor_list = self.neighbor_list( data, self.r_max, self.max_num_neighbors )[self.name] edge_index = neighbor_list["index_mapping"] vectors, lengths = get_edge_vectors_and_lengths( positions=data.pos, edge_index=edge_index, shifts=neighbor_list["cell_shifts"], ) edge_attrs = self.spherical_harmonics(vectors) edge_feats = self.radial_embedding( lengths, node_attrs, edge_index, self.atomic_numbers ) if self.pair_repulsion_fn: pair_node_energy = self.pair_repulsion_fn( lengths, node_attrs, edge_index, self.atomic_numbers ) pair_energy = scatter_sum( src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs, ) # [n_graphs,] else: pair_energy = torch.zeros( data.batch.max() + 1, device=data.pos.device, dtype=data.pos.dtype, ) # Interactions energies = [pair_energy] for interaction, product, readout in zip( self.interactions, self.products, self.readouts ): node_feats, sc = interaction( node_attrs=node_attrs, node_feats=node_feats, edge_attrs=edge_attrs, edge_feats=edge_feats, edge_index=edge_index, ) node_feats = product( node_feats=node_feats, sc=sc, node_attrs=node_attrs ) node_energies = readout(node_feats, node_heads)[ num_atoms_arange, node_heads ] # [n_nodes, len(heads)] energy = scatter_sum( src=node_energies, index=data["batch"], dim=0, dim_size=num_graphs, ) # [n_graphs,] energies.append(energy) # Sum over energy contributions contributions = torch.stack(energies, dim=-1) total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] data.out[self.name] = {ENERGY_KEY: total_energy} return data
def is_nl_compatible(self, nl): is_compatible = False if validate_neighborlist(nl): if ( nl["order"] == 2 and nl["self_interaction"] is False and nl["rcut"] == self.r_max ): is_compatible = True return is_compatible
[docs] @staticmethod def neighbor_list( data: AtomicData, rcut: float, max_num_neighbors: int = 1000 ) -> dict: """Computes the neighborlist for :obj:`data` using a strict cutoff of :obj:`rcut`.""" return { MACE.name: atomic_data2neighbor_list( data, rcut, self_interaction=False, max_num_neighbors=max_num_neighbors, ) }
[docs] @compile_mode("script") class StandardMACE(MACE): """ Standard configuration of the MACE model. This class provides a convenient interface for constructing a MACE model with typical settings and block choices, including embedding, interaction, and readout modules. Args: r_max (float): Cutoff radius for neighbor list. num_bessel (int): Number of Bessel functions for radial basis. num_polynomial_cutoff (int): Number of polynomial cutoff functions. max_ell (int): Maximum angular momentum for spherical harmonics. interaction_cls (str): Class name for interaction blocks. interaction_cls_first (str): Class name for the first interaction block. num_interactions (int): Number of interaction blocks. hidden_irreps (str): Irreducible representations for hidden features. For example if only a scalar representation with 128 channels is used can be "128x0e". If also a vector representation is used can be "128x0e + 128x1o". MLP_irreps (str): Irreducible representations for MLP layers. avg_num_neighbors (float): Average number of neighbors per atom used for normalization and numerical stability. atomic_numbers (List[int]): List of atomic numbers in the system. correlation (Union[int, List[int]]): Correlation order(s) for product blocks. gate (Optional[Callable]): Activation function for non-linearities. max_num_neighbors (int, optional): Maximum number of neighbors per atom. pair_repulsion (bool, optional): Whether to use pairwise repulsion. distance_transform (str, optional): Distance transformation type. radial_MLP (Optional[List[int]], optional): Radial MLP architecture. radial_type (Optional[str], optional): Radial basis type. cueq_config (Optional[Dict[str, Any]], optional): Configuration for charge equilibration. """ def __init__( self, r_max: float, num_bessel: int, num_polynomial_cutoff: int, max_ell: int, interaction_cls: str, interaction_cls_first: str, num_interactions: int, hidden_irreps: str, MLP_irreps: str, avg_num_neighbors: float, atomic_numbers: List[int], correlation: Union[int, List[int]], gate: Optional[Callable], max_num_neighbors: int = 1000, pair_repulsion: bool = False, distance_transform: str = "None", radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", cueq_config: Optional[Dict[str, Any]] = None, ): from mlcg.pl.model import get_class_from_str atomic_numbers.sort() atomic_numbers = torch.as_tensor(atomic_numbers) num_elements = atomic_numbers.shape[0] hidden_irreps = o3.Irreps(hidden_irreps) MLP_irreps = o3.Irreps(MLP_irreps) if isinstance(correlation, int): correlation = [correlation] * num_interactions # Embedding node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps( [(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))] ) node_embedding = LinearNodeEmbeddingBlock( irreps_in=node_attr_irreps, irreps_out=node_feats_irreps, cueq_config=cueq_config, ) radial_embedding = RadialEmbeddingBlock( r_max=r_max, num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, radial_type=radial_type, distance_transform=distance_transform, ) edge_feats_irreps = o3.Irreps(f"{radial_embedding.out_dim}x0e") pair_repulsion_fn = None if pair_repulsion: pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) sh_irreps = o3.Irreps.spherical_harmonics(max_ell) num_features = hidden_irreps.count(o3.Irrep(0, 1)) interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() spherical_harmonics = o3.SphericalHarmonics( sh_irreps, normalize=True, normalization="component" ) if radial_MLP is None: radial_MLP = [64, 64, 64] # Interactions and readout inter = get_class_from_str(interaction_cls_first)( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, cueq_config=cueq_config, ) interactions = [inter] # Use the appropriate self connection at the first layer for proper E0 use_sc_first = False if "Residual" in interaction_cls_first: use_sc_first = True node_feats_irreps_out = inter.target_irreps prod = EquivariantProductBasisBlock( node_feats_irreps=node_feats_irreps_out, target_irreps=hidden_irreps, correlation=correlation[0], num_elements=num_elements, use_sc=use_sc_first, cueq_config=cueq_config, ) products = [prod] readouts = [ LinearReadoutBlock(hidden_irreps, o3.Irreps("1x0e"), cueq_config) ] for i in range(num_interactions - 1): if i == num_interactions - 2: hidden_irreps_out = str( hidden_irreps[0] ) # Select only scalars for last layer else: hidden_irreps_out = hidden_irreps inter = get_class_from_str(interaction_cls)( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, edge_attrs_irreps=sh_irreps, edge_feats_irreps=edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, cueq_config=cueq_config, ) interactions.append(inter) prod = EquivariantProductBasisBlock( node_feats_irreps=interaction_irreps, target_irreps=hidden_irreps_out, correlation=correlation[i + 1], num_elements=num_elements, use_sc=True, cueq_config=cueq_config, ) products.append(prod) if i == num_interactions - 2: readouts.append( NonLinearReadoutBlock( hidden_irreps_out, (1 * MLP_irreps).simplify(), gate, o3.Irreps("1x0e"), 1, cueq_config, ) ) else: readouts.append( LinearReadoutBlock( hidden_irreps, o3.Irreps("1x0e"), cueq_config ) ) super().__init__( atomic_numbers, node_embedding, radial_embedding, spherical_harmonics, interactions, products, readouts, r_max, max_num_neighbors, pair_repulsion_fn, )