Commit 22798d00 authored by marta-sd's avatar marta-sd
Browse files

docstrings + small stylistic changes

parent 1ae9bc95
Loading
Loading
Loading
Loading
+300 −64
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ def merge_two_dicts(x, y):


def compute_centroid(coordinates):
  """Compute compute the x,y,z centroid of provided coordinates
  """Compute the x,y,z centroid of provided coordinates

  coordinates: np.ndarray
    Shape (N, 3), where N is number atoms.
@@ -59,14 +59,13 @@ def compute_centroid(coordinates):


def generate_random__unit_vector():
  """generate a random unit vector on the 3-sphere
  """Generate a random unit vector on the 3-sphere.
  citation:
  http://mathworld.wolfram.com/SpherePointPicking.html

  a. Choose random theta \element [0, 2*pi]
  b. Choose random z \element [-1, 1]
  c. Compute output: (x,y,z) = (sqrt(1-z^2)*cos(theta), sqrt(1-z^2)*sin(theta),z)
  d. output u
  c. Compute output vector u: (x,y,z) = (sqrt(1-z^2)*cos(theta), sqrt(1-z^2)*sin(theta),z)
  """

  theta = np.random.uniform(low=0.0, high=2 * np.pi)
@@ -79,10 +78,9 @@ def generate_random__unit_vector():

def generate_random_rotation_matrix():
  """
   1. Generate a random unit vector, i.e., randomly sampled from the unit
      3-sphere
    a. see function _generate_random__unit_vector() for details
    2. Generate a second random unit vector thru the algorithm in (1), output v
    1. Generate a random unit vector u, randomly sampled from the unit
      3-sphere (see function generate_random__unit_vector() for details)
    2. Generate a second random unit vector v
      a. If absolute value of u \dot v > 0.99, repeat.
       (This is important for numerical stability. Intuition: we want them to
        be as linearly independent as possible or else the orthogonalized
@@ -196,9 +194,8 @@ def hash_ecfp(ecfp, power):


def hash_ecfp_pair(ecfp_pair, power):
  """
  Returns an int of size 2^power representing that
  ECFP pair. Input must be a tuple of strings.
  """Returns an int of size 2^power representing that ECFP pair. Input must be
  a tuple of strings.
  """
  ecfp = "%s,%s" % (ecfp_pair[0], ecfp_pair[1])
  ecfp = ecfp.encode('utf-8')
@@ -210,9 +207,7 @@ def hash_ecfp_pair(ecfp_pair, power):


def compute_all_ecfp(mol, indices=None, degree=2):
  """
  For each atom:
    Obtain molecular fragment for all atoms emanating outward to given degree.
  """Obtain molecular fragment for all atoms emanating outward to given degree.
  For each fragment, compute SMILES string (for now) and hash to an int.
  Return a dictionary mapping atom index to hashed SMILES.
  """
@@ -275,7 +270,9 @@ def featurize_binding_pocket_ecfp(protein_xyz,
  pairwise_distances: np.ndarray
    Array of pairwise protein-ligand distances (Angstroms)
  cutoff: float
    Cutoff distance for contact consideration.
    Cutoff distance for contact consideration
  ecfp_degree: int
    ECFP radius
  """

  if pairwise_distances is None:
@@ -337,7 +334,7 @@ def compute_splif_features_in_range(protein,
                                    ecfp_degree=2):
  """Computes SPLIF features for protein atoms close to ligand atoms.

  Find all protein atoms that are > contact_bin[0] and < contact_bin[1] away
  Finds all protein atoms that are > contact_bin[0] and < contact_bin[1] away
  from ligand atoms. Then, finds the ECFP fingerprints for the contacting
  atoms. Returns a dictionary mapping (protein_index_i, ligand_index_j) -->
  (protein_ecfp_i, ligand_ecfp_j)
@@ -363,7 +360,7 @@ def featurize_splif(protein_xyz, protein, ligand_xyz, ligand, contact_bins,

  For each contact range (i.e. 1 A to 2 A, 2 A to 3 A, etc.) compute a
  dictionary mapping (protein_index_i, ligand_index_j) tuples -->
  (protein_ecfp_i, ligand_ecfp_j) tuples.  return a list of such splif
  (protein_ecfp_i, ligand_ecfp_j) tuples. Return a list of such splif
  dictionaries.
  """
  if pairwise_distances is None:
@@ -378,6 +375,20 @@ def featurize_splif(protein_xyz, protein, ligand_xyz, ligand, contact_bins,


def compute_ring_center(mol, ring_indices):
  """Computes 3D coordinates of a center of a given ring.

  Parameters:
  -----------
    mol: rdkit.rdchem.Mol
      Molecule containing a ring
    ring_indices: array-like
      Indices of atoms forming a ring

  Returns:
  --------
    ring_centroid: np.ndarray
      Position of a ring center
  """
  conformer = mol.GetConformer()
  ring_xyz = np.zeros((len(ring_indices), 3))
  for i, atom_idx in enumerate(ring_indices):
@@ -388,6 +399,20 @@ def compute_ring_center(mol, ring_indices):


def compute_ring_normal(mol, ring_indices):
  """Computes normal to a plane determined by a given ring.

  Parameters:
  -----------
    mol: rdkit.rdchem.Mol
      Molecule containing a ring
    ring_indices: array-like
      Indices of atoms forming a ring

  Returns:
  --------
    normal: np.ndarray
      Normal vector
  """
  conformer = mol.GetConformer()
  points = np.zeros((3, 3))
  for i, atom_idx in enumerate(ring_indices[:3]):
@@ -406,6 +431,23 @@ def is_pi_parallel(ring1_center,
                   ring2_normal,
                   dist_cutoff=8.0,
                   angle_cutoff=30.0):
  """Check if two aromatic rings form a parallel pi-pi contact.

  Parameters:
  -----------
    ring1_center, ring2_center: np.ndarray
      Positions of centers of the two rings. Can be computed with the
      compute_ring_center function.
    ring1_normal, ring2_normal: np.ndarray
      Normals of the two rings. Can be computed with the compute_ring_normal
      function.
    dist_cutoff: float
      Distance cutoff. Max allowed distance between the ring center (Angstroms).
    angle_cutoff: float
      Angle cutoff. Max allowed deviation from the ideal (0deg) angle between
      the rings (in degrees).
  """

  dist = np.linalg.norm(ring1_center - ring2_center)
  angle = angle_between(ring1_normal, ring2_normal) * 180 / np.pi
  if ((angle < angle_cutoff or angle > 180.0 - angle_cutoff) and
@@ -420,6 +462,22 @@ def is_pi_t(ring1_center,
            ring2_normal,
            dist_cutoff=5.5,
            angle_cutoff=30.0):
  """Check if two aromatic rings form a T-shaped pi-pi contact.

  Parameters:
  -----------
    ring1_center, ring2_center: np.ndarray
      Positions of centers of the two rings. Can be computed with the
      compute_ring_center function.
    ring1_normal, ring2_normal: np.ndarray
      Normals of the two rings. Can be computed with the compute_ring_normal
      function.
    dist_cutoff: float
      Distance cutoff. Max allowed distance between the ring center (Angstroms).
    angle_cutoff: float
      Angle cutoff. Max allowed deviation from the ideal (90deg) angle between
      the rings (in degrees).
  """
  dist = np.linalg.norm(ring1_center - ring2_center)
  angle = angle_between(ring1_normal, ring2_normal) * 180 / np.pi
  if ((90.0 - angle_cutoff < angle < 90.0 + angle_cutoff) and
@@ -433,22 +491,40 @@ def compute_pi_stack(protein,
                     pairwise_distances=None,
                     dist_cutoff=4.4,
                     angle_cutoff=30.):
  """
  """Find aromatic rings in protein and ligand that form pi-pi contacts.
  For each atom in the contact, count number of atoms in the other molecule
  that form this contact.

  Pseudocode:

  for each ring in ligand:
    if it is aromatic:
      for each ring in protein:
        if it is aromatic:
  for each aromatic ring in protein:
    for each aromatic ring in ligand:
      compute distance between centers
          compute angle.
      compute angle between normals
      if it counts as parallel pi-pi:
            for each atom in ligand and in protein,
              add to list of atom indices
        count interacting atoms
      if it counts as pi-T:
            for each atom in ligand and in protein:
              add to list of atom indices
        count interacting atoms

  Parameters:
  -----------
    protein, ligand: rdkit.rdchem.Mol
      Two interacting molecules.
    pairwise_distances: np.ndarray (optional)
      Array of pairwise protein-ligand distances (Angstroms)
    dist_cutoff: float
      Distance cutoff. Max allowed distance between the ring center (Angstroms).
    angle_cutoff: float
      Angle cutoff. Max allowed deviation from the ideal angle between rings.

  Returns:
  --------
    protein_pi_t, protein_pi_parallel, ligand_pi_t, ligand_pi_parallel: dict
      Dictionaries mapping atom indices to number of atoms they interact with.
      Separate dictionary is created for each type of pi stacking (parallel and
      T-shaped) and each molecule (protein and ligand).
  """

  protein_pi_parallel = Counter()
  protein_pi_t = Counter()
  ligand_pi_parallel = Counter()
@@ -460,11 +536,14 @@ def compute_pi_stack(protein,
                         (ligand, ligand_aromatic_rings)):
    aromatic_atoms = {atom.GetIdx() for atom in mol.GetAromaticAtoms()}
    for ring in Chem.GetSymmSSSR(mol):
      # if ring is aromatic
      if set(ring).issubset(aromatic_atoms):
        # save its indices, center, and normal
        ring_center = compute_ring_center(mol, ring)
        ring_normal = compute_ring_normal(mol, ring)
        ring_list.append((ring, ring_center, ring_normal))

  # remember protein-ligand pairs we already counted
  counted_pairs_parallel = set()
  counted_pairs_t = set()
  for prot_ring, prot_ring_center, prot_ring_normal in protein_aromatic_rings:
@@ -481,6 +560,7 @@ def compute_pi_stack(protein,
        for prot_atom_idx in prot_ring:
          for lig_atom_idx in lig_ring:
            if (prot_atom_idx, lig_atom_idx) not in counted_pairs_parallel:
              # if this pair is new, count atoms forming a contact
              prot_to_update.add(prot_atom_idx)
              lig_to_update.add(lig_atom_idx)
              counted_pairs_parallel.add((prot_atom_idx, lig_atom_idx))
@@ -497,12 +577,13 @@ def compute_pi_stack(protein,
          dist_cutoff=dist_cutoff):
        prot_to_update = set()
        lig_to_update = set()
        for i in prot_ring:
          for j in lig_ring:
            if (i, j) not in counted_pairs_t:
              prot_to_update.add(i)
              lig_to_update.add(j)
              counted_pairs_t.add((i, j))
        for prot_atom_idx in prot_ring:
          for lig_atom_idx in lig_ring:
            if (prot_atom_idx, lig_atom_idx) not in counted_pairs_t:
              # if this pair is new, count atoms forming a contact
              prot_to_update.add(prot_atom_idx)
              lig_to_update.add(lig_atom_idx)
              counted_pairs_t.add((prot_atom_idx, lig_atom_idx))

        protein_pi_t.update(prot_to_update)
        ligand_pi_t.update(lig_to_update)
@@ -515,6 +596,22 @@ def is_cation_pi(cation_position,
                 ring_normal,
                 dist_cutoff=6.5,
                 angle_cutoff=30.0):
  """Check if a cation and an aromatic ring form contact.

  Parameters:
  -----------
    ring_center: np.ndarray
      Positions of ring center. Can be computed with the compute_ring_center
      function.
    ring_normal: np.ndarray
      Normal of ring. Can be computed with the compute_ring_normal function.
    dist_cutoff: float
      Distance cutoff. Max allowed distance between ring center and cation
      (in Angstroms).
    angle_cutoff: float
      Angle cutoff. Max allowed deviation from the ideal (0deg) angle between
      ring normal and vector pointing from ring center to cation (in degrees).
  """
  cation_to_ring_vec = cation_position - ring_center
  dist = np.linalg.norm(cation_to_ring_vec)
  angle = angle_between(cation_to_ring_vec, ring_normal) * 180. / np.pi
@@ -525,7 +622,30 @@ def is_cation_pi(cation_position,


def compute_cation_pi(mol1, mol2, charge_tolerance=0.01, **kwargs):
  """Finds aromatic rings in mo1 interacting with cations in mol2"""
  """Finds aromatic rings in mo1 and cations in mol2 that interact with each
  other.

  Parameters:
  -----------
    mol1: rdkit.rdchem.Mol
      Molecule to look for interacting rings
    mol2: rdkit.rdchem.Mol
      Molecule to look for interacting cations
    charge_tolerance: float
      Atom is considered a cation if its formal charge is greater than
      1 - charge_tolerance
    **kwargs:
      Arguments that are passed to is_cation_pi function

  Returns:
  --------
    mol1_pi: dict
      Dictionary that maps atom indices (from mol1) to the number of cations
      (in mol2) they interact with
    mol2_cation: dict
      Dictionary that maps atom indices (from mol2) to the number of aromatic
      atoms (in mol1) they interact with
  """
  mol1_pi = Counter()
  mol2_cation = Counter()
  conformer = mol2.GetConformer()
@@ -534,23 +654,45 @@ def compute_cation_pi(mol1, mol2, charge_tolerance=0.01, **kwargs):
  rings = [list(r) for r in Chem.GetSymmSSSR(mol1)]

  for ring in rings:
    # if ring from mol1 is aromatic
    if set(ring).issubset(aromatic_atoms):
      ring_center = compute_ring_center(mol1, ring)
      ring_normal = compute_ring_normal(mol1, ring)

      for atom in mol2.GetAtoms():
        # ...and atom from mol2 is a cation
        if atom.GetFormalCharge() > 1.0 - charge_tolerance:
          cation_position = np.array(conformer.GetAtomPosition(atom.GetIdx()))
          # if angle and distance are correct
          if is_cation_pi(cation_position, ring_center, ring_normal, **kwargs):
            # count atoms forming a contact
            mol1_pi.update(ring)
            mol2_cation.update([atom.GetIndex()])
  return mol1_pi, mol2_cation


def compute_binding_pocket_cation_pi(protein, ligand, **kwargs):
  """Finds cation-pi interactions between protein and ligand.

  Parameters:
  -----------
    protein, ligand: rdkit.rdchem.Mol
      Interacting molecules
    **kwargs:
      Arguments that are passed to compute_cation_pi function

  Returns:
  --------
    protein_cation_pi, ligand_cation_pi: dict
      Dictionaries that maps atom indices to the number of cations/aromatic
      atoms they interact with
  """
  # find interacting rings from protein and cations from ligand
  protein_pi, ligand_cation = compute_cation_pi(protein, ligand, **kwargs)
  # find interacting cations from protein and rings from ligand
  ligand_pi, protein_cation = compute_cation_pi(ligand, protein, **kwargs)

  # merge counters
  protein_cation_pi = Counter()
  protein_cation_pi.update(protein_pi)
  protein_cation_pi.update(protein_cation)
@@ -563,6 +705,7 @@ def compute_binding_pocket_cation_pi(protein, ligand, **kwargs):


def get_partial_charge(atom):
  """Get partial charge of a given atom (rdkit Atom object)"""
  try:
    value = atom.GetProp(str("_GasteigerCharge"))
    if value == '-nan':
@@ -579,6 +722,7 @@ def get_formal_charge(atom):


def is_salt_bridge(atom_i, atom_j):
  """Check if two atoms have correct charges to form a salt bridge"""
  if np.abs(2.0 - np.abs(
      get_partial_charge(atom_i) - get_partial_charge(atom_j))) < 0.01:
    return True
@@ -591,6 +735,26 @@ def compute_salt_bridges(protein_xyz,
                         ligand,
                         pairwise_distances,
                         cutoff=5.0):
  """Find salt bridge contacts between protein and lingand.

  Parameters:
  -----------
  protein_xyz, ligand_xyz: np.ndarray
    Arrays with atomic coordinates
  protein, ligand: rdkit.rdchem.Mol
    Interacting molecules
  pairwise_distances: np.ndarray
    Array of pairwise protein-ligand distances (Angstroms)
  cutoff: float
    Cutoff distance for contact consideration

  Returns:
  --------
    salt_bridge_contacts: list of tuples
      List of contacts. Tuple (i, j) indicates that atom i from protein
      interacts with atom j from ligand.
  """

  salt_bridge_contacts = []

  contacts = np.nonzero(pairwise_distances < cutoff)
@@ -664,27 +828,37 @@ def convert_atom_to_voxel(molecule_xyz,
                          box_width,
                          voxel_width,
                          verbose=False):
  """Converts atom coordinates to an i,j,k grid index.

  Parameters:
  -----------
    molecule_xyz: np.ndarray
      Array with coordinates of all atoms in the molecule, shape (N, 3)
    atom_index: int
      Index of an atom
    box_width: float
      Size of a box
    voxel_width: float
      Size of a voxel
    verbose: bool
      Print warnings when atom is outside of a box
  """
  Converts an atom to an i,j,k grid index.
  """
  from warnings import warn

  indices = np.floor(
      (molecule_xyz[atom_index, :] + np.array([box_width, box_width, box_width]
                                             ) / 2.0) / voxel_width).astype(int)
      (molecule_xyz[atom_index] + box_width / 2.0) / voxel_width).astype(int)
  if ((indices < 0) | (indices >= box_width / voxel_width)).any():
    if verbose:
      warn(
          'Coordinates are outside of the box (atom id = %s, coords xyz = %s, coords in box = %s'
          % (atom_index, molecule_xyz[atom_index], indices))
      warn('Coordinates are outside of the box (atom id = %s,'
           ' coords xyz = %s, coords in box = %s' %
           (atom_index, molecule_xyz[atom_index], indices))

  return ([indices])


def convert_atom_pair_to_voxel(molecule_xyz_tuple, atom_index_pair, box_width,
                               voxel_width):
  """
  Converts a pair of atoms to a list of i,j,k tuples.
  """
  """Converts a pair of atoms to a list of i,j,k tuples."""

  indices_list = []
  indices_list.append(
      convert_atom_to_voxel(molecule_xyz_tuple[0], atom_index_pair[0],
@@ -720,6 +894,9 @@ def subtract_centroid(xyz, centroid):


class RdkitGridFeaturizer(ComplexFeaturizer):
  """Featurizes protein-ligand complex using flat features or a 3D grid (in which
  each voxel is described with a vector of features).
  """

  def __init__(self,
               nb_rotations=0,
@@ -734,6 +911,56 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
               flatten=False,
               verbose=True,
               **kwargs):
    """Parameters:
    -----------
    nb_rotations: int, optional (default 0)
      Number of additional random rotations of a complex to generate.
    feature_types: list, optional (default ['ecfp_ligand'])
      Types of features to calculate. Available types are:
        flat features: 'ecfp_ligand', 'ecfp_hashed', 'splif_hashed', 'hbond_count'
        voxel features: 'ecfp', 'splif', 'sybyl', 'salt_bridge', 'charge', 'hbond',
        'pi_stack, 'cation_pi'
      There are also 3 predefined sets of features: 'flat_combined',
      'voxel_combined', and 'all_combined'. Calculated features are concatenated
      and their order is preserved (features in predefined sets are in
      alphabetical order).
    ecfp_degree: int, optional (default 2)
      ECFP radius.
    ecfp_power: int, optional (default 3)
      Number of bits to store ECFP features (resulting vector will be
      2^ecfp_power long)
    splif_power: int, optional (default 3)
      Number of bits to store SPLIF features (resulting vector will be
      2^splif_power long)
    ligand_only: bool, optional (defaul False)
      Do not load protein. Can speed up computations when are used.
    box_width: float, optional (default 16.0)
      Size of a box in which voxel features are calculated. Box is centered on a
      ligand centroid.
    voxel_width: float, optional (default 1.0)
      Size of a 3D voxel in a grid.
    flatten: bool, optional (defaul False)
      Indicate whether calculated features should be flattened. Output is always
      flattened if flat features are specified in feature_types.
    verbose: bool, optional (defaul True)
      Verbolity for logging
    **kwargs: dict, optional
      Keyword arguments can be usaed to specify custom cutoffs and bins (see
      default values below).

    Default cutoffs and bins:
    -------------------------
      hbond_dist_bins: [(2.2, 2.5), (2.5, 3.2), (3.2, 4.0)]
      hbond_angle_cutoffs: [5, 50, 90]
      splif_contact_bins: [(0, 2.0), (2.0, 3.0), (3.0, 4.5)]
      ecfp_cutoff: 4.5
      sybyl_cutoff: 7.0
      salt_bridges_cutoff: 5.0
      pi_stack_dist_cutoff: 4.4
      pi_stack_angle_cutoff: 30.0
      cation_pi_dist_cutoff: 6.5
      cation_pi_angle_cutoff: 30.0
    """

    # check if user tries to set removed arguments
    deprecated_args = [
@@ -741,7 +968,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        'parallel', 'voxel_feature_types'
    ]
    for arg in deprecated_args:
      if arg in kwargs:
      if arg in kwargs and verbose:
        warn('%s argument was removed and it is ignored,'
             ' using it will result in error in version 1.4' % arg,
             DeprecationWarning)
@@ -789,6 +1016,9 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        "SO", "P3", "P", "P3+", "F", "Cl", "Br", "I"
    ]

    # define methods to calculate available flat features
    # all methods (flat and voxel) must have the same API:
    # f(prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances) -> list of np.ndarrays
    self.FLAT_FEATURES = {
        'ecfp_ligand': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [compute_ecfp_features(
@@ -874,6 +1104,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
          nb_channel=1)
      return [pi_parallel_tensor, pi_t_tensor]

    # define methods to calculate available voxel features
    self.VOXEL_FEATURES = {
        'ecfp': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [sum([self._voxelize(
@@ -988,13 +1219,16 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    if feature_types is None:
      feature_types = ['ecfp_ligand']

    # each entry is a tuple (is_flat, feature_name)
    self.feature_types = []

    # parse provided feature types
    for feature_type in feature_types:
      if feature_type in self.FLAT_FEATURES:
        self.feature_types.append((True, feature_type))
        if self.flatten is False:
          warn('%s feature is used, output will be flatten' % feature_type)
          if self.verbose:
            warn('%s feature is used, output will be flattened' % feature_type)
          self.flatten = True

      elif feature_type in self.VOXEL_FEATURES:
@@ -1005,7 +1239,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
            zip([True] * len(self.FLAT_FEATURES),
                sorted(self.FLAT_FEATURES.keys())))
        if self.flatten is False:
          warn('flat features are used, output will be flatten')
          if self.verbose:
            warn('flat features are used, output will be flattened')
          self.flatten = True

      elif feature_type == 'voxel_combined':
@@ -1020,9 +1255,10 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
            zip([False] * len(self.VOXEL_FEATURES),
                sorted(self.VOXEL_FEATURES.keys())))
        if self.flatten is False:
          warn('flat feature are used, output will be flatten')
          if self.verbose:
            warn('flat feature are used, output will be flattened')
          self.flatten = True
      else:
      elif self.verbose:
        warn('Ignoring unknown feature %s' % feature_type)

  def _featurize_complex(self, ligand_ext, ligand_lines, protein_pdb_lines):