Commit 6c327cb7 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First nbr-list impl. Tests fail

parent 9ecdd1be
Loading
Loading
Loading
Loading
+55 −2
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ from deepchem.models.tf_new_models.vina_model import get_cells
from deepchem.models.tf_new_models.vina_model import put_atoms_in_cells
from deepchem.models.tf_new_models.vina_model import compute_neighbor_cells
from deepchem.models.tf_new_models.vina_model import compute_closest_neighbors
from deepchem.models.tf_new_models.vina_model import get_cells_for_atoms
from deepchem.models.tf_new_models.vina_model import compute_neighbor_list
import deepchem.utils.rdkit_util as rdkit_util
from deepchem.utils.save import load_sdf_files

@@ -56,6 +58,28 @@ class TestVinaModel(test_util.TensorFlowTestCase):

      # TODO(rbharath): Check that this operation is differentiable.

  def test_compute_neighbor_list(self):
    """Test that neighbor list can be computed with tensorflow"""
    N = 10
    start = 0
    stop = 12 
    nbr_cutoff = 3 
    ndim = 3
    M = 6
    k = 5
    # The number of cells which we should theoretically have
    n_cells = ((stop - start)/nbr_cutoff)**ndim
    ################################################### DEBUG
    print("n_cells")
    print(n_cells)
    ################################################### DEBUG

    with self.test_session() as sess:
      coords = start + np.random.rand(N, ndim)*(stop-start)
      nbr_list = compute_neighbor_list(coords, nbr_cutoff, N, M, ndim, k)
      nbr_list = nbr_list.eval()
      assert nbr_list.shape == (N, M)

  def test_put_atoms_in_cells(self):
    """Test that atoms can be partitioned into spatial cells."""
    N = 10
@@ -116,8 +140,37 @@ class TestVinaModel(test_util.TensorFlowTestCase):
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim)
      coords = np.random.rand(N, ndim)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
      nbrs = compute_closest_neighbors(coords, cells, atoms_in_cells, nbr_cells, N)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, n_cells,
                                          ndim, k)
      nbrs = compute_closest_neighbors(coords, cells, atoms_in_cells,
                                       nbr_cells, N, n_cells)

  def test_get_cells_for_atoms(self):
    """Test that atoms are placed in the correct cells."""
    N = 10
    start = 0
    stop = 4
    nbr_cutoff = 1
    ndim = 3
    k = 5
    # The number of cells which we should theoretically have
    n_cells = ((stop - start)/nbr_cutoff)**ndim

    # TODO(rbharath): The test below only checks that shapes work out.
    # Need to do a correctness implementation vs. a simple CPU impl.

    with self.test_session() as sess:
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      coords = np.random.rand(N, ndim)
      cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
      cells_for_atoms = cells_for_atoms.eval()
      ################################################################## DEBUG
      print("cells_for_atoms")
      print(cells_for_atoms)
      print("cells_for_atoms.shape")
      print(cells_for_atoms.shape)
      ################################################################## DEBUG
      assert cells_for_atoms.shape == (N, 1)

  def test_vina_generate_confs(self):
    """Test that vina model can generate meaningful conformations."""
+82 −51
Original line number Diff line number Diff line
@@ -15,11 +15,13 @@ from deepchem.models import Model
from deepchem.nn import model_ops
import deepchem.utils.rdkit_util as rdkit_util

