Commit 80117fff authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Basic phantom atom test passing

parent d38ecf10
Loading
Loading
Loading
Loading
+55 −20
Original line number Diff line number Diff line
@@ -860,7 +860,7 @@ class WeightedLinearCombo(Layer):


class NeighborList(Layer):
  """Computes a neighbor-list on the GPU.
  """Computes a neighbor-list in Tensorflow.

  Neighbor-lists (also called Verlet Lists) are a tool for grouping atoms which
  are close to each other spatially
@@ -905,8 +905,41 @@ class NeighborList(Layer):
    self.out_tensor = nbr_list
    return nbr_list

  #def compute_nbr_list(self, coords):
  #  """Computes a neighbor list from atom coordinates.

  #  Parameters
  #  ----------
  #  coords: tf.Tensor
  #    Shape (N_atoms, ndim)

  #  Returns
  #  -------
  #  nbr_list: tf.Tensor
  #    Shape (N_atoms, M_nbrs) of atom indices
  #  """
  #  N_atoms, M_nbrs, n_cells, ndim = (self.N_atoms, self.M_nbrs, self.n_cells,
  #                                    self.ndim)
  #  nbr_cutoff = self.nbr_cutoff
  #  coords = tf.to_float(coords)

  #  nbrs, closest_nbrs = self.get_closest_nbrs(coords)

  #  # N_atoms elts of size (M_nbrs,) each 
  #  neighbor_list = [
  #      tf.gather(atom_nbrs, closest_nbr_ind)
  #      for (atom_nbrs, closest_nbr_ind) in zip(nbrs, closest_nbrs)
  #  ]

  #  # Shape (N_atoms, M_nbrs)
  #  nbr_list = tf.stack(neighbor_list)

  #  return nbr_list

  def compute_nbr_list(self, coords):
    """Computes a neighbor list from atom coordinates.
    """Get closest neighbors for atoms.

    Needs to handle padding for atoms with no neighbors.

    Parameters
    ----------
@@ -918,45 +951,47 @@ class NeighborList(Layer):
    nbr_list: tf.Tensor
      Shape (N_atoms, M_nbrs) of atom indices
    """
    N_atoms, M_nbrs, n_cells, ndim = (self.N_atoms, self.M_nbrs, self.n_cells,
                                      self.ndim)
    nbr_cutoff = self.nbr_cutoff
    coords = tf.to_float(coords)
    # Shape (n_cells, ndim)
    cells = self.get_cells()

    # List of length N_atoms, each element of different length uniques_i
    nbrs = self.get_atoms_in_nbrs(coords, cells)
    padding = tf.fill((self.M_nbrs,), -1)
    padded_nbrs = [tf.concat([unique_nbrs, padding], 0) for unique_nbrs in nbrs]

    # List of length N_atoms, each element of different length uniques_i
    # List of length N_atoms, each a tensor of shape
    # (uniques_i, ndim)
    nbr_coords = [tf.gather(coords, atom_nbrs) for atom_nbrs in nbrs]

    # Add phantom atoms that exist far outside the box
    coord_padding = tf.to_float(tf.fill((self.M_nbrs, self.ndim), 2*self.stop))
    padded_nbr_coords = [tf.concat([nbr_coord, coord_padding], 0)
                         for nbr_coord in nbr_coords]

    # List of length N_atoms, each of shape (1, ndim)
    atom_coords = tf.split(coords, N_atoms)
    atom_coords = tf.split(coords, self.N_atoms)
    # TODO(rbharath): How does distance need to be modified here to   
    # account for periodic boundary conditions?   
    # Shape (N_atoms, n_nbr_cells, M_nbrs)
    dists = [
        tf.reduce_sum((atom_coord - nbr_coord)**2, axis=1)
        for (atom_coord, nbr_coord) in zip(atom_coords, nbr_coords)
    # List of length N_atoms each of shape (M_nbrs)
    padded_dists = [
        tf.reduce_sum((atom_coord - padded_nbr_coord)**2, axis=1)
        for (atom_coord, padded_nbr_coord) in zip(atom_coords, padded_nbr_coords)
    ]

    # TODO(rbharath): What if uniques_i < M_nbrs? Will crash
    # List of length N_atoms each of size M_nbrs
    closest_nbr_inds = [tf.nn.top_k(-dist, k=M_nbrs)[1] for dist in dists]
    padded_closest_nbrs = [tf.nn.top_k(-padded_dist, k=self.M_nbrs)[1] for
                           padded_dist in padded_dists]

    # N_atoms elts of size (M_nbrs,) each 
    neighbor_list = [
        tf.gather(atom_nbrs, closest_nbr_ind)
        for (atom_nbrs, closest_nbr_ind) in zip(nbrs, closest_nbr_inds)
    padded_neighbor_list = [
        tf.gather(padded_atom_nbrs, padded_closest_nbr)
        for (padded_atom_nbrs, padded_closest_nbr)
        in zip(padded_nbrs, padded_closest_nbrs)
    ]

    # Shape (N_atoms, M_nbrs)
    nbr_list = tf.stack(neighbor_list)
    neighbor_list = tf.stack(padded_neighbor_list)

    return nbr_list
    return neighbor_list

  def get_atoms_in_nbrs(self, coords, cells):
    """Get the atoms in neighboring cells for each cells.
