Commit 9b079e51 authored by Peter Eastman's avatar Peter Eastman
Browse files

Converted multitask fit regressor to tensorgraph

parent 65764c86
Loading
Loading
Loading
Loading
+156 −3
Original line number Diff line number Diff line
@@ -34,6 +34,33 @@ class TensorGraphMultiTaskClassifier(TensorGraph):
               dropouts=[0.5],
               n_classes=2,
               **kwargs):
    """Create a TensorGraphMultiTaskClassifier.

    In addition to the following arguments, this class also accepts all the keywork arguments
    from TensorGraph.

    Parameters
    ----------
    n_tasks: int
      number of tasks
    n_features: int
      number of features
    layer_sizes: list
      the size of each dense layer in the network.  The length of this list determines the number of layers.
    weight_init_stddevs: list
      the standard deviation of the distribution to use for weight initialization of each layer.  The length
      of this list should equal len(layer_sizes).
    bias_init_consts: list
      the value to initialize the biases in each layer to.  The length of this list should equal len(layer_sizes).
    weight_decay_penalty: float
      the magnitude of the weight decay penalty to use
    weight_decay_penalty_type: str
      the type of penalty to use for weight decay, either 'l1' or 'l2'
    dropouts: list
      the dropout probablity to use for each layer.  The length of this list should equal len(layer_sizes).
    n_classes: int
      the number of classes
    """
    super().__init__(mode='classification', **kwargs)
    self.n_tasks = n_tasks
    self.n_features = n_features
@@ -94,12 +121,38 @@ class TensorGraphMultiTaskRegressor(TensorGraph):
               n_tasks,
               n_features,
               layer_sizes=[1000],
               weight_init_stddevs=[0.02],
               bias_init_consts=[1.0],
               weight_init_stddevs=[0.02, 0.02],
               bias_init_consts=[1.0, 1.0],
               weight_decay_penalty=0.0,
               weight_decay_penalty_type="l2",
               dropouts=[0.5],
               **kwargs):
    """Create a TensorGraphMultiTaskRegressor.

    In addition to the following arguments, this class also accepts all the keywork arguments
    from TensorGraph.

    Parameters
    ----------
    n_tasks: int
      number of tasks
    n_features: int
      number of features
    layer_sizes: list
      the size of each dense layer in the network.  The length of this list determines the number of layers.
    weight_init_stddevs: list
      the standard deviation of the distribution to use for weight initialization of each layer.  The length
      of this list should equal len(layer_sizes)+1.  The final element corresponds to the output layer.
    bias_init_consts: list
      the value to initialize the biases in each layer to.  The length of this list should equal len(layer_sizes)+1.
      The final element corresponds to the output layer.
    weight_decay_penalty: float
      the magnitude of the weight decay penalty to use
    weight_decay_penalty_type: str
      the type of penalty to use for weight decay, either 'l1' or 'l2'
    dropouts: list
      the dropout probablity to use for each layer.  The length of this list should equal len(layer_sizes).
    """
    super().__init__(mode='regression', **kwargs)
    self.n_tasks = n_tasks
    self.n_features = n_features
@@ -121,7 +174,9 @@ class TensorGraphMultiTaskRegressor(TensorGraph):

    # Compute the loss function for each label.

    output = Reshape(shape=(-1, n_tasks, 1), in_layers=[Dense(in_layers=[prev_layer], out_channels=n_tasks)])
    output = Reshape(shape=(-1, n_tasks, 1), in_layers=[Dense(in_layers=[prev_layer], out_channels=n_tasks,
                    weights_initializer=TFWrapper(tf.truncated_normal_initializer, stddev=weight_init_stddevs[-1]),
                    biases_initializer=TFWrapper(tf.constant_initializer, value=bias_init_consts[-1]))])
    self.add_output(output)
    labels = Label(shape=(None, n_tasks, 1))
    weights = Weights(shape=(None, n_tasks))
@@ -154,6 +209,104 @@ class TensorGraphMultiTaskRegressor(TensorGraph):



class TensorGraphMultiTaskFitRegressor(TensorGraphMultiTaskRegressor):
  """Implements a TensorGraphMultiTaskRegressor that performs on-the-fly transformation during fit/predict.

  Example:

  >>> n_samples = 10
  >>> n_features = 3
  >>> n_tasks = 1
  >>> ids = np.arange(n_samples)
  >>> X = np.random.rand(n_samples, n_features, n_features)
  >>> y = np.zeros((n_samples, n_tasks))
  >>> w = np.ones((n_samples, n_tasks))
  >>> dataset = dc.data.NumpyDataset(X, y, w, ids)
  >>> fit_transformers = [dc.trans.CoulombFitTransformer(dataset)]
  >>> model = dc.models.TensorflowMultiTaskFitTransformRegressor(n_tasks, [n_features, n_features],
  ...     dropouts=[0.], learning_rate=0.003, weight_init_stddevs=[np.sqrt(6)/np.sqrt(1000)],
  ...     batch_size=n_samples, fit_transformers=fit_transformers, n_evals=1)
  n_features after fit_transform: 12
  """

  def __init__(self,
               n_tasks,
               n_features,
               fit_transformers=[],
               n_evals=1,
               batch_size=50,
               **kwargs):
    """Create a TensorGraphMultiTaskFitRegressor.

    In addition to the following arguments, this class also accepts all the keywork arguments
    from TensorGraphMultiTaskRegressor.

    Parameters
    ----------
    n_tasks: int
      number of tasks
    n_features: list or int
      number of features
    fit_transformers: list
      List of dc.trans.FitTransformer objects
    n_evals: int
      Number of evalations per example at predict time
    """
    self.fit_transformers = fit_transformers
    self.n_evals = n_evals

    # Run fit transformers on dummy dataset to determine n_features after transformation

    if isinstance(n_features, list):
      X_b = np.ones([batch_size] + n_features)
    elif isinstance(n_features, int):
      X_b = np.ones([batch_size, n_features])
    else:
      raise ValueError("n_features should be list or int")
    for transformer in fit_transformers:
      X_b = transformer.X_transform(X_b)
    n_features = X_b.shape[1]
    print("n_features after fit_transform: %d" % int(n_features))
    super().__init__(n_tasks, n_features, batch_size=batch_size, **kwargs)


  def default_generator(self,
                        dataset,
                        epochs=1,
                        predict=False,
                        pad_batches=True):
    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          deterministic=True,
          pad_batches=pad_batches):
        feed_dict = dict()
        if y_b is not None and not predict:
          feed_dict[self.labels[0]] = y_b.reshape(-1, self.n_tasks, 1)
        if X_b is not None:
          if not predict:
            for transformer in self.fit_transformers:
              X_b = transformer.X_transform(X_b)
          feed_dict[self.features[0]] = X_b
        if w_b is not None and not predict:
          feed_dict[self.task_weights[0]] = w_b
        yield feed_dict


  def predict_proba_on_generator(self, generator, transformers=[]):
    def transform_generator():
      for feed_dict in generator:
        X = feed_dict[self.features[0]]
        for i in range(self.n_evals):
          X_t = X
        for transformer in self.fit_transformers:
          X_t = transformer.X_transform(X_t)
        feed_dict[self.features[0]] = X_t
        yield feed_dict
    return super().predict_proba_on_generator(transform_generator(), transformers)



class TensorflowMultiTaskClassifier(TensorflowClassifier):
  """Implements an icml model as configured in a model_config.proto."""