Commit fa10ed11 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Continued neighbors work

parent 29248837
Loading
Loading
Loading
Loading
+134 −0
Original line number Diff line number Diff line
"""
Testing construction of Vina models. 
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import os
import unittest
import tensorflow as tf
import deepchem as dc
import numpy as np
from tensorflow.python.framework import test_util
from deepchem.models.tf_new_models.vina_model import VinaModel 
from deepchem.models.tf_new_models.vina_model import get_cells
from deepchem.models.tf_new_models.vina_model import put_atoms_in_cells
from deepchem.models.tf_new_models.vina_model import compute_neighbor_cells
from deepchem.models.tf_new_models.vina_model import compute_closest_neighbors
import deepchem.utils.rdkit_util as rdkit_util
from deepchem.utils.save import load_sdf_files


class TestVinaModel(test_util.TensorFlowTestCase):
  """
  Test Container usage.
  """

  def setUp(self):
    super(TestVinaModel, self).setUp()
    self.root = '/tmp'

  def test_vina_model(self):
    """Simple test that a vina model can be initialized."""
    vina_model = VinaModel()

  def test_get_cells(self):
    """Test that tensorflow can compute grid cells."""
    N = 10
    start = 0
    stop = 4
    nbr_cutoff = 1
    with self.test_session() as sess:
      ndim = 3
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim).eval()
      assert len(cells.shape) == 2
      assert cells.shape[0] == 4**ndim

      ndim = 2
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim).eval()
      assert len(cells.shape) == 2
      assert cells.shape[0] == 4**ndim

      # TODO(rbharath): Check that this operation is differentiable.

  def test_put_atoms_in_cells(self):
    """Test that atoms can be partitioned into spatial cells."""
    N = 10
    start = 0
    stop = 4
    nbr_cutoff = 1
    ndim = 3
    k = 5
    # The number of cells which we should theoretically have
    n_cells = ((stop - start)/nbr_cutoff)**ndim

    with self.test_session() as sess:
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      coords = np.random.rand(N, ndim)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
      atoms_in_cells = [atoms.eval() for atoms in atoms_in_cells]
      assert len(atoms_in_cells) == n_cells
      # Each atom neighbors tensor should be (k, ndim) shaped.
      for atoms in atoms_in_cells:
        assert atoms.shape == (k, ndim)
    
  def test_compute_neighbor_cells(self):
    """Test that indices of neighboring cells can be computed."""
    N = 10
    start = 0
    stop = 4
    nbr_cutoff = 1
    ndim = 3
    # The number of cells which we should theoretically have
    n_cells = ((stop - start)/nbr_cutoff)**ndim

    # TODO(rbharath): The test below only checks that shapes work out.
    # Need to do a correctness implementation vs. a simple CPU impl.

    with self.test_session() as sess:
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim)
      assert len(nbr_cells) == n_cells
      nbr_cells = [nbr_cell.eval() for nbr_cell in nbr_cells]
      for nbr_cell in nbr_cells:
        assert nbr_cell.shape == (26,)

  def test_compute_closest_neighbors(self):
    """Test that closest neighbors can be computed properly"""
    N = 10
    start = 0
    stop = 4
    nbr_cutoff = 1
    ndim = 3
    k = 5
    # The number of cells which we should theoretically have
    n_cells = ((stop - start)/nbr_cutoff)**ndim

    # TODO(rbharath): The test below only checks that shapes work out.
    # Need to do a correctness implementation vs. a simple CPU impl.

    with self.test_session() as sess:
      cells = get_cells(start, stop, nbr_cutoff, ndim=ndim)
      nbr_cells = compute_neighbor_cells(cells, ndim)
      coords = np.random.rand(N, ndim)
      atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
      nbrs = compute_closest_neighbors(coords, cells, atoms_in_cells, nbr_cells, N)

  def test_vina_generate_confs(self):
    """Test that vina model can generate meaningful conformations."""
    data_dir = os.path.dirname(os.path.realpath(__file__))
    protein_file = os.path.join(data_dir, "1jld_protein.pdb")
    ligand_file = os.path.join(data_dir, "1jld_ligand.pdb")

    print("Loading protein file")
    protein_mol = rdkit_util.load_molecule(protein_file)
    print("Loading ligand file")
    ligand_mol = rdkit_util.load_molecule(ligand_file)

    vina_model = VinaModel()
+103 −42
Original line number Diff line number Diff line
@@ -32,47 +32,101 @@ def compute_neighbor_list(coords, nbr_cutoff, N, M, ndim=3, k=5):
  start = tf.reduce_min(coords)
  stop = tf.reduce_max(coords)
  cells = get_cells(start, stop, nbr_cutoff, ndim)
  # Associate each atom with cell it belongs to. O(N*n_cells)
  atoms_in_cells = put_atoms_in_cells(coords, cells, N, ndim, k)
  
  # Associate each atom with cell it belongs to. O(N)
  cell_to_atoms, atom_to_cell = put_atoms_in_cells(
      coords, x_bins, y_bins, z_bins)
  # Associate each cell with its neighbor cells. Assumes periodic boundary   
  # conditions, so does wrapround. O(constant)    
  N_x, N_y, N_z = len(x_bins), len(y_bins), len(z_bins)   
  neighbor_cell_map = compute_neighbor_cell_map(N_x, N_y, N_z)    
  neighbor_cells = compute_neighbor_cells(cells, ndim)    

  # For each atom, loop through all atoms in its cell and neighboring cells.    
  # Accept as neighbors only those within threshold. This computation should be   
  # O(Nm), where m is the number of atoms within a set of neighboring-cells.
  neighbor_list = {}
def get_cells_for_atoms(coords, cells, N, ndim=3):
  """Compute the cells each atom belongs to.

  TODO(rbharath): Move this past a stub implementation.

  Parameters
  ----------
  coords: tf.Tensor
    Shape (N, ndim)
  cells: tf.Tensor
    (box_size**ndim, ndim) shape.
  Returns
  -------
  cells_for_atoms: tf.Tensor
    Shape (N, 1)
  """ 
  return tf.zeros((N, 1))
    

def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N, ndim=3, k=5):
  """Computes nearest neighbors from neighboring cells.

  TODO(rbharath): Make this pass test

  Parameters
  ---------
  atoms_in_cells: list
    Of length n_cells. Each entry tensor of shape (k, ndim)
  neighbor_cells: list
    Of length n_cells. Each entry tensor of shape (26,)
  N: int
    Number atoms
  """
  n_cells = len(atoms_in_cells)
  # Tensor of shape (n_cells, k, ndim)
  atoms_in_cells = tf.pack(atoms_in_cells)
  ## Tensor of shape (n_cells, 26)
  neighbor_cells = tf.pack(neighbor_cells)

  cells_for_atoms = get_cells_for_atoms(coords, cells, N, ndim)
  all_closest = []
  for atom in range(N):
    cell = atom_to_cell[atom]
    neighbor_cells = neighbor_cell_map[cell]
    # For smaller systems especially, the periodic boundary conditions can
    # result in neighboring cells being seen multiple times. Use a set() to
    # make sure duplicate neighbors are ignored. Convert back to list before
    # returning. 
    neighbor_list[atom] = set()
    for neighbor_cell in neighbor_cells:
      atoms_in_cell = cell_to_atoms[neighbor_cell]
      for neighbor_atom in atoms_in_cell:
        if neighbor_atom == atom:    
           continue    
        # TODO(rbharath): How does distance need to be modified here to   
        # account for periodic boundary conditions?   
        dist = np.linalg.norm(coords[atom] - coords[neighbor_atom])   
        if dist < neighbor_cutoff:    
          neighbor_list[atom].add((neighbor_atom, dist))    
             
    # Sort neighbors by distance    
    closest_neighbors = sorted(   
        list(neighbor_list[atom]), key=lambda elt: elt[1])    
    closest_neighbors = [nbr for (nbr, dist) in closest_neighbors]    
    # Pick up to max_num_neighbors    
    closest_neighbors = closest_neighbors[:max_num_neighbors]   
    neighbor_list[atom] = closest_neighbors
  return neighbor_list   
    atom_vec = coords[atom]
    cell = cells_for_atoms[atom] 
    nbr_inds = neighbor_cells[cell]
    # Tensor of shape (26, k, ndim)
    nbr_atoms = tf.gather(atoms_in_cells, nbr_inds)
    # Reshape to (26*k, ndim)
    nbr_atoms = tf.reshape(nbr_atoms, (-1, 3))
    # Subtract out atom vector. Still of shape (26*k, ndim) due to broadcast.
    nbr_atoms = nbr_atoms - atom_vec
    # Dists of shape (26*k, 1)
    nbr_dists = tf.reduce_sum(nbr_atoms**2, axis=1)
    # Of shape (k, ndim)
    closest_inds = tf.nn.top_k(nbr_dists, k=k)[1]
    all_closest.append(closest_inds)
  return all_closest
  ## For each atom, loop through all atoms in its cell and neighboring cells.    
  ## Accept as neighbors only those within threshold. This computation should be   
  ## O(Nm), where m is the number of atoms within a set of neighboring-cells.
  #neighbor_list = {}
  #for atom in range(N):
  #  cell = atom_to_cell[atom]
  #  neighbor_cells = neighbor_cell_map[cell]
  #  # For smaller systems especially, the periodic boundary conditions can
  #  # result in neighboring cells being seen multiple times. Use a set() to
  #  # make sure duplicate neighbors are ignored. Convert back to list before
  #  # returning. 
  #  neighbor_list[atom] = set()
  #  for neighbor_cell in neighbor_cells:
  #    atoms_in_cell = cell_to_atoms[neighbor_cell]
  #    for neighbor_atom in atoms_in_cell:
  #      if neighbor_atom == atom:    
  #         continue    
  #      # TODO(rbharath): How does distance need to be modified here to   
  #      # account for periodic boundary conditions?   
  #      dist = np.linalg.norm(coords[atom] - coords[neighbor_atom])   
  #      if dist < neighbor_cutoff:    
  #        neighbor_list[atom].add((neighbor_atom, dist))    
  #           
  #  # Sort neighbors by distance    
  #  closest_neighbors = sorted(   
  #      list(neighbor_list[atom]), key=lambda elt: elt[1])    
  #  closest_neighbors = [nbr for (nbr, dist) in closest_neighbors]    
  #  # Pick up to max_num_neighbors    
  #  closest_neighbors = closest_neighbors[:max_num_neighbors]   
  #  neighbor_list[atom] = closest_neighbors
  #return neighbor_list   

def get_cells(start, stop, nbr_cutoff, ndim=3):
  """Returns the locations of all grid points in box.
