Commit 99206097 authored by Thomas Blaschke's avatar Thomas Blaschke
Browse files

Add configproto to TensorGraph kwargs to create custom sessions.

parent 52d41150
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -52,6 +52,7 @@ class TensorGraph(Model):
    learning_rate: float or LearningRateSchedule
      the learning rate to use for optimization
    kwargs
      "configproto": a tf.ConfigProto() object used to create tf.Session()
    """

    # Layer Management
@@ -66,6 +67,7 @@ class TensorGraph(Model):
    self.queue_installed = False
    self.optimizer = Adam(
        learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-7)
    self.configproto = kwargs.pop("configproto", tf.ConfigProto())

    # Singular place to hold Tensor objects which don't serialize
    # These have to be reconstructed on restoring from pickle
@@ -470,7 +472,7 @@ class TensorGraph(Model):
          self.rnn_final_states += layer.rnn_final_states
          self.rnn_zero_states += layer.rnn_zero_states
          layer.add_summary_to_tg()
      self.session = tf.Session()
      self.session = tf.Session(config=self.configproto)
      self.built = True

      # Ensure all training operators have been created.