Commit 44b2c473 authored by leswing's avatar leswing
Browse files

Better GraphConvs

parent 8a35de2e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
 No newline at end of file
from deepchem.models.tensorgraph.tensor_graph import TensorGraphfrom deepchem.models.tensorgraph import models
 No newline at end of file
+230 −42
Original line number Diff line number Diff line
@@ -3,9 +3,9 @@ import string

import tensorflow as tf

from deepchem.nn import model_ops, initializations

class Layer(object):

  def __init__(self, **kwargs):
    if "name" not in kwargs:
      self.name = "%s%s" % (self.__class__.__name__, self._random_name())
@@ -30,7 +30,6 @@ class Layer(object):


class Conv1DLayer(Layer):

  def __init__(self, width, out_channels, **kwargs):
    self.width = width
    self.out_channels = out_channels
@@ -55,7 +54,6 @@ class Conv1DLayer(Layer):


class Dense(Layer):

  def __init__(self, out_channels, activation_fn=None, **kwargs):
    self.out_channels = out_channels
    self.out_tensor = None
@@ -78,7 +76,6 @@ class Dense(Layer):


class Flatten(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

@@ -96,7 +93,6 @@ class Flatten(Layer):


class Reshape(Layer):

  def __init__(self, shape, **kwargs):
    self.shape = shape
    super().__init__(**kwargs)
@@ -107,7 +103,6 @@ class Reshape(Layer):


class CombineMeanStd(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

@@ -122,7 +117,6 @@ class CombineMeanStd(Layer):


class Repeat(Layer):

  def __init__(self, n_times, **kwargs):
    self.n_times = n_times
    super().__init__(**kwargs)
@@ -137,7 +131,6 @@ class Repeat(Layer):


class GRU(Layer):

  def __init__(self, n_hidden, out_channels, batch_size, **kwargs):
    self.n_hidden = n_hidden
    self.out_channels = out_channels
@@ -160,7 +153,6 @@ class GRU(Layer):


class TimeSeriesDense(Layer):

  def __init__(self, out_channels, **kwargs):
    self.out_channels = out_channels
    super().__init__(**kwargs)
@@ -175,10 +167,10 @@ class TimeSeriesDense(Layer):


class Input(Layer):

  def __init__(self, shape, pre_queue=False, **kwargs):
  def __init__(self, shape, dtype=tf.float32, pre_queue=False, **kwargs):
    self.shape = shape
    self.pre_queue = pre_queue
    self.dtype = dtype
    super().__init__(**kwargs)

  def __call__(self, *parents):
@@ -187,14 +179,14 @@ class Input(Layer):
      placeholder = queue.out_tensors[self.get_pre_q_name()]
      self.out_tensor = tf.placeholder_with_default(placeholder, self.shape)
      return self.out_tensor
    self.out_tensor = tf.placeholder(tf.float32, shape=self.shape)
    self.out_tensor = tf.placeholder(dtype=self.dtype, shape=self.shape)
    return self.out_tensor

  def create_pre_q(self, batch_size):
    if self.pre_queue:
      raise ValueError("Input is already pre_q")
    q_shape = (batch_size,) + self.shape[1:]
    return Input(shape=q_shape, name="%s_pre_q" % self.name, pre_queue=True)
    return Input(shape=q_shape, name="%s_pre_q" % self.name, dtype=self.dtype, pre_queue=True)

  def get_pre_q_name(self):
    if self.pre_queue:
@@ -203,7 +195,6 @@ class Input(Layer):


class LossLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

@@ -215,7 +206,6 @@ class LossLayer(Layer):


class SoftMax(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

@@ -228,7 +218,6 @@ class SoftMax(Layer):


class Concat(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

@@ -243,7 +232,6 @@ class Concat(Layer):


class SoftMaxCrossEntropy(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

@@ -258,7 +246,6 @@ class SoftMaxCrossEntropy(Layer):


class ReduceMean(Layer):

  def __call__(self, *parents):
    parent_tensor = parents[0].out_tensor
    self.out_tensor = tf.reduce_mean(parent_tensor)
@@ -266,7 +253,6 @@ class ReduceMean(Layer):


class Conv2d(Layer):

  def __init__(self, num_outputs, kernel_size=5, **kwargs):
    self.num_outputs = num_outputs
    self.kernel_size = kernel_size
@@ -285,7 +271,6 @@ class Conv2d(Layer):


class MaxPool(Layer):

  def __init__(self,
               ksize=[1, 2, 2, 1],
               strides=[1, 2, 2, 1],
@@ -335,3 +320,206 @@ class InputFifoQueue(Layer):

  def close(self):
    self.queue.close()


class GraphConvLayer(Layer):
  def __init__(self, out_channel, min_deg=0, max_deg=10, activation_fn=None, **kwargs):
    self.out_channel = out_channel
    self.min_degree = min_deg
    self.max_degree = max_deg
    self.num_deg = 2 * max_deg + (1 - min_deg)
    self.activation_fn = activation_fn
    super().__init__(**kwargs)

  def __call__(self, *parents):
    #   parents = [atom_features, deg_slice, membership, deg_adj_list placeholders...]
    in_channels = parents[0].out_tensor.get_shape()[-1].value

    # Generate the nb_affine weights and biases
    self.W_list = [
      initializations.glorot_uniform([in_channels, self.out_channel])
      for k in range(self.num_deg)
    ]
    self.b_list = [
      model_ops.zeros(shape=[
        self.out_channel,
      ]) for k in range(self.num_deg)
    ]

    # Extract atom_features
    atom_features = parents[0].out_tensor

    # Extract graph topology
    deg_slice = parents[1].out_tensor
    deg_adj_lists = [x.out_tensor for x in parents[3:]]

    # Perform the mol conv
    # atom_features = graph_conv(atom_features, deg_adj_lists, deg_slice,
    #                            self.max_deg, self.min_deg, self.W_list,
    #                            self.b_list)

    W = iter(self.W_list)
    b = iter(self.b_list)

    # Sum all neighbors using adjacency matrix
    deg_summed = self.sum_neigh(atom_features, deg_adj_lists)

    # Get collection of modified atom features
    new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]

    for deg in range(1, self.max_degree + 1):
      # Obtain relevant atoms for this degree
      rel_atoms = deg_summed[deg - 1]

      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)

      # Apply hidden affine to relevant atoms and append
      rel_out = tf.matmul(rel_atoms, next(W)) + next(b)
      self_out = tf.matmul(self_atoms, next(W)) + next(b)
      out = rel_out + self_out

      new_rel_atoms_collection[deg - self.min_degree] = out

    # Determine the min_deg=0 case
    if self.min_degree == 0:
      deg = 0

      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)

      # Only use the self layer
      out = tf.matmul(self_atoms, next(W)) + next(b)

      new_rel_atoms_collection[deg - self.min_degree] = out

    # Combine all atoms back into the list
    atom_features = tf.concat(axis=0, values=new_rel_atoms_collection)

    if self.activation_fn is not None:
      atom_features = self.activation_fn(atom_features)

    self.out_tensor = atom_features
    return atom_features

  def sum_neigh(self, atoms, deg_adj_lists):
    """Store the summed atoms by degree"""
    deg_summed = self.max_degree * [None]

    # Tensorflow correctly processes empty lists when using concat
    for deg in range(1, self.max_degree + 1):
      gathered_atoms = tf.gather(atoms, deg_adj_lists[deg - 1])
      # Sum along neighbors as well as self, and store
      summed_atoms = tf.reduce_sum(gathered_atoms, 1)
      deg_summed[deg - 1] = summed_atoms

    return deg_summed

  def none_tensors(self):
    out_tensor, W_list, b_list = self.out_tensor, self.W_list, self.b_list
    self.out_tensor, self.W_list, self.b_list = None, None, None
    return out_tensor, W_list, b_list

  def set_tensors(self, tensors):
    self.out_tensor, self.W_list, self.b_list = tensors


class GraphPoolLayer(Layer):
  def __init__(self, min_degree=0, max_degree=10, **kwargs):
    self.min_degree = min_degree
    self.max_degree = max_degree
    super().__init__(**kwargs)

  def __call__(self, *parents):
    atom_features = parents[0].out_tensor
    deg_slice = parents[1].out_tensor
    deg_adj_lists = [x.out_tensor for x in parents[3:]]

    # Perform the mol gather
    # atom_features = graph_pool(atom_features, deg_adj_lists, deg_slice,
    #                            self.max_degree, self.min_degree)

    deg_maxed = (self.max_degree + 1 - self.min_degree) * [None]

    # Tensorflow correctly processes empty lists when using concat

    for deg in range(1, self.max_degree + 1):
      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)

      # Expand dims
      self_atoms = tf.expand_dims(self_atoms, 1)

      # always deg-1 for deg_adj_lists
      gathered_atoms = tf.gather(atom_features, deg_adj_lists[deg - 1])
      gathered_atoms = tf.concat(axis=1, values=[self_atoms, gathered_atoms])

      maxed_atoms = tf.reduce_max(gathered_atoms, 1)
      deg_maxed[deg - self.min_degree] = maxed_atoms

    if self.min_degree == 0:
      begin = tf.stack([deg_slice[0, 0], 0])
      size = tf.stack([deg_slice[0, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      deg_maxed[0] = self_atoms

    self.out_tensor = tf.concat(axis=0, values=deg_maxed)
    return self.out_tensor


class GraphGather(Layer):
  def __init__(self, batch_size, activation_fn=None, **kwargs):
    self.batch_size = batch_size
    self.activation_fn = activation_fn
    super().__init__(**kwargs)

  def __call__(self, *parents):
    # x = [atom_features, deg_slice, membership, deg_adj_list placeholders...]
    atom_features = parents[0].out_tensor

    # Extract graph topology
    membership = parents[2].out_tensor

    # Perform the mol gather

    assert (self.batch_size > 1, "graph_gather requires batches larger than 1")

    # Obtain the partitions for each of the molecules
    activated_par = tf.dynamic_partition(atom_features, membership, self.batch_size)

    # Sum over atoms for each molecule
    sparse_reps = [
      tf.reduce_sum(activated, 0, keep_dims=True) for activated in activated_par
    ]
    max_reps = [
      tf.reduce_max(activated, 0, keep_dims=True) for activated in activated_par
    ]

    # Get the final sparse representations
    sparse_reps = tf.concat(axis=0, values=sparse_reps)
    max_reps = tf.concat(axis=0, values=max_reps)
    mol_features = tf.concat(axis=1, values=[sparse_reps, max_reps])

    if self.activation_fn is not None:
      mol_features = self.activation_fn(mol_features)
    self.out_tensor = mol_features
    return mol_features


class BatchNormLayer(Layer):
  def __call__(self, *parents):
    parent_tensor = parents[0].out_tensor
    self.out_tensor = tf.layers.batch_normalization(parent_tensor)
    return self.out_tensor


class WeightedError(Layer):
  def __call__(self, *parents):
    entropy, weights = parents[0], parents[1]
    self.out_tensor = tf.reduce_sum(entropy.out_tensor * weights.out_tensor)
    return self.out_tensor
+1 −0
Original line number Diff line number Diff line
from deepchem.models.tensorgraph.models.graph_conv import graph_conv_model
 No newline at end of file
+157 −0
Original line number Diff line number Diff line
import tensorflow as tf
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.models.tensorgraph.layers import Input, Dense, Concat, SoftMax, SoftMaxCrossEntropy, Layer, \
  GraphConvLayer, BatchNormLayer, GraphPoolLayer, GraphGather, WeightedError
from deepchem.metrics import to_one_hot
from deepchem.feat.mol_graphs import ConvMol
import time



class GraphConvTensorGraph(TensorGraph):
  """
  """
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.min_degree=0
    self.max_degree=10

  def _construct_feed_dict(self, X_b, y_b, w_b, ids_b):
    feed_dict = dict()
    if y_b is not None:
      for index, label in enumerate(self.labels):
        feed_dict[label.out_tensor] = to_one_hot(y_b[:, index])
    if self.task_weights is not None and w_b is not None:
      feed_dict[self.task_weights[0].out_tensor] = w_b
    if self.features is not None:
      multiConvMol = ConvMol.agglomerate_mols(X_b)
      feed_dict[self.features[0].out_tensor] = multiConvMol.get_atom_features()
      feed_dict[self.features[1].out_tensor] = multiConvMol.deg_slice
      feed_dict[self.features[2].out_tensor] = multiConvMol.membership
      for i in range(self.max_degree):
        feed_dict[self.features[i + 3].out_tensor] = multiConvMol.get_deg_adjacency_lists()[i + 1]
    return feed_dict

  def fit(self,
          dataset,
          nb_epoch=10,
          max_checkpoints_to_keep=5,
          log_every_N_batches=50,
          checkpoint_interval=10):
    """
    TODO(LESWING) put this logic into tensor_graph or figure out how to use an input queue.
    Parameters
    ----------
    dataset
    nb_epoch
    max_checkpoints_to_keep
    log_every_N_batches
    checkpoint_interval

    Returns
    -------

    """
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      time1 = time.time()
      train_op = self._get_tf('train_op')
      saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      with tf.Session() as sess:
        self._initialize_weights(sess, saver)
        avg_loss, n_batches = 0.0, 0.0
        for epoch in range(nb_epoch):
          for ind, (X_b, y_b, w_b, ids_b) in enumerate(
            dataset.iterbatches(self.batch_size, deterministic=True, pad_batches=True)):
            feed_dict = self._construct_feed_dict(X_b, y_b, w_b, ids_b)
            output_tensors = [x.out_tensor for x in self.outputs]
            fetches = output_tensors + [train_op, self.loss.out_tensor]
            fetched_values = sess.run(fetches, feed_dict=feed_dict)
            loss = fetched_values[-1]
            avg_loss += loss
            n_batches += 1
            self.global_step += 1
          if epoch % checkpoint_interval == checkpoint_interval - 1:
            saver.save(sess, self.save_file, global_step=self.global_step)
            avg_loss = float(avg_loss) / n_batches
            print('Ending epoch %d: Average loss %g' % (epoch, avg_loss))
        saver.save(sess, self.save_file, global_step=self.global_step)
        self.last_checkpoint = saver.last_checkpoints[-1]
      ############################################################## TIMING
      time2 = time.time()
      print("TIMING: model fitting took %0.3f s" % (time2 - time1))
      ############################################################## TIMING


def graph_conv_model(batch_size, num_tasks):
  model = GraphConvTensorGraph(batch_size=batch_size,
                               use_queue=False)
  atom_features = Input(shape=(None, 75))
  model.add_feature(atom_features)

  degree_slice = Input(shape=(None, 2), dtype=tf.int32)
  model.add_feature(degree_slice)

  membership = Input(shape=(None,), dtype=tf.int32)
  model.add_feature(membership)

  deg_adjs = []
  for i in range(model.min_degree, model.max_degree + 1):
    deg_adj = Input(shape=(None, i + 1), dtype=tf.int32)
    model.add_feature(deg_adj)
    deg_adjs.append(deg_adj)

  gc1 = GraphConvLayer(64, activation_fn=tf.nn.relu)
  model.add_layer(gc1, parents=[atom_features, degree_slice, membership] + deg_adjs)

  batch_norm1 = BatchNormLayer()
  model.add_layer(batch_norm1, parents=[gc1])

  gp1 = GraphPoolLayer()
  model.add_layer(gp1, parents=[batch_norm1, degree_slice, membership] + deg_adjs)

  gc2 = GraphConvLayer(64, activation_fn=tf.nn.relu)
  model.add_layer(gc2, parents=[gp1, degree_slice, membership] + deg_adjs)

  batch_norm2 = BatchNormLayer()
  model.add_layer(batch_norm2, parents=[gc2])

  gp2 = GraphPoolLayer()
  model.add_layer(gp2, parents=[batch_norm2, degree_slice, membership] + deg_adjs)

  dense = Dense(out_channels=128, activation_fn=None)
  model.add_layer(dense, parents=[gp2])

  batch_norm3 = BatchNormLayer()
  model.add_layer(batch_norm3, parents=[dense])

  gg1 = GraphGather(batch_size=batch_size, activation_fn=tf.nn.tanh)
  model.add_layer(gg1, parents=[batch_norm3, degree_slice, membership] + deg_adjs)

  costs = []
  for task in range(num_tasks):
    classification = Dense(out_channels=2, name="GUESS%s" % task, activation_fn=None)
    model.add_layer(classification, parents=[gg1])

    softmax = SoftMax(name="SOFTMAX%s" % task)
    model.add_layer(softmax, parents=[classification])
    model.add_output(softmax)

    label = Input(shape=(None, 2), name="LABEL%s" % task)
    model.add_label(label)

    cost = SoftMaxCrossEntropy(name="COST%s" % task)
    model.add_layer(cost, parents=[label, classification])
    costs.append(cost)

  entropy = Concat(name="ENT")
  model.add_layer(entropy, parents=costs)

  task_weights = Input(shape=(None, num_tasks), name="W")
  model.add_task_weight(task_weights)

  loss = WeightedError(name="ERROR")
  model.add_layer(loss, parents=[entropy, task_weights])
  model.set_loss(loss)
  return model
+30 −20
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ class TensorGraph(Model):
               tensorboard_log_frequency=100,
               learning_rate=0.001,
               batch_size=100,
               use_queue=True,
               mode="classification",
               **kwargs):
    """
@@ -61,6 +62,7 @@ class TensorGraph(Model):
    self.global_step = 0
    self.last_checkpoint = None
    self.input_queue = None
    self.use_queue = use_queue

    self.learning_rate = learning_rate
    self.batch_size = batch_size
@@ -174,25 +176,27 @@ class TensorGraph(Model):
    return retval

  def predict_proba_on_batch(self, X, sess=None):
    if not self.built:
      self.build()
    close_session = sess is None
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      if sess is None:
        sess = tf.Session()
        saver.restore(sess, self.last_checkpoint)
    def predict():
      out_tensors = [x.out_tensor for x in self.outputs]
      fetches = out_tensors
      feed_dict = self._construct_feed_dict(X, None, None, None)
      fetched_values = sess.run(fetches, feed_dict=feed_dict)
      retval = np.array(fetched_values)
      return np.array(fetched_values)

    if not self.built:
      self.build()
    if sess is None:
      saver = tf.train.Saver()
      with tf.Session() as sess:
        saver.restore(sess, self.last_checkpoint)
        with self._get_tf("Graph").as_default():
          retval = predict()
    else:
      retval = predict()
    if self.mode == 'classification':  # sample, task, class
      retval = np.transpose(retval, axes=[1, 0, 2])
    elif self.mode == 'regression':  # sample, task
      retval = np.transpose(retval, axes=[1, 0])
      if close_session:
        sess.close()
    return retval

  def predict(self, dataset, transformers=[], batch_size=None):
@@ -204,6 +208,8 @@ class TensorGraph(Model):
    """
    if not self.built:
      self.build()
    if batch_size is None:
      batch_size = self.batch_size
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      with tf.Session() as sess:
@@ -233,6 +239,8 @@ class TensorGraph(Model):
    """
    if not self.built:
      self.build()
    if batch_size is None:
      batch_size = self.batch_size
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      with tf.Session() as sess:
@@ -267,6 +275,7 @@ class TensorGraph(Model):
        with tf.name_scope(node):
          node_layer.__call__(*parents)
      self.built = True
      if self.use_queue:
        self.input_queue.out_tensors = None

    for layer in self.layers.values():
@@ -282,8 +291,9 @@ class TensorGraph(Model):
      writer.close()

  def _install_queue(self):
    if self.input_queue is not None:
      return
    if not self.use_queue or self.input_queue is not None:
      for feature in self.features:
        feature.pre_queue = True
    names = []
    shapes = []
    pre_q_inputs = []
@@ -366,7 +376,7 @@ class TensorGraph(Model):
      self.tensor_objects['FileWriter'] = tf.summary.FileWriter(self.model_dir)
    elif obj == 'train_op':
      self.tensor_objects['train_op'] = tf.train.AdamOptimizer(
          self.learning_rate).minimize(self.loss.out_tensor)
          self.learning_rate,beta1=.9, beta2=.999).minimize(self.loss.out_tensor)
    elif obj == 'summary_op':
      self.tensor_objects['summary_op'] = tf.summary.merge_all(
          key=tf.GraphKeys.SUMMARIES)
Loading