Commit 61450d4f authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Yapf

parent a9419708
Loading
Loading
Loading
Loading
+18 −11
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ import tempfile
import deepchem as dc
from deepchem.molnet.load_function.bace_features import bace_user_specified_features


def load_bace(mode="regression", transform=True, split="20-80"):
  """Load BACE-1 dataset as regression/classification problem."""
  assert split in ["20-80", "80-20"]
@@ -18,11 +19,11 @@ def load_bace(mode="regression", transform=True, split="20-80"):

  current_dir = os.path.dirname(os.path.realpath(__file__))
  if split == "20-80":
    dataset_file = os.path.join(
        current_dir, "../../datasets/desc_canvas_aug30.csv")
    dataset_file = os.path.join(current_dir,
                                "../../datasets/desc_canvas_aug30.csv")
  elif split == "80-20":
    dataset_file = os.path.join(
        current_dir, "../../datasets/rev8020split_desc.csv")
    dataset_file = os.path.join(current_dir,
                                "../../datasets/rev8020split_desc.csv")

  crystal_dataset_file = os.path.join(
      current_dir, "../../datasets/crystal_desc_canvas_aug30.csv")
@@ -33,7 +34,9 @@ def load_bace(mode="regression", transform=True, split="20-80"):
    bace_tasks = ["Class"]
  featurizer = dc.feat.UserDefinedFeaturizer(bace_user_specified_features)
  loader = dc.data.UserCSVLoader(
      tasks=bace_tasks, smiles_field="mol", id_field="CID",
      tasks=bace_tasks,
      smiles_field="mol",
      id_field="CID",
      featurizer=featurizer)
  dataset = loader.featurize(dataset_file)
  crystal_dataset = loader.featurize(crystal_dataset_file)
@@ -55,11 +58,15 @@ def load_bace(mode="regression", transform=True, split="20-80"):
  print(len(crystal_dataset))

  transformers = [
      dc.trans.NormalizationTransformer(transform_X=True, dataset=train_dataset),
      dc.trans.ClippingTransformer(transform_X=True, dataset=train_dataset)]
      dc.trans.NormalizationTransformer(
          transform_X=True, dataset=train_dataset),
      dc.trans.ClippingTransformer(transform_X=True, dataset=train_dataset)
  ]
  if mode == "regression":
    transformers += [
      dc.trans.NormalizationTransformer(transform_y=True, dataset=train_dataset)]
        dc.trans.NormalizationTransformer(
            transform_y=True, dataset=train_dataset)
    ]

  for dataset in [train_dataset, valid_dataset, test_dataset, crystal_dataset]:
    if len(dataset) > 0:
+9 −6
Original line number Diff line number Diff line
@@ -13,10 +13,11 @@ from deepchem import metrics
from deepchem.metrics import Metric
from deepchem.utils.evaluate import Evaluator


def bace_rf_model(mode="classification", split="20-80"):
  """Train random forests on BACE dataset."""
  (bace_tasks, (train, valid, test, crystal),
   transformers) = load_bace(mode=mode, transform=False, split=split)
  (bace_tasks, (train, valid, test, crystal), transformers) = load_bace(
      mode=mode, transform=False, split=split)

  if mode == "regression":
    r2_metric = Metric(metrics.r2_score)
@@ -25,6 +26,7 @@ def bace_rf_model(mode="classification", split="20-80"):
    all_metrics = [r2_metric, rms_metric, mae_metric]
    metric = r2_metric
    model_class = RandomForestRegressor

    def rf_model_builder(model_params, model_dir):
      sklearn_model = RandomForestRegressor(**model_params)
      return SklearnModel(sklearn_model, model_dir)
@@ -37,6 +39,7 @@ def bace_rf_model(mode="classification", split="20-80"):
    model_class = RandomForestClassifier
    all_metrics = [accuracy_metric, mcc_metric, recall_metric, roc_auc_metric]
    metric = roc_auc_metric

    def rf_model_builder(model_params, model_dir):
      sklearn_model = RandomForestClassifier(**model_params)
      return SklearnModel(sklearn_model, model_dir)
@@ -50,8 +53,7 @@ def bace_rf_model(mode="classification", split="20-80"):

  optimizer = HyperparamOpt(rf_model_builder)
  best_rf, best_rf_hyperparams, all_rf_results = optimizer.hyperparam_search(
      params_dict, train, valid, transformers,
      metric=metric)
      params_dict, train, valid, transformers, metric=metric)

  if len(train) > 0:
    rf_train_evaluator = Evaluator(best_rf, train, transformers)
@@ -85,6 +87,7 @@ def bace_rf_model(mode="classification", split="20-80"):
        all_metrics, csv_out=csv_out, stats_out=stats_out)
    print("RF Crystal set: %s" % (str(rf_crystal_score)))


if __name__ == "__main__":
  print("Classifier RF 20-80:")
  print("--------------------------------")