Commit 4034c5c9 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Continued debugging

parent 8b6ae492
Loading
Loading
Loading
Loading
+11 −8
Original line number Diff line number Diff line
@@ -8,19 +8,22 @@ from __future__ import unicode_literals
import os
import numpy as np
import shutil
import logging
import deepchem as dc

logger = logging.getLogger(__name__)

def load_sweet(featurizer='ECFP', split='index', reload=True, frac_train=.8):
  """Load sweet datasets."""
  # Load Sweetlead dataset
  logger.info("About to load Sweetlead dataset.")
  data_dir = deepchem.utils.get_data_dir()
  data_dir = dc.utils.get_data_dir()
  if reload:
    save_dir = os.path.join(data_dir, "sweetlead/" + featurizer + "/" + str(split))

  dataset_file = os.path.join(data_dir, "sweet.csv.gz")
  if not os.path.exists(dataset_file):
    deepchem.utils.download_url(
    dc.utils.download_url(
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sweet.csv.gz'
    )

@@ -45,16 +48,16 @@ def load_sweet(featurizer='ECFP', split='index', reload=True, frac_train=.8):
      dataset = transformer.transform(dataset)

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'task': deepchem.splits.TaskSplitter()
      'index': dc.splits.IndexSplitter(),
      'random': dc.splits.RandomSplitter(),
      'scaffold': dc.splits.ScaffoldSplitter(),
      'task': dc.splits.TaskSplitter()
  }
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
    dc.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                       transformers)
    all_dataset = (train, valid, test)

+11 −86
Original line number Diff line number Diff line
@@ -17,59 +17,11 @@ from deepchem.models.multitask import SingletaskToMultitask
from deepchem import metrics
from deepchem.metrics import Metric
from deepchem.models.sklearn_models import SklearnModel
#from deepchem.splits import StratifiedSplitter, RandomSplitter
#from sweetlead_datasets import load_sweet

#sys.path.append('./../toxcast')
#sys.path.append('./../sider')
#
#from tox_datasets import load_tox
#from sider_datasets import load_sider

#"""
#Load toxicity models now
#"""

## Set some global variables up top
#reload = False
#verbosity = "high"
#
#base_tox_data_dir = "/home/apappu/deepchem-models/toxcast_models/toxcast/toxcast_data"

tox_tasks, (tox_train, tox_valid, tox_test), tox_transformers = dc.molnet.load_toxcast()

#removes directory if present -- warning
#base_tox_dir = "/home/apappu/deepchem-models/toxcast_models/toxcast/toxcast_analysis"

#tox_train_dir = os.path.join(base_tox_dir, "train_dataset")
#tox_valid_dir = os.path.join(base_tox_dir, "valid_dataset")
#tox_test_dir = os.path.join(base_tox_dir, "test_dataset")
#tox_model_dir = os.path.join(base_tox_dir, "model")

#tox_splitter = StratifiedSplitter()

#default split is 80-10-10 train-valid-test split
#tox_train_dataset, tox_valid_dataset, tox_test_dataset = tox_splitter.train_valid_test_split(
#  tox_dataset, tox_train_dir, tox_valid_dir, tox_test_dir)

## Fit Logistic Regression models
#tox_task_types = {task: "classification" for task in tox_tasks}

tox_tasks, (tox_train, tox_valid, tox_test), tox_transformers = dc.molnet.load_tox21()

classification_metric = Metric(metrics.roc_auc_score, np.mean, mode="classification")

#params_dict = {
#    "batch_size": None,
#    "data_shape": tox_train_dataset.get_data_shape(),
#}

#def model_builder(tasks, task_types, model_params, model_dir, verbosity=None):
#  return SklearnModel(tasks, task_types, model_params, model_dir,
#                      model_instance=RandomForestClassifier(
#                          class_weight="balanced",
#                          n_estimators=500,
#                          n_jobs=-1),
#                      verbosity=verbosity)
def model_builder(model_dir):
  sklearn_model = RandomForestClassifier(
                          class_weight="balanced",
@@ -77,49 +29,22 @@ def model_builder(model_dir):
                          n_jobs=-1)
  return dc.models.SklearnModel(sklearn_model, model_dir)

print(tox_train.get_task_names())
print(tox_tasks)
tox_model = SingletaskToMultitask(tox_tasks, model_builder)
tox_model.reload()

"""
Load sider models now
"""

base_sider_data_dir = "/home/apappu/deepchem-models/toxcast_models/sider/sider_data"

sider_tasks, sider_dataset, sider_transformers = load_sider(
    base_sider_data_dir, reload=reload)

base_sider_dir = "/home/apappu/deepchem-models/toxcast_models/sider/sider_analysis"
tox_model.fit(tox_train)

sider_train_dir = os.path.join(base_sider_dir, "train_dataset")
sider_valid_dir = os.path.join(base_sider_dir, "valid_dataset")
sider_test_dir = os.path.join(base_sider_dir, "test_dataset")
sider_model_dir = os.path.join(base_sider_dir, "model")
# Load sider models now

sider_splitter = RandomSplitter()
sider_train_dataset, sider_valid_dataset, sider_test_dataset = sider_splitter.train_valid_test_split(
  sider_dataset, sider_train_dir, sider_valid_dir, sider_test_dir)
sider_tasks, (sider_train, sider_valid, sider_test), sider_transformers = dc.molnet.load_sider(split="random")

# Fit Logistic Regression models
sider_task_types = {task: "classification" for task in sider_tasks}

params_dict = {
  "batch_size": None,
  "data_shape": sider_train_dataset.get_data_shape(),
}

sider_model = SingletaskToMultitask(sider_tasks, sider_task_types, params_dict, sider_model_dir,
                              model_builder, verbosity=verbosity)
sider_model.reload()

"""
Load sweetlead dataset now. Pass in dataset object and appropriate transformers to predict functions
"""
sider_model = SingletaskToMultitask(tox_tasks, model_builder)
sider_model.fit(sider_train)

base_sweet_data_dir = "/home/apappu/deepchem-models/toxcast_models/sweetlead/sweet_data"
# Load sweetlead dataset now. Pass in dataset object and appropriate
# transformers to predict functions

sweet_dataset, sweet_transformers = dc.molnet.load_sweet(
    base_sweet_data_dir, reload=reload)
sweet_tasks, (sweet_dataset, _, _), sweet_transformers = dc.molnet.load_sweet()

sider_predictions = sider_model.predict(sweet_dataset, sweet_transformers)