Commit b5a3f977 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #505 from lilleswing/tg-graphconv-cr

TensorGraph GraphConvs
parents 472eeece 79655a7e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
 No newline at end of file
from deepchem.models.tensorgraph.tensor_graph import TensorGraphfrom deepchem.models.tensorgraph import models
 No newline at end of file
+226 −3
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@ import string

import tensorflow as tf

from deepchem.nn import model_ops, initializations


class Layer(object):

@@ -176,9 +178,10 @@ class TimeSeriesDense(Layer):

class Input(Layer):

  def __init__(self, shape, pre_queue=False, **kwargs):
  def __init__(self, shape, dtype=tf.float32, pre_queue=False, **kwargs):
    self.shape = shape
    self.pre_queue = pre_queue
    self.dtype = dtype
    super().__init__(**kwargs)

  def __call__(self, *parents):
@@ -187,14 +190,18 @@ class Input(Layer):
      placeholder = queue.out_tensors[self.get_pre_q_name()]
      self.out_tensor = tf.placeholder_with_default(placeholder, self.shape)
      return self.out_tensor
    self.out_tensor = tf.placeholder(tf.float32, shape=self.shape)
    self.out_tensor = tf.placeholder(dtype=self.dtype, shape=self.shape)
    return self.out_tensor

  def create_pre_q(self, batch_size):
    if self.pre_queue:
      raise ValueError("Input is already pre_q")
    q_shape = (batch_size,) + self.shape[1:]
    return Input(shape=q_shape, name="%s_pre_q" % self.name, pre_queue=True)
    return Input(
        shape=q_shape,
        name="%s_pre_q" % self.name,
        dtype=self.dtype,
        pre_queue=True)

  def get_pre_q_name(self):
    if self.pre_queue:
@@ -335,3 +342,219 @@ class InputFifoQueue(Layer):

  def close(self):
    self.queue.close()


class GraphConvLayer(Layer):

  def __init__(self,
               out_channel,
               min_deg=0,
               max_deg=10,
               activation_fn=None,
               **kwargs):
    self.out_channel = out_channel
    self.min_degree = min_deg
    self.max_degree = max_deg
    self.num_deg = 2 * max_deg + (1 - min_deg)
    self.activation_fn = activation_fn
    super().__init__(**kwargs)

  def __call__(self, *parents):
    #   parents = [atom_features, deg_slice, membership, deg_adj_list placeholders...]
    in_channels = parents[0].out_tensor.get_shape()[-1].value

    # Generate the nb_affine weights and biases
    self.W_list = [
        initializations.glorot_uniform([in_channels, self.out_channel])
        for k in range(self.num_deg)
    ]
    self.b_list = [
        model_ops.zeros(shape=[
            self.out_channel,
        ]) for k in range(self.num_deg)
    ]

    # Extract atom_features
    atom_features = parents[0].out_tensor

    # Extract graph topology
    deg_slice = parents[1].out_tensor
    deg_adj_lists = [x.out_tensor for x in parents[3:]]

    # Perform the mol conv
    # atom_features = graph_conv(atom_features, deg_adj_lists, deg_slice,
    #                            self.max_deg, self.min_deg, self.W_list,
    #                            self.b_list)

    W = iter(self.W_list)
    b = iter(self.b_list)

    # Sum all neighbors using adjacency matrix
    deg_summed = self.sum_neigh(atom_features, deg_adj_lists)

    # Get collection of modified atom features
    new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]

    for deg in range(1, self.max_degree + 1):
      # Obtain relevant atoms for this degree
      rel_atoms = deg_summed[deg - 1]

      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)

      # Apply hidden affine to relevant atoms and append
      rel_out = tf.matmul(rel_atoms, next(W)) + next(b)
      self_out = tf.matmul(self_atoms, next(W)) + next(b)
      out = rel_out + self_out

      new_rel_atoms_collection[deg - self.min_degree] = out

    # Determine the min_deg=0 case
    if self.min_degree == 0:
      deg = 0

      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)

      # Only use the self layer
      out = tf.matmul(self_atoms, next(W)) + next(b)

      new_rel_atoms_collection[deg - self.min_degree] = out

    # Combine all atoms back into the list
    atom_features = tf.concat(axis=0, values=new_rel_atoms_collection)

    if self.activation_fn is not None:
      atom_features = self.activation_fn(atom_features)

    self.out_tensor = atom_features
    return atom_features

  def sum_neigh(self, atoms, deg_adj_lists):
    """Store the summed atoms by degree"""
    deg_summed = self.max_degree * [None]

    # Tensorflow correctly processes empty lists when using concat
    for deg in range(1, self.max_degree + 1):
      gathered_atoms = tf.gather(atoms, deg_adj_lists[deg - 1])
      # Sum along neighbors as well as self, and store
      summed_atoms = tf.reduce_sum(gathered_atoms, 1)
      deg_summed[deg - 1] = summed_atoms

    return deg_summed

  def none_tensors(self):
    out_tensor, W_list, b_list = self.out_tensor, self.W_list, self.b_list
    self.out_tensor, self.W_list, self.b_list = None, None, None
    return out_tensor, W_list, b_list

  def set_tensors(self, tensors):
    self.out_tensor, self.W_list, self.b_list = tensors


