Commit 29a42789 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

more tests

parent bb03256f
Loading
Loading
Loading
Loading
+69 −69
Original line number Diff line number Diff line
@@ -798,7 +798,7 @@ class NeighborList(Layer):
  are close to each other spatially
  """

  def __init__(self, N_atoms, M_nbrs, ndim, n_cells, k, nbr_cutoff, start, stop,
  def __init__(self, N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop,
               **kwargs):
    """
    Parameters
@@ -810,8 +810,6 @@ class NeighborList(Layer):
    ndim: int
      Dimensionality of space atoms live in. (Typically 3D, but sometimes will
      want to use higher dimensional descriptors for atoms).
    n_cells: int
      Number of grid cells in the simulation box.
    k: int
      Number of nearest neighbors to pull in using tf.nn.top_k.
      TODO(rbharath): Are both k and M_nbrs needed?
@@ -821,6 +819,8 @@ class NeighborList(Layer):
    self.N = N_atoms
    self.M = M_nbrs
    self.ndim = ndim
    # Number of grid cells
    n_cells = int(((stop - start) / nbr_cutoff)**ndim)
    self.n_cells = n_cells
    self.k = k
    self.nbr_cutoff = nbr_cutoff
@@ -857,32 +857,34 @@ class NeighborList(Layer):
    """
    N, M, n_cells, ndim, k = self.N, self.M, self.n_cells, self.ndim, self.k
    nbr_cutoff = self.nbr_cutoff
    cells = self._get_cells()
    # Associate each atom with cell it belongs to. O(N*n_cells)
    # Shape (n_cells, ndim)
    cells = self.get_cells()
    # Find k atoms closest to each cell 
    # Shape (n_cells, k)
    atoms_in_cells, _ = self._put_atoms_in_cells(coords, cells)
    atoms_in_cells = self.get_closest_atoms(coords, cells)
    # Shape (N, 1)
    cells_for_atoms = self._get_cells_for_atoms(coords, cells)
    cells_for_atoms = self.get_cells_for_atoms(coords, cells)

    # Associate each cell with its neighbor cells. Assumes periodic boundary   
    # conditions, so does wrapround. O(constant)    
    # Shape (n_cells, 26)
    neighbor_cells = self._compute_neighbor_cells(cells)
    # Shape (n_cells, n_nbrs)
    neighbor_cells = self.get_neighbor_cells(cells)

    # Shape (N, 26)
    # Shape (N, n_nbrs)
    neighbor_cells = tf.squeeze(tf.gather(neighbor_cells, cells_for_atoms))

    # coords of shape (N, ndim)
    # Shape (N, 26, k, ndim)
    tiled_coords = tf.tile(tf.reshape(coords, (N, 1, 1, ndim)), (1, 26, k, 1))
    # Shape (N, n_nbrs, k, ndim)
    n_nbrs = self._get_num_nbrs()
    tiled_coords = tf.tile(tf.reshape(coords, (N, 1, 1, ndim)), (1, n_nbrs, k, 1))

    # Shape (N, 26, k)
    # Shape (N, n_nbrs, k)
    nbr_inds = tf.gather(atoms_in_cells, neighbor_cells)

    # Shape (N, 26, k)
    # Shape (N, n_nbrs, k)
    atoms_in_nbr_cells = tf.gather(atoms_in_cells, neighbor_cells)

    # Shape (N, 26, k, ndim)
    # Shape (N, n_nbrs, k, ndim)
    nbr_coords = tf.gather(coords, atoms_in_nbr_cells)

    # For smaller systems especially, the periodic boundary conditions can
@@ -891,10 +893,10 @@ class NeighborList(Layer):

    # TODO(rbharath): How does distance need to be modified here to   
    # account for periodic boundary conditions?   
    # Shape (N, 26, k)
    # Shape (N, n_nbrs, k)
    dists = tf.reduce_sum((tiled_coords - nbr_coords)**2, axis=3)

    # Shape (N, 26*k)
    # Shape (N, n_nbrs*k)
    dists = tf.reshape(dists, [N, -1])

    # TODO(rbharath): This will cause an issue with duplicates!
@@ -906,10 +908,10 @@ class NeighborList(Layer):
        tf.squeeze(locs) for locs in tf.split(closest_nbr_locs, N)
    ]

    # Shape (N, 26*k)
    # Shape (N, n_nbrs*k)
    nbr_inds = tf.reshape(nbr_inds, [N, -1])

    # N elts of size (26*k,) each
    # N elts of size (n_nbrs*k,) each
    split_nbr_inds = [tf.squeeze(split) for split in tf.split(nbr_inds, N)]

    # N elts of size (M,) each 
@@ -924,59 +926,43 @@ class NeighborList(Layer):

    return neighbor_list

  def _put_atoms_in_cells(self, coords, cells):
    """Place each atom into cells. O(N) runtime.    
  def get_closest_atoms(self, coords, cells):
    """For each cell, find k closest atoms.
    
    Let N be the number of atoms.
        
    Parameters    
    ----------    
    coords: tf.Tensor 
      (N, 3) shape.
      (N, ndim) shape.
    cells: tf.Tensor
      (n_cells, ndim) shape.
    N: int
      Number atoms
    ndim: int
      Dimensionality of input space
    k: int
      Number of nearest neighbors.

    Returns
    -------
    closest_atoms: tf.Tensor 
      Of shape (n_cells, k, ndim)
    closest_inds: tf.Tensor 
      Of shape (n_cells, k)
    """
    N, n_cells, ndim, k = self.N, self.n_cells, self.ndim, self.k
    # Tile both cells and coords to form arrays of size (n_cells*N, ndim)
    tiled_cells = tf.reshape(tf.tile(cells, (1, N)), (n_cells * N, ndim))
    # TODO(rbharath): Change this for tf 1.0
    # n_cells tensors of shape (N, 1)
    tiled_cells = tf.split(tiled_cells, n_cells)
    # Tile both cells and coords to form arrays of size (N*n_cells, ndim)
    tiled_cells = tf.reshape(tf.tile(cells, (1, N)), (N * n_cells, ndim))

    # Shape (N*n_cells, 1) after tile
    # Shape (N*n_cells, ndim) after tile
    tiled_coords = tf.tile(coords, (n_cells, 1))
    # List of n_cells tensors of shape (N, 1)
    tiled_coords = tf.split(tiled_coords, n_cells)

    # Lists of length n_cells
    coords_rel = [
        tf.to_float(coords) - tf.to_float(cells)
        for (coords, cells) in zip(tiled_coords, tiled_cells)
    ]
    coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]
    # Shape (N*n_cells)
    coords_vec = tf.reduce_sum((tiled_coords - tiled_cells)**2, axis=1)
    # Shape (n_cells, N)
    coords_norm = tf.reshape(coords_vec, (n_cells, N))

    # Lists of length n_cells
    # Get indices of k atoms closest to each cell point
    closest_inds = [tf.nn.top_k(norm, k=k)[1] for norm in coords_norm]
    # n_cells tensors of shape (k, ndim)
    closest_atoms = tf.stack([tf.gather(coords, inds) for inds in closest_inds])
    # Find k atoms closest to this cell. Notice negative sign since
    # tf.nn.top_k returns *largest* not smallest.
    # Tensor of shape (n_cells, k)
    closest_inds = tf.stack(closest_inds)
    closest_inds = tf.nn.top_k(-coords_norm,k=k)[1]

    return closest_inds, closest_atoms
    return closest_inds

  def _get_cells_for_atoms(self, coords, cells):
  def get_cells_for_atoms(self, coords, cells):
    """Compute the cells each atom belongs to.

    Parameters
@@ -1017,31 +1003,45 @@ class NeighborList(Layer):
    # TODO(rbharath): tf.stack for tf 1.0
    return tf.stack(closest_inds)

  def _compute_neighbor_cells(self, cells):
  def _get_num_nbrs(self):
    """Get number of neighbors in current dimensionality space."""
    ndim = self.ndim
    if ndim == 1:
      n_nbrs = 2
    elif ndim == 2:
      # 8 neighbors in 2-space
      n_nbrs = 8
    # TODO(rbharath): Shoddy handling of higher dimensions...
    elif ndim >= 3:
      # Number of neighbors of central cube in 3-space is
      # 3^2 (top-face) + 3^2 (bottom-face) + (3^2-1) (middle-band)
      n_nbrs = 9 + 9 + 8  # (26 faces on Rubik's cube for example)
    return n_nbrs

  def get_neighbor_cells(self, cells):
    """Compute neighbors of cells in grid.    

    # TODO(rbharath): Do we need to handle periodic boundary conditions
    properly here?
    # TODO(rbharath): This doesn't handle boundaries well. We hard-code
    # looking for 26 neighbors, which isn't right for boundary cells in
    # looking for n_nbrs neighbors, which isn't right for boundary cells in
    # the cube.
        
    Note n_cells is box_size**ndim. 26 is the number of neighbors of a cube in
    a grid (including diagonals).
    Note n_cells is box_size**ndim. n_nbrs  is the number of neighbors of a
    cube in a grid (including diagonals).

    Parameters    
    ----------    
    cells: tf.Tensor
      (n_cells, 26) shape.
      (n_cells, n_nbrs) shape.
    Returns
    -------
    nbr_cells: tf.Tensor
      (n_cells, n_nbrs)
    """
    ndim, n_cells = self.ndim, self.n_cells
    n_cells = int(n_cells)
    if ndim != 3:
      raise ValueError("Not defined for dimensions besides 3")
    # Number of neighbors of central cube in 3-space is
    # 3^2 (top-face) + 3^2 (bottom-face) + (3^2-1) (middle-band)
    n_nbrs = self._get_num_nbrs()
    # TODO(rbharath)
    k = 9 + 9 + 8  # (26 faces on Rubik's cube for example)
    #n_cells = int(cells.get_shape()[0])
    # Tile cells to form arrays of size (n_cells*n_cells, ndim)
    # Two tilings (a, b, c, a, b, c, ...) vs. (a, a, a, b, b, b, etc.)
@@ -1063,13 +1063,13 @@ class NeighborList(Layer):
    coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]

    # Lists of length n_cells
    # Get indices of k atoms closest to each cell point
    # n_cells tensors of shape (26,)
    closest_inds = tf.stack([tf.nn.top_k(norm, k=k)[1] for norm in coords_norm])
    # Get indices of n_nbrs atoms closest to each cell point
    # n_cells tensors of shape (n_nbrs,)
    closest_inds = tf.stack([tf.nn.top_k(norm, k=n_nbrs)[1] for norm in coords_norm])

    return closest_inds

  def _get_cells(self):
  def get_cells(self):
    """Returns the locations of all grid points in box.

    Suppose start is -10 Angstrom, stop is 10 Angstrom, nbr_cutoff is 1.
