Commit e3fae462 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Clean up rdkit util tests

parent cd02ba8b
Loading
Loading
Loading
Loading
+4 −136
Original line number Diff line number Diff line
@@ -177,7 +177,6 @@ def compute_charges(mol):
def load_complex(molecular_complex,
                 add_hydrogens=True,
                 calc_charges=True,
                 pdbfix=True,
                 sanitize=True):
  """Loads a molecular complex.

@@ -197,8 +196,6 @@ def load_complex(molecular_complex,
    If true, add hydrogens via pdbfixer
  calc_charges: bool, optional
    If true, add charges via rdkit
  pdbfix: bool, optional
    If true, apply pdbfixer to clean up this molecule.
  sanitize: bool, optional
    If true, sanitize molecules via rdkit

@@ -218,7 +215,6 @@ def load_complex(molecular_complex,
        mol,
        add_hydrogens=add_hydrogens,
        calc_charges=calc_charges,
        pdbfix=pdbfix,
        sanitize=sanitize)
    if isinstance(loaded, list):
      fragments += loaded
@@ -230,8 +226,7 @@ def load_complex(molecular_complex,
def load_molecule(molecule_file,
                  add_hydrogens=True,
                  calc_charges=True,
                  sanitize=True,
                  pdbfix=True):
                  sanitize=True):
  """Converts molecule file to (xyz-coords, obmol object)

  Given molecule_file, returns a tuple of xyz coords of molecule
