Commit b66208b6 authored by evanfeinberg's avatar evanfeinberg
Browse files

fixed tox21 issue 3

parent 875e0ac5
Loading
Loading
Loading
Loading
+41 −23
Original line number Diff line number Diff line
@@ -6,16 +6,29 @@ from __future__ import division
from __future__ import unicode_literals

import os
import deepchem as dc
import deepchem


def load_tox21(featurizer='ECFP', split='index'):
def load_tox21(featurizer='ECFP', split='index', reload=True, K=4):
  """Load Tox21 datasets. Does not do train/test split"""
  # Featurize Tox21 dataset
  if "DEEPCHEM_DATA_DIR" in os.environ:
    data_dir = os.environ["DEEPCHEM_DATA_DIR"]
  else:
    data_dir = "/tmp"
  if reload:
    save_dir = os.path.join(data_dir, "tox21/" + featurizer + "/" + split)

  tox21_tasks = [
      'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
      'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
  ]

  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
      return tox21_tasks, all_dataset, transformers

  dataset_file = os.path.join(data_dir, "tox21.csv.gz")
  if not os.path.exists(dataset_file):
@@ -24,41 +37,46 @@ def load_tox21(featurizer='ECFP', split='index'):
        ' http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz'
    )

  tox21_tasks = [
      'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
      'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
  ]
  if featurizer == 'ECFP':
    featurizer = dc.feat.CircularFingerprint(size=1024)
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
    featurizer = dc.feat.ConvMolFeaturizer()
    featurizer = deepchem.feat.ConvMolFeaturizer()
  elif featurizer == 'Weave':
    featurizer = deepchem.feat.WeaveFeaturizer()
  elif featurizer == 'Raw':
    featurizer = dc.feat.RawFeaturizer()
    featurizer = deepchem.feat.RawFeaturizer()
  elif featurizer == 'AdjacencyConv':
    featurizer = dc.feat.AdjacencyFingerprint(max_n_atoms=150, max_valence=6)

  loader = dc.data.CSVLoader(
  loader = deepchem.data.CSVLoader(
      tasks=tox21_tasks, smiles_field="smiles", featurizer=featurizer)
  dataset = loader.featurize(dataset_file, shard_size=8192)
  #saved_dataset = copy.deepcopy(dataset)
  #dataset

  # Initialize transformers 
  transformers = []
  #transformers = [
  #    dc.trans.BalancingTransformer(transform_w=True, dataset=dataset)
  #]
  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
  ]

  print("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)

  splitters = {
      'index': dc.splits.IndexSplitter(),
      'random': dc.splits.RandomSplitter(),
      'scaffold': dc.splits.ScaffoldSplitter(),
      'butina': dc.splits.ButinaSplitter()
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'butina': deepchem.splits.ButinaSplitter(),
      'task': deepchem.splits.TaskSplitter()
  }
  splitter = splitters[split]
  if split == 'task':
    fold_datasets = splitter.k_fold_split(dataset, K)
    all_dataset = fold_datasets
  else:
    train, valid, test = splitter.train_valid_test_split(dataset)
  return tox21_tasks, (train, valid, test), transformers
    all_dataset = (train, valid, test)
    if reload:
      deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                               transformers)
  return tox21_tasks, all_dataset, transformers