Commit a437991e authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Progress debugging 1D, but still buggy

parent 29a42789
Loading
Loading
Loading
Loading
+66 −91
Original line number Diff line number Diff line
@@ -848,43 +848,48 @@ class NeighborList(Layer):
    Parameters
    ----------
    coords: tf.Tensor
      Shape (N, ndim)
      Shape (N_atoms, ndim)

    Returns
    -------
    nbr_list: tf.Tensor
      Shape (N, M) of atom indices
      Shape (N_atoms, M) of atom indices
    """
    N, M, n_cells, ndim, k = self.N, self.M, self.n_cells, self.ndim, self.k
    N_atoms, M, n_cells, ndim, k = self.N, self.M, self.n_cells, self.ndim, self.k
    nbr_cutoff = self.nbr_cutoff
    # Shape (n_cells, ndim)
    cells = self.get_cells()
    # Find k atoms closest to each cell 
    # Shape (n_cells, k)
    atoms_in_cells = self.get_closest_atoms(coords, cells)
    # Shape (N, 1)
    closest_atoms = self.get_closest_atoms(coords, cells)
    # Shape (N_atoms, 1)
    cells_for_atoms = self.get_cells_for_atoms(coords, cells)

    # Associate each cell with its neighbor cells. Assumes periodic boundary   
    # conditions, so does wrapround. O(constant)    
    # Shape (n_cells, n_nbrs)
    # Shape (n_cells, n_nbr_cells)
    neighbor_cells = self.get_neighbor_cells(cells)

    # Shape (N, n_nbrs)
    # Shape (N_atoms, n_nbr_cells)
    neighbor_cells = tf.squeeze(tf.gather(neighbor_cells, cells_for_atoms))
    ############################################### DEBUG
    print("neighbor_cells")
    print(neighbor_cells)
    ############################################### DEBUG

    # coords of shape (N, ndim)
    # Shape (N, n_nbrs, k, ndim)
    n_nbrs = self._get_num_nbrs()
    tiled_coords = tf.tile(tf.reshape(coords, (N, 1, 1, ndim)), (1, n_nbrs, k, 1))
    # coords of shape (N_atoms, ndim)
    # Shape (N_atoms, n_nbr_cells, k, ndim)
    n_nbr_cells = self._get_num_nbrs()
    tiled_coords = tf.tile(tf.reshape(coords, (N_atoms, 1, 1, ndim)),
                                      (1, n_nbr_cells, k, 1))

    # Shape (N, n_nbrs, k)
    nbr_inds = tf.gather(atoms_in_cells, neighbor_cells)
    # Shape (N_atoms, n_nbr_cells, k)
    nbr_inds = tf.gather(closest_atoms, neighbor_cells)

    # Shape (N, n_nbrs, k)
    atoms_in_nbr_cells = tf.gather(atoms_in_cells, neighbor_cells)
    # Shape (N_atoms, n_nbr_cells, k)
    atoms_in_nbr_cells = tf.gather(closest_atoms, neighbor_cells)

    # Shape (N, n_nbrs, k, ndim)
    # Shape (N_atoms, n_nbr_cells, k, ndim)
    nbr_coords = tf.gather(coords, atoms_in_nbr_cells)

    # For smaller systems especially, the periodic boundary conditions can
@@ -893,35 +898,35 @@ class NeighborList(Layer):

    # TODO(rbharath): How does distance need to be modified here to   
    # account for periodic boundary conditions?   
    # Shape (N, n_nbrs, k)
    # Shape (N_atoms, n_nbr_cells, k)
    dists = tf.reduce_sum((tiled_coords - nbr_coords)**2, axis=3)

    # Shape (N, n_nbrs*k)
    dists = tf.reshape(dists, [N, -1])
    # Shape (N_atoms, n_nbr_cells*k)
    dists = tf.reshape(dists, [N_atoms, -1])

    # TODO(rbharath): This will cause an issue with duplicates!
    # Shape (N, M)
    closest_nbr_locs = tf.nn.top_k(dists, k=M)[1]
    # Shape (N_atoms, M)
    closest_nbr_locs = tf.nn.top_k(-dists, k=M)[1]

    # N elts of size (M,) each
    # N_atoms elts of size (M,) each
    split_closest_nbr_locs = [
        tf.squeeze(locs) for locs in tf.split(closest_nbr_locs, N)
        tf.squeeze(locs) for locs in tf.split(closest_nbr_locs, N_atoms)
    ]

    # Shape (N, n_nbrs*k)
    nbr_inds = tf.reshape(nbr_inds, [N, -1])
    # Shape (N_atoms, n_nbr_cells*k)
    nbr_inds = tf.reshape(nbr_inds, [N_atoms, -1])

    # N elts of size (n_nbrs*k,) each
    split_nbr_inds = [tf.squeeze(split) for split in tf.split(nbr_inds, N)]
    # N_atoms elts of size (n_nbr_cells*k,) each
    split_nbr_inds = [tf.squeeze(split) for split in tf.split(nbr_inds, N_atoms)]

    # N elts of size (M,) each 
    # N_atoms elts of size (M,) each 
    neighbor_list = [
        tf.gather(nbr_inds, closest_nbr_locs)
        for (nbr_inds, closest_nbr_locs
            ) in zip(split_nbr_inds, split_closest_nbr_locs)
    ]

    # Shape (N, M)
    # Shape (N_atoms, M)
    neighbor_list = tf.stack(neighbor_list)

    return neighbor_list
@@ -929,12 +934,12 @@ class NeighborList(Layer):
  def get_closest_atoms(self, coords, cells):
    """For each cell, find k closest atoms.
    
    Let N be the number of atoms.
    Let N_atoms be the number of atoms.
        
    Parameters    
    ----------    
    coords: tf.Tensor 
      (N, ndim) shape.
      (N_atoms, ndim) shape.
    cells: tf.Tensor
      (n_cells, ndim) shape.

