Unverified Commit 5820f4c9 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1308 from rbharath/sweet

Sweetlead example
parents 2041df1b ae13294c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ from deepchem.molnet.load_function.qm8_datasets import load_qm8
from deepchem.molnet.load_function.qm9_datasets import load_qm9
from deepchem.molnet.load_function.sampl_datasets import load_sampl
from deepchem.molnet.load_function.sider_datasets import load_sider
from deepchem.molnet.load_function.sweetlead_datasets import load_sweet
from deepchem.molnet.load_function.tox21_datasets import load_tox21
from deepchem.molnet.load_function.toxcast_datasets import load_toxcast

+69 −0
Original line number Diff line number Diff line
"""
SWEET dataset loader.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import numpy as np
import shutil
import logging
import deepchem as dc

logger = logging.getLogger(__name__)


def load_sweet(featurizer='ECFP', split='index', reload=True, frac_train=.8):
  """Load sweet datasets."""
  # Load Sweetlead dataset
  logger.info("About to load Sweetlead dataset.")
  data_dir = dc.utils.get_data_dir()
  if reload:
    save_dir = os.path.join(data_dir,
                            "sweetlead/" + featurizer + "/" + str(split))

  dataset_file = os.path.join(data_dir, "sweet.csv.gz")
  if not os.path.exists(dataset_file):
    dc.utils.download_url(
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sweet.csv.gz'
    )

  # Featurize SWEET dataset
  print("About to featurize SWEET dataset.")
  if featurizer == 'ECFP':
    featurizer = dc.feat.CircularFingerprint(size=1024)
  else:
    raise ValueError("Other featurizations not supported")
  SWEET_tasks = ["task"]

  loader = dc.data.CSVLoader(
      tasks=SWEET_tasks, smiles_field="smiles", featurizer=featurizer)
  dataset = loader.featurize(dataset_file)

  # Initialize transformers
  transformers = [
      dc.trans.BalancingTransformer(transform_w=True, dataset=dataset)
  ]
  print("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)

  if split == None:
    return SWEET_tasks, (dataset, None, None), transformers

  splitters = {
      'index': dc.splits.IndexSplitter(),
      'random': dc.splits.RandomSplitter(),
      'scaffold': dc.splits.ScaffoldSplitter(),
      'task': dc.splits.TaskSplitter()
  }
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)

  if reload:
    dc.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                       transformers)
    all_dataset = (train, valid, test)

  return SWEET_tasks, (train, valid, test), transformers

examples/sweetlead/sweet.csv.gz

deleted100644 → 0
−53.4 KiB

File deleted.

+27 −93
Original line number Diff line number Diff line
"""
Script that loads random forest models trained on the sider and toxcast datasets, predicts on sweetlead,
creates covariance matrix
Script that loads random forest models trained on the sider and tox21 datasets,
predicts on sweetlead, creates covariance matrix

