Commit fa101221 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

More debugging

parent a437991e
Loading
Loading
Loading
Loading
+75 −82
Original line number Diff line number Diff line
@@ -798,7 +798,7 @@ class NeighborList(Layer):
  are close to each other spatially
  """

  def __init__(self, N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop,
  def __init__(self, N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop,
               **kwargs):
    """
    Parameters
@@ -810,19 +810,15 @@ class NeighborList(Layer):
    ndim: int
      Dimensionality of space atoms live in. (Typically 3D, but sometimes will
      want to use higher dimensional descriptors for atoms).
    k: int
      Number of nearest neighbors to pull in using tf.nn.top_k.
      TODO(rbharath): Are both k and M_nbrs needed?
    nbr_cutoff: float
      Length in Angstroms (?) at which atom boxes are gridded.
    """
    self.N = N_atoms
    self.M = M_nbrs
    self.N_atoms = N_atoms
    self.M_nbrs = M_nbrs
    self.ndim = ndim
    # Number of grid cells
    n_cells = int(((stop - start) / nbr_cutoff)**ndim)
    self.n_cells = n_cells
    self.k = k
    self.nbr_cutoff = nbr_cutoff
    self.start = start
    self.stop = stop
@@ -838,11 +834,11 @@ class NeighborList(Layer):
      # TODO(rbharath): Support batching
      raise ValueError("Parent tensor must be (num_atoms, ndum)")
    coords = parent.out_tensor
    nbr_list = self._compute_nbr_list(coords)
    nbr_list = self.compute_nbr_list(coords)
    self.out_tensor = nbr_list
    return nbr_list

  def _compute_nbr_list(self, coords):
  def compute_nbr_list(self, coords):
    """Computes a neighbor list from atom coordinates.

    Parameters
