Commit 20862551 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #488 from peastman/queue

Use queue to improve efficiency during training
parents d9869f00 0c1884fc
Loading
Loading
Loading
Loading
+24 −5
Original line number Diff line number Diff line
@@ -95,20 +95,39 @@ class TensorflowMultiTaskIRVClassifier(TensorflowLogisticRegression):
    with graph.as_default():
      output = []
      with placeholder_scope:
        self.features = tf.placeholder(
        mol_features = tf.placeholder(
            tf.float32, shape=[None, self.n_features], name='mol_features')
      with tf.name_scope('variable'):
        V = tf.Variable(tf.constant([0.01, 1.]), name="vote", dtype=tf.float32)
        W = tf.Variable(tf.constant([1., 1.]), name="w", dtype=tf.float32)
        b = tf.Variable(tf.constant([0.01]), name="b", dtype=tf.float32)
        b2 = tf.Variable(tf.constant([0.01]), name="b2", dtype=tf.float32)

      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
        features = queue_outputs[0]
      else:
        labels = label_placeholders
        weights = weight_placeholders
        features = mol_features

      for count in range(self.n_tasks):
        similarity = self.features[:, 2 * K * count:(2 * K * count + K)]
        ys = tf.to_int32(
            self.features[:, (2 * K * count + K):2 * K * (count + 1)])
        similarity = features[:, 2 * K * count:(2 * K * count + K)]
        ys = tf.to_int32(features[:, (2 * K * count + K):2 * K * (count + 1)])
        R = b + W[0] * similarity + W[1] * tf.constant(
            np.arange(K) + 1, dtype=tf.float32)
        R = tf.sigmoid(R)
        z = tf.reduce_sum(R * tf.gather(V, ys), axis=1) + b2
        output.append(tf.reshape(z, shape=[-1, 1]))
    return output
    return (output, labels, weights)
+47 −24
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import numpy as np
import pandas as pd
import tensorflow as tf
import tempfile
import threading
from deepchem.models import Model
from deepchem.metrics import from_one_hot
from deepchem.nn import model_ops
@@ -239,9 +240,7 @@ class TensorflowGraphModel(Model):
    with graph.as_default():
      if seed is not None:
        tf.set_random_seed(seed)
      output = self.build(graph, name_scopes, training)
      labels = self.add_label_placeholders(graph, name_scopes)
      weights = self.add_example_weight_placeholders(graph, name_scopes)
      (output, labels, weights) = self.build(graph, name_scopes, training)

    if training:
      loss = self.add_training_cost(graph, name_scopes, output, labels, weights)
@@ -335,33 +334,57 @@ class TensorflowGraphModel(Model):
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        # Save an initial checkpoint.
        saver.save(sess, self._save_path, global_step=0)

        # Define the code that runs on a separate thread to feed data into the queue.
        def enqueue(sess, dataset, nb_epoch, epoch_end_indices):
          index = 0
          for epoch in range(nb_epoch):
          avg_loss, n_batches = 0., 0
          for ind, (X_b, y_b, w_b, ids_b) in enumerate(
              # Turns out there are valid cases where we don't want pad-batches
              # on by default.
              #dataset.iterbatches(batch_size, pad_batches=True)):
              dataset.iterbatches(
                  self.batch_size, pad_batches=self.pad_batches)):
            if ind % log_every_N_batches == 0:
              log("On batch %d" % ind, self.verbose)
            # Run training op.
            for X_b, y_b, w_b, ids_b in dataset.iterbatches(
                self.batch_size, pad_batches=self.pad_batches):
              feed_dict = self.construct_feed_dict(X_b, y_b, w_b, ids_b)
              sess.run(self.train_graph.graph.enqueue, feed_dict=feed_dict)
              index += 1
            epoch_end_indices.append(index)
          sess.run(self.train_graph.graph.queue.close())

        epoch_end_indices = []
        enqueue_thread = threading.Thread(
            target=enqueue, args=[sess, dataset, nb_epoch, epoch_end_indices])
        enqueue_thread.daemon = True
        enqueue_thread.start()

        # Main training loop.
        try:
          epoch = 0
          index = 0
          index_in_epoch = 0
          avg_loss = 0.0
          while True:
            if index_in_epoch % log_every_N_batches == 0:
              log("On batch %d" % index_in_epoch, self.verbose)
            # Run training op.
            fetches = self.train_graph.output + [
                train_op, self.train_graph.loss
            ]
            fetched_values = sess.run(fetches, feed_dict=feed_dict)
            output = fetched_values[:len(self.train_graph.output)]
            fetched_values = sess.run(fetches)
            loss = fetched_values[-1]
            avg_loss += loss
            y_pred = np.squeeze(np.array(output))
            y_b = y_b.flatten()
            n_batches += 1
            index += 1
            index_in_epoch += 1
            if len(epoch_end_indices) > 0 and index >= epoch_end_indices[0]:
              # We have reached the end of an epoch.
              if epoch % checkpoint_interval == checkpoint_interval - 1:
                saver.save(sess, self._save_path, global_step=epoch)
          avg_loss = float(avg_loss) / n_batches
              avg_loss = float(avg_loss) / index_in_epoch
              log('Ending epoch %d: Average loss %g' % (epoch, avg_loss),
                  self.verbose)
              epoch += 1
              index_in_epoch = 0
              avg_loss = 0.0
              del epoch_end_indices[0]
        except tf.errors.OutOfRangeError:
          # We have reached the end of the data.
          pass
        # Always save a final checkpoint when complete.
        saver.save(sess, self._save_path, global_step=epoch + 1)
    ############################################################## TIMING
+90 −26
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from __future__ import unicode_literals
import time
import numpy as np
import tensorflow as tf
import threading

import deepchem as dc
from deepchem.nn import model_ops
@@ -33,7 +34,7 @@ class TensorflowMultiTaskClassifier(TensorflowClassifier):
    n_features = self.n_features
    with graph.as_default():
      with placeholder_scope:
        self.mol_features = tf.placeholder(
        mol_features = tf.placeholder(
            tf.float32, shape=[None, n_features], name='mol_features')

      layer_sizes = self.layer_sizes
@@ -50,7 +51,25 @@ class TensorflowMultiTaskClassifier(TensorflowClassifier):
      n_layers = lengths_set.pop()
      assert n_layers > 0, 'Must have some layers defined.'

      prev_layer = self.mol_features
      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
        prev_layer = queue_outputs[0]
      else:
        labels = label_placeholders
        weights = weight_placeholders
        prev_layer = mol_features

      prev_layer_size = n_features
      for i in range(n_layers):
        layer = tf.nn.relu(
@@ -67,7 +86,7 @@ class TensorflowMultiTaskClassifier(TensorflowClassifier):
        prev_layer_size = layer_sizes[i]

      output = model_ops.multitask_logits(layer, self.n_tasks)
    return output
    return (output, labels, weights)

  def construct_feed_dict(self, X_b, y_b=None, w_b=None, ids_b=None):
    """Construct a feed dictionary from minibatch data.
@@ -112,7 +131,7 @@ class TensorflowMultiTaskRegressor(TensorflowRegressor):
                                                              name_scopes)
    with graph.as_default():
      with placeholder_scope:
        self.mol_features = tf.placeholder(
        mol_features = tf.placeholder(
            tf.float32, shape=[None, n_features], name='mol_features')

      layer_sizes = self.layer_sizes
@@ -129,7 +148,25 @@ class TensorflowMultiTaskRegressor(TensorflowRegressor):
      n_layers = lengths_set.pop()
      assert n_layers > 0, 'Must have some layers defined.'

      prev_layer = self.mol_features
      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
        prev_layer = queue_outputs[0]
      else:
        labels = label_placeholders
        weights = weight_placeholders
        prev_layer = mol_features

      prev_layer_size = n_features
      for i in range(n_layers):
        layer = tf.nn.relu(
@@ -157,7 +194,7 @@ class TensorflowMultiTaskRegressor(TensorflowRegressor):
                        stddev=weight_init_stddevs[i]),
                    bias_init=tf.constant(value=bias_init_consts[i], shape=[1
                                                                           ]))))
      return output
    return (output, labels, weights)

  def construct_feed_dict(self, X_b, y_b=None, w_b=None, ids_b=None):
    """Construct a feed dictionary from minibatch data.
@@ -343,32 +380,59 @@ class TensorflowMultiTaskFitTransformRegressor(TensorflowMultiTaskRegressor):
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        # Save an initial checkpoint.
        saver.save(sess, self._save_path, global_step=0)

        # Define the code that runs on a separate thread to feed data into the queue.
        def enqueue(sess, dataset, nb_epoch, epoch_end_indices):
          index = 0
          for epoch in range(nb_epoch):
          avg_loss, n_batches = 0., 0
          for ind, (X_b, y_b, w_b, ids_b) in enumerate(
              dataset.iterbatches(
                  self.batch_size, pad_batches=self.pad_batches)):
            if ind % log_every_N_batches == 0:
              log("On batch %d" % ind, self.verbose)
            for X_b, y_b, w_b, ids_b in dataset.iterbatches(
                self.batch_size, pad_batches=self.pad_batches):
              for transformer in self.fit_transformers:
                X_b = transformer.X_transform(X_b)
            # Run training op.
              feed_dict = self.construct_feed_dict(X_b, y_b, w_b, ids_b)
              sess.run(self.train_graph.graph.enqueue, feed_dict=feed_dict)
              index += 1
            epoch_end_indices.append(index)
          sess.run(self.train_graph.graph.queue.close())

        epoch_end_indices = []
        enqueue_thread = threading.Thread(
            target=enqueue, args=[sess, dataset, nb_epoch, epoch_end_indices])
        enqueue_thread.daemon = True
        enqueue_thread.start()

        # Main training loop.
        try:
          epoch = 0
          index = 0
          index_in_epoch = 0
          avg_loss = 0.0
          while True:
            if index_in_epoch % log_every_N_batches == 0:
              log("On batch %d" % index_in_epoch, self.verbose)
            # Run training op.
            fetches = self.train_graph.output + [
                train_op, self.train_graph.loss
            ]
            fetched_values = sess.run(fetches, feed_dict=feed_dict)
            output = fetched_values[:len(self.train_graph.output)]
            fetched_values = sess.run(fetches)
            loss = fetched_values[-1]
            avg_loss += loss
            y_pred = np.squeeze(np.array(output))
            y_b = y_b.flatten()
            n_batches += 1
            index += 1
            index_in_epoch += 1
            if len(epoch_end_indices) > 0 and index >= epoch_end_indices[0]:
              # We have reached the end of an epoch.
              if epoch % checkpoint_interval == checkpoint_interval - 1:
                saver.save(sess, self._save_path, global_step=epoch)
          avg_loss = float(avg_loss) / n_batches
              avg_loss = float(avg_loss) / index_in_epoch
              log('Ending epoch %d: Average loss %g' % (epoch, avg_loss),
                  self.verbose)
              epoch += 1
              index_in_epoch = 0
              avg_loss = 0.0
              del epoch_end_indices[0]
        except tf.errors.OutOfRangeError:
          # We have reached the end of the data.
          pass
        # Always save a final checkpoint when complete.
        saver.save(sess, self._save_path, global_step=epoch + 1)
    ############################################################## TIMING
+23 −3
Original line number Diff line number Diff line
@@ -54,22 +54,42 @@ class TensorflowLogisticRegression(TensorflowGraphModel):
    n_features = self.n_features
    with graph.as_default():
      with placeholder_scope:
        self.mol_features = tf.placeholder(
        mol_features = tf.placeholder(
            tf.float32, shape=[None, n_features], name='mol_features')

      weight_init_stddevs = self.weight_init_stddevs
      bias_init_consts = self.bias_init_consts
      lg_list = []

      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
        prev_layer = queue_outputs[0]
      else:
        labels = label_placeholders
        weights = weight_placeholders
        prev_layer = mol_features

      for task in range(self.n_tasks):
        #setting up n_tasks nodes(output nodes)
        lg = model_ops.fully_connected_layer(
            tensor=self.mol_features,
            tensor=prev_layer,
            size=1,
            weight_init=tf.truncated_normal(
                shape=[self.n_features, 1], stddev=weight_init_stddevs[0]),
            bias_init=tf.constant(value=bias_init_consts[0], shape=[1]))
        lg_list.append(lg)
    return lg_list
    return (lg_list, labels, weights)

  def add_label_placeholders(self, graph, name_scopes):
    #label placeholders with size batch_size * 1
+46 −8
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ class RobustMultitaskClassifier(TensorflowMultiTaskClassifier):
                                                              name_scopes)
    with graph.as_default():
      with placeholder_scope:
        self.mol_features = tf.placeholder(
        mol_features = tf.placeholder(
            tf.float32, shape=[None, num_features], name='mol_features')

      layer_sizes = self.layer_sizes
@@ -79,7 +79,26 @@ class RobustMultitaskClassifier(TensorflowMultiTaskClassifier):
          "All bypass_layer params" + " must have same length.")
      num_bypass_layers = bypass_lengths_set.pop()

      prev_layer = self.mol_features
      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
        prev_layer = queue_outputs[0]
      else:
        labels = label_placeholders
        weights = weight_placeholders
        prev_layer = mol_features

      top_layer = prev_layer
      prev_layer_size = num_features
      for i in range(num_layers):
        # layer has shape [None, layer_sizes[i]]
@@ -105,7 +124,7 @@ class RobustMultitaskClassifier(TensorflowMultiTaskClassifier):
        # TODO(rbharath): Might want to make it feasible to have multiple
        # bypass layers.
        # Construct task bypass layer
        prev_bypass_layer = self.mol_features
        prev_bypass_layer = top_layer
        prev_bypass_layer_size = num_features
        for i in range(num_bypass_layers):
          # bypass_layer has shape [None, bypass_layer_sizes[i]]
@@ -147,7 +166,7 @@ class RobustMultitaskClassifier(TensorflowMultiTaskClassifier):
                        stddev=weight_init_stddevs[-1]),
                    bias_init=tf.constant(
                        value=bias_init_consts[-1], shape=[2]))))
      return output
      return (output, labels, weights)


