Commit 90021fac authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First impl of Vina in Tensorgraph

parent d009ccb9
Loading
Loading
Loading
Loading
+37 −16
Original line number Diff line number Diff line
@@ -60,7 +60,7 @@ class Layer(object):
        else:
          raise ValueError("Layer must be invoked on layers or tensors")
      self.in_layers = layers
    self._create_tensor()
    return self._create_tensor()
  

class TensorWrapper(Layer):
@@ -305,20 +305,27 @@ class Concat(Layer):
class InteratomicL2Distances(Layer):
  """Compute (squared) L2 Distances between atoms given neighbors."""

  def __init__(self, N_atoms, M_nbrs, ndim, **kwargs):
    self.N_atoms = N_atoms
    self.M_nbrs = M_nbrs
    self.ndim = ndim
    super(InteratomicL2Distances, self).__init__(**kwargs)

  def _create_tensor(self):
    if len(self.in_layers) != 2:
      raise ValueError("InteratomicDistances requires coords,nbr_list")
    coords, nbr_list = (self.in_layers[0].out_tensor,
                        self.in_layers[1].out_tensor)
    N_atoms, ndim = coords.get_shape()
    _, M = nbr_list.get_shape()
    # Shape (N_atoms, M, ndim)
    N_atoms, M_nbrs, ndim = self.N_atoms, self.M_nbrs, self.ndim 
    # Shape (N_atoms, M_nbrs, ndim)
    nbr_coords = tf.gather(coords, nbr_list)
    # Shape (N_atoms, M, ndim)
    tiled_atom_coords = tf.tile(
        tf.reshape(atom_coords, (N_atoms, 1, ndim)), (1, M, 1))
    # Shape (N_atoms, M)
    dists = tf.reduce_sum((tiled_atom_coords - nbr_coords)**2, axis=2)
    # Shape (N_atoms, M_nbrs, ndim)
    tiled_coords = tf.tile(
        tf.reshape(coords, (N_atoms, 1, ndim)), (1, M_nbrs, 1))
    # Shape (N_atoms, M_nbrs)
    dists = tf.reduce_sum((tiled_coords - nbr_coords)**2, axis=2)
    self.out_tensor = dists
    return self.out_tensor



@@ -681,17 +688,31 @@ class WeightedError(Layer):

class Cutoff(Layer):
  """Truncates interactions that are too far away."""
  def __init__(dist, **kwargs):
    self.d = dist
    super(Cutoff, self).__init__(**kwargs)
  
  
  def _create_tensor(self):
    d = self.d
    x = self.in_layers[0].out_tensor
    if len(self.in_layers) != 2:
      raise ValueError("Cutoff must be given distances and energies.")
    d, x = self.in_layers[0].out_tensor, self.in_layers[1].out_tensor
    self.out_tensor = tf.where(d < 8, x, tf.zeros_like(x))
    return self.out_tensor

class VinaNonlinearity(Layer):
  """Computes non-linearity used in Vina."""

  def __init__(self, stddev=.3, Nrot=1, **kwargs):
    self.stddev = stddev
    # Number of rotatable bonds
    # TODO(rbharath): Vina actually sets this per-molecule. See if makes
    # a difference.
    self.Nrot = Nrot
    super(VinaNonlinearity, self).__init__(**kwargs)

  def _create_tensor(self):
    c = self.in_layers[0].out_tensor
    w = tf.Variable(tf.random_normal((1,), stddev=self.stddev))
    self.out_tensor = c / (1 + w * self.Nrot)
    return self.out_tensor

class VinaRepulsion(Layer):
  """Computes Autodock Vina's repulsion interaction term."""
  
@@ -700,7 +721,7 @@ class VinaRepulsion(Layer):
    self.out_tensor = tf.where(d < 0, d**2, tf.zeros_like(d))
    return self.out_tensor

def VinaHydrophobic(Layer):
class VinaHydrophobic(Layer):
  """Computes Autodock Vina's hydrophobic interaction term."""

  def _create_tensor(self):
+79 −13
Original line number Diff line number Diff line
@@ -2,7 +2,9 @@ import unittest

import numpy as np
import os
import tensorflow as tf
from nose.tools import assert_true
from tensorflow.python.framework import test_util

import deepchem as dc
from deepchem.data import NumpyDataset
@@ -10,14 +12,23 @@ from deepchem.data.datasets import Databag
from deepchem.models.tensorgraph.layers import ReduceSum 
from deepchem.models.tensorgraph.layers import Feature, Label
from deepchem.models.tensorgraph.layers import ToFloat
from deepchem.models.tensorgraph.layers import Concat
from deepchem.models.tensorgraph.layers import NeighborList
from deepchem.models.tensorgraph.layers import ReduceSquareDifference
from deepchem.models.tensorgraph.layers import WeightedLinearCombo
from deepchem.models.tensorgraph.layers import InteratomicL2Distances
from deepchem.models.tensorgraph.layers import Cutoff
from deepchem.models.tensorgraph.layers import VinaRepulsion
from deepchem.models.tensorgraph.layers import VinaNonlinearity
from deepchem.models.tensorgraph.layers import VinaHydrophobic
from deepchem.models.tensorgraph.layers import VinaHydrogenBond
from deepchem.models.tensorgraph.layers import VinaGaussianFirst
from deepchem.models.tensorgraph.layers import VinaGaussianSecond
from deepchem.models.tensorgraph.layers import L2LossLayer
from deepchem.models.tensorgraph.tensor_graph import TensorGraph


