Commit 7656443e authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

1D tests passing

parent fa101221
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -854,6 +854,7 @@ class NeighborList(Layer):
    N_atoms, M_nbrs, n_cells, ndim = (
        self.N_atoms, self.M_nbrs, self.n_cells, self.ndim)
    nbr_cutoff = self.nbr_cutoff
    coords = tf.to_float(coords)
    # Shape (n_cells, ndim)
    cells = self.get_cells()

@@ -887,7 +888,7 @@ class NeighborList(Layer):
    # Shape (N_atoms, M_nbrs)
    nbr_list = tf.stack(neighbor_list)

    return nbrs, nbr_coords, atom_coords, dists, closest_nbr_inds, neighbor_list, nbr_list
    return nbr_list

  def get_atoms_in_nbrs(self, coords, cells):
    """Get the atoms in neighboring cells for each cells.
@@ -921,6 +922,11 @@ class NeighborList(Layer):
    nbrs_per_atom = tf.split(atoms_in_nbrs, self.N_atoms)
    uniques = [tf.unique(tf.squeeze(atom_nbrs))[0] for atom_nbrs in nbrs_per_atom]

    # TODO(rbharath): FRAGILE! Uses fact that identity seems to be the first
    # element removed to remove self from list of neighbors. Need to verify
    # this holds more broadly or come up with robust alternative.
    uniques = [unique[1:] for unique in uniques]
    
    return uniques

  def get_closest_atoms(self, coords, cells):
+23 −43
Original line number Diff line number Diff line
@@ -160,7 +160,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    with self.test_session() as sess:
      coords = start + np.random.rand(N_atoms, ndim) * (stop - start)
      coords = tf.stack(coords)
      nbr_list = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff,
      nbr_list = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff,
                              start, stop)(coords)
      nbr_list = nbr_list.eval()
      assert nbr_list.shape == (N_atoms, M_nbrs)
@@ -179,7 +179,7 @@ class TestDocking(test_util.TensorFlowTestCase):

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      cells_eval = cells.eval()
      true_cells = np.reshape(np.arange(10), (10, 1))
@@ -203,8 +203,10 @@ class TestDocking(test_util.TensorFlowTestCase):
      cells = nbr_list_layer.get_cells()
      closest_atoms = nbr_list_layer.get_closest_atoms(coords, cells)
      atoms_in_cells_eval = closest_atoms.eval() 
      true_atoms_in_cells = np.reshape(np.array([0, 0, 1, 1, 1, 1, 2, 2, 2, 3]), (n_cells, k))
      np.testing.assert_array_almost_equal(atoms_in_cells_eval, true_atoms_in_cells)
      true_atoms_in_cells = np.reshape(
          np.array([0, 0, 1, 1, 1, 1, 2, 2, 2, 3]), (n_cells, M_nbrs))
      np.testing.assert_array_almost_equal(
          atoms_in_cells_eval, true_atoms_in_cells)

  def test_get_neighbor_cells_1D(self):
    """Test that get_neighbor_cells works in 1D"""
@@ -225,16 +227,16 @@ class TestDocking(test_util.TensorFlowTestCase):
      nbr_cells = nbr_list_layer.get_neighbor_cells(cells)
      nbr_cells_eval = nbr_cells.eval()
      true_nbr_cells = np.array(
          [[0, 1],
           [1, 0],
           [2, 1],
           [3, 2],
           [4, 3],
           [5, 4],
           [6, 5],
           [7, 6],
           [8, 7],
           [9, 8]])
          [[0, 1, 2],
           [1, 0, 2],
           [2, 1, 3],
           [3, 2, 4],
           [4, 3, 5],
           [5, 4, 6],
           [6, 5, 7],
           [7, 6, 8],
           [8, 7, 9],
           [9, 8, 7]])
      np.testing.assert_array_almost_equal(nbr_cells_eval, true_nbr_cells)

  def test_get_cells_for_atoms_1D(self):
@@ -280,14 +282,15 @@ class TestDocking(test_util.TensorFlowTestCase):
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      uniques = nbr_list_layer.get_atoms_in_nbrs(coords, cells)

      uniques_eval = [unique.eval() for unique in uniques]
      uniques_eval = np.array(uniques_eval)

      true_uniques = np.array(
        [[0, 1],
         [1, 0],
         [2, 3],
         [3, 2]])
        [[1],
         [0],
         [3],
         [2]])
      np.testing.assert_array_almost_equal(uniques_eval, true_uniques)


@@ -305,31 +308,9 @@ class TestDocking(test_util.TensorFlowTestCase):

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      ###################################################################### DEBUG
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      nbrs, nbr_coords, atom_coords, dists, closest_nbr_inds, neighbor_list, nbr_list = nbr_list_layer.compute_nbr_list(coords)
      nbrs_eval = [nbr.eval() for nbr in nbrs]
      print("nbrs_eval")
      print(nbrs_eval)
      nbr_coords_eval = [nbr_coord.eval() for nbr_coord in nbr_coords]
      print("nbr_coords_eval")
      print(nbr_coords_eval)
      atom_coords_eval = [atom_coord.eval() for atom_coord in atom_coords]
      print("atom_coords_eval")
      print(atom_coords_eval)
      dists_eval = [dist.eval() for dist in dists]
      print("dists_eval")
      print(dists_eval)
      closest_nbr_inds_eval = [closest_nbr_ind.eval() for closest_nbr_ind in closest_nbr_inds]
      print("closest_nbr_inds_eval")
      print(closest_nbr_inds_eval)
      neighbor_list_eval = [neighbor.eval() for neighbor in neighbor_list]
      print("neighbor_list_eval")
      print(neighbor_list_eval)
      nbr_list = nbr_list.eval()
      print("nbr_list")
      print(nbr_list)
      ###################################################################### DEBUG
      nbr_list = nbr_list_layer.compute_nbr_list(coords)
      nbr_list = np.squeeze(nbr_list.eval())
      np.testing.assert_array_almost_equal(nbr_list, np.array([1, 0, 3, 2]))

  def test_neighbor_list_vina(self):
@@ -396,7 +377,6 @@ class TestDocking(test_util.TensorFlowTestCase):
        N_protein + N_ligand,
        M_nbrs,
        ndim,
        k,
        nbr_cutoff,
        start,
        stop,