@@ -1083,6 +1083,6 @@ class NeighborList(Layer):
    """
    start, stop, nbr_cutoff = self.start, self.stop, self.nbr_cutoff
    mesh_args = [tf.range(start, stop, nbr_cutoff) for _ in range(self.ndim)]
    return tf.reshape(
    return tf.to_float(tf.reshape(
        tf.transpose(tf.stack(tf.meshgrid(*mesh_args))),
        (self.n_cells, self.ndim))
        (self.n_cells, self.ndim)))
+98 −22
Original line number Diff line number Diff line
@@ -44,9 +44,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    ndim = 3
    M = 6
    k = 5
    # The number of cells which we should theoretically have
    n_cells = int(((stop - start) / nbr_cutoff)**ndim)

    X = np.random.rand(N_atoms, ndim)
    y = np.random.rand(N_atoms, 1)
    dataset = NumpyDataset(X, y)
@@ -54,7 +51,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    features = Feature(shape=(N_atoms, ndim))
    labels = Label(shape=(N_atoms,))
    nbr_list = NeighborList(
        N_atoms, M, ndim, n_cells, k, nbr_cutoff, in_layers=[features])
        N_atoms, M, ndim, k, nbr_cutoff, in_layers=[features])
    nbr_list = ToFloat(in_layers=[nbr_list])
    # This isn't a meaningful loss, but just for test
    loss = ReduceSum(in_layers=[nbr_list])
@@ -161,17 +158,109 @@ class TestDocking(test_util.TensorFlowTestCase):
    ndim = 3
    M_nbrs = 2
    k = 5
    # The number of cells which we should theoretically have
    n_cells = int(((stop - start) / nbr_cutoff)**ndim)

    with self.test_session() as sess:
      coords = start + np.random.rand(N_atoms, ndim) * (stop - start)
      coords = tf.stack(coords)
      nbr_list = NeighborList(N_atoms, M_nbrs, ndim, n_cells, k, nbr_cutoff,
      nbr_list = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff,
                              start, stop)(coords)
      nbr_list = nbr_list.eval()
      assert nbr_list.shape == (N_atoms, M_nbrs)

  def test_get_cells_1D(self):
    """Test neighbor-list method get_cells() in 1D"""
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      cells_eval = cells.eval()
      true_cells = np.reshape(np.arange(10), (10, 1))
      np.testing.assert_array_almost_equal(cells_eval, true_cells) 

  def test_get_closest_atoms_1D(self):
    """Test get_closest_atoms works correctly in 1D"""
    N_atoms = 4
    start = 0
    stop = 10
    n_cells = 10
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      atoms_in_cells = nbr_list_layer.get_closest_atoms(coords, cells)
      atoms_in_cells_eval = atoms_in_cells.eval() 
      true_atoms_in_cells = np.reshape(np.array([0, 0, 1, 1, 1, 1, 2, 2, 2, 3]), (n_cells, k))
      np.testing.assert_array_almost_equal(atoms_in_cells_eval, true_atoms_in_cells)

  def test_compute_neighbor_cells_1D(self):
    """Test that computation of get_neighbor_cells works in 1D"""
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      


  def test_neighbor_list_1D(self):
    """Test neighbor list on 1D example."""
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
      ###################################################################### DEBUG
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      atoms_in_cells, _ = nbr_list_layer._put_atoms_in_cells(coords, cells)
      atoms_in_cells_eval = atoms_in_cells.eval() 
      print("atoms_in_cells_eval")
      print(atoms_in_cells_eval)
      nbr_list = nbr_list_layer(coords)
      nbr_list = nbr_list.eval()
      print("nbr_list")
      print(nbr_list)
      print("nbr_list.shape")
      print(nbr_list.shape)
      ###################################################################### DEBUG
      np.testing.assert_array_almost_equal(nbr_list, np.array([1, 0, 3, 2]))

  def test_neighbor_list_vina(self):
    """Test under conditions closer to Vina usage."""
    N_atoms = 5
@@ -181,24 +270,14 @@ class TestDocking(test_util.TensorFlowTestCase):
    start = 0
    stop = 4
    nbr_cutoff = 1
    # The number of cells which we should theoretically have
    n_cells = ((stop - start) / nbr_cutoff)**ndim

    X = NumpyDataset(start + np.random.rand(N_atoms, ndim) * (stop - start))

    coords = Feature(shape=(N_atoms, ndim))

    # Now an (N, M) shape
    nbr_list = NeighborList(
        N_atoms,
        M_nbrs,
        ndim,
        n_cells,
        k,
        nbr_cutoff,
        start,
        stop,
        in_layers=[coords])
    nbr_list = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start,
                            stop, in_layers=[coords])

    nbr_list = ToFloat(in_layers=[nbr_list])
    flattened = Flatten(in_layers=[nbr_list])
@@ -222,8 +301,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    start = 0
    stop = 4
    nbr_cutoff = 1
    # The number of cells which we should theoretically have
    n_cells = ((stop - start) / nbr_cutoff)**ndim

    X_prot = NumpyDataset(start + np.random.rand(N_protein, ndim) * (stop -
                                                                     start))
@@ -250,7 +327,6 @@ class TestDocking(test_util.TensorFlowTestCase):
        N_protein + N_ligand,
        M_nbrs,
        ndim,
        n_cells,
        k,
        nbr_cutoff,
        start,