@@ -853,86 +849,82 @@ class NeighborList(Layer):
    Returns
    -------
    nbr_list: tf.Tensor
      Shape (N_atoms, M) of atom indices
      Shape (N_atoms, M_nbrs) of atom indices
    """
    N_atoms, M, n_cells, ndim, k = self.N, self.M, self.n_cells, self.ndim, self.k
    N_atoms, M_nbrs, n_cells, ndim = (
        self.N_atoms, self.M_nbrs, self.n_cells, self.ndim)
    nbr_cutoff = self.nbr_cutoff
    # Shape (n_cells, ndim)
    cells = self.get_cells()
    # Find k atoms closest to each cell 
    # Shape (n_cells, k)
    closest_atoms = self.get_closest_atoms(coords, cells)
    # Shape (N_atoms, 1)
    cells_for_atoms = self.get_cells_for_atoms(coords, cells)

    # Associate each cell with its neighbor cells. Assumes periodic boundary   
    # conditions, so does wrapround. O(constant)    
    # Shape (n_cells, n_nbr_cells)
    neighbor_cells = self.get_neighbor_cells(cells)
    # List of length N_atoms, each element of different length uniques_i
    nbrs = self.get_atoms_in_nbrs(coords, cells)

    # Shape (N_atoms, n_nbr_cells)
    neighbor_cells = tf.squeeze(tf.gather(neighbor_cells, cells_for_atoms))
    ############################################### DEBUG
    print("neighbor_cells")
    print(neighbor_cells)
    ############################################### DEBUG
    # List of length N_atoms, each element of different length uniques_i
    # List of length N_atoms, each a tensor of shape
    # (uniques_i, ndim)
    nbr_coords = [tf.gather(coords, atom_nbrs) for atom_nbrs in nbrs]

    # coords of shape (N_atoms, ndim)
    # Shape (N_atoms, n_nbr_cells, k, ndim)
    n_nbr_cells = self._get_num_nbrs()
    tiled_coords = tf.tile(tf.reshape(coords, (N_atoms, 1, 1, ndim)),
                                      (1, n_nbr_cells, k, 1))
    # List of length N_atoms, each of shape (1, ndim)
    atom_coords = tf.split(coords, N_atoms)
    # TODO(rbharath): How does distance need to be modified here to   
    # account for periodic boundary conditions?   
    # Shape (N_atoms, n_nbr_cells, M_nbrs)
    dists = [tf.reduce_sum((atom_coord - nbr_coord)**2, axis=1)
             for (atom_coord, nbr_coord) in zip(atom_coords, nbr_coords)]
  
    # Shape (N_atoms, n_nbr_cells, k)
    nbr_inds = tf.gather(closest_atoms, neighbor_cells)
    # TODO(rbharath): What if uniques_i < M_nbrs? Will crash
    # List of length N_atoms each of size M_nbrs
    closest_nbr_inds = [tf.nn.top_k(-dist, k=M_nbrs)[1] for dist in dists]

    # Shape (N_atoms, n_nbr_cells, k)
    atoms_in_nbr_cells = tf.gather(closest_atoms, neighbor_cells)
    # N_atoms elts of size (M_nbrs,) each 
    neighbor_list = [
        tf.gather(atom_nbrs, closest_nbr_ind)
        for (atom_nbrs, closest_nbr_ind)
        in zip(nbrs, closest_nbr_inds)
    ]

    # Shape (N_atoms, n_nbr_cells, k, ndim)
    nbr_coords = tf.gather(coords, atoms_in_nbr_cells)
    # Shape (N_atoms, M_nbrs)
    nbr_list = tf.stack(neighbor_list)

    # For smaller systems especially, the periodic boundary conditions can
    # result in neighboring cells being seen multiple times. Maybe use tf.unique to
    # make sure duplicate neighbors are ignored?
    return nbrs, nbr_coords, atom_coords, dists, closest_nbr_inds, neighbor_list, nbr_list

    # TODO(rbharath): How does distance need to be modified here to   
    # account for periodic boundary conditions?   
    # Shape (N_atoms, n_nbr_cells, k)
    dists = tf.reduce_sum((tiled_coords - nbr_coords)**2, axis=3)
  def get_atoms_in_nbrs(self, coords, cells):
    """Get the atoms in neighboring cells for each cells.

    # Shape (N_atoms, n_nbr_cells*k)
    dists = tf.reshape(dists, [N_atoms, -1])
    Returns
    -------
    atoms_in_nbrs = (N_atoms, n_nbr_cells, M_nbrs)
    """
    # Shape (N_atoms, 1)
    cells_for_atoms = self.get_cells_for_atoms(coords, cells)

    # TODO(rbharath): This will cause an issue with duplicates!
    # Shape (N_atoms, M)
    closest_nbr_locs = tf.nn.top_k(-dists, k=M)[1]
    # Find M_nbrs atoms closest to each cell 
    # Shape (n_cells, M_nbrs)
    closest_atoms = self.get_closest_atoms(coords, cells)

    # N_atoms elts of size (M,) each
    split_closest_nbr_locs = [
        tf.squeeze(locs) for locs in tf.split(closest_nbr_locs, N_atoms)
    ]
    # Associate each cell with its neighbor cells. Assumes periodic boundary   
    # conditions, so does wrapround. O(constant)    
    # Shape (n_cells, n_nbr_cells)
    neighbor_cells = self.get_neighbor_cells(cells)

    # Shape (N_atoms, n_nbr_cells*k)
    nbr_inds = tf.reshape(nbr_inds, [N_atoms, -1])
    # Shape (N_atoms, n_nbr_cells)
    neighbor_cells = tf.squeeze(tf.gather(neighbor_cells, cells_for_atoms))

    # N_atoms elts of size (n_nbr_cells*k,) each
    split_nbr_inds = [tf.squeeze(split) for split in tf.split(nbr_inds, N_atoms)]
    # Shape (N_atoms, n_nbr_cells, M_nbrs)
    atoms_in_nbrs = tf.gather(closest_atoms, neighbor_cells)

    # N_atoms elts of size (M,) each 
    neighbor_list = [
        tf.gather(nbr_inds, closest_nbr_locs)
        for (nbr_inds, closest_nbr_locs
            ) in zip(split_nbr_inds, split_closest_nbr_locs)
    ]
    # Shape (N_atoms, n_nbr_cells*M_nbrs)
    atoms_in_nbrs = tf.reshape(atoms_in_nbrs, [self.N_atoms, -1])

    # Shape (N_atoms, M)
    neighbor_list = tf.stack(neighbor_list)
    # List of length N_atoms, each element length uniques_i
    nbrs_per_atom = tf.split(atoms_in_nbrs, self.N_atoms)
    uniques = [tf.unique(tf.squeeze(atom_nbrs))[0] for atom_nbrs in nbrs_per_atom]
    
    return neighbor_list
    return uniques

  def get_closest_atoms(self, coords, cells):
    """For each cell, find k closest atoms.
    """For each cell, find M_nbrs closest atoms.
    
    Let N_atoms be the number of atoms.
        
@@ -946,11 +938,13 @@ class NeighborList(Layer):
    Returns
    -------
    closest_inds: tf.Tensor 
      Of shape (n_cells, k)
      Of shape (n_cells, M_nbrs)
    """
    N_atoms, n_cells, ndim, k = self.N, self.n_cells, self.ndim, self.k
    N_atoms, n_cells, ndim, M_nbrs = (
        self.N_atoms, self.n_cells, self.ndim, self.M_nbrs)
    # Tile both cells and coords to form arrays of size (N_atoms*n_cells, ndim)
    tiled_cells = tf.reshape(tf.tile(cells, (1, N_atoms)), (N_atoms * n_cells, ndim))
    tiled_cells = tf.reshape(tf.tile(cells, (1, N_atoms)),
                             (N_atoms * n_cells, ndim))

    # Shape (N_atoms*n_cells, ndim) after tile
    tiled_coords = tf.tile(coords, (n_cells, 1))