+155 −10
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ class TestDocking(test_util.TensorFlowTestCase):
  Test that tensorgraph docking-style models work. 
  """

  def test_neighbor_list(self):
  def test_neighbor_list_simple(self):
    """Test that neighbor lists can be constructed."""
    N_atoms = 10
    start = 0
@@ -49,7 +49,8 @@ class TestDocking(test_util.TensorFlowTestCase):

    features = Feature(shape=(N_atoms, ndim))
    labels = Label(shape=(N_atoms,))
    nbr_list = NeighborList(N_atoms, M, ndim, nbr_cutoff, in_layers=[features])
    nbr_list = NeighborList(N_atoms, M, ndim, nbr_cutoff, start, stop,
                            in_layers=[features])
    nbr_list = ToFloat(in_layers=[nbr_list])
    # This isn't a meaningful loss, but just for test
    loss = ReduceSum(in_layers=[nbr_list])
@@ -147,7 +148,7 @@ class TestDocking(test_util.TensorFlowTestCase):
      gauss_2_np = gauss_2.eval()
      assert gauss_2_np.shape == (N_atoms, M_nbrs)

  def test_neighbor_list(self):
  def test_neighbor_list_shape(self):
    """Test that NeighborList works."""
    N_atoms = 5
    start = 0
@@ -158,7 +159,7 @@ class TestDocking(test_util.TensorFlowTestCase):

    with self.test_session() as sess:
      coords = start + np.random.rand(N_atoms, ndim) * (stop - start)
      coords = tf.stack(coords)
      coords = tf.to_float(tf.stack(coords))
      nbr_list = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
                              stop)(coords)
      nbr_list = nbr_list.eval()
@@ -174,7 +175,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
@@ -196,7 +197,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    coords = np.reshape(coords, (N_atoms, ndim))
    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
@@ -219,7 +220,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
@@ -243,7 +244,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
@@ -266,7 +267,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
@@ -291,7 +292,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
@@ -301,6 +302,150 @@ class TestDocking(test_util.TensorFlowTestCase):
      nbr_list = np.squeeze(nbr_list.eval())
      np.testing.assert_array_almost_equal(nbr_list, np.array([1, 0, 3, 2]))

  def test_neighbor_list_2D(self):
    """Test neighbor list on 2D example."""
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 2
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array(
      [[1.0, 1.0],
       [2.0, 2.0],
       [8.0, 8.0],
       [9.0, 9.0]])
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
                                    stop)
      nbr_list = nbr_list_layer.compute_nbr_list(coords)
      nbr_list = np.squeeze(nbr_list.eval())
      np.testing.assert_array_almost_equal(nbr_list, np.array([1, 0, 3, 2]))

  def test_neighbor_list_3D(self):
    """Test neighbor list on 3D example."""
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 3
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array(
      [[1.0, 0.0, 1.0],
       [2.0, 2.0, 2.0],
       [8.0, 8.0, 8.0],
       [9.0, 9.0, 9.0]])
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
                                    stop)
      nbr_list = nbr_list_layer.compute_nbr_list(coords)
      nbr_list = np.squeeze(nbr_list.eval())
      np.testing.assert_array_almost_equal(nbr_list, np.array([1, 0, 3, 2]))

  def test_neighbor_list_3D_empty_cells(self):
    """Test neighbor list on 3D example where cells are empty.

    Stresses the failure mode where the neighboring cells are empty
    so top_k will throw a failure.
    """
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 3
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array(
      [[1.0, 0.0, 1.0],
       [2.0, 5.0, 2.0],
       [8.0, 8.0, 8.0],
       [9.0, 9.0, 9.0]])
    coords = np.reshape(coords, (N_atoms, ndim))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
                                    stop)
      nbr_list = nbr_list_layer.compute_nbr_list(coords)
      nbr_list = np.squeeze(nbr_list.eval())
      np.testing.assert_array_almost_equal(nbr_list, np.array([-1, -1, 3, 2]))

  #def test_get_closest_nbrs_3D_empty_cells(self):
  #  """Test get_closest_nbrs in 3D with empty nbrs.
  #  Stresses the failure mode where the neighboring cells are empty
  #  so top_k will throw a failure.
  #  """
  #  N_atoms = 4
  #  start = 0
  #  stop = 10
  #  nbr_cutoff = 1
  #  ndim = 3
  #  M_nbrs = 1
  #  # 1 and 2 are nbrs. 8 and 9 are nbrs
  #  coords = np.array(
  #    [[1.0, 0.0, 1.0],
  #     [2.0, 5.0, 2.0],
  #     [8.0, 8.0, 8.0],
  #     [9.0, 9.0, 9.0]])
  #  coords = np.reshape(coords, (N_atoms, ndim))

  #  with self.test_session() as sess:
  #    coords = tf.convert_to_tensor(coords, dtype=tf.float32)
  #    nbr_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
  #                                  stop)

  #    neighbor_list, padded_neighbor_list, padded_closest_nbrs, padded_dists, dists, padded_nbr_coords, nbr_coords, padded_nbrs, nbrs, closest_nbrs = nbr_layer.get_closest_nbrs(coords)

  #    neighbor_list_eval = neighbor_list.eval()
  #    print("neighbor_list_eval")
  #    print(neighbor_list_eval)

  #    padded_neighbor_list_eval = [padded_nbr.eval() for padded_nbr in padded_neighbor_list]
  #    print("padded_neighbor_list_eval")
  #    print(padded_neighbor_list_eval)

  #    padded_dists_eval = [padded_dist.eval() for padded_dist in padded_dists]
  #    print("padded_dists_eval")
  #    print(padded_dists_eval)

  #    dists_eval = [dist.eval() for dist in dists]
  #    print("dists_eval")
  #    print(dists_eval)
  #
  #    nbr_coords_eval = [nbr_coord.eval() for nbr_coord in nbr_coords]
  #    print("nbr_coords_eval")
  #    print(nbr_coords_eval)

  #    padded_nbr_coords_eval = [padded_nbr_coord.eval() for padded_nbr_coord in padded_nbr_coords]
  #    print("padded_nbr_coords_eval")
  #    print(padded_nbr_coords_eval)

  #    nbrs_eval = [nbr.eval() for nbr in nbrs]
  #    print("nbrs_eval")
  #    print(nbrs_eval)
  #
  #    padded_nbrs_eval = [padded_nbr.eval() for padded_nbr in padded_nbrs]
  #    print("padded_nbrs_eval")
  #    print(padded_nbrs_eval)

  #    padded_closest_nbrs_eval = [padded_closest_nbr.eval() for padded_closest_nbr in padded_closest_nbrs] 
  #    print("padded_closest_nbrs_eval")
  #    print(padded_closest_nbrs_eval)

  #    #closest_nbrs_eval = [closest_nbr.eval() for closest_nbr in closest_nbrs]
  #    #print("closest_nbrs_eval")
  #    #print(closest_nbrs_eval)

  #    assert 0 == 1

  def test_neighbor_list_vina(self):
    """Test under conditions closer to Vina usage."""
    N_atoms = 5