Unverified Commit 6a74b9d4 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1369 from rbharath/atomic_conv_feat

Unit tests for Atomic conv featurization.
parents ce4860a5 db437763
Loading
Loading
Loading
Loading
+17 −0
Original line number Diff line number Diff line
@@ -188,6 +188,23 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):


class ComplexNeighborListFragmentAtomicCoordinates(ComplexFeaturizer):
  """This class computes the featurization that corresponds to AtomicConvModel.

  This class computes featurizations needed for AtomicConvModel. Given a
  two molecular structures, it computes a number of useful geometric
  features. In particular, for each molecule and the global complex, it
  computes a coordinates matrix of size (N_atoms, 3) where N_atoms is the
  number of atoms. It also computes a neighbor-list, a dictionary with
  N_atoms elements where neighbor-list[i] is a list of the atoms the i-th
  atom has as neighbors. In addition, it computes a z-matrix for the
  molecule which is an array of shape (N_atoms,) that contains the atomic
  number of that atom.

  Since the featurization computes these three quantities for each of the
  two molecules and the complex, a total of 9 quantities are returned for
  each complex. Note that for efficiency, fragments of the molecules can be
  provided rather than the full molecules themselves.
  """

  def __init__(self,
               frag1_num_atoms,
+38 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from deepchem.feat.atomic_coordinates import get_coords
from deepchem.feat.atomic_coordinates import AtomicCoordinates
from deepchem.feat.atomic_coordinates import NeighborListAtomicCoordinates
from deepchem.feat.atomic_coordinates import NeighborListComplexAtomicCoordinates
from deepchem.feat.atomic_coordinates import ComplexNeighborListFragmentAtomicCoordinates

logger = logging.getLogger(__name__)

@@ -156,3 +157,40 @@ class TestAtomicCoordinates(unittest.TestCase):
    assert len(system_neighbor_list.keys()) == N
    for atom in range(N):
      assert len(system_neighbor_list[atom]) <= max_num_neighbors

  def test_full_complex_featurization(self):
    """Unit test for ComplexNeighborListFragmentAtomicCoordinates."""
    dir_path = os.path.dirname(os.path.realpath(__file__))
    ligand_file = os.path.join(dir_path, "data/3zso_ligand_hyd.pdb")
    protein_file = os.path.join(dir_path, "data/3zso_protein.pdb")
    # Pulled from PDB files. For larger datasets with more PDBs, would use
    # max num atoms instead of exact.
    frag1_num_atoms = 44  # for ligand atoms
    frag2_num_atoms = 2336  # for protein atoms
    complex_num_atoms = 2380  # in total
    max_num_neighbors = 4
    # Cutoff in angstroms
    neighbor_cutoff = 4
    complex_featurizer = ComplexNeighborListFragmentAtomicCoordinates(
        frag1_num_atoms, frag2_num_atoms, complex_num_atoms, max_num_neighbors,
        neighbor_cutoff)
    (frag1_coords, frag1_neighbor_list, frag1_z, frag2_coords,
     frag2_neighbor_list, frag2_z, complex_coords,
     complex_neighbor_list, complex_z) = complex_featurizer._featurize_complex(
         ligand_file, protein_file)

    self.assertEqual(frag1_coords.shape, (frag1_num_atoms, 3))
    self.assertEqual(
        sorted(list(frag1_neighbor_list.keys())), list(range(frag1_num_atoms)))
    self.assertEqual(frag1_z.shape, (frag1_num_atoms,))

    self.assertEqual(frag2_coords.shape, (frag2_num_atoms, 3))
    self.assertEqual(
        sorted(list(frag2_neighbor_list.keys())), list(range(frag2_num_atoms)))
    self.assertEqual(frag2_z.shape, (frag2_num_atoms,))

    self.assertEqual(complex_coords.shape, (complex_num_atoms, 3))
    self.assertEqual(
        sorted(list(complex_neighbor_list.keys())),
        list(range(complex_num_atoms)))
    self.assertEqual(complex_z.shape, (complex_num_atoms,))