Source code for mlcg_tk.input_generator.prior_fit.repulsion

import torch
from typing import Dict, Optional
from scipy.integrate import trapezoid
from scipy.optimize import curve_fit
import numpy as np


[docs] def repulsion(x, sigma): """Method defining the repulsion interaction""" rr = (sigma / x) * (sigma / x) return rr * rr * rr
[docs] def fit_repulsion_from_potential_estimates( bin_centers_nz: torch.Tensor, **kwargs ) -> Dict: r"""Method for fitting interaction parameters from data Parameters ---------- bin_centers: Bin centers from a discrete histgram used to estimate the energy through logarithmic inversion of the associated Boltzmann factor dG_nz: The value of the energy :math:`U` as a function of the bin centers, as retrieved via: :math:`U(x) = -\frac{1}{\beta}\log{ \left( p(x)\right)}` where :math:`\beta` is the inverse thermodynamic temperature and :math:`p(x)` is the normalized probability distribution of :math:`x`. Returns ------- Dict: Dictionary of interaction parameters as retrived through `scipy.optimize.curve_fit` """ delta = bin_centers_nz[1] - bin_centers_nz[0] sigma = bin_centers_nz[0] - 0.5 * delta stat = {"sigma": sigma} return stat
[docs] def fit_repulsion_from_values( bin_centers_nz: torch.Tensor, ncounts_nz: torch.Tensor, percentile: float, cutoff: Optional[float] = None, **kwargs, ) -> Dict: """Method for fitting interaction parameters directly from input features Parameters ---------- values: Input features as a tensor of shape (n_frames) percentile: If specified, the sigma value is calculated using the specified distance percentile (eg, percentile = 1) sets the sigma value at the location of the 1th percentile of pairwise distances. This option is useful for estimating repulsions for distance distribtions with long lower tails or lower distance outliers. Must be a number from 0 to 1 cutoff: If specified, only those input values below this cutoff will be used in evaluating the percentile Returns ------- Dict: Dictionary of interaction parameters as retrived through `scipy.optimize.curve_fit` """ values = np.repeat(bin_centers_nz.numpy(), ncounts_nz.int().numpy()) if cutoff != None: values = values[values < cutoff] sigma = torch.tensor(np.percentile(values, percentile)) stat = {"sigma": sigma} return stat