Commit c2ee2927 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

partial fixes, but still broken

parent 6c327cb7
Loading
Loading
Loading
Loading
+7 −5
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ class TestVinaModel(test_util.TensorFlowTestCase):
    M = 6
    k = 5
    # The number of cells which we should theoretically have
    n_cells = ((stop - start)/nbr_cutoff)**ndim
    n_cells = int(((stop - start)/nbr_cutoff)**ndim)
    ################################################### DEBUG
    print("n_cells")
    print(n_cells)
@@ -76,7 +76,8 @@ class TestVinaModel(test_util.TensorFlowTestCase):

    with self.test_session() as sess:
      coords = start + np.random.rand(N, ndim)*(stop-start)
      nbr_list = compute_neighbor_list(coords, nbr_cutoff, N, M, ndim, k)
      nbr_list = compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells,
                                       ndim=ndim, k=k)
      nbr_list = nbr_list.eval()
      assert nbr_list.shape == (N, M)

@@ -95,7 +96,7 @@ class TestVinaModel(test_util.TensorFlowTestCase):
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      coords = np.random.rand(N, ndim)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
      atoms_in_cells = [atoms.eval() for atoms in atoms_in_cells]
      atoms_in_cells = atoms_in_cells.eval()
      assert len(atoms_in_cells) == n_cells
      # Each atom neighbors tensor should be (k, ndim) shaped.
      for atoms in atoms_in_cells:
@@ -116,7 +117,8 @@ class TestVinaModel(test_util.TensorFlowTestCase):

    with self.test_session() as sess:
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim, n_cells)
      nbr_cells = nbr_cells.eval()
      assert len(nbr_cells) == n_cells
      nbr_cells = [nbr_cell.eval() for nbr_cell in nbr_cells]
      for nbr_cell in nbr_cells:
@@ -138,7 +140,7 @@ class TestVinaModel(test_util.TensorFlowTestCase):

    with self.test_session() as sess:
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim, n_cells)
      coords = np.random.rand(N, ndim)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, n_cells,
                                          ndim, k)