@@ -943,17 +948,17 @@ class NeighborList(Layer):
    closest_inds: tf.Tensor 
      Of shape (n_cells, k)
    """
    N, n_cells, ndim, k = self.N, self.n_cells, self.ndim, self.k
    # Tile both cells and coords to form arrays of size (N*n_cells, ndim)
    tiled_cells = tf.reshape(tf.tile(cells, (1, N)), (N * n_cells, ndim))
    N_atoms, n_cells, ndim, k = self.N, self.n_cells, self.ndim, self.k
    # Tile both cells and coords to form arrays of size (N_atoms*n_cells, ndim)
    tiled_cells = tf.reshape(tf.tile(cells, (1, N_atoms)), (N_atoms * n_cells, ndim))

    # Shape (N*n_cells, ndim) after tile
    # Shape (N_atoms*n_cells, ndim) after tile
    tiled_coords = tf.tile(coords, (n_cells, 1))

    # Shape (N*n_cells)
    # Shape (N_atoms*n_cells)
    coords_vec = tf.reduce_sum((tiled_coords - tiled_cells)**2, axis=1)
    # Shape (n_cells, N)
    coords_norm = tf.reshape(coords_vec, (n_cells, N))
    # Shape (n_cells, N_atoms)
    coords_norm = tf.reshape(coords_vec, (n_cells, N_atoms))

    # Find k atoms closest to this cell. Notice negative sign since
    # tf.nn.top_k returns *largest* not smallest.
@@ -968,55 +973,42 @@ class NeighborList(Layer):
    Parameters
    ----------
    coords: tf.Tensor
      Shape (N, ndim)
      Shape (N_atoms, ndim)
    cells: tf.Tensor
      (box_size**ndim, ndim) shape.
      (n_cells, ndim) shape.
    Returns
    -------
    cells_for_atoms: tf.Tensor
      Shape (N, 1)
      Shape (N_atoms, 1)
    """
    N, n_cells, ndim = self.N, self.n_cells, self.ndim
    N_atoms, n_cells, ndim = self.N, self.n_cells, self.ndim
    n_cells = int(n_cells)
    # Tile both cells and coords to form arrays of size (n_cells*N, ndim)
    tiled_cells = tf.tile(cells, (N, 1))
    # N tensors of shape (n_cells, 1)
    tiled_cells = tf.split(tiled_cells, N)
    # Tile both cells and coords to form arrays of size (N_atoms*n_cells, ndim)
    tiled_cells = tf.tile(cells, (N_atoms, 1))

    # Shape (N*n_cells, 1) after tile
    # Shape (N_atoms*n_cells, 1) after tile
    tiled_coords = tf.reshape(
        tf.tile(coords, (1, n_cells)), (n_cells * N, ndim))
    # List of N tensors of shape (n_cells, 1)
    tiled_coords = tf.split(tiled_coords, N)

    # Lists of length N 
    coords_rel = [
        tf.to_float(coords) - tf.to_float(cells)
        for (coords, cells) in zip(tiled_coords, tiled_cells)
    ]
    coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]

    # Lists of length n_cells
    # Get indices of k atoms closest to each cell point
    closest_inds = [tf.nn.top_k(-norm, k=1)[1] for norm in coords_norm]
        tf.tile(coords, (1, n_cells)), (n_cells * N_atoms, ndim))
    coords_vec = tf.reduce_sum((tiled_coords-tiled_cells)**2, axis=1)
    coords_norm = tf.reshape(coords_vec, (N_atoms, n_cells))

    # TODO(rbharath): tf.stack for tf 1.0
    return tf.stack(closest_inds)
    closest_inds = tf.nn.top_k(-coords_norm,k=1)[1]
    return closest_inds

  def _get_num_nbrs(self):
    """Get number of neighbors in current dimensionality space."""
    ndim = self.ndim
    if ndim == 1:
      n_nbrs = 2
      n_nbr_cells = 2
    elif ndim == 2:
      # 8 neighbors in 2-space
      n_nbrs = 8
      n_nbr_cells = 8
    # TODO(rbharath): Shoddy handling of higher dimensions...
    elif ndim >= 3:
      # Number of neighbors of central cube in 3-space is
      # 3^2 (top-face) + 3^2 (bottom-face) + (3^2-1) (middle-band)
      n_nbrs = 9 + 9 + 8  # (26 faces on Rubik's cube for example)
    return n_nbrs
      n_nbr_cells = 9 + 9 + 8  # (26 faces on Rubik's cube for example)
    return n_nbr_cells

  def get_neighbor_cells(self, cells):
    """Compute neighbors of cells in grid.    
@@ -1024,25 +1016,20 @@ class NeighborList(Layer):
    # TODO(rbharath): Do we need to handle periodic boundary conditions
    properly here?
    # TODO(rbharath): This doesn't handle boundaries well. We hard-code
    # looking for n_nbrs neighbors, which isn't right for boundary cells in
    # looking for n_nbr_cells neighbors, which isn't right for boundary cells in
    # the cube.
        
    Note n_cells is box_size**ndim. n_nbrs  is the number of neighbors of a
    cube in a grid (including diagonals).

    Parameters    
    ----------    
    cells: tf.Tensor
      (n_cells, n_nbrs) shape.
      (n_cells, ndim) shape.
    Returns
    -------
    nbr_cells: tf.Tensor
      (n_cells, n_nbrs)
      (n_cells, n_nbr_cells)
    """
    ndim, n_cells = self.ndim, self.n_cells
    n_nbrs = self._get_num_nbrs()
    # TODO(rbharath)
    #n_cells = int(cells.get_shape()[0])
    n_nbr_cells = self._get_num_nbrs()
    # Tile cells to form arrays of size (n_cells*n_cells, ndim)
    # Two tilings (a, b, c, a, b, c, ...) vs. (a, a, a, b, b, b, etc.)
    # Tile (a, a, a, b, b, b, etc.)
