Commit 3f7c89a2 authored by Peter Eastman's avatar Peter Eastman
Browse files

Added queue to TensorflowMultiTaskFitTransformRegressor

parent ed7782f0
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -377,6 +377,7 @@ class TensorflowGraphModel(Model):
              avg_loss = float(avg_loss) / index_in_epoch
              log('Ending epoch %d: Average loss %g' % (epoch, avg_loss),
                  self.verbose)
              epoch += 1
              index_in_epoch = 0
              avg_loss = 0.0
              del epoch_end_indices[0]
+46 −20
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from __future__ import unicode_literals
import time
import numpy as np
import tensorflow as tf
import threading

import deepchem as dc
from deepchem.nn import model_ops
@@ -369,32 +370,57 @@ class TensorflowMultiTaskFitTransformRegressor(TensorflowMultiTaskRegressor):
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        # Save an initial checkpoint.
        saver.save(sess, self._save_path, global_step=0)

        # Define the code that runs on a separate thread to feed data into the queue.
        def enqueue(sess, dataset, nb_epoch, epoch_end_indices):
          index = 0
          for epoch in range(nb_epoch):
          avg_loss, n_batches = 0., 0
          for ind, (X_b, y_b, w_b, ids_b) in enumerate(
              dataset.iterbatches(
                  self.batch_size, pad_batches=self.pad_batches)):
            if ind % log_every_N_batches == 0:
              log("On batch %d" % ind, self.verbose)
            for X_b, y_b, w_b, ids_b in dataset.iterbatches(self.batch_size, pad_batches=self.pad_batches):
              for transformer in self.fit_transformers:
                X_b = transformer.X_transform(X_b)
            # Run training op.
              feed_dict = self.construct_feed_dict(X_b, y_b, w_b, ids_b)
              sess.run(self.train_graph.graph.enqueue, feed_dict=feed_dict)
              index += 1
            epoch_end_indices.append(index)
          sess.run(self.train_graph.graph.queue.close())

        epoch_end_indices = []
        enqueue_thread = threading.Thread(target=enqueue, args=[sess, dataset, nb_epoch, epoch_end_indices])
        enqueue_thread.daemon = True
        enqueue_thread.start()

        # Main training loop.
        try:
          epoch = 0
          index = 0
          index_in_epoch = 0
          avg_loss = 0.0
          while True:
            if index_in_epoch % log_every_N_batches == 0:
              log("On batch %d" % index_in_epoch, self.verbose)
            # Run training op.
            fetches = self.train_graph.output + [
                train_op, self.train_graph.loss
            ]
            fetched_values = sess.run(fetches, feed_dict=feed_dict)
            output = fetched_values[:len(self.train_graph.output)]
            fetched_values = sess.run(fetches)
            loss = fetched_values[-1]
            avg_loss += loss
            y_pred = np.squeeze(np.array(output))
            y_b = y_b.flatten()
            n_batches += 1
            index += 1
            index_in_epoch += 1
            if len(epoch_end_indices) > 0 and index >= epoch_end_indices[0]:
              # We have reached the end of an epoch.
              if epoch % checkpoint_interval == checkpoint_interval - 1:
                saver.save(sess, self._save_path, global_step=epoch)
          avg_loss = float(avg_loss) / n_batches
              avg_loss = float(avg_loss) / index_in_epoch
              log('Ending epoch %d: Average loss %g' % (epoch, avg_loss),
                  self.verbose)
              epoch += 1
              index_in_epoch = 0
              avg_loss = 0.0
              del epoch_end_indices[0]
        except tf.errors.OutOfRangeError:
          # We have reached the end of the data.
          pass
        # Always save a final checkpoint when complete.
        saver.save(sess, self._save_path, global_step=epoch + 1)
    ############################################################## TIMING