Commit 4decac9c authored by Joseph Gomes's avatar Joseph Gomes
Browse files

Update FitTransformRegressor for changes in GraphModel

parent 8bdf015c
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -196,17 +196,16 @@ class TensorflowMultiTaskFitTransformRegressor(TensorflowMultiTaskRegressor):
               weight_init_stddevs=[.02], bias_init_consts=[1.], penalty=0.0,
               penalty_type="l2", dropouts=[0.5], learning_rate=.001,
               momentum=.9, optimizer="adam", batch_size=50, n_classes=2,
               fit_transformers=[], verbose=True, seed=None, **kwargs):
               fit_transformers=[], n_random_samples=10, verbose=True, seed=None, **kwargs):

    self.pad_batches = False
    self.fit_transformers = fit_transformers
    self.n_random_samples = n_random_samples
    # Run fit transformers on dummy dataset to determine n_features after transformation
    # JSG This could be generalized by passing in init_data_shape rather than n_features
    # JSG for now this only works with full CoulombMatrix featurizer
    X_b = np.random.rand(batch_size, n_features, n_features)
    for transformer in self.fit_transformers:
      X_b = transformer.X_transform(X_b)
    print(X_b.shape) 
    n_features = X_b.shape[1]
    print("n_features after fit_transform: %d" % int(n_features))

@@ -283,7 +282,7 @@ class TensorflowMultiTaskFitTransformRegressor(TensorflowMultiTaskRegressor):
          self.verbose)
    ############################################################## TIMING

  def predict_on_batch(self, X, pad_batch=False):
  def predict_on_batch(self, X):
    """Return model output for the provided input.

    Restore(checkpoint) must have previously been called on this object.