Source code for mlcg_tk.input_generator.prior_fit.dihedral

import torch
from typing import Dict, Optional
from scipy.integrate import trapezoid
from scipy.optimize import curve_fit
import numpy as np
from .utils import neg_log_likelihood


[docs] def dihedral( theta: torch.Tensor, v_0: torch.Tensor, k1s: torch.Tensor, k2s: torch.Tensor, ) -> torch.Tensor: """Compute the dihedral interaction for a list of angles and models parameters. The ineraction is computed as a sin/cos basis expansion up to N basis functions. Parameters ---------- theta : angles to compute the value of the dihedral interaction on v_0 : constant offset k1s : list of sin parameters k2s : list of cos parameters Returns ------- torch.Tensor: Dihedral interaction energy """ _, n_k = k1s.shape n_degs = torch.arange(1, n_k + 1, dtype=theta.dtype, device=theta.device) # expand the features w.r.t the mult integer so that it has the # shape of k1s and k2s angles = theta.view(-1, 1) * n_degs.view(1, -1) V = k1s * torch.sin(angles) + k2s * torch.cos(angles) # HOTFIX to avoid shape mismatch when using specialized priors # TODO: think of a better fix if v_0.ndim > 1: v_0 = v_0[:, 0] return V.sum(dim=1) + v_0
[docs] def dihedral_wrapper_fit_func(theta: torch.Tensor, *args) -> torch.Tensor: args = args[0] v_0 = torch.tensor(args[0]) k_args = args[1:] num_ks = len(k_args) // 2 k1s, k2s = k_args[:num_ks], k_args[num_ks:] k1s = torch.tensor(k1s).view(-1, num_ks) k2s = torch.tensor(k2s).view(-1, num_ks) return dihedral(theta, v_0, k1s, k2s)
def _init_parameters(n_degs): """Helper method for guessing initial parameter values""" p0 = [1.00] # start with constant offset k1s_0 = [1 for _ in range(n_degs)] k2s_0 = [1 for _ in range(n_degs)] p0.extend(k1s_0) p0.extend(k2s_0) return p0 def _init_parameter_dict(n_degs): """Helper method for initializing the parameter dictionary""" stat = {"k1s": {}, "k2s": {}, "v_0": 0.00} k1_names = ["k1_" + str(ii) for ii in range(1, n_degs + 1)] k2_names = ["k2_" + str(ii) for ii in range(1, n_degs + 1)] for ii in range(n_degs): k1_name = k1_names[ii] k2_name = k2_names[ii] stat["k1s"][k1_name] = {} stat["k2s"][k2_name] = {} return stat def _make_parameter_dict(stat, popt, n_degs): """Helper method for constructing a fitted parameter dictionary""" v_0 = popt[0] k_popt = popt[1:] num_k1s = int(len(k_popt) / 2) k1_names = sorted(list(stat["k1s"].keys())) k2_names = sorted(list(stat["k2s"].keys())) for ii in range(n_degs): k1_name = k1_names[ii] k2_name = k2_names[ii] stat["k1s"][k1_name] = {} stat["k2s"][k2_name] = {} if len(k_popt) > 2 * ii: stat["k1s"][k1_name] = k_popt[ii] stat["k2s"][k2_name] = k_popt[num_k1s + ii] else: stat["k1s"][k1_name] = 0 stat["k2s"][k2_name] = 0 stat["v_0"] = v_0 return stat def _compute_adjusted_R2(bin_centers_nz, dG_nz, mask, popt, free_parameters): """ Method for model selection using adjusted R2 Higher values imply better model selection """ dG_fit = dihedral_wrapper_fit_func(bin_centers_nz[mask], *[popt]) SSres = torch.sum(torch.square(dG_nz[mask] - dG_fit)) SStot = torch.sum(torch.square(dG_nz[mask] - torch.mean(dG_nz[mask]))) n_samples = len(dG_nz[mask]) R2 = 1 - (SSres / (n_samples - free_parameters - 1)) / (SStot / (n_samples - 1)) return R2 def _compute_aic(bin_centers_nz, dG_nz, mask, popt, free_parameters): """Method for computing the AIC""" aic = ( 2 * neg_log_likelihood( dG_nz[mask], dihedral_wrapper_fit_func(bin_centers_nz[mask], *[popt]), ) + 2 * free_parameters ) return aic def _linear_regression(bin_centers, targets, n_degs): """Vanilla linear regression""" features = [torch.ones_like(bin_centers)] for n in range(n_degs): features.append(torch.sin((n + 1) * bin_centers)) for n in range(n_degs): features.append(torch.cos((n + 1) * bin_centers)) features = torch.stack(features).t() targets = targets.to(features.dtype) sol = torch.linalg.lstsq(features, targets.t()) return sol
[docs] def fit_dihedral_from_potential_estimates( bin_centers_nz: torch.Tensor, dG_nz: torch.Tensor, n_degs: int = 6, constrain_deg: Optional[int] = None, regression_method: str = "linear", metric: str = "aic", **kwargs, ) -> Dict: """ Loop over n_degs basins and use either the AIC criterion or a prechosen degree to select best fit. Parameter fitting occurs over unmaksed regions of the free energy only. Parameters ---------- bin_centers_nz: Bin centers over which the fit is carried out dG_nz: The emperical free energy correspinding to the bin centers n_degs: The maximum number of degrees to attempt to fit if using the AIC criterion for prior model selection constrain_deg: If not None, a single fit is produced for the specified integer degree instead of using the AIC criterion for fit selection between multiple degrees regression_method: String specifying which regression method to use. If "nonlinear", the default `scipy.optimize.curve_fit` method is used. If 'linear', linear regression via `torch.linalg.lstsq` is used metric: If a constrain deg is not specified, this string specifies whether to use either AIC ('aic') or adjusted R squared ('r2') for automated degree selection. If the automatic degree determination fails, users should consider searching for a proper constrained degree. Returns ------- Dict: Statistics dictionary with fitted interaction parameters """ integral = torch.tensor( float(trapezoid(dG_nz.cpu().numpy(), bin_centers_nz.cpu().numpy())) ) mask = torch.abs(dG_nz) > 1e-4 * torch.abs(integral) if constrain_deg != None: assert isinstance(constrain_deg, int) stat = _init_parameter_dict(constrain_deg) if regression_method == "linear": popt = ( _linear_regression(bin_centers_nz[mask], dG_nz[mask], constrain_deg) .solution.numpy() .tolist() ) elif regression_method == "nonlinear": p0 = _init_parameters(constrain_deg) popt, _ = curve_fit( lambda theta, *p0: dihedral_wrapper_fit_func(theta, p0), bin_centers_nz[mask], dG_nz[mask], p0=p0, ) else: raise ValueError( "regression method {} is neither 'linear' nor 'nonlinear'".format( regression_method ) ) stat = _make_parameter_dict(stat, popt, constrain_deg) else: if metric == "aic": metric_func = _compute_aic best_func = min elif metric == "r2": metric_func = _compute_adjusted_R2 best_func = max else: raise ValueError("metric {} is neither 'aic' nor 'r2'".format(metric)) # Determine best fit for unknown # of parameters stat = _init_parameter_dict(n_degs) popts = [] metric_vals = [] try: for deg in range(1, n_degs + 1): free_parameters = 1 + (2 * deg) if regression_method == "linear": popt = ( _linear_regression(bin_centers_nz[mask], dG_nz[mask], deg) .solution.numpy() .tolist() ) elif regression_method == "nonlinear": p0 = _init_parameters(deg) popt, _ = curve_fit( lambda theta, *p0: dihedral_wrapper_fit_func(theta, p0), bin_centers_nz[mask], dG_nz[mask], p0=p0, ) else: raise ValueError( "regression method {} is neither 'linear' nor 'nonlinear'".format( regression_method ) ) metric_val = metric_func( bin_centers_nz, dG_nz, mask, popt, free_parameters ) popts.append(popt) metric_vals.append(metric_val) best_val = best_func(metric_vals) best_i_val = metric_vals.index(best_val) popt = popts[best_i_val] stat = _make_parameter_dict(stat, popt, n_degs) except: print(f"failed to fit potential estimate for Dihedral") stat = _init_parameter_dict(n_degs) k1_names = sorted(list(stat["k1s"].keys())) k2_names = sorted(list(stat["k2s"].keys())) for ii in range(n_degs): k1_name = k1_names[ii] k2_name = k2_names[ii] stat["k1s"][k1_name] = torch.tensor(float("nan")) stat["k2s"][k2_name] = torch.tensor(float("nan")) return stat