@@ -962,8 +956,8 @@ class NeighborList(Layer):

    # Find k atoms closest to this cell. Notice negative sign since
    # tf.nn.top_k returns *largest* not smallest.
    # Tensor of shape (n_cells, k)
    closest_inds = tf.nn.top_k(-coords_norm,k=k)[1]
    # Tensor of shape (n_cells, M_nbrs)
    closest_inds = tf.nn.top_k(-coords_norm,k=M_nbrs)[1]

    return closest_inds

@@ -981,7 +975,7 @@ class NeighborList(Layer):
    cells_for_atoms: tf.Tensor
      Shape (N_atoms, 1)
    """
    N_atoms, n_cells, ndim = self.N, self.n_cells, self.ndim
    N_atoms, n_cells, ndim = self.N_atoms, self.n_cells, self.ndim
    n_cells = int(n_cells)
    # Tile both cells and coords to form arrays of size (N_atoms*n_cells, ndim)
    tiled_cells = tf.tile(cells, (N_atoms, 1))
@@ -999,15 +993,14 @@ class NeighborList(Layer):
    """Get number of neighbors in current dimensionality space."""
    ndim = self.ndim
    if ndim == 1:
      n_nbr_cells = 2
      n_nbr_cells = 3
    elif ndim == 2:
      # 8 neighbors in 2-space
      n_nbr_cells = 8
      # 9 neighbors in 2-space
      n_nbr_cells = 9
    # TODO(rbharath): Shoddy handling of higher dimensions...
    elif ndim >= 3:
      # Number of neighbors of central cube in 3-space is
      # 3^2 (top-face) + 3^2 (bottom-face) + (3^2-1) (middle-band)
      n_nbr_cells = 9 + 9 + 8  # (26 faces on Rubik's cube for example)
      # Number of cells for cube in 3-space is
      n_nbr_cells = 27  # (26 faces on Rubik's cube for example)
    return n_nbr_cells

  def get_neighbor_cells(self, cells):
+53 −23
Original line number Diff line number Diff line
@@ -43,7 +43,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    nbr_cutoff = 3
    ndim = 3
    M = 6
    k = 5
    X = np.random.rand(N_atoms, ndim)
    y = np.random.rand(N_atoms, 1)
    dataset = NumpyDataset(X, y)
@@ -51,7 +50,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    features = Feature(shape=(N_atoms, ndim))
    labels = Label(shape=(N_atoms,))
    nbr_list = NeighborList(
        N_atoms, M, ndim, k, nbr_cutoff, in_layers=[features])
        N_atoms, M, ndim, nbr_cutoff, in_layers=[features])
    nbr_list = ToFloat(in_layers=[nbr_list])
    # This isn't a meaningful loss, but just for test
    loss = ReduceSum(in_layers=[nbr_list])
@@ -157,7 +156,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    nbr_cutoff = 3
    ndim = 3
    M_nbrs = 2
    k = 5

    with self.test_session() as sess:
      coords = start + np.random.rand(N_atoms, ndim) * (stop - start)
@@ -175,7 +173,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
@@ -197,13 +194,12 @@ class TestDocking(test_util.TensorFlowTestCase):
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      closest_atoms = nbr_list_layer.get_closest_atoms(coords, cells)
      atoms_in_cells_eval = closest_atoms.eval() 
@@ -218,14 +214,13 @@ class TestDocking(test_util.TensorFlowTestCase):
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      nbr_cells = nbr_list_layer.get_neighbor_cells(cells)
      nbr_cells_eval = nbr_cells.eval()
@@ -250,14 +245,13 @@ class TestDocking(test_util.TensorFlowTestCase):
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      cells_for_atoms = nbr_list_layer.get_cells_for_atoms(coords, cells)
      cells_for_atoms_eval = cells_for_atoms.eval()
@@ -269,6 +263,34 @@ class TestDocking(test_util.TensorFlowTestCase):
      np.testing.assert_array_almost_equal(cells_for_atoms_eval, true_cells_for_atoms)


  def test_get_atoms_in_nbrs_1D(self):
    """Test get_atoms_in_brs in 1D"""
    N_atoms = 4
    start = 0
    stop = 10
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))

    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      uniques = nbr_list_layer.get_atoms_in_nbrs(coords, cells)
      uniques_eval = [unique.eval() for unique in uniques]
      uniques_eval = np.array(uniques_eval)

      true_uniques = np.array(
        [[0, 1],
         [1, 0],
         [2, 3],
         [3, 2]])
      np.testing.assert_array_almost_equal(uniques_eval, true_uniques)


  def test_neighbor_list_1D(self):
    """Test neighbor list on 1D example."""
    N_atoms = 4
@@ -277,7 +299,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    nbr_cutoff = 1
    ndim = 1
    M_nbrs = 1
    k = 1
    # 1 and 2 are nbrs. 8 and 9 are nbrs
    coords = np.array([1.0, 2.0, 8.0, 9.0])
    coords = np.reshape(coords, (N_atoms, M_nbrs))
@@ -285,18 +306,29 @@ class TestDocking(test_util.TensorFlowTestCase):
    with self.test_session() as sess:
      coords = tf.convert_to_tensor(coords, dtype=tf.float32)
      ###################################################################### DEBUG
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start, stop)
      cells = nbr_list_layer.get_cells()
      closest_atoms = nbr_list_layer.get_closest_atoms(coords, cells)
      atoms_in_cells_eval = closest_atoms.eval() 
      print("atoms_in_cells_eval")
      print(atoms_in_cells_eval)
      nbr_list = nbr_list_layer(coords)
      nbr_list_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start, stop)
      nbrs, nbr_coords, atom_coords, dists, closest_nbr_inds, neighbor_list, nbr_list = nbr_list_layer.compute_nbr_list(coords)
      nbrs_eval = [nbr.eval() for nbr in nbrs]
      print("nbrs_eval")
      print(nbrs_eval)
      nbr_coords_eval = [nbr_coord.eval() for nbr_coord in nbr_coords]
      print("nbr_coords_eval")
      print(nbr_coords_eval)
      atom_coords_eval = [atom_coord.eval() for atom_coord in atom_coords]
      print("atom_coords_eval")
      print(atom_coords_eval)
      dists_eval = [dist.eval() for dist in dists]
      print("dists_eval")
      print(dists_eval)
      closest_nbr_inds_eval = [closest_nbr_ind.eval() for closest_nbr_ind in closest_nbr_inds]
      print("closest_nbr_inds_eval")
      print(closest_nbr_inds_eval)
      neighbor_list_eval = [neighbor.eval() for neighbor in neighbor_list]
      print("neighbor_list_eval")
      print(neighbor_list_eval)
      nbr_list = nbr_list.eval()
      print("nbr_list")
      print(nbr_list)
      print("nbr_list.shape")
      print(nbr_list.shape)
      ###################################################################### DEBUG
      np.testing.assert_array_almost_equal(nbr_list, np.array([1, 0, 3, 2]))

@@ -305,7 +337,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    N_atoms = 5
    M_nbrs = 2
    ndim = 3
    k = 5
    start = 0
    stop = 4
    nbr_cutoff = 1
@@ -315,7 +346,7 @@ class TestDocking(test_util.TensorFlowTestCase):
    coords = Feature(shape=(N_atoms, ndim))

    # Now an (N, M) shape
    nbr_list = NeighborList(N_atoms, M_nbrs, ndim, k, nbr_cutoff, start,
    nbr_list = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
                            stop, in_layers=[coords])

    nbr_list = ToFloat(in_layers=[nbr_list])
@@ -336,7 +367,6 @@ class TestDocking(test_util.TensorFlowTestCase):
    N_atoms = 5
    M_nbrs = 2
    ndim = 3
    k = 5
    start = 0
    stop = 4
    nbr_cutoff = 1