Commit 0fb31315 authored by peastman's avatar peastman
Browse files

Initial implementation of make_estimator()

parent 560ea1bb
Loading
Loading
Loading
Loading
+84 −0
Original line number Diff line number Diff line
@@ -839,6 +839,90 @@ class TensorGraph(Model):
      feed_dict[self._training_placeholder] = train_value
      yield feed_dict

  def make_estimator(self, feature_columns, weight_column=None, model_dir=None):
    """Construct a Tensorflow Estimator from this model.

    tf.estimator.Estimator is the standard Tensorflow API for representing models.
    This method provides interoperability between DeepChem and other Tensorflow
    based tools by allowing any model to be used an Estimator.

    Once this method returns, the Estimator it created is independent of the model
    it was created from.  They do not share tensors, variables, save files, or any
    other resources.  The Estimator is a self contained object with its own methods
    for training, evaluation, prediction, checkpointing, etc.

    Parameters
    ----------
    feature_columns: list of tf.feature_column objects
      this describes the input features to the models.  There must be one entry
      for each Feature layer in this model's features field.
    weight_column: tf.feature_column or None
      if this model includes a Weights layer, this describes the input weights.
      Otherwise, this should be None.
    model_dir: str
      the directory in which the Estimator should save files.  If None, this
      defaults to the model's model_dir.
    """
    # Check the inputs.

    if len(feature_columns) != len(self.features):
      raise ValueError('This model requires %d feature column(s)' % len(self.features))
    if len(self.labels) != 1:
      raise ValueError('Can only create an Estimator from a model with exactly one Label input')
    if len(self.task_weights) > 1:
      raise ValueError('Cannot create an Estimator from a model with multiple Weight inputs')
    if weight_column is None:
      if len(self.task_weights) > 0:
        raise ValueError('This model requires a weight column')
    else:
      if len(self.task_weights) == 0:
        raise ValueError('Cannot specify weight_column for a model with no Weight inputs')
    if model_dir is None:
      model_dir = self.model_dir

    # Define a function that recursively creates tensors from layers.

    def create_tensors(layer, tensors, training):
      if layer in tensors:
        return tensors[layer]
      inputs = [create_tensors(in_layer, tensors, training) for in_layer in layer.in_layers]
      tensor = layer.create_tensor(in_layers=inputs, set_tensors=False, training=training)
      tensors[layer] = tensor
      return tensor

    # Define the model function.

    def model_fn(features, labels, mode):
      # Define the inputs.

      tensors = {}
      for layer, column in zip(self.features, feature_columns):
        tensors[layer] = tf.feature_column.input_layer(features, [column])
      if weight_column is not None:
        tensors[self.task_weights[0]] = tf.feature_column.input_layer(features, [weight_column])
      tensors[self.labels[0]] = labels

      # Create the correct outputs, based on the mode.

      if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {}
        for i, output in enumerate(self.outputs):
          predictions[i] = create_tensors(output, tensors, 0)
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)
      if mode == tf.estimator.ModeKeys.EVAL:
        loss = create_tensors(self.loss, tensors, 0)
        return tf.esimator.EstimatorSpec(mode, loss=loss)
      if mode == tf.estimator.ModeKeys.TRAIN:
        loss = create_tensors(self.loss, tensors, 1)
        global_step = tf.train.get_global_step()
        optimizer = self.optimizer._create_optimizer(global_step)
        train_op = optimizer.minimize(loss, global_step=global_step)
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
      raise ValueError('Unknown mode')

    # Create the Estimator.

    return tf.estimator.Estimator(model_fn=model_fn)

def _enqueue_batch(tg, generator, graph, sess, n_enqueued, final_sample):
  """