import torch
from typing import Sequence, Any, List
from ..data.atomic_data import AtomicData
from ..data._keys import *
[docs]
class SumOut(torch.nn.Module):
r"""Property pooling wrapper for models
Parameters
----------
models:
Dictionary of predictors models keyed by their name attribute
targets:
List of prediction targets that will be pooled
Example
-------
To combine SchNet force predictions with prior interactions:
.. code-block:: python
import torch
from mlcg.nn import (StandardSchNet, HarmonicBonds, HarmonicAngles,
GradientsOut, SumOut, CosineCutoff,
GaussianBasis)
from mlcg.data._keys import FORCE_KEY, ENERGY_KEY
bond_terms = GradientsOut(HarmonicBonds(bond_stats), FORCE_KEY)
angle_terms = GradientsOut(HarmonicAngles(angle_stats), FORCE_KEY)
cutoff = CosineCutoff()
rbf = GaussianBasis(cutoff)
energy_network = StandardSchNet(cutoff, rbf, [128])
force_network = GradientsOut(energy_model, FORCE_KEY)
models = torch.nn.ModuleDict{
"bonds": bond_terms,
"angles": angle_terms,
"SchNet": force_network
}
full_model = SumOut(models, targets=[ENERGY_KEY, FORCE_KEY])
"""
name: str = "SumOut"
def __init__(
self,
models: torch.nn.ModuleDict,
targets: List[str] = None,
):
super(SumOut, self).__init__()
if targets is None:
targets = [ENERGY_KEY, FORCE_KEY]
self.targets = targets
self.models = models
def forward(self, data: AtomicData) -> AtomicData:
r"""Sums output properties from individual models into global
property predictions
Parameters
----------
data:
AtomicData instance whose 'out' field has been populated
for each predictor in the model. For example:
.. code-block::python
AtomicData(
out: {
SchNet: {
ENERGY_KEY: ...,
FORCE_KEY: ...,
},
bonds: {
ENERGY_KEY: ...,
FORCE_KEY: ...,
},
...
)
Returns
-------
data:
AtomicData instance with updated 'out' field that now contains
prediction target keys that map to tensors that have summed
up the respective contributions from each predictor in the model.
For example:
.. code-block::python
AtomicData(
out: {
SchNet: {
ENERGY_KEY: ...,
FORCE_KEY: ...,
},
bonds: {
ENERGY_KEY: ...,
FORCE_KEY: ...,
},
ENERGY_KEY: ...,
FORCE_KEY: ...,
...
)
"""
for target in self.targets:
data.out[target] = 0.00
for name in self.models.keys():
data = self.models[name](data)
for target in self.targets:
data.out[target] += data.out[name][target]
return data
def neighbor_list(self, **kwargs):
nl = {}
for _, model in self.models.items():
nl.update(**model.neighbor_list(**kwargs))
return nl
class EnergyOut(torch.nn.Module):
r"""Extractor for energy computed via SchNet
Parameters
----------
model:
model whose target should be extyracted
targets:
List of prediction targets that will be extracted
"""
name: str = "EnergyOut"
def __init__(
self,
model: torch.nn.Module,
targets: List[str] = None,
):
super().__init__()
if targets is None:
targets = ["enegy"]
self.targets = targets
self.model = model
self.name = self.model.name
def forward(self, data: AtomicData) -> AtomicData:
data = self.model(data)
for target in self.targets:
data.out[target] = data.out[self.name][target]
return data
[docs]
class GradientsOut(torch.nn.Module):
r"""Gradient wrapper for models.
Parameters
----------
targets:
The gradient targets to produce from a model output. These can be any
of the gradient properties referenced in `mlcg.data._keys`.
At the moment only forces are implemented.
Example
-------
To predict forces from an energy model, one would supply a model that
predicts a scalar atom property (an energy) and specify the `FORCE_KEY`
in the targets.
"""
_targets = {FORCE_KEY: ENERGY_KEY}
def __init__(self, model: torch.nn.Module, targets: str = FORCE_KEY):
super(GradientsOut, self).__init__()
self.model = model
self.name = self.model.name
self.targets = []
if isinstance(targets, str):
self.targets = [targets]
elif isinstance(targets, Sequence):
self.targets = targets
assert any(
[k in GradientsOut._targets for k in self.targets]
), f"targets={self.targets} should be any of {GradientsOut._targets}"
def forward(self, data: AtomicData) -> AtomicData:
"""Forward pass through the gradient layer.
Parameters
----------
data:
AtomicData instance
Returns
-------
data:
Updated AtomicData instance, where the "out" field has
been populated with the base predictions of the model (eg,
the energy as well as the target predictions produced through
gradient operations.
"""
data.pos.requires_grad_(True)
data = self.model(data)
if FORCE_KEY in self.targets:
if self.name == "SumOut":
y = data.out[ENERGY_KEY]
else:
y = data.out[self.name][ENERGY_KEY]
dy_dr = torch.autograd.grad(
y.sum(),
data.pos,
# grad_outputs=torch.ones_like(y),
# retain_graph=self.training,
create_graph=self.training,
)[0]
if self.name == "SumOut":
data.out[FORCE_KEY] = -dy_dr
else:
data.out[self.name][FORCE_KEY] = -dy_dr
# assert not torch.any(torch.isnan(dy_dr)), f"nan in {self.name}"
data.pos = data.pos.detach()
return data
def neighbor_list(self, **kwargs: Any):
return self.model.neighbor_list(**kwargs)