+33 −17
Original line number Diff line number Diff line
@@ -40,41 +40,47 @@ def compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells, ndim=3, k=5):
  stop = tf.to_int32(tf.reduce_max(coords))
  cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
  ##################################################### DEBUG
  print("ndim")
  print(ndim)
  print("cells.eval().shape")
  print(cells.eval().shape)
  print("start.eval()")
  print(start.eval())
  print("stop.eval()")
  print(stop.eval())
  print("cells.eval().shape")
  print(cells.eval().shape)
  ##################################################### DEBUG
  # Associate each atom with cell it belongs to. O(N*n_cells)
  # Shape (n_cells, k, ndim)
  atoms_in_cells = put_atoms_in_cells(coords, cells, N, n_cells, ndim, k)
  cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
  
  # Associate each cell with its neighbor cells. Assumes periodic boundary   
  # conditions, so does wrapround. O(constant)    
  neighbor_cells = compute_neighbor_cells(cells, ndim)    
  neighbor_cells = compute_neighbor_cells(cells, ndim, n_cells)    
  ## For each atom, loop through all atoms in its cell and neighboring cells.    
  ## Accept as neighbors only those within threshold. This computation should be   
  ## O(Nm), where m is the number of atoms within a set of neighboring-cells.
  neighbor_list = {}
  for atom in range(N):
    cell = atoms_for_cells[atom]
    neighbor_cells = neighbor_cells[cell]
    cell = cells_for_atoms[atom]
    neighbor_cells = tf.gather(neighbor_cells, cell)
    # For smaller systems especially, the periodic boundary conditions can
    # result in neighboring cells being seen multiple times. Use a set() to
    # make sure duplicate neighbors are ignored. Convert back to list before
    # returning. 
    #  neighbor_list[atom] = set()
    for neighbor_cell in neighbor_cells:
      atoms_in_cell = atoms_in_cells[neighbor_cell]
    def nbr_process(neighbor_cell):
      # Shape (1, k, ndim) (?)
      atoms_in_cell = tf.gather(atoms_in_cells, neighbor_cell)
      atoms_in_cell = tf.unique(atoms_in_cell)
      nbr_coords = tf.gather(coords, atoms_in_cells)
      for neighbor_atom in atoms_in_cell:
        # TODO(rbharath): How does distance need to be modified here to   
        # account for periodic boundary conditions?   
        dist = tf.reduce_sum((coords[atom] - coords[neighbor_atom])**2, axis=1)
        if dist < neighbor_cutoff:    
          neighbor_list[atom].add((neighbor_atom, dist))    
    atom_nbr_list = tf.map_fn(nbr_process, neighbor_cells)
             
    # Sort neighbors by distance    
    closest_neighbors = sorted(   
@@ -135,16 +141,14 @@ def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N,
  ---------
  atoms_in_cells: list
    Of length n_cells. Each entry tensor of shape (k, ndim)
  neighbor_cells: list
    Of length n_cells. Each entry tensor of shape (26,)
  neighbor_cells: tf.Tensor 
    Of shape (n_cells, 26).
  N: int
    Number atoms
  """
  n_cells = len(atoms_in_cells)
  # Tensor of shape (n_cells, k, ndim)
  atoms_in_cells = tf.pack(atoms_in_cells)
  ## Tensor of shape (n_cells, 26)
  neighbor_cells = tf.pack(neighbor_cells)
  #atoms_in_cells = tf.pack(atoms_in_cells)

  cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
  all_closest = []
@@ -197,8 +201,20 @@ def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
    Dimensionality of input space
  k: int
    Number of nearest neighbors.

  Returns
  -------
  closest_atoms: tf.Tensor 
    Of shape (n_cells, k, ndim)
  """   
  #n_cells = int(cells.get_shape()[0])
  ################################################# DEBUG
  print("coords")
  print(coords)
  print("cells")
  print(cells)
  print("n_cells, N, ndim")
  print(n_cells, N, ndim)
  ################################################# DEBUG

  # Tile both cells and coords to form arrays of size (n_cells*N, ndim)
  tiled_cells = tf.reshape(tf.tile(cells, (1, N)), (n_cells*N, ndim))
@@ -220,7 +236,7 @@ def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
  # Get indices of k atoms closest to each cell point
  closest_inds = [tf.nn.top_k(norm, k=k)[1] for norm in coords_norm]
  # n_cells tensors of shape (k, ndim)
  closest_atoms = [tf.gather(coords, inds) for inds in closest_inds]
  closest_atoms = tf.pack([tf.gather(coords, inds) for inds in closest_inds])

  return closest_atoms

@@ -231,7 +247,7 @@ def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
  #   - Return N lists corresponding to neighbors for every atom.
  
        
def compute_neighbor_cells(cells, ndim):
def compute_neighbor_cells(cells, ndim, n_cells):
  """Compute neighbors of cells in grid.    

  # TODO(rbharath): Do we need to handle periodic boundary conditions
@@ -251,7 +267,7 @@ def compute_neighbor_cells(cells, ndim):
  # 3^2 (top-face) + 3^2 (bottom-face) + (3^2-1) (middle-band)
  # TODO(rbharath)
  k = 9 + 9 + 8 # (26 faces on Rubik's cube for example)
  n_cells = int(cells.get_shape()[0])
  #n_cells = int(cells.get_shape()[0])
  # Tile cells to form arrays of size (n_cells*n_cells, ndim)
  # Two tilings (a, b, c, a, b, c, ...) vs. (a, a, a, b, b, b, etc.)
  # Tile (a, a, a, b, b, b, etc.)
@@ -271,7 +287,7 @@ def compute_neighbor_cells(cells, ndim):
  # Lists of length n_cells
  # Get indices of k atoms closest to each cell point
  # n_cells tensors of shape (26,)
  closest_inds = [tf.nn.top_k(norm, k=k)[1] for norm in coords_norm]
  closest_inds = tf.pack([tf.nn.top_k(norm, k=k)[1] for norm in coords_norm])

  return closest_inds