@@ -249,8 +244,6 @@ def load_molecule(molecule_file,
    If true, add charges via rdkit
  sanitize: bool, optional
    If true, sanitize molecules via rdkit
  pdbfix: bool, optional
    If true, apply pdbfixer to clean up this molecule.

  Returns
  -------
@@ -262,7 +255,6 @@ def load_molecule(molecule_file,
  This function requires RDKit to be installed.
  """
  from rdkit import Chem
  from rdkit.Chem.rdchem import AtomValenceException
  from_pdb = False
  if ".mol2" in molecule_file:
    my_mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
@@ -292,8 +284,9 @@ def load_molecule(molecule_file,
  if sanitize:
    try:
      Chem.SanitizeMol(my_mol)
    except AtomValenceException:
      logger.warn("Mol %s failed valence check" % Chem.MolToSmiles(my_mol))
    # Ideally we should catch AtomValenceException but Travis seems to choke on it for some reason.
    except:
      logger.warn("Mol %s failed sanitization" % Chem.MolToSmiles(my_mol))
  if calc_charges:
    # This updates in place
    compute_charges(my_mol)
@@ -380,128 +373,3 @@ def merge_molecules(molecules):
    for nextmol in molecules[1:]:
      combined = rdmolops.CombineMols(combined, nextmol)
    return combined


def compute_contact_centroid(molecular_complex, cutoff=4.5):
  """Computes the (x,y,z) centroid of the contact regions of this molecular complex.

  For a molecular complex, it's necessary for various featurizations
  that compute voxel grids to find a reasonable center for the
  voxelization. This function computes the centroid of all the contact
  atoms, defined as an atom that's within `cutoff` Angstroms of an
  atom from a different molecule.

  Parameters
  ----------
  molecular_complex: Object
    A representation of a molecular complex, produced by
    `rdkit_util.load_complex`.
  cutoff: float, optional
    The distance in Angstroms considered for computing contacts.
  """
  fragments = reduce_molecular_complex_to_contacts(molecular_complex, cutoff)
  coords = [frag[0] for frag in fragments]
  contact_coords = merge_molecules_xyz(coords)
  centroid = np.mean(contact_coords, axis=0)
  return (centroid)


def compute_ring_center(mol, ring_indices):
  """Computes 3D coordinates of a center of a given ring.

  Parameters:
  -----------
  mol: rdkit.rdchem.Mol
    Molecule containing a ring
  ring_indices: array-like
    Indices of atoms forming a ring

  Returns:
  --------
    ring_centroid: np.ndarray
      Position of a ring center
  """
  conformer = mol.GetConformer()
  ring_xyz = np.zeros((len(ring_indices), 3))
  for i, atom_idx in enumerate(ring_indices):
    atom_position = conformer.GetAtomPosition(atom_idx)
    ring_xyz[i] = np.array(atom_position)
  ring_centroid = compute_centroid(ring_xyz)
  return ring_centroid


def compute_ring_normal(mol, ring_indices):
  """Computes normal to a plane determined by a given ring.

  Parameters:
  -----------
  mol: rdkit.rdchem.Mol
    Molecule containing a ring
  ring_indices: array-like
    Indices of atoms forming a ring

  Returns:
  --------
  normal: np.ndarray
    Normal vector
  """
  conformer = mol.GetConformer()
  points = np.zeros((3, 3))
  for i, atom_idx in enumerate(ring_indices[:3]):
    atom_position = conformer.GetAtomPosition(atom_idx)
    points[i] = np.array(atom_position)

  v1 = points[1] - points[0]
  v2 = points[2] - points[0]
  normal = np.cross(v1, v2)
  return normal


def rotate_molecules(mol_coordinates_list):
  """Rotates provided molecular coordinates.

  Pseudocode:
  1. Generate random rotation matrix. This matrix applies a
     random transformation to any 3-vector such that, were the
     random transformation repeatedly applied, it would randomly
     sample along the surface of a sphere with radius equal to
     the norm of the given 3-vector cf.
     generate_random_rotation_matrix() for details
  2. Apply R to all atomic coordinates.
  3. Return rotated molecule

  Parameters
  ----------
  mol_coordinates_list: list
    Elements of list must be (N_atoms, 3) shaped arrays
  """
  R = generate_random_rotation_matrix()
  rotated_coordinates_list = []

  for mol_coordinates in mol_coordinates_list:
    coordinates = deepcopy(mol_coordinates)
    rotated_coordinates = np.transpose(np.dot(R, np.transpose(coordinates)))
    rotated_coordinates_list.append(rotated_coordinates)

  return (rotated_coordinates_list)


def compute_all_ecfp(mol, indices=None, degree=2):
  """Obtain molecular fragment for all atoms emanating outward to given degree.

  For each fragment, compute SMILES string (for now) and hash to
  an int. Return a dictionary mapping atom index to hashed
  SMILES.
  """

  ecfp_dict = {}
  from rdkit import Chem
  for i in range(mol.GetNumAtoms()):
    if indices is not None and i not in indices:
      continue
    env = Chem.FindAtomEnvironmentOfRadiusN(mol, degree, i, useHs=True)
    submol = Chem.PathToSubmol(mol, env)
    smile = Chem.MolToSmiles(submol)
    ecfp_dict[i] = "%s,%s" % (mol.GetAtoms()[i].GetAtomicNum(), smile)

  return ecfp_dict
+44 −15
Original line number Diff line number Diff line
@@ -19,7 +19,11 @@ class TestRdkitUtil(unittest.TestCase):
                                    '../../feat/tests/3ws9_ligand.sdf')

  def test_load_complex(self):
    pass
    complexes = rdkit_util.load_complex(
        (self.protein_file, self.ligand_file),
        add_hydrogens=False,
        calc_charges=False)
    assert len(complexes) == 2

  def test_load_molecule(self):
    # adding hydrogens and charges is tested in dc.utils
@@ -66,7 +70,25 @@ class TestRdkitUtil(unittest.TestCase):
    assert after_hydrogen_count >= original_hydrogen_count

  def test_apply_pdbfixer(self):
    pass
    current_dir = os.path.dirname(os.path.realpath(__file__))
    ligand_file = os.path.join(current_dir, "../../dock/tests/1jld_ligand.sdf")
    xyz, mol = rdkit_util.load_molecule(
        ligand_file, calc_charges=False, add_hydrogens=False)
    original_hydrogen_count = 0
    for atom_idx in range(mol.GetNumAtoms()):
      atom = mol.GetAtoms()[atom_idx]
      if atom.GetAtomicNum() == 1:
        original_hydrogen_count += 1

    assert mol is not None
    mol = rdkit_util.apply_pdbfixer(mol, hydrogenate=True, is_protein=False)
    assert mol is not None
    after_hydrogen_count = 0
    for atom_idx in range(mol.GetNumAtoms()):
      atom = mol.GetAtoms()[atom_idx]
      if atom.GetAtomicNum() == 1:
        after_hydrogen_count += 1
    assert_true(after_hydrogen_count >= original_hydrogen_count)

  def test_compute_charges(self):
    current_dir = os.path.dirname(os.path.realpath(__file__))
@@ -83,19 +105,6 @@ class TestRdkitUtil(unittest.TestCase):
        has_a_charge = True
    assert has_a_charge

  def test_rotate_molecules(self):
    # check if distances do not change
    vectors = np.random.rand(4, 2, 3)
    norms = np.linalg.norm(vectors[:, 1] - vectors[:, 0], axis=1)
    vectors_rot = np.array(rdkit_util.rotate_molecules(vectors))
    norms_rot = np.linalg.norm(vectors_rot[:, 1] - vectors_rot[:, 0], axis=1)
    self.assertTrue(np.allclose(norms, norms_rot))

    # check if it works for molecules with different numbers of atoms
    coords = [np.random.rand(n, 3) for n in (10, 20, 40, 100)]
    coords_rot = rdkit_util.rotate_molecules(coords)
    self.assertEqual(len(coords), len(coords_rot))

  def test_load_molecule(self):
    current_dir = os.path.dirname(os.path.realpath(__file__))
    ligand_file = os.path.join(current_dir, "../../dock/tests/1jld_ligand.sdf")
@@ -160,3 +169,23 @@ class TestRdkitUtil(unittest.TestCase):
      second_atom_equal = np.all(xyz[i] == merged[i + len(xyz)])
      assert first_atom_equal
      assert second_atom_equal

  def test_merge_molecules(self):
    current_dir = os.path.dirname(os.path.realpath(__file__))
    ligand_file = os.path.join(current_dir, "../../dock/tests/1jld_ligand.sdf")
    xyz, mol = rdkit_util.load_molecule(
        ligand_file, calc_charges=False, add_hydrogens=False)
    num_mol_atoms = mol.GetNumAtoms()
    # self.ligand_file is for 3ws9_ligand.sdf
    oth_xyz, oth_mol = rdkit_util.load_molecule(
        self.ligand_file, calc_charges=False, add_hydrogens=False)
    num_oth_mol_atoms = oth_mol.GetNumAtoms()
    merged = rdkit_util.merge_molecules([mol, oth_mol])
    merged_num_atoms = merged.GetNumAtoms()
    assert merged_num_atoms == num_mol_atoms + num_oth_mol_atoms

  def test_merge_molecular_fragments(self):
    pass

  def test_strip_hydrogens(self):
    pass