Commit a6d9b817 authored by peastman's avatar peastman
Browse files

make_estimator() supports Tensorboard

parent d38c917e
Loading
Loading
Loading
Loading
+16 −10
Original line number Diff line number Diff line
@@ -152,9 +152,8 @@ class Layer(object):

  def set_summary(self, summary_op, summary_description=None, collections=None):
    """Annotates a tensor with a tf.summary operation
    Collects data from self.out_tensor by default but can be changed by setting
    self.tb_input to another tensor in create_tensor

    This causes self.out_tensor to be logged to Tensorboard.

    Parameters
    ----------
@@ -175,21 +174,28 @@ class Layer(object):
    self.collections = collections
    self.tensorboard = True

  def add_summary_to_tg(self):
  def add_summary_to_tg(self, tb_input=None):
    """
    Can only be called after self.create_layer to gaurentee that name is not none
    Create the summary operation for this layer, if set_summary() has been called on it.

    Can only be called after self.create_layer to guarantee that name is not None.

    Parameters
    ----------
    tb_input: tensor
      the tensor to log to Tensorboard.  If None, self.out_tensor is used.
    """
    if self.tensorboard == False:
      return
    if self.tb_input == None:
      self.tb_input = self.out_tensor
    if tb_input == None:
      tb_input = self.out_tensor
    if self.summary_op == "tensor_summary":
      tf.summary.tensor_summary(self.name, self.tb_input,
                                self.summary_description, self.collections)
      tf.summary.tensor_summary(self.name, tb_input, self.summary_description,
                                self.collections)
    elif self.summary_op == 'scalar':
      tf.summary.scalar(self.name, self.tb_input, self.collections)
      tf.summary.scalar(self.name, tb_input, self.collections)
    elif self.summary_op == 'histogram':
      tf.summary.histogram(self.name, self.tb_input, self.collections)
      tf.summary.histogram(self.name, tb_input, self.collections)

  def copy(self, replacements={}, variables_graph=None, shared=False):
    """Duplicate this Layer and all its inputs.
+4 −1
Original line number Diff line number Diff line
@@ -894,6 +894,7 @@ class TensorGraph(Model):
      inputs = [create_tensors(in_layer, tensors, training) for in_layer in layer.in_layers]
      tensor = layer.create_tensor(in_layers=inputs, set_tensors=False, training=training)
      tensors[layer] = tensor
      layer.add_summary_to_tg(tensor)
      return tensor

    # Define the model function.
@@ -907,6 +908,8 @@ class TensorGraph(Model):
      if weight_column is not None:
        tensors[self.task_weights[0]] = tf.feature_column.input_layer(features, [weight_column])
      tensors[self.labels[0]] = labels
      for layer, tensor in tensors.items():
        layer.add_summary_to_tg(tensor)

      # Create the correct outputs, based on the mode.

@@ -936,7 +939,7 @@ class TensorGraph(Model):

    # Create the Estimator.

    return tf.estimator.Estimator(model_fn=model_fn)
    return tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir)

def _enqueue_batch(tg, generator, graph, sess, n_enqueued, final_sample):
  """