Source code for mlcg_tk.input_generator.prior_fit.fit_potentials

import torch
from typing import Dict, Callable, Optional
import numpy as np
from scipy.integrate import trapezoid
from ..prior_gen import PriorBuilder
from ..embedding_maps import CGEmbeddingMap
from copy import deepcopy


[docs] def fit_potentials( nl_name: str, prior_builder: PriorBuilder, embedding_map: Optional[CGEmbeddingMap], temperature: float, ): r""" Fits energy function to atom type-specific statistics defined for a group of atoms in a neighbour list. Assumes that the resulting prior energies will be in [kcal/mol]! Parameters ---------- nl_name: Neighbour list label prior_builder: PriorBuilder object containing histogram data embedding_map [Optional]: Instance of CGEmbeddingMap defining coarse-grained mapping; required to alter GLY statistics if defined in PriorBuilder.nl_builder. temperature: Temperature of the simulation data data (default=300K) Returns ------- prior_models: nn.ModuleDict of prior models all_stats: dictionary of statistics dictionaries for each prior fit Returns ------- model :ref:`mlcg.nn.GradientsOut` module containing gathered statistics and estimated energy parameters based on the `TargetPrior`. The following key/value pairs are common across all `TargetPrior`s: .. code-block:: python (*specific_types) : { ... "p" : torch.Tensor of shape [n_bins], containing the normalized bin counts of the of the 1-D feature corresponding to the atom_type group (*specific_types) = (specific_types[0], specific_types[1], ...) "p_bin": : torch.Tensor of shape [n_bins] containing the bin center values "V" : torch.tensor of shape [n_bins], containing the emperically estimated free energy curve according to a direct Boltzmann inversion of the normalized probability distribution for the feature. "V_bin" : torch_tensor of shape [n_bins], containing the bin center values } where `...` indicates other sub-key/value pairs apart from those enumerated above, which may appear depending on the chosen `TargetPrior`. For example, if `TargetPrior` is `HarmonicBonds`, there will also be keys/values associated with estimated bond constants and means. """ histograms = prior_builder.histograms[nl_name] bin_centers = prior_builder.histograms.bin_centers prior_fit_fn = prior_builder.prior_fit_fn target_fit_kwargs = prior_builder.nl_builder.get_fit_kwargs(nl_name) kB = 0.0019872041 # kcal/(molâ‹…K) beta = 1 / (temperature * kB) statistics = {} for kf in list(histograms.keys()): hist = torch.tensor(histograms[kf]) mask = hist > 0 bin_centers_nz = bin_centers[mask] ncounts_nz = hist[mask] dG_nz = -torch.log(ncounts_nz) / beta params = prior_fit_fn( bin_centers_nz=bin_centers_nz, dG_nz=dG_nz, ncounts_nz=ncounts_nz, **target_fit_kwargs, ) statistics[kf] = params statistics[kf]["p"] = hist / trapezoid( hist.cpu().numpy(), x=bin_centers.cpu().numpy() ) statistics[kf]["p_bin"] = bin_centers statistics[kf]["V"] = dG_nz statistics[kf]["V_bin"] = bin_centers_nz if getattr(prior_builder.nl_builder, "replace_gly_ca_stats", False): statistics = replace_gly_stats( statistics, gly_bead=embedding_map["GLY"], ca_bead=embedding_map["CA"] ) prior_model = prior_builder.get_prior_model( statistics, nl_name, targets="forces", **target_fit_kwargs ) return prior_model
[docs] def replace_gly_stats(statistics, gly_bead, ca_bead): """ Helper method for replacing poor GLY statistics for dihedral NL with statistics associated with general CA beads. """ gly_atom_groups = [group for group in list(statistics.keys()) if gly_bead in group] for group in gly_atom_groups: gly_idx = group.index(gly_bead) ca_group = list(deepcopy(group)) ca_group[gly_idx] = ca_bead try: statistics[group] = statistics[tuple(ca_group)] except KeyError: ca_group = (ca_group[0], ca_group[2], ca_group[1], ca_group[3]) return statistics