Commit d38c917e authored by peastman's avatar peastman
Browse files

make_estimator() supports metrics

parent 0fb31315
Loading
Loading
Loading
Loading
+16 −2
Original line number Diff line number Diff line
@@ -839,7 +839,7 @@ class TensorGraph(Model):
      feed_dict[self._training_placeholder] = train_value
      yield feed_dict

  def make_estimator(self, feature_columns, weight_column=None, model_dir=None):
  def make_estimator(self, feature_columns, weight_column=None, model_dir=None, metrics={}):
    """Construct a Tensorflow Estimator from this model.

    tf.estimator.Estimator is the standard Tensorflow API for representing models.
@@ -859,6 +859,12 @@ class TensorGraph(Model):
    weight_column: tf.feature_column or None
      if this model includes a Weights layer, this describes the input weights.
      Otherwise, this should be None.
    metrics: map
      metrics that should be computed in calls to evaluate().  For each entry,
      the key is the name to report for the metric, and the value is a function
      of the form f(labels, predictions, weights) that returns the tensors for
      computing the metric.  Any of the functions in tf.metrics can be used, as
      can other functions that satisfy the same interface.
    model_dir: str
      the directory in which the Estimator should save files.  If None, this
      defaults to the model's model_dir.
@@ -911,7 +917,15 @@ class TensorGraph(Model):
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)
      if mode == tf.estimator.ModeKeys.EVAL:
        loss = create_tensors(self.loss, tensors, 0)
        return tf.esimator.EstimatorSpec(mode, loss=loss)
        predictions = create_tensors(self.outputs[0], tensors, 0)
        if len(self.task_weights) == 0:
          weights = None
        else:
          weights = tensors[self.task_weights[0]]
        eval_metric_ops = {}
        for name, function in metrics.items():
          eval_metric_ops[name] = function(labels, predictions, weights)
        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=eval_metric_ops)
      if mode == tf.estimator.ModeKeys.TRAIN:
        loss = create_tensors(self.loss, tensors, 1)
        global_step = tf.train.get_global_step()