def compute_neighbor_list(coords, nbr_cutoff, N, M, ndim=3, k=5):
def compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells, ndim=3, k=5):
  """Computes a neighbor list from atom coordinates.

  Parameters
  ----------
  coords: tf.Tensor
    Shape (N, ndim)
  N: int
    Max number atoms
  M: int
@@ -28,22 +30,64 @@ def compute_neighbor_list(coords, nbr_cutoff, N, M, ndim=3, k=5):
    Dimensionality of space.
  k: int
    Number of nearest neighbors to pull down.

  Returns
  -------
  nbr_list: tf.Tensor
    Shape (N, M) of atom indices
  """
  start = tf.reduce_min(coords)
  stop = tf.reduce_max(coords)
  cells = get_cells(start, stop, nbr_cutoff, ndim)
  start = tf.to_int32(tf.reduce_min(coords))
  stop = tf.to_int32(tf.reduce_max(coords))
  cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
  ##################################################### DEBUG
  print("start.eval()")
  print(start.eval())
  print("stop.eval()")
  print(stop.eval())
  print("cells.eval().shape")
  print(cells.eval().shape)
  ##################################################### DEBUG
  # Associate each atom with cell it belongs to. O(N*n_cells)
  atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
  atoms_in_cells = put_atoms_in_cells(coords, cells, N, n_cells, ndim, k)
  cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
  
  # Associate each cell with its neighbor cells. Assumes periodic boundary   
  # conditions, so does wrapround. O(constant)    
  neighbor_cells = compute_neighbor_cells(cells, ndim)    

def get_cells_for_atoms(coords, cells, N, ndim=3):
  ## For each atom, loop through all atoms in its cell and neighboring cells.    
  ## Accept as neighbors only those within threshold. This computation should be   
  ## O(Nm), where m is the number of atoms within a set of neighboring-cells.
  neighbor_list = {}
  for atom in range(N):
    cell = atoms_for_cells[atom]
    neighbor_cells = neighbor_cells[cell]
    # For smaller systems especially, the periodic boundary conditions can
    # result in neighboring cells being seen multiple times. Use a set() to
    # make sure duplicate neighbors are ignored. Convert back to list before
    # returning. 
    #  neighbor_list[atom] = set()
    for neighbor_cell in neighbor_cells:
      atoms_in_cell = atoms_in_cells[neighbor_cell]
      atoms_in_cell = tf.unique(atoms_in_cell)
      for neighbor_atom in atoms_in_cell:
        # TODO(rbharath): How does distance need to be modified here to   
        # account for periodic boundary conditions?   
        dist = tf.reduce_sum((coords[atom] - coords[neighbor_atom])**2, axis=1)
        if dist < neighbor_cutoff:    
          neighbor_list[atom].add((neighbor_atom, dist))    
             
    # Sort neighbors by distance    
    closest_neighbors = sorted(   
        list(neighbor_list[atom]), key=lambda elt: elt[1])    
    closest_neighbors = [nbr for (nbr, dist) in closest_neighbors]    
    # Pick up to max_num_neighbors    
    closest_neighbors = closest_neighbors[:max_num_neighbors]   
    neighbor_list[atom] = closest_neighbors
  return neighbor_list   

def get_cells_for_atoms(coords, cells, N, n_cells, ndim=3):
  """Compute the cells each atom belongs to.

  TODO(rbharath): Move this past a stub implementation.

  Parameters
  ----------
  coords: tf.Tensor
@@ -55,10 +99,34 @@ def get_cells_for_atoms(coords, cells, N, ndim=3):
  cells_for_atoms: tf.Tensor
    Shape (N, 1)
  """ 
  return tf.zeros((N, 1))
  #n_cells = int(cells.get_shape()[0])
  # Tile both cells and coords to form arrays of size (n_cells*N, ndim)
  tiled_cells = tf.tile(cells, (N, 1))
  # N tensors of shape (n_cells, 1)
  tiled_cells = tf.split_v(tiled_cells, N)

  # Shape (N*n_cells, 1) after tile
  tiled_coords = tf.reshape(tf.tile(coords, (1, n_cells)), (n_cells*N, ndim))
  # List of N tensors of shape (n_cells, 1)
  tiled_coords = tf.split_v(tiled_coords, N)


def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N, ndim=3, k=5):
  # Lists of length N 
  coords_rel = [tf.to_float(coords) - tf.to_float(cells)
                for (coords, cells) in zip(tiled_coords, tiled_cells)]
  coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]

  # Lists of length n_cells
  # Get indices of k atoms closest to each cell point
  closest_inds = [tf.nn.top_k(-norm, k=1)[1]
                  for norm in coords_norm]

  # TODO(rbharath): tf.stack for tf 1.0
  return tf.pack(closest_inds)
    

