Gradient/sum wrappers¶
These classes are used to build more complicated models.
- class mlcg.nn.gradients.GradientsOut(model, targets='forces')[source]¶
Gradient wrapper for models.
- Parameters:
targets (
str) – The gradient targets to produce from a model output. These can be any of the gradient properties referenced in mlcg.data._keys. At the moment only forces are implemented.
Example
To predict forces from an energy model, one would supply a model that predicts a scalar atom property (an energy) and specify the FORCE_KEY in the targets.
- class mlcg.nn.gradients.SumOut(models, targets=None)[source]¶
Property pooling wrapper for models
- Parameters:
models (
ModuleDict) – Dictionary of predictors models keyed by their name attributetargets (
List[str]) – List of prediction targets that will be pooled
Example
To combine SchNet force predictions with prior interactions:
import torch from mlcg.nn import (StandardSchNet, HarmonicBonds, HarmonicAngles, GradientsOut, SumOut, CosineCutoff, GaussianBasis) from mlcg.data._keys import FORCE_KEY, ENERGY_KEY bond_terms = GradientsOut(HarmonicBonds(bond_stats), FORCE_KEY) angle_terms = GradientsOut(HarmonicAngles(angle_stats), FORCE_KEY) cutoff = CosineCutoff() rbf = GaussianBasis(cutoff) energy_network = StandardSchNet(cutoff, rbf, [128]) force_network = GradientsOut(energy_model, FORCE_KEY) models = torch.nn.ModuleDict{ "bonds": bond_terms, "angles": angle_terms, "SchNet": force_network } full_model = SumOut(models, targets=[ENERGY_KEY, FORCE_KEY])