Commit a903efcb authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Adding layers

parent 18fcf057
Loading
Loading
Loading
Loading
+147 −6
Original line number Diff line number Diff line
@@ -610,6 +610,52 @@ class WeightedError(Layer):
    self.out_tensor = tf.reduce_sum(entropy.out_tensor * weights.out_tensor)
    return self.out_tensor

class Cutoff(Layer):
  """Truncates interactions that are too far away."""
  def __init__(dist, **kwargs):
    self.d = dist
    super(Cutoff, self).__init__(**kwargs)
  
  
  def _create_tensor(self):
    d = self.d
    x = self.in_layers[0].out_tensor
    self.out_tensor = tf.where(d < 8, x, tf.zeros_like(x))
    return self.out_tensor

class VinaRepulsion(Layer):
  """Computes Autodock Vina's repulsion interaction term."""
  
  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.where(d < 0, d**2, tf.zeros_like(d))
    return self.out_tensor

def VinaHydrophobic(Layer):
  """Computes Autodock Vina's hydrophobic interaction term."""

  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.where(d < 0.5,
                               tf.ones_like(d),
                               tf.where(d < 1.5, 1.5 - d, tf.zeros_like(d)))
    return self.out_tensor

class VinaHydrogenBond(Layer):
  """Computes Autodock Vina's hydrogen bond interaction term."""

  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.where(d < -0.7,
                               tf.ones_like(d),
                               tf.where(d < 0,
                                        (1.0 / 0.7) * (0 - d),
                                        tf.zeros_like(d)))

def VinaGaussianFirst(Layer):
  """Computes Autodock Vina's first Gaussian interaction term."""
  pass

class NeighborList(Layer):
  """Computes a neighbor-list on the GPU.

@@ -641,6 +687,7 @@ class NeighborList(Layer):
    self.ndim = ndim
    self.n_cells = n_cells
    self.k = k
    self.nbr_cutoff = nbr_cutoff
    super(NeighborList, self).__init__(**kwargs)

  def _create_tensor(self):
@@ -652,7 +699,9 @@ class NeighborList(Layer):
      # TODO(rbharath): Support batching
      raise ValueError("Parent tensor must be (num_atoms, ndum)")
    coords = parent.out_tensor
    return self._compute_nbr_list(coords)
    nbr_list = self._compute_nbr_list(coords)
    self.out_tensor = nbr_list
    return nbr_list
  

  def _compute_nbr_list(self, coords):
@@ -672,17 +721,17 @@ class NeighborList(Layer):
    nbr_cutoff = self.nbr_cutoff
    start = tf.to_int32(tf.reduce_min(coords))
    stop = tf.to_int32(tf.reduce_max(coords))
    cells = self._get_cells(start, stop, nbr_cutoff, ndim=ndim)
    cells = self._get_cells(start, stop)
    # Associate each atom with cell it belongs to. O(N*n_cells)
    # Shape (n_cells, k)
    atoms_in_cells, _ = put_atoms_in_cells(coords, cells, N, n_cells, ndim, k)
    atoms_in_cells, _ = self._put_atoms_in_cells(coords, cells)
    # Shape (N, 1)
    cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
    cells_for_atoms = self._get_cells_for_atoms(coords, cells)

    # Associate each cell with its neighbor cells. Assumes periodic boundary   
    # conditions, so does wrapround. O(constant)    
    # Shape (n_cells, 26)
    neighbor_cells = compute_neighbor_cells(cells, ndim, n_cells)
    neighbor_cells = self._compute_neighbor_cells(cells)

    # Shape (N, 26)
    neighbor_cells = tf.squeeze(tf.gather(neighbor_cells, cells_for_atoms))
@@ -791,6 +840,98 @@ class NeighborList(Layer):

    return closest_inds, closest_atoms

  def _get_cells_for_atoms(self, coords, cells):
    """Compute the cells each atom belongs to.

    Parameters
    ----------
    coords: tf.Tensor
      Shape (N, ndim)
    cells: tf.Tensor
      (box_size**ndim, ndim) shape.
    Returns
    -------
    cells_for_atoms: tf.Tensor
      Shape (N, 1)
    """
    N, n_cells, ndim = self.N, self.n_cells, self.ndim
    n_cells = int(n_cells)
    # Tile both cells and coords to form arrays of size (n_cells*N, ndim)
    tiled_cells = tf.tile(cells, (N, 1))
    # N tensors of shape (n_cells, 1)
    tiled_cells = tf.split(tiled_cells, N)

    # Shape (N*n_cells, 1) after tile
    tiled_coords = tf.reshape(tf.tile(coords, (1, n_cells)), (n_cells * N, ndim))
    # List of N tensors of shape (n_cells, 1)
    tiled_coords = tf.split(tiled_coords, N)

    # Lists of length N 
    coords_rel = [
        tf.to_float(coords) - tf.to_float(cells)
        for (coords, cells) in zip(tiled_coords, tiled_cells)
    ]
    coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]

    # Lists of length n_cells
    # Get indices of k atoms closest to each cell point
    closest_inds = [tf.nn.top_k(-norm, k=1)[1] for norm in coords_norm]

    # TODO(rbharath): tf.stack for tf 1.0
    return tf.stack(closest_inds)

  def _compute_neighbor_cells(self, cells):
    """Compute neighbors of cells in grid.    

    # TODO(rbharath): Do we need to handle periodic boundary conditions
    properly here?
    # TODO(rbharath): This doesn't handle boundaries well. We hard-code
    # looking for 26 neighbors, which isn't right for boundary cells in
    # the cube.
        
    Note n_cells is box_size**ndim. 26 is the number of neighbors of a cube in
    a grid (including diagonals).

    Parameters    
    ----------    
    cells: tf.Tensor
      (n_cells, 26) shape.
    """
    ndim, n_cells = self.ndim, self.n_cells
    n_cells = int(n_cells)
    if ndim != 3:
      raise ValueError("Not defined for dimensions besides 3")
    # Number of neighbors of central cube in 3-space is
    # 3^2 (top-face) + 3^2 (bottom-face) + (3^2-1) (middle-band)
    # TODO(rbharath)
    k = 9 + 9 + 8  # (26 faces on Rubik's cube for example)
    #n_cells = int(cells.get_shape()[0])
    # Tile cells to form arrays of size (n_cells*n_cells, ndim)
    # Two tilings (a, b, c, a, b, c, ...) vs. (a, a, a, b, b, b, etc.)
    # Tile (a, a, a, b, b, b, etc.)
    tiled_centers = tf.reshape(
        tf.tile(cells, (1, n_cells)), (n_cells * n_cells, ndim))
    # Tile (a, b, c, a, b, c, ...)
    tiled_cells = tf.tile(cells, (n_cells, 1))

    # Lists of n_cells tensors of shape (N, 1)
    tiled_centers = tf.split(tiled_centers, n_cells)
    tiled_cells = tf.split(tiled_cells, n_cells)

    # Lists of length n_cells
    coords_rel = [
        tf.to_float(cells) - tf.to_float(centers)
        for (cells, centers) in zip(tiled_centers, tiled_cells)
    ]
    coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]

    # Lists of length n_cells
    # Get indices of k atoms closest to each cell point
    # n_cells tensors of shape (26,)
    closest_inds = tf.stack([tf.nn.top_k(norm, k=k)[1] for norm in coords_norm])

    return closest_inds

  def _get_cells(self, start, stop):
    """Returns the locations of all grid points in box.

