Source code for mlcg_tk.scripts.produce_delta_forces

import os.path as osp
import sys

import mdtraj as md
import numpy as np
import torch

from mlcg.data.atomic_data import AtomicData
from mlcg.data._keys import FORCE_KEY
from mlcg.nn import SumOut

from mlcg_tk.input_generator.raw_dataset import *
from mlcg_tk.input_generator.utils import get_output_tag

from tqdm import tqdm

from time import ctime

from typing import Dict, List, Union, Callable, Optional
from jsonargparse import CLI


[docs] def remove_baseline_forces_collated( collated_data: AtomicData, model: SumOut ) -> AtomicData: """Compute the forces on the input :obj:`collated_data` with the :obj:`models` and remove them from the reference forces contained in :obj:`data_list`. The computation of the forces is done on the whole :obj:`data_list` at once so it should not be too large. Parameters ---------- collated_data: Collated list of AtomicData instances that contain the full reference forces models: SumOut object containing models that compute prior/baseline forces Returns ------- List[AtomicData]: Uncollated list of AtomicData instances, where the value of the 'forces' field is now the delta forces (original forces minus the baseline/prior forces). An additional field 'baseline' forces is added, whose value is equal to the baseline/prior forces """ model.eval() collated_data = model(collated_data) baseline_forces = collated_data.out[FORCE_KEY].detach() collated_data.forces -= baseline_forces collated_data.baseline_forces = baseline_forces return collated_data
[docs] def produce_delta_forces( dataset_name: str, names: List[str], tag: str, save_dir: str, prior_tag: str, prior_fn: str, device: str, batch_size: int, force_tag: Optional[str] = None, mol_num_batches: Optional[int] = 1, ): """ Removes prior energy terms from input forces to produce delta force input for training Parameters ---------- dataset_name : str Name given to specific dataset names : List[str] List of sample names tag : str Label given to all output files produced from dataset save_dir : str Path to directory from which input will be loaded and to which output will be saved prior_tag : str String identifying the specific combination of prior terms prior_fn : str Path to filename in which prior model is saved device: str Device on which to run delta force calculations batch_size : int Number of frames to take per batch force_tag: str Optional tag to identify input for a particular run of delta force calculation mol_num_batches : int If greater than 1, will load each molecule data from the specified number of batches that were be treated as different samples """ # prior_model = torch.load(open(prior_fn, "rb")).models.to(device) prior_model = torch.load(open(prior_fn, "rb")).to(device) dataset = RawDataset(dataset_name, names, tag, n_batches=mol_num_batches) for samples in tqdm( dataset, f"Processing delta forces for {dataset_name} dataset..." ): if not samples.has_saved_cg_output(save_dir, prior_tag): continue coords, forces, embeds, pdb, prior_nls = samples.load_cg_output( save_dir=save_dir, prior_tag=prior_tag ) num_frames = coords.shape[0] delta_forces = [] aux_data_list = [ AtomicData.from_points( pos=torch.tensor(coords[i]), forces=torch.tensor(forces[i]), atom_types=torch.tensor(embeds), masses=None, neighborlist=prior_nls, ) for i in range(batch_size) ] collated_data, _, _ = collate(aux_data_list[0].__class__, aux_data_list) collated_data = collated_data.to(device) slices = range(0, num_frames, batch_size) n_chunks = len(slices) - 1 for k in range(n_chunks): current_frames = slice(slices[k], slices[k + 1]) collated_data.pos = torch.tensor( coords[current_frames, :, :].reshape(-1, 3), device=device, ) collated_data.forces = torch.tensor( forces[current_frames, :, :].reshape(-1, 3), device=device, ) _ = remove_baseline_forces_collated( collated_data, prior_model, ) delta_force = ( collated_data.forces.detach() .cpu() .reshape(slices[k + 1] - slices[k], -1, 3) ) delta_forces.append(delta_force.numpy()) if slices[-1] < num_frames: # final piece last_batch_size = num_frames - slices[-1] collated_data, _, _ = collate( aux_data_list[0].__class__, aux_data_list[:last_batch_size] ) collated_data = collated_data.to(device) collated_data.pos = torch.tensor( coords[slices[-1] :, :, :].reshape(-1, 3), device=device, ) collated_data.forces = torch.tensor( forces[slices[-1] :, :, :].reshape(-1, 3), device=device, ) _ = remove_baseline_forces_collated( collated_data, prior_model, ) delta_force = ( collated_data.forces.detach().cpu().reshape(last_batch_size, -1, 3) ) delta_forces.append(delta_force.numpy()) fnout = os.path.join( save_dir, f"{get_output_tag([tag, samples.name, prior_tag, force_tag], placement='before')}delta_forces.npy", ) np.save( fnout, np.concatenate(delta_forces, axis=0).reshape(*coords.shape), )
def main(): print("Start produce_delta_forces.py: {}".format(ctime())) CLI([produce_delta_forces]) print("Finish produce_delta_forces.py: {}".format(ctime())) if __name__ == "__main__": main()