Commit 45724315 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Neighbor list shape tests now pass

parent c2ee2927
Loading
Loading
Loading
Loading
+4 −7
Original line number Diff line number Diff line
@@ -69,15 +69,12 @@ class TestVinaModel(test_util.TensorFlowTestCase):
    k = 5
    # The number of cells which we should theoretically have
    n_cells = int(((stop - start)/nbr_cutoff)**ndim)
    ################################################### DEBUG
    print("n_cells")
    print(n_cells)
    ################################################### DEBUG

    with self.test_session() as sess:
      coords = start + np.random.rand(N, ndim)*(stop-start)
      coords = tf.pack(coords)
      nbr_list = compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells,
                                       ndim=ndim, k=k)
                                       ndim=ndim, k=k, sess=sess)
      nbr_list = nbr_list.eval()
      assert nbr_list.shape == (N, M)

@@ -95,7 +92,7 @@ class TestVinaModel(test_util.TensorFlowTestCase):
    with self.test_session() as sess:
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      coords = np.random.rand(N, ndim)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
      _, atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
      atoms_in_cells = atoms_in_cells.eval()
      assert len(atoms_in_cells) == n_cells
      # Each atom neighbors tensor should be (k, ndim) shaped.
@@ -142,7 +139,7 @@ class TestVinaModel(test_util.TensorFlowTestCase):
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim, n_cells)
      coords = np.random.rand(N, ndim)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, n_cells,
      _, atoms_in_cells = put_atoms_in_cells(coords, cells, N, n_cells,
                                          ndim, k)
      nbrs = compute_closest_neighbors(coords, cells, atoms_in_cells,
                                       nbr_cells, N, n_cells)
+70 −57
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ from deepchem.models import Model
from deepchem.nn import model_ops
import deepchem.utils.rdkit_util as rdkit_util

def compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells, ndim=3, k=5):
def compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells, ndim=3, k=5, sess=None):
  """Computes a neighbor list from atom coordinates.

  Parameters
@@ -39,56 +39,66 @@ def compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells, ndim=3, k=5):
  start = tf.to_int32(tf.reduce_min(coords))
  stop = tf.to_int32(tf.reduce_max(coords))
  cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
  ##################################################### DEBUG
  print("ndim")
  print(ndim)
  print("cells.eval().shape")
  print(cells.eval().shape)
  print("start.eval()")
  print(start.eval())
  print("stop.eval()")
  print(stop.eval())
  ##################################################### DEBUG
  # Associate each atom with cell it belongs to. O(N*n_cells)
  # Shape (n_cells, k, ndim)
  atoms_in_cells = put_atoms_in_cells(coords, cells, N, n_cells, ndim, k)
  # Shape (n_cells, k)
  atoms_in_cells, _ = put_atoms_in_cells(coords, cells, N, n_cells, ndim, k)
  # Shape (N, 1)
  cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
  
  # Associate each cell with its neighbor cells. Assumes periodic boundary   
  # conditions, so does wrapround. O(constant)    
  # Shape (n_cells, 26)
  neighbor_cells = compute_neighbor_cells(cells, ndim, n_cells)    
  ## For each atom, loop through all atoms in its cell and neighboring cells.    
  ## Accept as neighbors only those within threshold. This computation should be   
  ## O(Nm), where m is the number of atoms within a set of neighboring-cells.
  neighbor_list = {}
  for atom in range(N):
    cell = cells_for_atoms[atom]
    neighbor_cells = tf.gather(neighbor_cells, cell)

  # Shape (N, 26)
  neighbor_cells = tf.squeeze(tf.gather(neighbor_cells, cells_for_atoms))

  # coords of shape (N, ndim)
  # Shape (N, 26, k, ndim)
  tiled_coords = tf.tile(tf.reshape(coords, (N, 1, 1, ndim)), (1, 26, k, 1))

  # Shape (N, 26, k)
  nbr_inds = tf.gather(atoms_in_cells, neighbor_cells)

  # Shape (N, 26, k)
  atoms_in_nbr_cells = tf.gather(atoms_in_cells, neighbor_cells)

  # Shape (N, 26, k, ndim)
  nbr_coords = tf.gather(coords, atoms_in_nbr_cells)

  # For smaller systems especially, the periodic boundary conditions can
    # result in neighboring cells being seen multiple times. Use a set() to
    # make sure duplicate neighbors are ignored. Convert back to list before
    # returning. 
    #  neighbor_list[atom] = set()
    def nbr_process(neighbor_cell):
      # Shape (1, k, ndim) (?)
      atoms_in_cell = tf.gather(atoms_in_cells, neighbor_cell)
      atoms_in_cell = tf.unique(atoms_in_cell)
      nbr_coords = tf.gather(coords, atoms_in_cells)
      for neighbor_atom in atoms_in_cell:
  # result in neighboring cells being seen multiple times. Maybe use tf.unique to
  # make sure duplicate neighbors are ignored?

  # TODO(rbharath): How does distance need to be modified here to   
  # account for periodic boundary conditions?   
        dist = tf.reduce_sum((coords[atom] - coords[neighbor_atom])**2, axis=1)
        if dist < neighbor_cutoff:    
          neighbor_list[atom].add((neighbor_atom, dist))    
    atom_nbr_list = tf.map_fn(nbr_process, neighbor_cells)
             
    # Sort neighbors by distance    
    closest_neighbors = sorted(   
        list(neighbor_list[atom]), key=lambda elt: elt[1])    
    closest_neighbors = [nbr for (nbr, dist) in closest_neighbors]    
    # Pick up to max_num_neighbors    
    closest_neighbors = closest_neighbors[:max_num_neighbors]   
    neighbor_list[atom] = closest_neighbors
  # Shape (N, 26, k)
  dists = tf.reduce_sum((tiled_coords - nbr_coords)**2, axis=3)

  # Shape (N, 26*k)
  dists = tf.reshape(dists, [N, -1])

  # TODO(rbharath): This will cause an issue with duplicates!
  # Shape (N, M)
  closest_nbr_locs = tf.nn.top_k(dists, k=M)[1]

  # N elts of size (M,) each
  split_closest_nbr_locs = [tf.squeeze(locs) for locs in tf.split_v(closest_nbr_locs, N)]

  # Shape (N, 26*k)
  nbr_inds = tf.reshape(nbr_inds, [N, -1])

  # N elts of size (26*k,) each
  split_nbr_inds = [tf.squeeze(split) for split in tf.split_v(nbr_inds, N)]

  # N elts of size (M,) each 
  neighbor_list = [tf.gather(nbr_inds, closest_nbr_locs)
                   for (nbr_inds, closest_nbr_locs)
                   in zip(split_nbr_inds, split_closest_nbr_locs)]

  # Shape (N, M)
  neighbor_list = tf.pack(neighbor_list)

  return neighbor_list   

def get_cells_for_atoms(coords, cells, N, n_cells, ndim=3):
@@ -207,15 +217,6 @@ def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
  closest_atoms: tf.Tensor 
    Of shape (n_cells, k, ndim)
  """   
  ################################################# DEBUG
  print("coords")
  print(coords)
  print("cells")
  print(cells)
  print("n_cells, N, ndim")
  print(n_cells, N, ndim)
  ################################################# DEBUG

  # Tile both cells and coords to form arrays of size (n_cells*N, ndim)
  tiled_cells = tf.reshape(tf.tile(cells, (1, N)), (n_cells*N, ndim))
  # TODO(rbharath): Change this for tf 1.0
@@ -237,8 +238,17 @@ def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
  closest_inds = [tf.nn.top_k(norm, k=k)[1] for norm in coords_norm]
  # n_cells tensors of shape (k, ndim)
  closest_atoms = tf.pack([tf.gather(coords, inds) for inds in closest_inds])

  return closest_atoms
  # Tensor of shape (n_cells, k)
  closest_inds = tf.pack(closest_inds)
  ########################################################################## DEBUG
  #print("put_atoms_in_cells")
  #print("closest_inds")
  #print(closest_inds)
  #print("closest_atoms")
  #print(closest_atoms)
  ########################################################################## DEBUG

  return closest_inds, closest_atoms

  # TODO(rbharath):
  #   - Need to find neighbors of the cells (+/- 1 in every dimension).
@@ -256,10 +266,13 @@ def compute_neighbor_cells(cells, ndim, n_cells):
  # looking for 26 neighbors, which isn't right for boundary cells in
  # the cube.
      
  Note n_cells is box_size**ndim. 26 is the number of neighbors of a cube in
  a grid (including diagonals).

  Parameters    
  ----------    
  cells: tf.Tensor
    (box_size**ndim, ndim) shape.
    (n_cells, 26) shape.
  """   
  if ndim != 3:
    raise ValueError("Not defined for dimensions besides 3")