Commit 8a1762d4 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Unit test beginnings

parent ce4860a5
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -188,6 +188,9 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):


class ComplexNeighborListFragmentAtomicCoordinates(ComplexFeaturizer):
  """This class computes the featurization that corresponds to AtomicConvModel.

  This class computes featurizations needed for AtomicConvModel. Given a two molecular structures, it computes a number of useful geometric features.

  def __init__(self,
               frag1_num_atoms,
@@ -228,6 +231,12 @@ class ComplexNeighborListFragmentAtomicCoordinates(ComplexFeaturizer):
           system_coords, system_neighbor_list, system_z

  def get_Z_matrix(self, mol, max_atoms):
    ################################################### DEBUG
    print("len(mol.GetAtoms())")
    print(len(mol.GetAtoms()))
    print("max_atoms")
    print(max_atoms)
    ################################################### DEBUG
    return pad_array(
        np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()]), max_atoms)

+32 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from deepchem.feat.atomic_coordinates import get_coords
from deepchem.feat.atomic_coordinates import AtomicCoordinates
from deepchem.feat.atomic_coordinates import NeighborListAtomicCoordinates
from deepchem.feat.atomic_coordinates import NeighborListComplexAtomicCoordinates
from deepchem.feat.atomic_coordinates import ComplexNeighborListFragmentAtomicCoordinates

logger = logging.getLogger(__name__)

@@ -156,3 +157,34 @@ class TestAtomicCoordinates(unittest.TestCase):
    assert len(system_neighbor_list.keys()) == N
    for atom in range(N):
      assert len(system_neighbor_list[atom]) <= max_num_neighbors

  def test_full_complex_featurization(self):
    """Unit test for ComplexNeighborListFragmentAtomicCoordinates."""
    dir_path = os.path.dirname(os.path.realpath(__file__))
    ligand_file = os.path.join(dir_path, "data/3zso_ligand_hyd.pdb")
    protein_file = os.path.join(dir_path, "data/3zso_protein.pdb")
    # Pulled from PDB files. For larger datasets with more PDBs, would use
    # max num atoms instead of exact.
    frag1_num_atoms = 44 # for ligand atoms 
    frag2_num_atoms = 2336 # for protein atoms
    complex_num_atoms = 2380 # in total
    max_num_neighbors = 4
    # Cutoff in angstroms
    neighbor_cutoff = 4
    complex_featurizer = ComplexNeighborListFragmentAtomicCoordinates(frag1_num_atoms, frag2_num_atoms, complex_num_atoms, max_num_neighbors, neighbor_cutoff)
    (frag1_coords, frag1_neighbor_list, frag1_z, frag2_coords,
    frag2_neighbor_list, frag2_z, complex_coords, complex_neighbor_list,
    complex_z) = complex_featurizer._featurize_complex(ligand_file, protein_file)

    self.assertEqual(frag1_coords.shape, (frag1_num_atoms, 3))
    self.assertEqual(sorted(list(frag1_neighbor_list.keys())), list(range(frag1_num_atoms)))
    print("type(frag1_z)")
    print(type(frag1_z))
    print("frag1_z.shape")
    print(frag1_z.shape)

    self.assertEqual(frag2_coords.shape, (frag2_num_atoms, 3))
    self.assertEqual(sorted(list(frag2_neighbor_list.keys())), list(range(frag2_num_atoms)))

    self.assertEqual(complex_coords.shape, (complex_num_atoms, 3))
    self.assertEqual(sorted(list(complex_neighbor_list.keys())), list(range(complex_num_atoms)))