Commit a1ad561e authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #791 from peastman/session

TensorGraph uses a persistent Session
parents 3b801508 3e3438cc
Loading
Loading
Loading
Loading
+32 −38
Original line number Diff line number Diff line
@@ -605,9 +605,6 @@ class GraphConvTensorGraph(TensorGraph):
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      with tf.Session() as sess:
        saver = tf.train.Saver()
        saver.restore(sess, self.last_checkpoint)
      out_tensors = [x.out_tensor for x in self.outputs]
      results = []
      for feed_dict in generator:
@@ -616,7 +613,7 @@ class GraphConvTensorGraph(TensorGraph):
            for k, v in six.iteritems(feed_dict)
        }
        feed_dict[self._training_placeholder] = 1.0  ##
          result = np.array(sess.run(out_tensors, feed_dict=feed_dict))
        result = np.array(self.session.run(out_tensors, feed_dict=feed_dict))
        if len(result.shape) == 3:
          result = np.transpose(result, axes=[1, 0, 2])
        if len(transformers) > 0:
@@ -868,9 +865,6 @@ class MPNNTensorGraph(TensorGraph):
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      with tf.Session() as sess:
        saver = tf.train.Saver()
        self._initialize_weights(sess, saver)
      out_tensors = [x.out_tensor for x in self.outputs]
      results = []
      for feed_dict in generator:
@@ -881,7 +875,7 @@ class MPNNTensorGraph(TensorGraph):
            for k, v in six.iteritems(feed_dict)
        }
        feed_dict[self._training_placeholder] = 0.0
          result = np.array(sess.run(out_tensors, feed_dict=feed_dict))
        result = np.array(self.session.run(out_tensors, feed_dict=feed_dict))
        if len(result.shape) == 3:
          result = np.transpose(result, axes=[1, 0, 2])
        result = undo_transforms(result, transformers)
+124 −94
Original line number Diff line number Diff line
@@ -81,7 +81,6 @@ class TensorGraph(Model):
    self.tensorboard_log_frequency = tensorboard_log_frequency
    self.tensorboard_step = 0
    self.global_step = 0
    self.last_checkpoint = None
    self.use_queue = use_queue

    self.batch_size = batch_size
@@ -116,16 +115,52 @@ class TensorGraph(Model):
          nb_epoch=10,
          max_checkpoints_to_keep=5,
          checkpoint_interval=1000,
          deterministic=False):
          deterministic=False,
          restore=False):
    """Train this model on a dataset.

    Parameters
    ----------
    dataset: Dataset
      the Dataset to train on
    nb_epoch: int
      the number of epochs to train for
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    checkpoint_interval: int
      the frequency at which to write checkpoints, measured in training steps.
    deterministic: bool
      if True, the samples are processed in order.  If False, a different random
      order is used for each epoch.
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    """
    return self.fit_generator(
        self.default_generator(
            dataset, epochs=nb_epoch, deterministic=deterministic),
        max_checkpoints_to_keep, checkpoint_interval)
        max_checkpoints_to_keep, checkpoint_interval, restore)

  def fit_generator(self,
                    feed_dict_generator,
                    max_checkpoints_to_keep=5,
                    checkpoint_interval=1000):
                    checkpoint_interval=1000,
                    restore=False):
    """Train this model on data from a generator.

    Parameters
    ----------
    feed_dict_generator: generator
      this should generate batches, each represented as a dict that maps
      Layers to values.
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    checkpoint_interval: int
      the frequency at which to write checkpoints, measured in training steps.
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    """

    def create_feed_dict():
      if self.use_queue:
@@ -142,36 +177,36 @@ class TensorGraph(Model):
      time1 = time.time()
      train_op = self._get_tf('train_op')
      saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      with tf.Session() as sess:
        self._initialize_weights(sess, saver)
      self.session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      avg_loss, n_batches = 0.0, 0.0
      coord = tf.train.Coordinator()
      n_samples = 0
      if self.use_queue:
        enqueue_thread = threading.Thread(
            target=_enqueue_batch,
              args=(self, feed_dict_generator, self._get_tf("Graph"), sess,
                    coord))
            args=(self, feed_dict_generator, self._get_tf("Graph"),
                  self.session, coord))
        enqueue_thread.start()
      output_tensors = [x.out_tensor for x in self.outputs]
      fetches = output_tensors + [train_op, self.loss.out_tensor]
      for feed_dict in create_feed_dict():
        try:
            fetched_values = sess.run(fetches, feed_dict=feed_dict)
          fetched_values = self.session.run(fetches, feed_dict=feed_dict)
          loss = fetched_values[-1]
          avg_loss += loss
          n_batches += 1
          self.global_step += 1
          n_samples += 1
          if self.tensorboard and n_samples % self.tensorboard_log_frequency == 0:
              summary = sess.run(
            summary = self.session.run(
                self._get_tf("summary_op"), feed_dict=feed_dict)
            self._log_tensorboard(summary)
        except OutOfRangeError:
          break
        if self.global_step % checkpoint_interval == checkpoint_interval - 1:
            saver.save(sess, self.save_file, global_step=self.global_step)
            self.last_checkpoint = saver.last_checkpoints[-1]
          saver.save(self.session, self.save_file, global_step=self.global_step)
          avg_loss = float(avg_loss) / n_batches
          print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                            avg_loss))