@Author Aneesh Pappu
"""
from __future__ import print_function
@@ -11,110 +12,44 @@ import os
import sys
import numpy as np
import pandas as pd
import deepchem as dc
from sklearn.ensemble import RandomForestClassifier
from deepchem.models.multitask import SingletaskToMultitask
from deepchem import metrics
from deepchem.metrics import Metric
from deepchem.models.sklearn_models import SklearnModel
from deepchem.splits import StratifiedSplitter, RandomSplitter
from sweetlead_datasets import load_sweet

sys.path.append('./../toxcast')
sys.path.append('./../sider')

from tox_datasets import load_tox
from sider_datasets import load_sider

"""
Load toxicity models now
"""

# Set some global variables up top
reload = False
verbosity = "high"
tox_tasks, (tox_train, tox_valid,
            tox_test), tox_transformers = dc.molnet.load_tox21()

base_tox_data_dir = "/home/apappu/deepchem-models/toxcast_models/toxcast/toxcast_data"
classification_metric = Metric(
    metrics.roc_auc_score, np.mean, mode="classification")

tox_tasks, tox_dataset, tox_transformers = load_tox(
    base_tox_data_dir, reload=reload)

#removes directory if present -- warning
base_tox_dir = "/home/apappu/deepchem-models/toxcast_models/toxcast/toxcast_analysis"
def model_builder(model_dir):
  sklearn_model = RandomForestClassifier(
      class_weight="balanced", n_estimators=500, n_jobs=-1)
  return dc.models.SklearnModel(sklearn_model, model_dir)

tox_train_dir = os.path.join(base_tox_dir, "train_dataset")
tox_valid_dir = os.path.join(base_tox_dir, "valid_dataset")
tox_test_dir = os.path.join(base_tox_dir, "test_dataset")
tox_model_dir = os.path.join(base_tox_dir, "model")

tox_splitter = StratifiedSplitter()
print(tox_train.get_task_names())
print(tox_tasks)
tox_model = SingletaskToMultitask(tox_tasks, model_builder)
tox_model.fit(tox_train)

#default split is 80-10-10 train-valid-test split
tox_train_dataset, tox_valid_dataset, tox_test_dataset = tox_splitter.train_valid_test_split(
  tox_dataset, tox_train_dir, tox_valid_dir, tox_test_dir)
# Load sider models now

# Fit Logistic Regression models
tox_task_types = {task: "classification" for task in tox_tasks}


classification_metric = Metric(metrics.roc_auc_score, np.mean,
                               verbosity=verbosity,
                               mode="classification")
params_dict = {
    "batch_size": None,
    "data_shape": tox_train_dataset.get_data_shape(),
}

def model_builder(tasks, task_types, model_params, model_dir, verbosity=None):
  return SklearnModel(tasks, task_types, model_params, model_dir,
                      model_instance=RandomForestClassifier(
                          class_weight="balanced",
                          n_estimators=500,
                          n_jobs=-1),
                      verbosity=verbosity)
tox_model = SingletaskToMultitask(tox_tasks, tox_task_types, params_dict, tox_model_dir,
                              model_builder, verbosity=verbosity)
tox_model.reload()

"""
Load sider models now
"""
sider_tasks, (
    sider_train, sider_valid,
    sider_test), sider_transformers = dc.molnet.load_sider(split="random")

base_sider_data_dir = "/home/apappu/deepchem-models/toxcast_models/sider/sider_data"
sider_model = SingletaskToMultitask(sider_tasks, model_builder)
sider_model.fit(sider_train)

sider_tasks, sider_dataset, sider_transformers = load_sider(
    base_sider_data_dir, reload=reload)
# Load sweetlead dataset now. Pass in dataset object and appropriate
# transformers to predict functions

base_sider_dir = "/home/apappu/deepchem-models/toxcast_models/sider/sider_analysis"

sider_train_dir = os.path.join(base_sider_dir, "train_dataset")
sider_valid_dir = os.path.join(base_sider_dir, "valid_dataset")
sider_test_dir = os.path.join(base_sider_dir, "test_dataset")
sider_model_dir = os.path.join(base_sider_dir, "model")

sider_splitter = RandomSplitter()
sider_train_dataset, sider_valid_dataset, sider_test_dataset = sider_splitter.train_valid_test_split(
  sider_dataset, sider_train_dir, sider_valid_dir, sider_test_dir)

# Fit Logistic Regression models
sider_task_types = {task: "classification" for task in sider_tasks}

params_dict = {
  "batch_size": None,
  "data_shape": sider_train_dataset.get_data_shape(),
}

sider_model = SingletaskToMultitask(sider_tasks, sider_task_types, params_dict, sider_model_dir,
                              model_builder, verbosity=verbosity)
sider_model.reload()

"""
Load sweetlead dataset now. Pass in dataset object and appropriate transformers to predict functions
"""

base_sweet_data_dir = "/home/apappu/deepchem-models/toxcast_models/sweetlead/sweet_data"

sweet_dataset, sweet_transformers = load_sweet(
    base_sweet_data_dir, reload=reload)
sweet_tasks, (sweet_dataset, _, _), sweet_transformers = dc.molnet.load_sweet()

sider_predictions = sider_model.predict(sweet_dataset, sweet_transformers)

@@ -134,4 +69,3 @@ for i in range(tox_predictions.shape[0]):
df = pd.DataFrame(confusion_matrix)

df.to_csv("./tox_sider_matrix.csv")
+0 −41
Original line number Diff line number Diff line
"""
SWEET dataset loader.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import numpy as np
import shutil
import deepchem as dc

def load_sweet(base_dir, frac_train=.8):
  """Load sweet datasets. Does not do train/test split"""
  current_dir = os.path.dirname(os.path.realpath(__file__))

  # Load SWEET dataset
  dataset_file = os.path.join(
      current_dir, "./sweet.csv.gz")

  # Featurize SWEET dataset
  print("About to featurize SWEET dataset.")
  featurizer = dc.feat.CircularFingerprint(size=1024)
  SWEET_tasks = dataset.columns.values[1:].tolist()

  loader = dc.data.CSVLoader(
      tasks=SWEET_tasks, smiles_field="smiles", featurizer=featurizer)
  dataset = loader.featurize(dataset_file)


  # Initialize transformers 
  transformers = [
      dc.trans.BalancingTransformer(transform_w=True, dataset=dataset)]
  print("About to transform data")
  for transformer in transformers:
      dataset = transformer.transform(dataset)

  spliter = dc.splits.IndexSplitter()
  train, valid, test = splitter.train_valid_test_split(dataset)

  return SWEET_tasks, (train, valid, test), transformers