Commit 97b22c5d authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixes to tox21 sklearn example

parent 181c6938
Loading
Loading
Loading
Loading
+82 −0
Original line number Diff line number Diff line
"""
Tox21 dataset loader.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import numpy as np
import shutil
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from deepchem.utils.save import load_from_disk
from deepchem.datasets import Dataset
from deepchem.featurizers.featurize import DataFeaturizer
from deepchem.featurizers.fingerprints import CircularFingerprint
from deepchem.splits import ScaffoldSplitter
from deepchem.splits import RandomSplitter
from deepchem.datasets import Dataset
from deepchem.transformers import BalancingTransformer
from deepchem.hyperparameters import HyperparamOpt
from deepchem.models.multitask import SingletaskToMultitask
from deepchem import metrics
from deepchem.metrics import Metric
from deepchem.metrics import to_one_hot
from deepchem.models.sklearn_models import SklearnModel
from deepchem.utils.evaluate import relative_difference
from deepchem.utils.evaluate import Evaluator

def load_tox21(base_dir, reload=True):
  """Load Tox21 datasets. Does not do train/test split"""
  # Set some global variables up top
  reload = True
  verbosity = "high"
  model = "logistic"

  # Create some directories for analysis
  # The base_dir holds the results of all analysis
  if not reload:
    if os.path.exists(base_dir):
      shutil.rmtree(base_dir)
  if not os.path.exists(base_dir):
    os.makedirs(base_dir)
  current_dir = os.path.dirname(os.path.realpath(__file__))
  #Make directories to store the raw and featurized datasets.
  samples_dir = os.path.join(base_dir, "samples")
  data_dir = os.path.join(base_dir, "dataset")

  # Load Tox21 dataset
  print("About to load Tox21 dataset.")
  dataset_file = os.path.join(
      current_dir, "../../datasets/tox21.csv.gz")
  dataset = load_from_disk(dataset_file)
  print("Columns of dataset: %s" % str(dataset.columns.values))
  print("Number of examples in dataset: %s" % str(dataset.shape[0]))

  # Featurize Tox21 dataset
  print("About to featurize Tox21 dataset.")
  featurizers = [CircularFingerprint(size=1024)]
  all_tox21_tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER',
                     'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5',
                     'SR-HSE', 'SR-MMP', 'SR-p53']

  if not reload or not os.path.exists(data_dir):
    featurizer = DataFeaturizer(tasks=all_tox21_tasks,
                                smiles_field="smiles",
                                featurizers=featurizers,
                                verbosity=verbosity)
    dataset = featurizer.featurize(
        dataset_file, data_dir, shard_size=8192)
  else:
    dataset = Dataset(data_dir, all_tox21_tasks, reload=True)

  # Initialize transformers 
  transformers = [
      BalancingTransformer(transform_w=True, dataset=dataset)]
  if not reload:
    print("About to transform data")
    for transformer in transformers:
        transformer.transform(dataset)
  
  return all_tox21_tasks, dataset, transformers
+4 −56
Original line number Diff line number Diff line
@@ -8,24 +8,15 @@ from __future__ import unicode_literals
import os
import shutil
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from deepchem.utils.save import load_from_disk
from deepchem.datasets import Dataset
from deepchem.featurizers.featurize import DataFeaturizer
from deepchem.featurizers.fingerprints import CircularFingerprint
from deepchem.splits import ScaffoldSplitter
from deepchem.splits import RandomSplitter
from deepchem.datasets import Dataset
from deepchem.transformers import BalancingTransformer
from deepchem.hyperparameters import HyperparamOpt
from deepchem.models.multitask import SingletaskToMultitask
from deepchem import metrics
from deepchem.metrics import Metric
from deepchem.metrics import to_one_hot
from deepchem.models.sklearn_models import SklearnModel
from deepchem.utils.evaluate import relative_difference
from deepchem.utils.evaluate import Evaluator
from deepchem.datasets.tox21_datasets import load_tox21


# Only for debug!
@@ -41,9 +32,7 @@ if not os.path.exists(base_dir):

current_dir = os.path.dirname(os.path.realpath(__file__))
#Make directories to store the raw and featurized datasets.
feature_dir = os.path.join(base_dir, "features")
samples_dir = os.path.join(base_dir, "samples")
full_dir = os.path.join(base_dir, "full_dataset")
data_dir = os.path.join(base_dir, "dataset")
train_dir = os.path.join(base_dir, "train_dataset")
valid_dir = os.path.join(base_dir, "valid_dataset")
test_dir = os.path.join(base_dir, "test_dataset")
@@ -57,46 +46,9 @@ dataset = load_from_disk(dataset_file)
print("Columns of dataset: %s" % str(dataset.columns.values))
print("Number of examples in dataset: %s" % str(dataset.shape[0]))

# Featurize tox21 dataset
print("About to featurize Tox21 dataset.")
featurizers = [CircularFingerprint(size=1024)]
all_tox21_tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
                   'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
# For debugging purposes
n_tasks = 12 
tox21_tasks = all_tox21_tasks[0:n_tasks]
valid_scores = {}

print("Using following tasks")
print(tox21_tasks)

## This is for good debug (to make sure nasty state isn't being passed around)
if os.path.exists(feature_dir):
  shutil.rmtree(feature_dir)
featurizer = DataFeaturizer(tasks=tox21_tasks,
                            smiles_field="smiles",
                            compound_featurizers=featurizers,
                            verbosity=verbosity)
featurized_samples = featurizer.featurize(
    dataset_file, feature_dir,
    samples_dir, shard_size=8192,
    reload=reload)

# Generate datasets
print("About to create datasets")
print("tox21_tasks")
print(tox21_tasks)

# This is for good debug (to make sure nasty state isn't being passed around)
if os.path.exists(full_dir):
  shutil.rmtree(full_dir)
full_dataset = Dataset(data_dir=full_dir, samples=featurized_samples, 
                        featurizers=featurizers, tasks=tox21_tasks,
                        verbosity=verbosity, reload=reload)

# Do train/valid split.
tox21_tasks, tox21_dataset, transformers = load_tox21(data_dir, reload=reload)
num_train = 7200
X, y, w, ids = full_dataset.to_numpy()
X, y, w, ids = tox21_dataset.to_numpy()
X_train, X_valid = X[:num_train], X[num_train:]
y_train, y_valid = y[:num_train], y[num_train:]
w_train, w_valid = w[:num_train], w[num_train:]
@@ -113,9 +65,6 @@ if os.path.exists(valid_dir):
valid_dataset = Dataset.from_numpy(valid_dir, X_valid, y_valid,
                                   w_valid, ids_valid, tox21_tasks)

# No data transformations for now
transformers = []

# Fit models
tox21_task_types = {task: "classification" for task in tox21_tasks}

@@ -133,7 +82,6 @@ if os.path.exists(model_dir):
  shutil.rmtree(model_dir)
def model_builder(tasks, task_types, model_params, model_dir, verbosity=None):
  return SklearnModel(tasks, task_types, model_params, model_dir,
                      #model_instance=LogisticRegression(class_weight="balanced"),
                      model_instance=RandomForestClassifier(
                          class_weight="balanced",
                          n_estimators=500),