@@ -808,6 +949,6 @@ class NeighborList(Layer):
            tf.stack(
                tf.meshgrid(
                    * [tf.range(start, stop, self.nbr_cutoff) for _ in range(self.ndim)]))),
        (-1, ndim))
        (-1, self.ndim))

  
+17 −9
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ from deepchem.data.datasets import Databag
from deepchem.models.tensorgraph.layers import Dense, SoftMaxCrossEntropy, ReduceMean, SoftMax
from deepchem.models.tensorgraph.layers import Feature, Label
from deepchem.models.tensorgraph.layers import ReduceSquareDifference
from deepchem.models.tensorgraph.layers import NeighborList
from deepchem.models.tensorgraph.tensor_graph import TensorGraph


@@ -209,7 +210,7 @@ class TestTensorGraph(unittest.TestCase):
    assert_true(np.all(np.isclose(prediction, prediction2, atol=0.01)))

  def test_neighbor_list(self):
    N = 10
    N_atoms = 10
    start = 0
    stop = 12
    nbr_cutoff = 3
@@ -219,12 +220,19 @@ class TestTensorGraph(unittest.TestCase):
    # The number of cells which we should theoretically have
    n_cells = int(((stop - start) / nbr_cutoff)**ndim)

    nbr_list = NeighborList(N, M, ndim, n_cells, k, nbr_cutoff)
    X = np.random.rand(N_atoms, ndim)
    y = np.random.rand(N_atoms, 1)
    dataset = NumpyDataset(X, y)

    features = Feature(shape=(N_atoms, ndim))
    labels = Label(shape=(N_atoms,))
    nbr_list = NeighborList(N_atoms, M, ndim, n_cells, k, nbr_cutoff,
                            in_layers=[features])
    # This isn't a meaningful loss, but just for test
    loss = ReduceMean(in_layers=[nbr_list])
    tg = dc.models.TensorGraph(use_queue=False)
    tg.add_output(nbr_list)
    tg.set_loss(loss)

    tg.fit(dataset, nb_epoch=1)
    #with self.test_session() as sess:
    #  coords = start + np.random.rand(N, ndim) * (stop - start)
    #  coords = tf.stack(coords)
    #  nbr_list = compute_neighbor_list(
    #      coords, nbr_cutoff, N, M, n_cells, ndim=ndim, k=k)
    #  nbr_list = nbr_list.eval()
    #  assert nbr_list.shape == (N, M)