Commit 7f751d94 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Starting to convert layers

parent 88581403
Loading
Loading
Loading
Loading
+105 −94
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ class Layer(object):
  def set_tensors(self, tensor):
    self.out_tensor = tensor

  def _create_tensor(self):
  def create_tensor(self, in_layers=None):
    raise NotImplementedError("Subclasses must implement for themselves")

  def __key(self):
@@ -70,17 +70,7 @@ class Layer(object):
    raise ValueError("Each Layer must implement shared for itself")

  def __call__(self, *in_layers):
    if len(in_layers) > 0:
      layers = []
      for in_layer in in_layers:
        if isinstance(in_layer, Layer):
          layers.append(layer)
        elif isinstance(in_layer, tf.Tensor):
          layers.append(TensorWrapper(in_layer))
        else:
          raise ValueError("Layer must be invoked on layers or tensors")
      self.in_layers = layers
    return self._create_tensor()
    return self.create_tensor(in_layers=in_layers)


class TensorWrapper(Layer):
@@ -89,6 +79,22 @@ class TensorWrapper(Layer):
  def __init__(self, out_tensor):
    self.out_tensor = out_tensor

  def create_tensor(self, in_layers=None):
    """Take no actions."""
    pass


def convert_to_layers(in_layers):
  """Wrap all inputs into tensors if necessary."""
  layers = []
  for in_layer in in_layers:
    if isinstance(in_layer, Layer):
      layers.append(layer)
    elif isinstance(in_layer, tf.Tensor):
      layers.append(TensorWrapper(in_layer))
    else:
      raise ValueError("convert_to_layers must be invoked on layers or tensors")
  return layers

class Conv1DLayer(Layer):

@@ -98,10 +104,13 @@ class Conv1DLayer(Layer):
    self.out_tensor = None
    super(Conv1DLayer, self).__init__(**kwargs)

  def _create_tensor(self):
    if len(self.in_layers) != 1:
  def create_tensor(self, in_layers=None):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
    if len(in_layers) != 1:
      raise ValueError("Only One Parent to conv1D over")
    parent = self.in_layers[0]
    parent = in_layers[0]
    if len(parent.out_tensor.get_shape()) != 3:
      raise ValueError("Parent tensor must be (batch, width, channel)")
    parent_shape = parent.out_tensor.get_shape()