@@ -81,7 +135,10 @@ def get_cells(start, stop, nbr_cutoff, ndim=3):
  Then would return a list of length 20^3 whose entries would be
  [(-10, -10, -10), (-10, -10, -9), ..., (9, 9, 9)]

  TODO(rbharath): Make this work in more than 3 dimensions.
  Returns
  -------
  cells: tf.Tensor
    (box_size**ndim, ndim) shape.
  """
  return tf.reshape(tf.transpose(tf.pack(tf.meshgrid(
      *[tf.range(start, stop, nbr_cutoff) for _ in range(ndim)]))), (-1, ndim))
@@ -137,11 +194,14 @@ def put_atoms_in_cells(coords, cells, N, ndim, k=5):
  #   - Return N lists corresponding to neighbors for every atom.
  
        
def compute_neighbor_cell_map(cells, ndim):
def compute_neighbor_cells(cells, ndim):
  """Compute neighbors of cells in grid.    

  # TODO(rbharath): Do we need to handle periodic boundary conditions
  properly here?
  # TODO(rbharath): This doesn't handle boundaries well. We hard-code
  # looking for 26 neighbors, which isn't right for boundary cells in
  # the cube.
      
  Parameters    
  ----------    
@@ -158,10 +218,14 @@ def compute_neighbor_cell_map(cells, ndim):
  # Tile cells to form arrays of size (n_cells*n_cells, ndim)
  # Two tilings (a, b, c, a, b, c, ...) vs. (a, a, a, b, b, b, etc.)
  # Tile (a, a, a, b, b, b, etc.)
  tiled_centers = tf.reshape(tf.tile(cells, (1, N)), (n_cells*N, ndim))
  tiled_centers = tf.reshape(tf.tile(cells, (1, n_cells)), (n_cells*n_cells, ndim))
  # Tile (a, b, c, a, b, c, ...)
  tiled_cells = tf.tile(cells, (n_cells, 1))

  # Lists of n_cells tensors of shape (N, 1)
  tiled_centers = tf.split_v(tiled_centers, n_cells)
  tiled_cells = tf.split_v(tiled_cells, n_cells)

  # Lists of length n_cells
  coords_rel = [tf.to_float(cells) - tf.to_float(centers)
                for (cells, centers) in zip(tiled_centers, tiled_cells)]
@@ -169,7 +233,7 @@ def compute_neighbor_cell_map(cells, ndim):

  # Lists of length n_cells
  # Get indices of k atoms closest to each cell point
  # n_cells tensors of shape (26, ndim)
  # n_cells tensors of shape (26,)
  closest_inds = [tf.nn.top_k(norm, k=k)[1] for norm in coords_norm]

  return closest_inds
@@ -335,6 +399,3 @@ class VinaModel(Model):
        if loc_score < best_score:
          best_conf = loc_conf
    return best_conf