@@ -179,8 +214,7 @@ class TensorGraph(Model):
      avg_loss = float(avg_loss) / n_batches
      print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                        avg_loss))
        saver.save(sess, self.save_file, global_step=self.global_step)
        self.last_checkpoint = saver.last_checkpoints[-1]
      saver.save(self.session, self.save_file, global_step=self.global_step)
      time2 = time.time()
      print("TIMING: model fitting took %0.3f s" % (time2 - time1))

@@ -256,9 +290,6 @@ class TensorGraph(Model):
    elif not isinstance(outputs, collections.Sequence):
      outputs = [outputs]
    with self._get_tf("Graph").as_default():
      with tf.Session() as sess:
        saver = tf.train.Saver()
        self._initialize_weights(sess, saver)
      out_tensors = [x.out_tensor for x in self.outputs]
      # Gather results for each output
      results = [[] for out in out_tensors]
@@ -268,7 +299,7 @@ class TensorGraph(Model):
            for k, v in six.iteritems(feed_dict)
        }
        feed_dict[self._training_placeholder] = 0.0
          feed_results = sess.run(out_tensors, feed_dict=feed_dict)
        feed_results = self.session.run(out_tensors, feed_dict=feed_dict)
        if len(feed_results) > 1:
          if len(transformers):
            raise ValueError("Does not support transformations "
@@ -394,6 +425,7 @@ class TensorGraph(Model):
          self.rnn_final_states += node_layer.rnn_final_states
          self.rnn_zero_states += node_layer.rnn_zero_states
          node_layer.add_summary_to_tg()
      self.session = tf.Session()

      self.built = True

@@ -488,10 +520,12 @@ class TensorGraph(Model):
    rnn_initial_states = self.rnn_initial_states
    rnn_final_states = self.rnn_final_states
    rnn_zero_states = self.rnn_zero_states
    session = self.session
    self.tensor_objects = {}
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    self.session = None
    out_tensors = []
    if self.built:
      must_restore = True
@@ -526,6 +560,7 @@ class TensorGraph(Model):
    self.rnn_initial_states = rnn_initial_states
    self.rnn_final_states = rnn_final_states
    self.rnn_zero_states = rnn_zero_states
    self.session = session

  def evaluate_generator(self,
                         feed_dict_generator,
@@ -563,7 +598,7 @@ class TensorGraph(Model):
      self.build()
    with self._get_tf("Graph").as_default():
      return tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope=layer.variable_scope)
          tf.GraphKeys.TRAINABLE_VARIABLES, scope=layer.variable_scope)

  def get_global_step(self):
    return self._get_tf("GlobalStep")
@@ -610,25 +645,16 @@ class TensorGraph(Model):
        self.tensor_objects['GlobalStep'] = tf.Variable(0, trainable=False)
    return self._get_tf(obj)

  def _initialize_weights(self, sess, saver):
    """
    Parameters
    ----------
    sess: tf.Session
      The Session must be open
    saver: tf.train.Saver
      A saver object to save/restore checkpoints

    Returns
    -------

    """
    if self.last_checkpoint is None:
      sess.run(tf.global_variables_initializer())
      saver.save(sess, self.save_file, global_step=self.global_step)
      self.last_checkpoint = saver.last_checkpoints[-1]
    else:
      saver.restore(sess, self.last_checkpoint)
  def restore(self):
    """Reload the values of all variables from the most recent checkpoint file."""
    if not self.built:
      self.build()
    last_checkpoint = tf.train.latest_checkpoint(self.model_dir)
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      saver.restore(self.session, last_checkpoint)

  def get_num_tasks(self):
    return len(self.outputs)
@@ -644,6 +670,10 @@ class TensorGraph(Model):
    with open(pickle_name, 'rb') as fout:
      tensorgraph = pickle.load(fout)
      tensorgraph.built = False
      try:
        tensorgraph.restore()
      except ValueError:
        pass  # No checkpoint to load
      return tensorgraph

  def __del__(self):