class GraphPoolLayer(Layer):

  def __init__(self, min_degree=0, max_degree=10, **kwargs):
    self.min_degree = min_degree
    self.max_degree = max_degree
    super().__init__(**kwargs)

  def __call__(self, *parents):
    atom_features = parents[0].out_tensor
    deg_slice = parents[1].out_tensor
    deg_adj_lists = [x.out_tensor for x in parents[3:]]

    # Perform the mol gather
    # atom_features = graph_pool(atom_features, deg_adj_lists, deg_slice,
    #                            self.max_degree, self.min_degree)

    deg_maxed = (self.max_degree + 1 - self.min_degree) * [None]

    # Tensorflow correctly processes empty lists when using concat

    for deg in range(1, self.max_degree + 1):
      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)

      # Expand dims
      self_atoms = tf.expand_dims(self_atoms, 1)

      # always deg-1 for deg_adj_lists
      gathered_atoms = tf.gather(atom_features, deg_adj_lists[deg - 1])
      gathered_atoms = tf.concat(axis=1, values=[self_atoms, gathered_atoms])

      maxed_atoms = tf.reduce_max(gathered_atoms, 1)
      deg_maxed[deg - self.min_degree] = maxed_atoms

    if self.min_degree == 0:
      begin = tf.stack([deg_slice[0, 0], 0])
      size = tf.stack([deg_slice[0, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      deg_maxed[0] = self_atoms

    self.out_tensor = tf.concat(axis=0, values=deg_maxed)
    return self.out_tensor


class GraphGather(Layer):

  def __init__(self, batch_size, activation_fn=None, **kwargs):
    self.batch_size = batch_size
    self.activation_fn = activation_fn
    super().__init__(**kwargs)

  def __call__(self, *parents):
    # x = [atom_features, deg_slice, membership, deg_adj_list placeholders...]
    atom_features = parents[0].out_tensor

    # Extract graph topology
    membership = parents[2].out_tensor

    # Perform the mol gather

    assert (self.batch_size > 1, "graph_gather requires batches larger than 1")

    # Obtain the partitions for each of the molecules
    activated_par = tf.dynamic_partition(atom_features, membership,
                                         self.batch_size)

    # Sum over atoms for each molecule
    sparse_reps = [
        tf.reduce_sum(activated, 0, keep_dims=True)
        for activated in activated_par
    ]
    max_reps = [
        tf.reduce_max(activated, 0, keep_dims=True)
        for activated in activated_par
    ]

    # Get the final sparse representations
    sparse_reps = tf.concat(axis=0, values=sparse_reps)
    max_reps = tf.concat(axis=0, values=max_reps)
    mol_features = tf.concat(axis=1, values=[sparse_reps, max_reps])

    if self.activation_fn is not None:
      mol_features = self.activation_fn(mol_features)
    self.out_tensor = mol_features
    return mol_features


class BatchNormLayer(Layer):

  def __call__(self, *parents):
    parent_tensor = parents[0].out_tensor
    self.out_tensor = tf.layers.batch_normalization(parent_tensor)
    return self.out_tensor


class WeightedError(Layer):

  def __call__(self, *parents):
    entropy, weights = parents[0], parents[1]
    self.out_tensor = tf.reduce_sum(entropy.out_tensor * weights.out_tensor)
    return self.out_tensor
+86 −0
Original line number Diff line number Diff line
import tensorflow as tf
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.models.tensorgraph.layers import Input, Dense, Concat, SoftMax, SoftMaxCrossEntropy, Layer, \
  GraphConvLayer, BatchNormLayer, GraphPoolLayer, GraphGather, WeightedError
from deepchem.metrics import to_one_hot
from deepchem.feat.mol_graphs import ConvMol
import time


class GraphConvTensorGraph(TensorGraph):
  """
  """

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.min_degree = 0
    self.max_degree = 10

  def _construct_feed_dict(self, X_b, y_b, w_b, ids_b):
    feed_dict = dict()
    if y_b is not None:
      for index, label in enumerate(self.labels):
        feed_dict[label.out_tensor] = to_one_hot(y_b[:, index])
    if self.task_weights is not None and w_b is not None:
      feed_dict[self.task_weights[0].out_tensor] = w_b
    if self.features is not None:
      multiConvMol = ConvMol.agglomerate_mols(X_b)
      feed_dict[self.features[0].out_tensor] = multiConvMol.get_atom_features()
      feed_dict[self.features[1].out_tensor] = multiConvMol.deg_slice
      feed_dict[self.features[2].out_tensor] = multiConvMol.membership
      for i in range(self.max_degree):
        feed_dict[self.features[i + 3]
                  .out_tensor] = multiConvMol.get_deg_adjacency_lists()[i + 1]
    return feed_dict

  def fit(self,
          dataset,
          nb_epoch=10,
          max_checkpoints_to_keep=5,
          log_every_N_batches=50,
          checkpoint_interval=10):
    """
    TODO(LESWING) put this logic into tensor_graph or figure out how to use an input queue.
    Parameters
    ----------
    dataset
    nb_epoch
    max_checkpoints_to_keep
    log_every_N_batches
    checkpoint_interval

    Returns
    -------

    """
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      time1 = time.time()
      train_op = self._get_tf('train_op')
      saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      with tf.Session() as sess:
        self._initialize_weights(sess, saver)
        avg_loss, n_batches = 0.0, 0.0
        for epoch in range(nb_epoch):
          for ind, (X_b, y_b, w_b, ids_b) in enumerate(
              dataset.iterbatches(
                  self.batch_size, deterministic=True, pad_batches=True)):
            feed_dict = self._construct_feed_dict(X_b, y_b, w_b, ids_b)
            output_tensors = [x.out_tensor for x in self.outputs]
            fetches = output_tensors + [train_op, self.loss.out_tensor]
            fetched_values = sess.run(fetches, feed_dict=feed_dict)
            loss = fetched_values[-1]
            avg_loss += loss
            n_batches += 1
            self.global_step += 1
          if epoch % checkpoint_interval == checkpoint_interval - 1:
            saver.save(sess, self.save_file, global_step=self.global_step)
            avg_loss = float(avg_loss) / n_batches
            print('Ending epoch %d: Average loss %g' % (epoch, avg_loss))
        saver.save(sess, self.save_file, global_step=self.global_step)
        self.last_checkpoint = saver.last_checkpoints[-1]
      ############################################################## TIMING
      time2 = time.time()
      print("TIMING: model fitting took %0.3f s" % (time2 - time1))
      ############################################################## TIMING
+32 −19
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ class TensorGraph(Model):
               tensorboard_log_frequency=100,
               learning_rate=0.001,
               batch_size=100,
               use_queue=True,
               mode="classification",
               **kwargs):
    """
@@ -61,6 +62,7 @@ class TensorGraph(Model):
    self.global_step = 0
    self.last_checkpoint = None
    self.input_queue = None
    self.use_queue = use_queue

    self.learning_rate = learning_rate
    self.batch_size = batch_size
@@ -174,25 +176,28 @@ class TensorGraph(Model):
    return retval

  def predict_proba_on_batch(self, X, sess=None):
    if not self.built:
      self.build()
    close_session = sess is None
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      if sess is None:
        sess = tf.Session()
        saver.restore(sess, self.last_checkpoint)

    def predict():
      out_tensors = [x.out_tensor for x in self.outputs]
      fetches = out_tensors
      feed_dict = self._construct_feed_dict(X, None, None, None)
      fetched_values = sess.run(fetches, feed_dict=feed_dict)
      retval = np.array(fetched_values)
      return np.array(fetched_values)

    if not self.built:
      self.build()
    if sess is None:
      saver = tf.train.Saver()
      with tf.Session() as sess:
        saver.restore(sess, self.last_checkpoint)
        with self._get_tf("Graph").as_default():
          retval = predict()
    else:
      retval = predict()
    if self.mode == 'classification':  # sample, task, class
      retval = np.transpose(retval, axes=[1, 0, 2])
    elif self.mode == 'regression':  # sample, task
      retval = np.transpose(retval, axes=[1, 0])
      if close_session:
        sess.close()
    return retval

  def predict(self, dataset, transformers=[], batch_size=None):
@@ -204,6 +209,8 @@ class TensorGraph(Model):
    """
    if not self.built:
      self.build()
    if batch_size is None:
      batch_size = self.batch_size
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      with tf.Session() as sess:
@@ -233,6 +240,8 @@ class TensorGraph(Model):
    """
    if not self.built:
      self.build()
    if batch_size is None:
      batch_size = self.batch_size
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      with tf.Session() as sess:
@@ -267,6 +276,7 @@ class TensorGraph(Model):
        with tf.name_scope(node):
          node_layer.__call__(*parents)
      self.built = True
      if self.use_queue:
        self.input_queue.out_tensors = None

    for layer in self.layers.values():
@@ -282,7 +292,9 @@ class TensorGraph(Model):
      writer.close()

  def _install_queue(self):
    if self.input_queue is not None:
    if not self.use_queue:
      for layer in self.features + self.labels + self.task_weights:
        layer.pre_queue = True
      return
    names = []
    shapes = []
@@ -366,7 +378,8 @@ class TensorGraph(Model):
      self.tensor_objects['FileWriter'] = tf.summary.FileWriter(self.model_dir)
    elif obj == 'train_op':
      self.tensor_objects['train_op'] = tf.train.AdamOptimizer(
          self.learning_rate).minimize(self.loss.out_tensor)
          self.learning_rate, beta1=.9,
          beta2=.999).minimize(self.loss.out_tensor)
    elif obj == 'summary_op':
      self.tensor_objects['summary_op'] = tf.summary.merge_all(
          key=tf.GraphKeys.SUMMARIES)
Loading