Commit db437763 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Comments

parent 8135642b
Loading
Loading
Loading
Loading
+14 −7
Original line number Diff line number Diff line
@@ -190,7 +190,20 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):
class ComplexNeighborListFragmentAtomicCoordinates(ComplexFeaturizer):
  """This class computes the featurization that corresponds to AtomicConvModel.

  This class computes featurizations needed for AtomicConvModel. Given a two molecular structures, it computes a number of useful geometric features.
  This class computes featurizations needed for AtomicConvModel. Given a
  two molecular structures, it computes a number of useful geometric
  features. In particular, for each molecule and the global complex, it
  computes a coordinates matrix of size (N_atoms, 3) where N_atoms is the
  number of atoms. It also computes a neighbor-list, a dictionary with
  N_atoms elements where neighbor-list[i] is a list of the atoms the i-th
  atom has as neighbors. In addition, it computes a z-matrix for the
  molecule which is an array of shape (N_atoms,) that contains the atomic
  number of that atom.

  Since the featurization computes these three quantities for each of the
  two molecules and the complex, a total of 9 quantities are returned for
  each complex. Note that for efficiency, fragments of the molecules can be
  provided rather than the full molecules themselves.
  """

  def __init__(self,
@@ -232,12 +245,6 @@ class ComplexNeighborListFragmentAtomicCoordinates(ComplexFeaturizer):
           system_coords, system_neighbor_list, system_z

  def get_Z_matrix(self, mol, max_atoms):
    ################################################### DEBUG
    print("len(mol.GetAtoms())")
    print(len(mol.GetAtoms()))
    print("max_atoms")
    print(max_atoms)
    ################################################### DEBUG
    return pad_array(
        np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()]), max_atoms)

+19 −13
Original line number Diff line number Diff line
@@ -171,20 +171,26 @@ class TestAtomicCoordinates(unittest.TestCase):
    max_num_neighbors = 4
    # Cutoff in angstroms
    neighbor_cutoff = 4
    complex_featurizer = ComplexNeighborListFragmentAtomicCoordinates(frag1_num_atoms, frag2_num_atoms, complex_num_atoms, max_num_neighbors, neighbor_cutoff)
    complex_featurizer = ComplexNeighborListFragmentAtomicCoordinates(
        frag1_num_atoms, frag2_num_atoms, complex_num_atoms, max_num_neighbors,
        neighbor_cutoff)
    (frag1_coords, frag1_neighbor_list, frag1_z, frag2_coords,
    frag2_neighbor_list, frag2_z, complex_coords, complex_neighbor_list,
    complex_z) = complex_featurizer._featurize_complex(ligand_file, protein_file)
     frag2_neighbor_list, frag2_z, complex_coords,
     complex_neighbor_list, complex_z) = complex_featurizer._featurize_complex(
         ligand_file, protein_file)

    self.assertEqual(frag1_coords.shape, (frag1_num_atoms, 3))
    self.assertEqual(sorted(list(frag1_neighbor_list.keys())), list(range(frag1_num_atoms)))
    print("type(frag1_z)")
    print(type(frag1_z))
    print("frag1_z.shape")
    print(frag1_z.shape)
    self.assertEqual(
        sorted(list(frag1_neighbor_list.keys())), list(range(frag1_num_atoms)))
    self.assertEqual(frag1_z.shape, (frag1_num_atoms,))

    self.assertEqual(frag2_coords.shape, (frag2_num_atoms, 3))
    self.assertEqual(sorted(list(frag2_neighbor_list.keys())), list(range(frag2_num_atoms)))
    self.assertEqual(
        sorted(list(frag2_neighbor_list.keys())), list(range(frag2_num_atoms)))
    self.assertEqual(frag2_z.shape, (frag2_num_atoms,))

    self.assertEqual(complex_coords.shape, (complex_num_atoms, 3))
    self.assertEqual(sorted(list(complex_neighbor_list.keys())), list(range(complex_num_atoms)))
    self.assertEqual(
        sorted(list(complex_neighbor_list.keys())),
        list(range(complex_num_atoms)))
    self.assertEqual(complex_z.shape, (complex_num_atoms,))