Source code for mlcg.geometry.topology

import warnings

import mdtraj
from mdtraj.core.element import Element

from ase.geometry.analysis import Analysis
from ase import Atoms
from typing import NamedTuple, List, Optional, Tuple, Dict, Callable, Union
import torch
import numpy as np
import networkx as nx
from itertools import combinations

from .utils import ase_z2name
from ..neighbor_list.neighbor_list import make_neighbor_list
from ._symmetrize import (
    _symmetrise_map,
    _symmetrise_angle_interaction,
    _symmetrise_distance_interaction,
)


[docs] class Atom(NamedTuple): """Define an atom Attributes ---------- type: int Atom type/integer label name: Optional[str] = None Atom name resname: Optional[str] = None Name of the residue containing the atom resid: Optional[int] = None Index of the desired residue """ #: type of the atom type: int #: name of the atom name: Optional[str] = None #: name of the residue containing the atom resname: Optional[str] = None #: number of the resid containing the atom resid: Optional[int] = None #: partial charge of the atom charge: Optional[float] = None
[docs] class Topology(object): """Topology of an isolated protein.""" #: types of the atoms types: List[int] #: name of the atoms names: List[str] #: name of the residue containing the atoms resnames: List[str] #: number of the resid containing the atoms resids: List[int] #: charge of the atoms charges: List[int] #: list of bonds between the atoms. Defines the bonded topology. bonds: Tuple[List[int], List[int]] #: list of angles formed by triplets of atoms angles: Tuple[List[int], List[int], List[int]] #: list of dihedrals formed by quadruplets of atoms dihedrals: Tuple[List[int], List[int], List[int], List[int]] #: list of impropers formed by quadruplets of atoms impropers: Tuple[List[int], List[int], List[int], List[int]] def __init__(self) -> None: super(Topology, self).__init__() self.types = [] self.names = [] self.resnames = [] self.resids = [] self.charges = [] self.bonds = ([], []) self.angles = ([], [], []) self.dihedrals = ([], [], [], []) self.impropers = ([], [], [], []) def add_atom( self, type: int, name: str, resname: Optional[str] = None, resid: Optional[int] = None, charge: Optional[float] = None, ): self.types.append(type) self.names.append(name) self.resnames.append(resname) self.resids.append(resid) self.charges.append(charge) @property def atoms(self): for type, name, resname, resid, charge in zip( self.types, self.names, self.resnames, self.resids, self.charges ): yield Atom( type=type, name=name, resname=resname, resid=resid, charge=charge, ) @property def n_atoms(self) -> int: """Number of atoms in the topology.""" return len(self.types) def types2torch(self, device: str = "cpu") -> torch.Tensor: return torch.tensor(self.types, dtype=torch.long, device=device) def bonds2torch(self, device: str = "cpu") -> torch.Tensor: return torch.tensor(self.bonds, dtype=torch.long, device=device) def angles2torch(self, device: str = "cpu") -> torch.Tensor: return torch.tensor(self.angles, dtype=torch.long, device=device) def dihedrals2torch(self, device: str = "cpu") -> torch.Tensor: return torch.tensor(self.dihedrals, dtype=torch.long, device=device) def impropers2torch(self, device: str = "cpu") -> torch.Tensor: return torch.tensor(self.impropers, dtype=torch.long, device=device) def fully_connected2torch(self, device: str = "cpu") -> torch.Tensor: ids = torch.arange(self.n_atoms) mapping = torch.cartesian_prod(ids, ids).t() mapping = mapping[:, mapping[0] != mapping[1]] return mapping
[docs] def neighbor_list(self, type: str, device: str = "cpu") -> Dict: """Build Neighborlist from a :ref:`mlcg.neighbor_list.neighbor_list.Topology`. Parameters ---------- type: kind of information to extract (should be in ["bonds", "angles", "dihedrals", "fully connected"]). device: device upon which the neighborlist is returned Returns ------- Dict: Neighborlist dictionary """ allowed_types = [ "bonds", "angles", "dihedrals", "impropers", "fully connected", ] assert type in allowed_types, f"type should be any of {allowed_types}" if type == "bonds": mapping = self.bonds2torch(device) elif type == "angles": mapping = self.angles2torch(device) elif type == "dihedrals": mapping = self.dihedrals2torch(device) elif type == "impropers": mapping = self.impropers2torch(device) elif type == "fully connected": mapping = self.fully_connected2torch(device) nl = make_neighbor_list( tag=type, order=mapping.shape[0], index_mapping=mapping, self_interaction=False, ) return nl
[docs] def add_bond(self, idx1: int, idx2: int) -> None: """Define a bond between two atoms. Parameters ---------- idx1: The index of the first atom in the bond idx2: The index of the second atom in the bond """ self.bonds[0].append(idx1) self.bonds[1].append(idx2)
[docs] def add_angle(self, idx1: int, idx2: int, idx3: int) -> None: r"""Define an angle between three atoms. `idx2` represents the apex/central angles: .. code:: 2---3 / 1 Parameters ---------- idx1: The index of the first atom defining the angle idx2: The index of the central atom defining the angle idx3: The index of the last atom defining the angle """ self.angles[0].append(idx1) self.angles[1].append(idx2) self.angles[2].append(idx3)
[docs] def add_dihedral(self, idx1: int, idx2: int, idx3: int, idx4: int) -> None: r""" The dihedral angle formed by a quadruplet of indices (1,2,3,4) is difined around the axis connecting index 2 and 3 (i.e., the angle between the planes spanned by indices (1,2,3) and (2,3,4)): .. code:: 4 | 2-----3 / 1 Parameters ---------- idx1: The index of the first atom defining the dihedral idx2: The index of the second atom defining the dihedral idx3: The index of the third atom defining the dihedral idx3: The index of the last atom defining the dihedral """ self.dihedrals[0].append(idx1) self.dihedrals[1].append(idx2) self.dihedrals[2].append(idx3) self.dihedrals[3].append(idx4)
[docs] def bonds_from_edge_index(self, edge_index: torch.Tensor) -> None: """Overwrites the internal bond list with the bonds defined in the supplied bond edge_index Parameters ---------- edge_index: Edge index tensor of shape (2, n_bonds) """ if edge_index.shape[0] != 2: raise ValueError("Bond edge index must have shape (2, n_bonds)") self.bonds = tuple(edge_index.numpy().tolist())
[docs] def angles_from_edge_index(self, edge_index: torch.Tensor) -> None: """Overwrites the internal angle list with the angles defined in the supplied angle edge_index Parameters ---------- edge_index: Edge index tensor of shape (3, n_angles) """ if edge_index.shape[0] != 3: raise ValueError("Angle edge index must have shape (3, n_angles)") self.angles = tuple(edge_index.numpy().tolist())
[docs] def dihedrals_from_edge_index(self, edge_index: torch.Tensor) -> None: """Overwrites the internal dihedral list with the dihedral defined in the supplied dihedral edge_index Parameters ---------- edge_index: Edge index tensor of shape (4, n_dihedrals) """ if edge_index.shape[0] != 4: raise ValueError( "Dihedral edge index must have shape (4, n_dihedrals)" ) self.dihedrals = tuple(edge_index.numpy().tolist())
[docs] def impropers_from_edge_index(self, edge_index: torch.Tensor) -> None: """Overwrites the internal improper list with the improper defined in the supplied improper edge_index Parameters ---------- edge_index: Edge index tensor of shape (4, n_impropers) """ if edge_index.shape[0] != 4: raise ValueError( "improper edge index must have shape (4, n_impropers)" ) self.impropers = tuple(edge_index.numpy().tolist())
[docs] def remove_bond(self, bond_removal_list) -> None: r"""Method to remove bonds given list of bonds to be removed. The changes are made in place to the Topology.bonds attribute. .. warning:: The order of the removal list matters, e.g., [1,2] and [2,1] are treated differently. Parameters ---------- bond_removal_list : list List of bonds, of shape (2, n_bonds), to be removed from current bond list. Format: [[index1, index2], ..., [index1, index2]] where index1 and index are the indices of the first and second atom involved in bonding, respectively. """ for bond in bond_removal_list: index1 = bond[0] index2 = bond[1] mask_1 = np.array(self.bonds[0]) == index1 mask_2 = np.array(self.bonds[1]) == index2 mask = np.array(mask_1) * np.array(mask_2) if True in mask: to_pop = np.where(mask)[0][0] self.bonds[0].pop(to_pop) self.bonds[1].pop(to_pop)
[docs] def remove_angle(self, angle_removal_list) -> None: r"""Method to remove angles given list of angles to be removed. The changes are made in place to the Topology.angles attribute. .. warning:: The order of the removal list matters, e.g., [1,2,3] and [3,2,1] are treated differently Parameters ---------- angle_removal_list : list List of angles, of shape (3, n_angles), to be removed from current angle list. Format: [[index1, index2, index3], ..., [index1, index2, index3]] where index1, index2, index3 are the indices of the first, second, and third atom involved in angle formation, respectively. """ for angle in angle_removal_list: index1 = angle[0] index2 = angle[1] index3 = angle[2] mask_1 = np.array(self.angles[0]) == index1 mask_2 = np.array(self.angles[1]) == index2 mask_3 = np.array(self.angles[2]) == index3 mask = np.array(mask_1) * np.array(mask_2) * np.array(mask_3) if True in mask: to_pop = np.where(mask)[0][0] self.angles[0].pop(to_pop) self.angles[1].pop(to_pop) self.angles[2].pop(to_pop)
[docs] def to_mdtraj(self) -> mdtraj.Topology: r"""Convert to mdtraj format. If the topology does not have a resids attribute, the resids will be written incrementally for each atom. Returns ------- mdtraj.Topology: MDTraj topology instance from mlcg Topology """ topo = mdtraj.Topology() chain = topo.add_chain() for i_at in range(self.n_atoms): if ( self.names[i_at].strip().upper() not in Element._elements_by_symbol ): element = Element( self.types[i_at], self.names[i_at], self.names[i_at], 10, 2 ) else: element = Element.getBySymbol(self.names[i_at]) if self.resids == None: residue = topo.add_residue(self.resnames[i_at], chain=chain) else: residue = topo.add_residue( self.resnames[i_at], chain=chain, resSeq=self.resids[i_at] ) topo.add_atom(self.names[i_at], element, residue) for idx in range(len(self.bonds[0])): idx1, idx2 = self.bonds[0][idx], self.bonds[1][idx] a1, a2 = topo.atom(idx1), topo.atom(idx2) topo.add_bond(a1, a2) return topo
[docs] @staticmethod def from_mdtraj(topology) -> "Topology": r"""Build topology from an existing mdtraj topology. Parameters ---------- topology: Input MDTraj topology Returns ------- Topology: Topology instance created from the input MDTraj topology """ # assert ( # topology.n_chains == 1 # ), f"Does not support multiple chains but {topology.n_chains}" topo = Topology() for at in topology.atoms: topo.add_atom( at.element.atomic_number, at.name, at.residue.name, at.residue.index, ) for at1, at2 in topology.bonds: topo.add_bond(at1.index, at2.index) return topo
[docs] @staticmethod def from_ase(mol: Atoms, unique=True) -> "Topology": r"""Build topology from an ASE Atoms instance .. warning:: The minimum image convention is applied to build the topology. Parameters ---------- mol: ASE atoms instance unique: If True, only the unique bonds and angles will be added to the resulting Topology object. If False, all redundant (backwards) bonds and angles will be added as well. Returns ------- Topology: Topology instance based on the ASE input """ analysis = Analysis(mol) topo = Topology() types = mol.get_atomic_numbers() names = [ase_z2name[anum] for anum in types] for name, atom_type in zip(names, types): topo.add_atom(atom_type, name) if unique: bond_list = analysis.unique_bonds else: bond_list = analysis.all_bonds for atom, neighbors in enumerate(bond_list[0]): for bonded_neighbor in neighbors: topo.bonds[0].append(atom) topo.bonds[1].append(bonded_neighbor) if unique: angle_list = analysis.unique_angles else: angle_list = analysis.all_angles for atom, end_point_list in enumerate(angle_list[0]): for end_points in end_point_list: topo.angles[0].append(end_points[0]) topo.angles[1].append(atom) topo.angles[2].append(end_points[1]) return topo
[docs] @staticmethod def from_file(filename: str) -> "Topology": """Uses mdtraj reader to read the input topology.""" topo = mdtraj.load(filename).topology return Topology.from_mdtraj(topo)
[docs] def draw( self, layout: Callable = nx.drawing.layout.spring_layout, layout_kwargs: Dict = None, drawing_kwargs: Dict = None, ) -> None: r"""Use NetworkX to draw the current molecular topology. by default, node labels correspond to atom types. Parameters ---------- layout: NetworkX layout drawing function (from networkx.drawing.layout) that determines the positions of the nodes layout_kwargs: keyword arguments for the node layout drawing function drawing_kwargs: keyword arguments for nx.draw """ from matplotlib.pyplot import get_cmap if layout_kwargs == None: layout_kwargs = {} if drawing_kwargs == None: drawing_kwargs = {} connectivity = get_connectivity_matrix(self) graph = nx.Graph(connectivity.numpy()) node_pos = layout(graph, **layout_kwargs) drawing_kwargs["pos"] = node_pos if "labels" not in list(drawing_kwargs.keys()): drawing_kwargs["labels"] = { node: str(self.types[node]) for node in graph.nodes } if "node_color" not in list(drawing_kwargs.keys()): num_colors = len(np.arange(1, max(self.types) + 2)) cmap = get_cmap("viridis", num_colors) drawing_kwargs["node_color"] = [ cmap.colors[node_type, :3] for node_type in self.types ] nx.draw(graph, **drawing_kwargs)
[docs] def get_connectivity_matrix( topology: Topology, directed: bool = False ) -> torch.Tensor: """Produces a full connectivity matrix from the graph structure implied by Topology.bonds Parameters ---------- topology: Topology for which a connectivity matrix will be constructed directed: If True, an asymmetric connectivity matrix will be returned correspending to a directed graph. If false, the connectivity matrix will be symmetric and the corresponding graph will be undirected. Returns ------- torch.Tensor: Torch tensor of shape (n_atoms, n_atoms) representing the connectivity/adjacency matrix from the bonded graph. """ if len(topology.bonds[0]) == 0 and len(topology.bonds[1]) == 0: raise ValueError("No bonds in the topology.") if topology.n_atoms == 0: raise ValueError("n_atoms is not specified in the topology") connectivity_matrix = torch.zeros(topology.n_atoms, topology.n_atoms) bonds = topology.bonds2torch() connectivity_matrix[bonds[0, :], bonds[1, :]] = 1 if directed == False: connectivity_matrix[bonds[1, :], bonds[0, :]] = 1 return connectivity_matrix
[docs] def add_chain_bonds(topology: Topology) -> None: r"""Add bonds to the topology assuming a chain-like pattern, i.e. atoms are linked together following their insertion order. A four atoms chain will are linked like: `1-2-3-4`. Parameters ---------- topology: Topology instance to which the bonds should be added """ for i in range(topology.n_atoms - 1): topology.add_bond(i, i + 1)
[docs] def add_chain_angles(topology: Topology) -> None: r"""Add angles to the topology assuming a chain-like pattern, i.e. angles are defined following the insertion order of the atoms in the topology. A four atoms chain `1-2-3-4` will fine the angles: `1-2-3, 2-3-4`. Parameters ---------- topology: Topology instance to which the angles should be added """ for i in range(topology.n_atoms - 2): topology.add_angle(i, i + 1, i + 2)
def add_chain_dihedrals(topology: Topology) -> None: r"""Add dihedrals to the topology assuming a chain-like pattern, i.e. dihedrals are defined following the insertion order of the atoms in the topology. A four atoms chain `1-2-3-4` will find the dihedral: `1-2-3-4`. """ for i in range(topology.n_atoms - 3): topology.add_dihedral(i, i + 1, i + 2, i + 3)
[docs] def get_n_pairs( connectivity_matrix: Union[torch.Tensor, nx.Graph, np.array], n: int = 3, unique: bool = True, ) -> torch.Tensor: r"""This function uses networkx to identify those pairs that are exactly n atoms away. Paths are found using Dijkstra's algorithm. Parameters ---------- connectivity_matrix: Connectivity/adjacency matrix of the molecular graph of shape (n_atoms, n_atoms) or a networkx graph object n: Number of atoms to count away from the starting atom, with the starting atom counting as n=1 unique: If True, the returned pairs will be unique and symmetrised. Returns ------- torch.Tensor: Edge index tensor of shape (2, n_pairs) """ if isinstance(connectivity_matrix, nx.Graph): graph = connectivity_matrix elif isinstance(connectivity_matrix, torch.Tensor): graph = nx.Graph(connectivity_matrix.numpy()) else: graph = nx.Graph(connectivity_matrix) pairs = ([], []) for atom in graph.nodes: n_hop_paths = nx.single_source_dijkstra_path(graph, atom, cutoff=n) termini = [ path[-1] for sub_atom, path in n_hop_paths.items() if len(path) == n ] for child_atom in termini: pairs[0].append(atom) pairs[1].append(child_atom) pairs = torch.tensor(pairs) if unique: pairs = _symmetrise_distance_interaction(pairs) pairs = torch.unique(pairs, dim=1) return pairs
[docs] def get_n_paths( connectivity_matrix: Union[torch.Tensor, nx.Graph, np.array], n: int = 3, unique: bool = True, ) -> torch.Tensor: r"""This function use networkx to grab all connected paths defined by n connecting edges. Paths are found using Dijkstra's algorithm. Parameters ---------- connectivity_matrix: Connectivity/adjacency matrix of the molecular graph of shape (n_atoms, n_atoms) or a networkx graph object n: Number of atoms to count away from the starting atom, with the starting atom counting as n=1 unique: If True, the returned pairs will be unique and symmetrised such that the lower bead index precedes the higher bead index in each pair. Returns ------- torch.Tensor: Path index tensor of shape (n, n_pairs) """ if n not in [2, 3, 4] and unique == True: raise NotImplementedError("Unique currently only works for n=2,3") if isinstance(connectivity_matrix, nx.Graph): graph = connectivity_matrix elif isinstance(connectivity_matrix, torch.Tensor): graph = nx.Graph(connectivity_matrix.numpy()) else: graph = nx.Graph(connectivity_matrix) final_paths = [[] for i in range(n)] for atom in graph.nodes: n_hop_paths = nx.single_source_dijkstra_path(graph, atom, cutoff=n) paths = [path for _, path in n_hop_paths.items() if len(path) == n] for path in paths: for k, sub_atom in enumerate(path): final_paths[k].append(sub_atom) final_paths = torch.tensor(final_paths) if unique and n in [2, 3, 4]: final_paths = _symmetrise_map[n](final_paths) final_paths = torch.unique(final_paths, dim=1) return final_paths
def get_improper_paths( connectivity_matrix: torch.Tensor, unique: bool = True ) -> torch.Tensor: r"""This function returns all paths defining an improper dihedral .. code:: k | i - j - l where the order of connected nodes is given as `[i,k,l,j]` - i.e., the central node is reported last. Parameters ---------- connectivity_matrix: Connectivity/adjacency matrix of the molecular graph of shape (n_atoms, n_atoms) unique: If True, the returned paths will be unique Returns ------- torch.Tensor: Path index tensor of shape (4, n_impropers) """ n = 4 neigh_counts = np.sum(connectivity_matrix.numpy(), axis=0) final_paths = [[] for i in range(n)] for i_nc, neigh_count in enumerate(neigh_counts): if neigh_count >= 3: neigh_list = np.where(connectivity_matrix.numpy()[i_nc] == 1)[0] for combo in combinations(neigh_list, 3): final_paths[-1].append(i_nc) for ic, ind in enumerate(combo): final_paths[ic].append(ind) final_paths = torch.tensor(final_paths) if unique: final_paths = torch.unique(final_paths, dim=1) return final_paths def _grab_atom_index_by_name( topology: mdtraj.Topology, atom_selection: Optional[np.ndarray] = None ) -> np.ndarray: """ rHelper function to select atoms indices based on atom names according to mdtraj scheme Some useful examples of possible :obj:`atom_selection` for (improper) dihedrals: Impropers: (Central atom must go last) GAMMA1_ATOMS = ["N", "CB", "C", "CA"] GAMMA2_ATOMS = ["CA", "O", "+N", "C"] Dihedrals: (Previous and next residue indicated by (-) and (+) sign) PHI_ATOMS = ["-C", "N", "CA", "C"] PSI_ATOMS = ["N", "CA", "C", "+N"] OMEGA_ATOMS = ["CA", "C", "+N", "+CA"] Parameters ---------- topology: MDTraj topology instance atom_selection: Array of MDtraj atom names to select Returns ------- np.ndarray: Atom indices according to name """ if hasattr(topology, "topology"): warnings.warn( "Passing a Trajectory object. Please pass a Topology object", DeprecationWarning, ) topology = topology.topology if atom_selection is not None: improper_atoms = mdtraj.geometry.dihedral._atom_sequence( topology, atom_selection )[1] else: improper_atoms = None return improper_atoms def _residue_mapping_dictionary( topology: mdtraj.Topology, atom_indices=None ) -> Dict: r""" Helper function to assign each set of atom_indices to a specific residue type Parameters ---------- topology: mdtraj topology object atom_indices: np.ndarray (n_instances,n_atoms). A row is a specific interaction and where each column are the atoms involved in it. Returns ------- Dict: Dictionary that maps residues to atom indices within them """ from collections import defaultdict if hasattr(topology, "topology"): warnings.warn( "Passing a Trajectory object. Please pass a Topology object", DeprecationWarning, ) topology = topology.topology residue_dictionary = defaultdict(list) resids = np.array([atom.residue.name for atom in topology.atoms]) for i in range(atom_indices.shape[0]): group = atom_indices[i] res_group = resids[group] unique_res, counts = np.unique(res_group, return_counts=True) current_res = unique_res[np.argmax(counts)] residue_dictionary[current_res].append(group) for k, v in residue_dictionary.items(): residue_dictionary[k] = np.array(v) return residue_dictionary