def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N,
                              n_cells, ndim=3, k=5):
  """Computes nearest neighbors from neighboring cells.

  TODO(rbharath): Make this pass test
@@ -78,17 +146,11 @@ def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N,
  ## Tensor of shape (n_cells, 26)
  neighbor_cells = tf.pack(neighbor_cells)

  cells_for_atoms = get_cells_for_atoms(coords, cells, N, ndim)
  cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
  all_closest = []
  for atom in range(N):
    atom_vec = coords[atom]
    cell = cells_for_atoms[atom] 
    ###################################################### DEBUG
    print("cell")
    print(cell)
    print("neighbor_cells")
    print(neighbor_cells)
    ###################################################### DEBUG
    nbr_inds = tf.gather(neighbor_cells, tf.to_int32(cell))
    # Tensor of shape (26, k, ndim)
    nbr_atoms = tf.gather(atoms_in_cells, nbr_inds)
@@ -102,37 +164,6 @@ def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N,
    closest_inds = tf.nn.top_k(nbr_dists, k=k)[1]
    all_closest.append(closest_inds)
  return all_closest
  ## For each atom, loop through all atoms in its cell and neighboring cells.    
  ## Accept as neighbors only those within threshold. This computation should be   
  ## O(Nm), where m is the number of atoms within a set of neighboring-cells.
  #neighbor_list = {}
  #for atom in range(N):
  #  cell = atom_to_cell[atom]
  #  neighbor_cells = neighbor_cell_map[cell]
  #  # For smaller systems especially, the periodic boundary conditions can
  #  # result in neighboring cells being seen multiple times. Use a set() to
  #  # make sure duplicate neighbors are ignored. Convert back to list before
  #  # returning. 
  #  neighbor_list[atom] = set()
  #  for neighbor_cell in neighbor_cells:
  #    atoms_in_cell = cell_to_atoms[neighbor_cell]
  #    for neighbor_atom in atoms_in_cell:
  #      if neighbor_atom == atom:    
  #         continue    
  #      # TODO(rbharath): How does distance need to be modified here to   
  #      # account for periodic boundary conditions?   
  #      dist = np.linalg.norm(coords[atom] - coords[neighbor_atom])   
  #      if dist < neighbor_cutoff:    
  #        neighbor_list[atom].add((neighbor_atom, dist))    
  #           
  #  # Sort neighbors by distance    
  #  closest_neighbors = sorted(   
  #      list(neighbor_list[atom]), key=lambda elt: elt[1])    
  #  closest_neighbors = [nbr for (nbr, dist) in closest_neighbors]    
  #  # Pick up to max_num_neighbors    
  #  closest_neighbors = closest_neighbors[:max_num_neighbors]   
  #  neighbor_list[atom] = closest_neighbors
  #return neighbor_list   

def get_cells(start, stop, nbr_cutoff, ndim=3):
  """Returns the locations of all grid points in box.
@@ -149,7 +180,7 @@ def get_cells(start, stop, nbr_cutoff, ndim=3):
  return tf.reshape(tf.transpose(tf.pack(tf.meshgrid(
      *[tf.range(start, stop, nbr_cutoff) for _ in range(ndim)]))), (-1, ndim))
     
def put_atoms_in_cells(coords, cells, N, ndim, k=5):
def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
  """Place each atom into cells. O(N) runtime.    
  
  Let N be the number of atoms.
@@ -167,7 +198,7 @@ def put_atoms_in_cells(coords, cells, N, ndim, k=5):
  k: int
    Number of nearest neighbors.
  """   
  n_cells = int(cells.get_shape()[0])
  #n_cells = int(cells.get_shape()[0])

  # Tile both cells and coords to form arrays of size (n_cells*N, ndim)
  tiled_cells = tf.reshape(tf.tile(cells, (1, N)), (n_cells*N, ndim))