Source code for mlcg_tk.input_generator.prior_nls

import torch

import mdtraj as md
from typing import Any, List, Union, Tuple, Optional
import numpy as np

from mlcg.geometry._symmetrize import _symmetrise_distance_interaction
from networkx.algorithms.shortest_paths.unweighted import (
    bidirectional_shortest_path,
)
import networkx as nx
from mlcg.geometry.topology import (
    Topology,
    get_connectivity_matrix,
    get_n_paths,
)

from .utils import get_dihedral_groups, split_bulk_termini
from .embedding_maps import all_residues


[docs] def check_graph_distance( graph: nx.Graph, conn_comp: List[set], node_1: int, node_2: int, min_distance: int ) -> bool: """Function to check if the shortest path between to nodes in a graph is smaller than `min_distance` This covers the case when the nodes are in different connected components before hand. to save computation time. """ con_1 = [i for i, comp in enumerate(conn_comp) if node_1 in comp][0] con_2 = [i for i, comp in enumerate(conn_comp) if node_2 in comp][0] if con_1 == con_2: shortest_path = bidirectional_shortest_path(graph, node_1, node_2) dist = len(shortest_path) return dist >= min_distance else: return True
[docs] class StandardBonds: """ Pairwise interactions corresponding to physically bonded atoms Attributes ---------- nl_names All possible outputs of bonded neighbourlist; If `separate_termini` is False, only bonds are returned, otherwise atom groups are split based on interactions between only bulk atoms or interactions with atoms in terminal residues. """ nl_names = ["n_term_bonds", "bulk_bonds", "c_term_bonds", "bonds"] def __call__( self, topology: md.Topology, separate_termini: bool = True, **kwargs ) -> Union[List[Tuple[str, int, torch.Tensor]], Tuple[str, int, torch.Tensor]]: """ Parameters ---------- topology: MDTraj topology object from which atom groups defining each prior term will be created. separate_termini: Whether atom groups should be split between bulk interactions and those involving atoms in terminal residues """ mlcg_top = Topology.from_mdtraj(topology) conn_mat = get_connectivity_matrix(mlcg_top).numpy() bond_edges = get_n_paths(conn_mat, n=2).numpy() if separate_termini: n_term_atoms, c_term_atoms = kwargs["n_term_atoms"], kwargs["c_term_atoms"] n_term_bonds, c_term_bonds, bulk_bonds = split_bulk_termini( n_term_atoms, c_term_atoms, bond_edges ) if len(bulk_bonds) == 0: bonds = [ ("n_term_bonds", 2, n_term_bonds), ("bulk_bonds", 2, torch.tensor([]).reshape(2, 0)), ("c_term_bonds", 2, c_term_bonds), ] elif len(n_term_bonds) == 0 or len(c_term_bonds) == 0: bonds = [ ("n_term_bonds", 2, torch.tensor([]).reshape(2, 0)), ("bulk_bonds", 2, bulk_bonds), ("c_term_bonds", 2, torch.tensor([]).reshape(2, 0)), ] else: bonds = [ ("n_term_bonds", 2, n_term_bonds), ("bulk_bonds", 2, bulk_bonds), ("c_term_bonds", 2, c_term_bonds), ] else: bonds = ("bonds", 2, bond_edges) return bonds
[docs] def get_fit_kwargs(self, nl_name): return {}
[docs] class StandardAngles: """ Interactions corresponding to angles formed between three physically bonded atoms Attributes ---------- nl_names All possible outputs of bonded neighbourlist; If `separate_termini` is False, only bonds are returned, otherwise atom groups are split based on interactions between only bulk atoms or interactions with atoms in terminal residues. """ nl_names = ["n_term_angles", "bulk_angles", "c_term_angles", "angles"] def __call__( self, topology: md.Topology, separate_termini: bool = True, **kwargs ) -> Union[List[Tuple[str, int, torch.Tensor]], Tuple[str, int, torch.Tensor]]: """ Parameters ---------- topology: MDTraj topology object from which atom groups defining each prior term will be created. separate_termini: Whether atom groups should be split between bulk interactions and those involving atoms in terminal residues """ mlcg_top = Topology.from_mdtraj(topology) conn_mat = get_connectivity_matrix(mlcg_top).numpy() angle_edges = get_n_paths(conn_mat, n=3).numpy() if separate_termini: n_term_atoms, c_term_atoms = kwargs["n_term_atoms"], kwargs["c_term_atoms"] n_term_angles, c_term_angles, bulk_angles = split_bulk_termini( n_term_atoms, c_term_atoms, angle_edges ) if len(bulk_angles) == 0: angles = [ ("n_term_angles", 3, n_term_angles), ("bulk_angles", 3, torch.tensor([]).reshape(3, 0)), ("c_term_angles", 3, c_term_angles), ] elif len(n_term_angles) == 0 or len(c_term_angles) == 0: angles = [ ("n_term_angles", 3, torch.tensor([]).reshape(3, 0)), ("bulk_angles", 3, bulk_angles), ("c_term_angles", 3, torch.tensor([]).reshape(3, 0)), ] else: angles = [ ("n_term_angles", 3, n_term_angles), ("bulk_angles", 3, bulk_angles), ("c_term_angles", 3, c_term_angles), ] else: angles = ("angles", 3, angle_edges) return angles
[docs] def get_fit_kwargs(self, nl_name): return {}
[docs] class Non_Bonded: """ Pairwise interactions corresponding to nonbonded atoms Attributes ---------- nl_names All possible outputs of bonded neighbourlist; If `separate_termini` is False, only bonds are returned, otherwise atom groups are split based on interactions between only bulk atoms or interactions with atoms in terminal residues. """ nl_names = ["n_term_nonbonded", "bulk_nonbonded", "c_term_nonbonded", "non_bonded"] def __call__( self, topology: md.Topology, bond_edges: Union[np.array, List, None] = None, angle_edges: Union[np.array, List, None] = None, min_pair: int = 6, res_exclusion: int = 1, separate_termini: bool = False, **kwargs, ) -> Union[List[Tuple[str, int, torch.Tensor]], Tuple[str, int, torch.Tensor]]: """ Parameters ---------- topology: MDTraj topology object from which atom groups defining each prior term will be created. bond_edges: All edges associated with bond atom groups already defined angle_edges: All edges associated with angle atom groups already defined min_pair: Minimum number of bond edges between two atoms in order to be considered a member of the non-bonded set res_exclusion: If supplied, pairs within res_exclusion residues of each other are removed from the non-bonded set separate_termini: Whether atom groups should be split between bulk interactions and those involving atoms in terminal residues """ mlcg_top = Topology.from_mdtraj(topology) fully_connected_edges = _symmetrise_distance_interaction( mlcg_top.fully_connected2torch() ).numpy() conn_mat = get_connectivity_matrix(mlcg_top).numpy() graph = nx.Graph(conn_mat) conn_comps = list(nx.connected_components(graph)) pairs_parsed = np.array( [ p for p in fully_connected_edges.T if ( abs( topology.atom(p[0]).residue.index - topology.atom(p[1]).residue.index ) >= res_exclusion ) and ( graph.has_edge(p[0], p[1]) == False and check_graph_distance(graph, conn_comps, p[0], p[1], min_pair) ) and not np.all(bond_edges == p[:, None], axis=0).any() and not np.all(angle_edges[[0, 2], :] == p[:, None], axis=0).any() ] ) non_bonded_edges = torch.tensor(pairs_parsed.T) non_bonded_edges = torch.unique( _symmetrise_distance_interaction(non_bonded_edges), dim=1 ).numpy() if separate_termini: if "use_terminal_res" in kwargs and kwargs["use_terminal_res"] == True: n_atoms = kwargs["n_term_atoms"] c_atoms = kwargs["c_term_atoms"] else: n_atoms = kwargs["n_atoms"] c_atoms = kwargs["c_atoms"] n_term_nonbonded, c_term_nonbonded, bulk_nonbonded = split_bulk_termini( n_atoms, c_atoms, non_bonded_edges ) return [ ("n_term_nonbonded", 2, n_term_nonbonded), ("bulk_nonbonded", 2, bulk_nonbonded), ("c_term_nonbonded", 2, c_term_nonbonded), ] else: return ("non_bonded", 2, non_bonded_edges)
[docs] def get_fit_kwargs(self, nl_name): return {}
[docs] class Phi: """ Phi (proper) dihedral angle formed by the following atoms: C_{n-1} - N_{n} - CA_{n} - C_{n} where n represents the amino acid for which the angle is defined Attributes ---------- nl_names All possible outputs of bonded neighbourlist; Atom groups of phi angles of each amino acid are recorded separately """ nl_names = [f"{res}_phi" for res in all_residues] def __call__( self, topology: md.Topology, **kwargs ) -> Union[List[Tuple[str, int, torch.Tensor]], Tuple[str, int, torch.Tensor]]: dihedral_dict = get_dihedral_groups( topology, atoms_needed=["C", "N", "CA", "C"], offset=[-1.0, 0.0, 0.0, 0.0], tag="_phi", ) dihedrals = [] for res in all_residues: dihedral_tag = f"{res}_phi" if dihedral_tag in dihedral_dict: atom_groups = np.array(dihedral_dict[dihedral_tag]) dihedrals.append((dihedral_tag, 4, torch.tensor(atom_groups).T)) else: dihedrals.append((dihedral_tag, 4, torch.tensor([]).reshape(4, 0))) return dihedrals
[docs] def get_fit_kwargs(self, nl_name): if nl_name == "PRO_phi": return {"n_degs": 1, "constrain_deg": 1} else: return {"n_degs": 3, "constrain_deg": 3}
[docs] class Psi: """ Psi (proper) dihedral angle formed by the following atoms: N_{n} - CA_{n} - C_{n} - N_{n+1} where n represents the amino acid for which the angle is defined Attributes ---------- nl_names All possible outputs of bonded neighbourlist; Atom groups of psi angles of each amino acid are recorded separately """ nl_names = [f"{res}_psi" for res in all_residues] def __call__( self, topology: md.Topology, **kwargs ) -> Union[List[Tuple[str, int, torch.Tensor]], Tuple[str, int, torch.Tensor]]: dihedral_dict = get_dihedral_groups( topology, atoms_needed=["N", "CA", "C", "N"], offset=[0.0, 0.0, 0.0, 1.0], tag="_psi", ) dihedrals = [] for res in all_residues: dihedral_tag = f"{res}_psi" if dihedral_tag in dihedral_dict: atom_groups = np.array(dihedral_dict[dihedral_tag]) dihedrals.append((dihedral_tag, 4, torch.tensor(atom_groups).T)) else: dihedrals.append((dihedral_tag, 4, torch.tensor([]).reshape(4, 0))) return dihedrals
[docs] def get_fit_kwargs(self, nl_name): return {"n_degs": 3, "constrain_deg": 3}
[docs] class Omega: """ Omega (proper) dihedral angle formed by the following atoms: CA_{n-1} - C_{n-1} - N_{n} - C_{n} where n represents the amino acid for which the angle is defined Attributes ---------- nl_names All possible outputs of bonded neighbourlist; Atom groups of omega angles are recorded separately only for proline """ nl_names = ["pro_omega", "non_pro_omega"] replace_gly_ca_stats = True def __call__( self, topology: md.Topology, **kwargs ) -> List[Tuple[str, int, torch.Tensor]]: dihedral_dict = get_dihedral_groups( topology, atoms_needed=["CA", "C", "N", "CA"], offset=[-1, -1, 0, 0], tag="_omega", ) pro_omega = [] non_pro_omega = [] for dihedral_tag in dihedral_dict.keys(): atom_groups = np.array(dihedral_dict[dihedral_tag]) if dihedral_tag == "PRO_omega": pro_omega.extend(atom_groups) else: non_pro_omega.extend(atom_groups) dihedrals = [] for dihedral in ["pro_omega", "non_pro_omega"]: if len(eval(dihedral)) == 0: dihedrals.append((dihedral, 4, torch.tensor([]).reshape(4, 0))) else: dihedrals.append( (dihedral, 4, torch.tensor(np.array(eval(dihedral))).T) ) return dihedrals
[docs] def get_fit_kwargs(self, nl_name): if nl_name == "pro_omega": return {"n_degs": 2, "constrain_deg": 2} else: return {"n_degs": 1, "constrain_deg": 1}
[docs] class Gamma1: """ Improper dihedral angle formed by the following atoms: N_{n} - CB_{n} - C_{n} - CA_{n} where n represents the amino acid for which the angle is defined; gamma_1 angle is measured between the plane formed by the first, third, and fourth atom and the vector from the first to second atom. Attributes ---------- nl_names All possible outputs of bonded neighbourlist; Atom groups of gamma_1 angles are not separaeted by amino acid type """ nl_names = ["gamma_1"] def __call__( self, topology: md.Topology, **kwargs ) -> Tuple[str, int, torch.Tensor]: dihedral_dict = get_dihedral_groups( topology, atoms_needed=["N", "CB", "C", "CA"], offset=[0, 0, 0, 0], tag="_gamma_1", ) atom_groups = [] for res in dihedral_dict: atom_groups.extend(dihedral_dict[res]) if len(atom_groups) == 0: dihedrals = ("gamma_1", 4, torch.tensor([]).reshape(4, 0)) else: dihedrals = ("gamma_1", 4, torch.tensor(np.array(atom_groups)).T) return dihedrals
[docs] def get_fit_kwargs(self, nl_name): return {"n_degs": 1, "constrain_deg": 1}
[docs] class Gamma2: """ Improper dihedral angle formed by the following atoms: CA_{n} - O_{n} - N_{n+1} - C_{n} where n represents the amino acid for which the angle is defined; gamma_2 angle is measured between the plane formed by the first, third, and fourth atom and the vector from the first to second atom. Attributes ---------- nl_names All possible outputs of bonded neighbourlist; Atom groups of gamma_2 angles are not separaeted by amino acid type """ nl_names = ["gamma_2"] def __call__( self, topology: md.Topology, **kwargs ) -> Tuple[str, int, torch.Tensor]: dihedral_dict = get_dihedral_groups( topology, atoms_needed=["CA", "O", "N", "C"], offset=[0, 0, 1, 0], tag="_gamma_2", ) atom_groups = [] for res in dihedral_dict: atom_groups.extend(dihedral_dict[res]) if len(atom_groups) == 0: dihedrals = ("gamma_2", 4, torch.tensor([]).reshape(4, 0)) else: dihedrals = ("gamma_2", 4, torch.tensor(np.array(atom_groups)).T) return dihedrals
[docs] def get_fit_kwargs(self, nl_name): return {"n_degs": 1, "constrain_deg": 1}
[docs] class CA_pseudo_dihedral: """ Proper dihedral angle formed by the 4 subsequence CA. It should represent the dihedral formed by 4 ca-ca Attributes ---------- nl_names All possible outputs of bonded neighbourlist; Atom groups of psi angles of each amino acid are recorded separately """ nl_names = ["pseudo_ca_dihedral"] def __call__( self, topology: md.Topology, **kwargs ) -> Union[List[Tuple[str, int, torch.Tensor]], Tuple[str, int, torch.Tensor]]: dihedral_dict = get_dihedral_groups( topology, atoms_needed=["CA", "CA", "CA", "CA"], offset=[0.0, 1.0, 2.0, 3.0], tag="", ) all_dihedrals = [] for _, v in dihedral_dict.items(): all_dihedrals.extend(v) all_dihedrals_np = np.array(sorted(all_dihedrals, key=lambda arr: arr[0])) return [("pseudo_ca_dihedral", 4, torch.tensor(all_dihedrals_np).T)]
[docs] def get_fit_kwargs(self, nl_name): return {"n_degs": 5, "constrain_deg": 5}