class TestDocking(unittest.TestCase):
class TestDocking(test_util.TensorFlowTestCase):
  """
  Test that tensorgraph docking-style models work. 
  """
@@ -76,31 +87,86 @@ class TestDocking(unittest.TestCase):

  def test_vina(self):
    """Test that vina graph can be constructed in TensorGraph."""
    N_protein = 4
    N_ligand = 1
    N_atoms = 5
    M_nbrs = 2
    ndim = 3
    k = 5
    start = 0
    stop = 4
    nbr_cutoff = 1
    # The number of cells which we should theoretically have
    n_cells = ((stop - start) / nbr_cutoff)**ndim

    prot_coords = Features(shape=(N_protein, 3))
    prot_Z = Features(shape=(N_protein,), dtype=tf.int32)
    ligand_coords = Features(shape=(N_ligand, 3))
    ligand_Z = Features(shape=(N_ligand,), dtype=tf.int32)
    X_prot = NumpyDataset(np.random.rand(N_protein, ndim))
    X_ligand = NumpyDataset(np.random.rand(N_ligand, ndim))
    y = NumpyDataset(np.random.rand(1,))

    # TODO(rbharath): Mysteriously, the actual atom types aren't
    # used in the current implementation. This is obviously wrong, but need
    # to dig out why this is happening.
    prot_coords = Feature(shape=(N_protein, ndim))
    ligand_coords = Feature(shape=(N_ligand, ndim))
    labels = Label(shape=(1,))

    coords = Concat(in_layers=[prot_coords, ligand_coords], axis=0)
    Z = Concat(in_layers=[prot_Z, ligand_Z], axis=0)

    #prot_Z = Feature(shape=(N_protein,), dtype=tf.int32)
    #ligand_Z = Feature(shape=(N_ligand,), dtype=tf.int32)
    #Z = Concat(in_layers=[prot_Z, ligand_Z], axis=0)

    # Now an (N, M) shape
    nbr_list = NeighborList(N_protein+N_ligand, M, ndim, n_cells, k,
    nbr_list = NeighborList(N_protein+N_ligand, M_nbrs, ndim, n_cells, k,
                            nbr_cutoff, in_layers=[coords])

    # Shape (N, M)
    dists = InteratomicL2Distances(N_protein+N_ligand, M_nbrs, ndim,
                                   in_layers=[coords, nbr_list])

    repulsion = VinaRepulsion(in_layers=[dists])
    hydrophobic = VinaHydrophobic(in_layers=[dists])
    hbond = VinaHydrogenBond(in_layers=[dists])
    gauss_1 = VinaGaussianFirst(in_layers=[dists]) 
    gauss_2 = VinaGaussianSecond(in_layers=[dists]) 

    # Shape (N, M)
    interactions = WeightedLinearCombo(
        in_layers=[repulsion, hydrophobic, hbond, gauss_1, gauss_2])
    
    # Shape (N, M)
    thresholded = Cutoff(in_layers=[dists, interactions])

    # Shape (N, M)
    free_energies = VinaNonlinearity(in_layers=[thresholded])
    free_energy = ReduceSum(in_layers=[free_energies])
    
    loss = L2LossLayer(in_layers=[free_energy, labels])
    
    databag = Databag({prot_coords: X_prot, ligand_coords: X_ligand,
                       labels: y})

    tg = dc.models.TensorGraph(learning_rate=0.1, use_queue=False)
    tg.set_loss(loss)
    tg.fit_generator(databag.iterbatches(epochs=1))
    
    

  def test_interatomic_distances(self):
    """Test that the interatomic distance calculation works."""
    N_atoms = 5
    M = 2
    M_nbrs = 2
    ndim = 3

    with self.test_session() as sess:
      coords = np.random.rand(N_atoms, ndim)
    nbr_list = np.random.randint(0, N_atoms, size=(N_atoms, M))
      nbr_list = np.random.randint(0, N_atoms, size=(N_atoms, M_nbrs))

      coords_tensor = tf.convert_to_tensor(coords)
      nbr_list_tensor = tf.convert_to_tensor(nbr_list)

    dist_tensor = 
      dist_tensor = InteratomicL2Distances(N_atoms, M_nbrs, ndim)(
          coords_tensor, nbr_list_tensor)

      dists = dist_tensor.eval()
      assert dists.shape == (N_atoms, M_nbrs)