Source code for mlcg_tk.prior_tools.utils

from itertools import combinations_with_replacement
import numpy as np
import torch
from typing import List, Tuple

from mlcg.nn.prior import _Prior, Harmonic, Repulsion, Dihedral


[docs] def symmetrized_keys_generator(order: int, emb_max: int = 20) -> List[Tuple]: r""" Auxiliar function to generate the symmetric tuples of size `order` with integers from 1 to `emb_max` """ lim = emb_max + 1 if order == 1: return [(i) for i in range(1, lim)] elif order == 2: return list(combinations_with_replacement(range(1, lim), 2)) elif order == 3: aux = symmetrized_keys_generator(order=2, emb_max=emb_max) fin_list = [(arr[0], i, arr[1]) for i in range(1, lim) for arr in aux] return fin_list elif order == 4: aux = symmetrized_keys_generator(order=2, emb_max=emb_max) fin_list = [ (arr1[0], arr2[0], arr2[1], arr1[1]) for arr2 in aux for arr1 in aux ] return fin_list else: raise ValueError(f"Not implemente for order {order}")
[docs] def get_nonzero_keys(prior_module: _Prior) -> torch.Tensor: r""" Function to extract the key combinations of a prior that have a non-zero value for the parameters. """ if issubclass(type(prior_module), Harmonic): keys = prior_module.k.nonzero() elif isinstance(prior_module, Repulsion): keys = prior_module.sigma.nonzero() elif issubclass(type(prior_module), Dihedral): keys = prior_module.v_0.nonzero() else: raise ValueError(f"Prior of type {prior_module.__class__} not supported") return keys
[docs] def optimal_offset(fit_arr: np.ndarray, data_arr: np.ndarray) -> float: r"""Find optimal offset such that the difference between arrays is minimized. This functions returns the solution to the optimization problem of minimizing: \min_{\lambda \in R} \sum_{k=1}^{n} (data_arr[k]-fit_arr[k] + \lambda) This is useful to plot two curves as overlapping as possible. """ if len(data_arr) != len(fit_arr): raise ValueError("Arrays should be of the same length") mask = ~np.isinf(data_arr) diff = fit_arr[mask] - data_arr[mask] return np.sum(diff) / np.sum(mask)
[docs] def get_prior_domain(name: str, n=201) -> torch.Tensor: r""" Function to return a tensor with the domain where some common priors are defined """ if "angles" in name: # we use the cosine to parametrize the angles a, b = -1.1, 1.1 elif "bonds" in name: # usual range for CA-bonds a, b = 3.4, 4.4 elif name == "non_bonded" or name == "repulsion": # standard range for an excluded value a, b = 0.1, 7 elif "dihedral" in name: # dihedrals can range a, b = -torch.pi, torch.pi return torch.linspace(a, b, n)
[docs] def prior_evaluator(prior_module: _Prior, key: Tuple, x: torch.Tensor) -> torch.Tensor: r""" Evaluate the `prior_module` for bead combination `key` over tensor `x` """ if issubclass(type(prior_module), Harmonic): x_0 = prior_module.x_0[key].item() k = prior_module.k[key].item() res = prior_module.compute(x, x_0, k) elif isinstance(prior_module, Repulsion): sigma = prior_module.sigma[key[0], key[1]].item() res = prior_module.compute(x, sigma) elif issubclass(type(prior_module), Dihedral): v_0 = prior_module.v_0[key] k1s = [prior_module.k1s[idx][key] for idx in range(prior_module.n_degs)] k1s = torch.tensor(k1s).view(1, -1) k2s = [prior_module.k2s[idx][key] for idx in range(prior_module.n_degs)] k2s = torch.tensor(k2s).view(1, -1) res = prior_module.compute(x.view(1, -1), v_0, k1s, k2s) else: raise ValueError(f"Prior of type {prior_module.__class__} is not supported") return res