Commit 48d2084a authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent 39f73650
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -148,7 +148,6 @@ class GaussianProcessHyperparamOpt(HyperparamOpt):
    if len(params_dict) > 20:
      raise ValueError(
          "This class can only search over 20 parameters in one invocation.")
    data_dir = deepchem.utils.get_data_dir()
    # Specify logfile
    if logfile:
      log_file = logfile
+12 −4
Original line number Diff line number Diff line
@@ -11,8 +11,16 @@ train, valid, test = delaney_datasets
# Fit models
regression_metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)

# TODO(rbharath): I don't like this awkward string/class divide. Maybe clean up?
optimizer = dc.hyper.GaussianProcessHyperparamOpt('tf_regression')
def rf_model_builder(**model_params):
  rf_params = {k:v for (k,v) in model_params.items() if k != 'model_dir'}
  model_dir = model_params['model_dir']
  sklearn_model = sklearn.ensemble.RandomForestRegressor(**rf_params)
  return dc.models.SklearnModel(sklearn_model, model_dir)

optimizer = dc.hyper.GaussianProcessHyperparamOpt(rf_model_builder)
best_hyper_params, best_performance = optimizer.hyperparam_search(
    dc.molnet.preset_hyper_parameters.hps['tf_regression'], train, valid,
    transformers, [regression_metric])
    params_dict,
    train_dataset,
    valid_dataset,
    transformers,
    metric)