@@ -1051,21 +1038,9 @@ class NeighborList(Layer):
    # Tile (a, b, c, a, b, c, ...)
    tiled_cells = tf.tile(cells, (n_cells, 1))

    # Lists of n_cells tensors of shape (N, 1)
    tiled_centers = tf.split(tiled_centers, n_cells)
    tiled_cells = tf.split(tiled_cells, n_cells)

    # Lists of length n_cells
    coords_rel = [
        tf.to_float(cells) - tf.to_float(centers)
        for (cells, centers) in zip(tiled_centers, tiled_cells)
    ]
    coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]

    # Lists of length n_cells
    # Get indices of n_nbrs atoms closest to each cell point
    # n_cells tensors of shape (n_nbrs,)
    closest_inds = tf.stack([tf.nn.top_k(norm, k=n_nbrs)[1] for norm in coords_norm])
    coords_vec = tf.reduce_sum((tiled_centers - tiled_cells)**2, axis=1)
    coords_norm = tf.reshape(coords_vec, (n_cells, n_cells))
    closest_inds = tf.nn.top_k(-coords_norm, k=n_nbr_cells)[1]

    return closest_inds

+47 −8
Original line number Diff line number Diff line
@@ -205,13 +205,13 @@ class TestDocking(test_util.TensorFlowTestCase):
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      atoms_in_cells = nbr_list_layer.get_closest_atoms(coords, cells)
      atoms_in_cells_eval = atoms_in_cells.eval() 
      closest_atoms = nbr_list_layer.get_closest_atoms(coords, cells)
      atoms_in_cells_eval = closest_atoms.eval() 
      true_atoms_in_cells = np.reshape(np.array([0, 0, 1, 1, 1, 1, 2, 2, 2, 3]), (n_cells, k))
      np.testing.assert_array_almost_equal(atoms_in_cells_eval, true_atoms_in_cells)

  def test_compute_neighbor_cells_1D(self):
    """Test that computation of get_neighbor_cells works in 1D"""
  def test_get_neighbor_cells_1D(self):
    """Test that get_neighbor_cells works in 1D"""
    N_atoms = 4
    start = 0
    stop = 10