@@ -139,10 +148,13 @@ class Dense(Layer):
      scope_name = self.name
    self.scope_name = scope_name

  def _create_tensor(self):
    if len(self.in_layers) != 1:
      raise ValueError("Only One Parent to Dense over %s" % self.in_layers)
    parent = self.in_layers[0]
  def create_tensor(self, in_layers=None):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
    if len(in_layers) != 1:
      raise ValueError("Only One Parent to Dense over %s" % in_layers)
    parent = in_layers[0]
    if not self.time_series:
      self.out_tensor = tf.contrib.layers.fully_connected(
          parent.out_tensor,
@@ -182,10 +194,13 @@ class Flatten(Layer):
  def __init__(self, **kwargs):
    super(Flatten, self).__init__(**kwargs)

  def _create_tensor(self):
    if len(self.in_layers) != 1:
      raise ValueError("Only One Parent to conv1D over")
    parent = self.in_layers[0]
  def create_tensor(self, in_layers=None):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
    if len(in_layers) != 1:
      raise ValueError("Only One Parent to Flatten")
    parent = in_layers[0]
    parent_shape = parent.out_tensor.get_shape()
    vector_size = 1
    for i in range(1, len(parent_shape)):
@@ -769,8 +784,8 @@ class Cutoff(Layer):
    return self.out_tensor


class VinaNonlinearity(Layer):
  """Computes non-linearity used in Vina."""
class VinaFreeEnergy(Layer):
  """Computes free-energy as defined by Autodock Vina."""

  def __init__(self, stddev=.3, Nrot=1, **kwargs):
    self.stddev = stddev
@@ -778,62 +793,91 @@ class VinaNonlinearity(Layer):
    # TODO(rbharath): Vina actually sets this per-molecule. See if makes
    # a difference.
    self.Nrot = Nrot
    super(VinaNonlinearity, self).__init__(**kwargs)
    super(VinaFreeEnergy, self).__init__(**kwargs)

  def _create_tensor(self):
    c = self.in_layers[0].out_tensor
  def nonlinearity(self, c):
    """Computes non-linearity used in Vina."""
    w = tf.Variable(tf.random_normal((1,), stddev=self.stddev))
    self.out_tensor = c / (1 + w * self.Nrot)
    return self.out_tensor
    out_tensor = c / (1 + w * self.Nrot)
    return out_tensor


class VinaRepulsion(Layer):
  def repulsion(self, d):
    """Computes Autodock Vina's repulsion interaction term."""

  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.where(d < 0, d**2, tf.zeros_like(d))
    return self.out_tensor
    out_tensor = tf.where(d < 0, d**2, tf.zeros_like(d))
    return out_tensor


class VinaHydrophobic(Layer):
  def hydrophobic(self, d):
    """Computes Autodock Vina's hydrophobic interaction term."""

  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.where(d < 0.5,
    out_tensor = tf.where(d < 0.5,
                               tf.ones_like(d),
                               tf.where(d < 1.5, 1.5 - d, tf.zeros_like(d)))
    return self.out_tensor
    return out_tensor


class VinaHydrogenBond(Layer):
  def hydrogen_bond(self, d):
    """Computes Autodock Vina's hydrogen bond interaction term."""

  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.where(d < -0.7,
    out_tensor = tf.where(d < -0.7,
                               tf.ones_like(d),
                               tf.where(d < 0, (1.0 / 0.7) * (0 - d),
                                        tf.zeros_like(d)))
    return self.out_tensor
    return out_tensor


class VinaGaussianFirst(Layer):
  def gaussian_first(self, d):
    """Computes Autodock Vina's first Gaussian interaction term."""

  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.exp(-(d / 0.5)**2)
    return self.out_tensor
    out_tensor = tf.exp(-(d / 0.5)**2)
    return out_tensor


class VinaGaussianSecond(Layer):
  def gaussian_second(self, d):
    """Computes Autodock Vina's second Gaussian interaction term."""
    out_tensor = tf.exp(-((d - 3) / 2)**2)
    return out_tensor

  def _create_tensor(self):
    d = self.in_layers[0].out_tensor
    self.out_tensor = tf.exp(-((d - 3) / 2)**2)
    """
    Parameters
    ----------
    X: tf.Tensor of shape (B, N, d)
      Coordinates/features.
    Z: tf.Tensor of shape (B, N)
      Atomic numbers of neighbor atoms.
      
    Returns
    -------
    layer: tf.Tensor of shape (B)
      The free energy of each complex in batch
    """
    X = self.in_layers[0].out_tensor
    Z = self.in_layers[2].out_tensor

    nbr_list = NeighborList(
        self.N_atoms, self.M_nbrs, self.ndim, self.nbr_cutoff, self.start,
        self.stop)(coords)

    # Shape (N, M)
    dists = InteratomicL2Distances(
        self.N_atoms, self.M_nbrs, self.ndim)(coords, nbr_list)

    repulsion = self.repulsion(dists)
    hydrophobic = self.hydrophobic(dists)
    hbond = self.hydrogen_bond(dists)
    gauss_1 = self.gaussian_first(dists)
    gauss_2 = self.gaussian_second(dists)

    # Shape (N, M)
    interactions = WeightedLinearCombo()(
        repulsion, hydrophobic, hbond, gauss_1, gauss_2)

    # Shape (N, M)
    thresholded = Cutoff()(dists, interactions)

    free_energies = self.nonlinearity(thresholded)
    free_energy = ReduceSum()(free_energies)

    self.output_tensor = free_energy
    return self.out_tensor


@@ -905,37 +949,6 @@ class NeighborList(Layer):
    self.out_tensor = nbr_list
    return nbr_list

  #def compute_nbr_list(self, coords):
  #  """Computes a neighbor list from atom coordinates.

  #  Parameters
  #  ----------
  #  coords: tf.Tensor
  #    Shape (N_atoms, ndim)

  #  Returns
  #  -------
  #  nbr_list: tf.Tensor
  #    Shape (N_atoms, M_nbrs) of atom indices
  #  """
  #  N_atoms, M_nbrs, n_cells, ndim = (self.N_atoms, self.M_nbrs, self.n_cells,
  #                                    self.ndim)
  #  nbr_cutoff = self.nbr_cutoff
  #  coords = tf.to_float(coords)

  #  nbrs, closest_nbrs = self.get_closest_nbrs(coords)

  #  # N_atoms elts of size (M_nbrs,) each 
  #  neighbor_list = [
  #      tf.gather(atom_nbrs, closest_nbr_ind)
  #      for (atom_nbrs, closest_nbr_ind) in zip(nbrs, closest_nbrs)
  #  ]

  #  # Shape (N_atoms, M_nbrs)
  #  nbr_list = tf.stack(neighbor_list)

  #  return nbr_list

  def compute_nbr_list(self, coords):
    """Get closest neighbors for atoms.

@@ -1207,8 +1220,6 @@ class AtomicConvolution(Layer):
    Nbrs_Z: tf.Tensor of shape (B, N, M)
      Atomic numbers of neighbor atoms.
    
    
    
    Returns
    -------
    layer: tf.Tensor of shape (l, B, N)
+1 −74
Original line number Diff line number Diff line
@@ -20,12 +20,7 @@ from deepchem.models.tensorgraph.layers import ReduceSquareDifference
from deepchem.models.tensorgraph.layers import WeightedLinearCombo
from deepchem.models.tensorgraph.layers import InteratomicL2Distances
from deepchem.models.tensorgraph.layers import Cutoff
from deepchem.models.tensorgraph.layers import VinaRepulsion
from deepchem.models.tensorgraph.layers import VinaNonlinearity
from deepchem.models.tensorgraph.layers import VinaHydrophobic
from deepchem.models.tensorgraph.layers import VinaHydrogenBond
from deepchem.models.tensorgraph.layers import VinaGaussianFirst
from deepchem.models.tensorgraph.layers import VinaGaussianSecond
from deepchem.models.tensorgraph.layers import VinaFreeEnergy
from deepchem.models.tensorgraph.layers import L2LossLayer
from deepchem.models.tensorgraph.tensor_graph import TensorGraph

@@ -378,74 +373,6 @@ class TestDocking(test_util.TensorFlowTestCase):
      nbr_list = np.squeeze(nbr_list.eval())
      np.testing.assert_array_almost_equal(nbr_list, np.array([-1, -1, 3, 2]))

  #def test_get_closest_nbrs_3D_empty_cells(self):
  #  """Test get_closest_nbrs in 3D with empty nbrs.
  #  Stresses the failure mode where the neighboring cells are empty
  #  so top_k will throw a failure.
  #  """
  #  N_atoms = 4
  #  start = 0
  #  stop = 10
  #  nbr_cutoff = 1
  #  ndim = 3
  #  M_nbrs = 1
  #  # 1 and 2 are nbrs. 8 and 9 are nbrs
  #  coords = np.array(
  #    [[1.0, 0.0, 1.0],
  #     [2.0, 5.0, 2.0],
  #     [8.0, 8.0, 8.0],
  #     [9.0, 9.0, 9.0]])
  #  coords = np.reshape(coords, (N_atoms, ndim))

  #  with self.test_session() as sess:
  #    coords = tf.convert_to_tensor(coords, dtype=tf.float32)
  #    nbr_layer = NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
  #                                  stop)

  #    neighbor_list, padded_neighbor_list, padded_closest_nbrs, padded_dists, dists, padded_nbr_coords, nbr_coords, padded_nbrs, nbrs, closest_nbrs = nbr_layer.get_closest_nbrs(coords)

  #    neighbor_list_eval = neighbor_list.eval()
  #    print("neighbor_list_eval")
  #    print(neighbor_list_eval)

  #    padded_neighbor_list_eval = [padded_nbr.eval() for padded_nbr in padded_neighbor_list]
  #    print("padded_neighbor_list_eval")
  #    print(padded_neighbor_list_eval)

  #    padded_dists_eval = [padded_dist.eval() for padded_dist in padded_dists]
  #    print("padded_dists_eval")
  #    print(padded_dists_eval)

  #    dists_eval = [dist.eval() for dist in dists]
  #    print("dists_eval")
  #    print(dists_eval)
  #
  #    nbr_coords_eval = [nbr_coord.eval() for nbr_coord in nbr_coords]
  #    print("nbr_coords_eval")
  #    print(nbr_coords_eval)

  #    padded_nbr_coords_eval = [padded_nbr_coord.eval() for padded_nbr_coord in padded_nbr_coords]
  #    print("padded_nbr_coords_eval")
  #    print(padded_nbr_coords_eval)

  #    nbrs_eval = [nbr.eval() for nbr in nbrs]
  #    print("nbrs_eval")
  #    print(nbrs_eval)
  #
  #    padded_nbrs_eval = [padded_nbr.eval() for padded_nbr in padded_nbrs]
  #    print("padded_nbrs_eval")
  #    print(padded_nbrs_eval)

  #    padded_closest_nbrs_eval = [padded_closest_nbr.eval() for padded_closest_nbr in padded_closest_nbrs] 
  #    print("padded_closest_nbrs_eval")
  #    print(padded_closest_nbrs_eval)

  #    #closest_nbrs_eval = [closest_nbr.eval() for closest_nbr in closest_nbrs]
  #    #print("closest_nbrs_eval")
  #    #print(closest_nbrs_eval)

  #    assert 0 == 1

  def test_neighbor_list_vina(self):
    """Test under conditions closer to Vina usage."""
    N_atoms = 5
+49 −0
Original line number Diff line number Diff line
import unittest

import numpy as np
import os
import tensorflow as tf
from nose.tools import assert_true
from tensorflow.python.framework import test_util
from deepchem.models.tensorgraph.layers import Conv1DLayer
from deepchem.models.tensorgraph.layers import Dense

import deepchem as dc

class TestLayers(test_util.TensorFlowTestCase):
  """
  Test that layers function as intended.
  """

  def test_conv_1D_layer(self):
    """Test that Conv1D can be invoked."""
    width = 5
    in_channels = 2
    out_channels = 3
    batch_size = 10
    in_tensor = np.random.rand(batch_size, width, in_channels)
    with self.test_session() as sess:
      in_tensor = tf.convert_to_tensor(in_tensor, dtype=tf.float32)
      out_tensor = Conv1DLayer(width, out_channels)(in_tensor)
      sess.run(tf.global_variables_initializer())
      out_tensor = out_tensor.eval()

      assert out_tensor.shape == (batch_size, width, out_channels)
    
  def test_dense(self):
    """Test that Dense can be invoked."""
    in_dim = 2
    out_dim = 3
    batch_size = 10
    in_tensor = np.random.rand(batch_size, in_dim)
    with self.test_session() as sess:
      in_tensor = tf.convert_to_tensor(in_tensor, dtype=tf.float32)
      out_tensor = Dense(out_dim)(in_tensor)
      sess.run(tf.global_variables_initializer())
      out_tensor = out_tensor.eval()

      assert out_tensor.shape == (batch_size, out_dim)

  def test_flatten(self):
    """Test that Flatten can be invoked."""
    pass