Source code for mlcg_tk.scripts.fit_priors

import os.path as osp
import sys

from mlcg_tk.input_generator.raw_dataset import RawDataset
from mlcg_tk.input_generator.embedding_maps import CGEmbeddingMap
from mlcg_tk.input_generator.prior_gen import PriorBuilder
from mlcg_tk.input_generator.prior_fit import HistogramsNL
from mlcg_tk.input_generator.prior_fit.fit_potentials import fit_potentials
from mlcg_tk.input_generator.utils import get_output_tag
from mlcg_tk.input_generator.prior_fit.utils import compute_nl_unique_keys

from tqdm import tqdm
import torch
from time import ctime
import numpy as np
import pickle as pck
from typing import Dict, List, Union, Callable, Optional
from jsonargparse import CLI
from scipy.integrate import trapezoid
from collections import defaultdict
from copy import deepcopy
import warnings

# import seaborn as sns

from mlcg.nn.gradients import SumOut
from mlcg.utils import makedirs


[docs] def compute_statistics( dataset_name: str, names: List[str], tag: str, save_dir: str, stride: int, batch_size: int, prior_tag: str, prior_builders: List[PriorBuilder], embedding_map: CGEmbeddingMap, statistics_tag: Optional[str] = None, device: str = "cpu", save_figs: bool = True, save_sample_statistics: bool = False, weights_template_fn: Optional[str] = None, mol_num_batches: Optional[int] = 1, ): """ Computes structural features and accumulates statistics on dataset samples Parameters ---------- dataset_name : str Name given to specific dataset names : List[str] List of sample names tag : str Label given to all output files produced from dataset save_dir : str Path to directory from which input will be loaded and to which output will be saved stride: int Integer by which to stride frames batch_size : int Number of frames to take per batch prior_tag : str String identifying the specific combination of prior terms prior_builders : List[PriorBuilder] List of PriorBuilder objects and their corresponding parameters embedding_map : CGEmbeddingMap Mapping object statistics_tag : str String differentiating parameters used for statistics computation device: str Device on which to run delta force calculations save_sample_statistics: If true, will save individual list of prior builders with accumulated statistics of one molecule save_figs: bool Whether to plot histograms of computed statistics weights_template_fn : str Template file location of weights to use for accumulating statistics mol_num_batches : int If greater than 1, will save each molecule data into the specified number of batches that will be treated as different samples """ all_nl_names = set() nl_name2prior_builder = {} for prior_builder in prior_builders: for nl_name in prior_builder.nl_builder.nl_names: all_nl_names.add(nl_name) nl_name2prior_builder[nl_name] = prior_builder dataset = RawDataset(dataset_name, names, tag, n_batches=mol_num_batches) tmp_batch = dataset[0].load_cg_output_into_batches( save_dir, prior_tag, 1, 1, weights_template_fn=weights_template_fn )[0] nl_names = set(tmp_batch.neighbor_list.keys()) assert nl_names.issubset( all_nl_names ), f"some of the NL names '{nl_names}' in {dataset_name}:{tmp_batch.name} have not been registered in the nl_builder '{all_nl_names}'" nl_names_key_list = {} atom_types = tmp_batch.atom_types for nl_name in nl_names: mapping = tmp_batch.neighbor_list[nl_name]["index_mapping"] nl_names_key_list[nl_name] = compute_nl_unique_keys(atom_types, mapping) for samples in tqdm( dataset, f"Compute histograms of CG data for {dataset_name} dataset..." ): if not samples.has_saved_cg_output(save_dir, prior_tag): continue if weights_template_fn != None and not osp.exists( osp.join(save_dir, weights_template_fn.format(samples.name)) ): warnings.warn( f"Could not find weights for sample {samples.name}; the file {osp.join(save_dir, weights_template_fn.format(samples.name))} does not exist. This entry will be skipped." ) continue batch_list = samples.load_cg_output_into_batches( save_dir, prior_tag, batch_size, stride, weights_template_fn=weights_template_fn, ) nl_names = set(batch_list[0].neighbor_list.keys()) assert nl_names.issubset( all_nl_names ), f"some of the NL names '{nl_names}' in {dataset_name}:{samples.name} have not been registered in the nl_builder '{all_nl_names}'" if save_sample_statistics: sample_fnout = osp.join( save_dir, f"{get_output_tag([samples.tag, samples.name, prior_tag, statistics_tag], placement='before')}prior_builders.pck", ) sample_prior_builders = [ deepcopy(prior_builder) for prior_builder in prior_builders ] sample_nl_name2prior_builder = {} for prior_builder in sample_prior_builders: for nl_name in prior_builder.nl_builder.nl_names: if ( nl_name in prior_builder.histograms.data.keys() and nl_name not in nl_names ): prior_builder.histograms.data.pop(nl_name) prior_builder.histograms.data[nl_name].clear() sample_nl_name2prior_builder[nl_name] = prior_builder for batch in tqdm( batch_list, f"molecule name: {samples.name}", leave=False ): batch = batch.to(device) for nl_name in nl_names: prior_builder = sample_nl_name2prior_builder[nl_name] prior_builder.accumulate_statistics( nl_name, batch, nl_names_key_list[nl_name] ) with open(sample_fnout, "wb") as f: pck.dump(sample_prior_builders, f) continue # does not save accumulated statistics if sample statistics saved for batch in tqdm(batch_list, f"molecule name: {samples.name}", leave=False): batch = batch.to(device) for nl_name in nl_names: prior_builder = nl_name2prior_builder[nl_name] prior_builder.accumulate_statistics( nl_name, batch, nl_names_key_list[nl_name] ) key_map = {v: k for k, v in embedding_map.items()} if save_figs: for prior_builder in prior_builders: figs = prior_builder.histograms.plot_histograms(key_map) for tag, fig in figs: makedirs(osp.join(save_dir, f"{prior_tag}_plots")) fig.savefig( osp.join(save_dir, f"{prior_tag}_plots", f"hist_{tag}.png"), dpi=300, bbox_inches="tight", ) if not save_sample_statistics: # cummulative statistics are only saved if individual statistics were not saved fnout = osp.join( save_dir, f"{get_output_tag([samples.tag, prior_tag], placement='before')}prior_builders.pck", ) with open(fnout, "wb") as f: pck.dump(prior_builders, f)
[docs] def fit_priors( save_dir: str, prior_tag: str, embedding_map: CGEmbeddingMap, temperature: float, ): """ Fits potential energy estimates to computed statistics Parameters ---------- save_dir : str Path to directory from which input will be loaded and to which output will be saved prior_tag : str String identifying the specific combination of prior terms embedding_map : CGEmbeddingMap Mapping object temperature : float Temperature from which beta value will be computed """ prior_fn = osp.join(save_dir, f"{prior_tag}_prior_builders.pck") fnout = osp.join(save_dir, f"{prior_tag}_prior_model.pt") with open(prior_fn, "rb") as f: prior_builders = pck.load(f) nl_names = [] nl_name2prior_builder = {} for prior_builder in prior_builders: for nl_name in list(prior_builder.histograms.data.keys()): nl_names.append(nl_name) nl_name2prior_builder[nl_name] = prior_builder prior_models = {} pbar = tqdm(nl_names) for nl_name in pbar: pbar.set_description(f"Fiting prior {nl_name}") prior_builder = nl_name2prior_builder[nl_name] prior_model = fit_potentials( nl_name=nl_name, prior_builder=prior_builder, embedding_map=embedding_map, temperature=temperature, ) prior_models[nl_name] = prior_model modules = torch.nn.ModuleDict(prior_models) full_prior_model = SumOut(modules, targets=["energy", "forces"]) torch.save(full_prior_model, fnout)
def main(): print("Start fit_priors.py: {}".format(ctime())) CLI([compute_statistics, fit_priors]) print("Finish fit_priors.py: {}".format(ctime())) if __name__ == "__main__": main()