@@ -227,7 +227,46 @@ class TestDocking(test_util.TensorFlowTestCase):
      coords = tf.convert_to_tensor(coords)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      nbr_cells = nbr_list_layer.get_neighbor_cells(cells)
      nbr_cells_eval = nbr_cells.eval()
      true_nbr_cells = np.array(
          [[0, 1],
           [1, 0],
           [2, 1],
           [3, 2],
           [4, 3],
           [5, 4],
           [6, 5],
           [7, 6],
           [8, 7],
           [9, 8]])
      np.testing.assert_array_almost_equal(nbr_cells_eval, true_nbr_cells)

  def test_get_cells_for_atoms_1D(self):
    """Test that get_cells_for_atoms works in 1D"""
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      cells_for_atoms = nbr_list_layer.get_cells_for_atoms(coords, cells)
      cells_for_atoms_eval = cells_for_atoms.eval()
      true_cells_for_atoms = np.array(
          [[1],
           [2],
           [8],
           [9]])
      np.testing.assert_array_almost_equal(cells_for_atoms_eval, true_cells_for_atoms)


  def test_neighbor_list_1D(self):
@@ -244,12 +283,12 @@ class TestDocking(test_util.TensorFlowTestCase):
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      ###################################################################### DEBUG
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      atoms_in_cells, _ = nbr_list_layer._put_atoms_in_cells(coords, cells)
      atoms_in_cells_eval = atoms_in_cells.eval() 
      closest_atoms = nbr_list_layer.get_closest_atoms(coords, cells)
      atoms_in_cells_eval = closest_atoms.eval() 
      print("atoms_in_cells_eval")
      print(atoms_in_cells_eval)
      nbr_list = nbr_list_layer(coords)