class RobustMultitaskRegressor(TensorflowMultiTaskRegressor):
@@ -185,7 +204,7 @@ class RobustMultitaskRegressor(TensorflowMultiTaskRegressor):
                                                              name_scopes)
    with graph.as_default():
      with placeholder_scope:
        self.mol_features = tf.placeholder(
        mol_features = tf.placeholder(
            tf.float32, shape=[None, num_features], name='mol_features')

      layer_sizes = self.layer_sizes
@@ -218,7 +237,26 @@ class RobustMultitaskRegressor(TensorflowMultiTaskRegressor):
          "All bypass_layer params" + " must have same length.")
      num_bypass_layers = bypass_lengths_set.pop()

      prev_layer = self.mol_features
      label_placeholders = self.add_label_placeholders(graph, name_scopes)
      weight_placeholders = self.add_example_weight_placeholders(graph,
                                                                 name_scopes)
      if training:
        graph.queue = tf.FIFOQueue(
            capacity=5,
            dtypes=[tf.float32] *
            (len(label_placeholders) + len(weight_placeholders) + 1))
        graph.enqueue = graph.queue.enqueue([mol_features] + label_placeholders
                                            + weight_placeholders)
        queue_outputs = graph.queue.dequeue()
        labels = queue_outputs[1:len(label_placeholders) + 1]
        weights = queue_outputs[len(label_placeholders) + 1:]
        prev_layer = queue_outputs[0]
      else:
        labels = label_placeholders
        weights = weight_placeholders
        prev_layer = mol_features

      top_layer = prev_layer
      prev_layer_size = num_features
      for i in range(num_layers):
        # layer has shape [None, layer_sizes[i]]
@@ -244,7 +282,7 @@ class RobustMultitaskRegressor(TensorflowMultiTaskRegressor):
        # TODO(rbharath): Might want to make it feasible to have multiple
        # bypass layers.
        # Construct task bypass layer
        prev_bypass_layer = self.mol_features
        prev_bypass_layer = top_layer
        prev_bypass_layer_size = num_features
        for i in range(num_bypass_layers):
          # bypass_layer has shape [None, bypass_layer_sizes[i]]
@@ -286,4 +324,4 @@ class RobustMultitaskRegressor(TensorflowMultiTaskRegressor):
                        stddev=weight_init_stddevs[-1]),
                    bias_init=tf.constant(
                        value=bias_init_consts[-1], shape=[1]))))
      return output
      return (output, labels, weights)
Loading