Commit e9878ee7 authored by Vignesh's avatar Vignesh
Browse files

Custom directory support for tox21, sampl, hiv. Split fracs for all datasets

parent a34adf28
Loading
Loading
Loading
Loading
+10 −1
Original line number Diff line number Diff line
@@ -149,7 +149,16 @@ def load_chembl25(featurizer="smiles2seq",
  logger.info("About to split data with {} splitter.".format(split))
  splitter = splitters[split]

  train, valid, test = splitter.train_valid_test_split(dataset, seed=split_seed)
  frac_train = kwargs.get('frac_train', 4 / 6)
  frac_valid = kwargs.get('frac_valid', 1 / 6)
  frac_test = kwargs.get('frac_test', 1 / 6)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      seed=split_seed,
      frac_train=frac_train,
      frac_test=frac_test,
      frac_valid=frac_valid)
  if transformer_type == "minmax":
    transformers = [
        dc.trans.MinMaxTransformer(
+37 −14
Original line number Diff line number Diff line
@@ -10,29 +10,42 @@ import deepchem

logger = logging.getLogger(__name__)

HIV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv'
DEFAULT_DIR = deepchem.utils.get_data_dir()

def load_hiv(featurizer='ECFP', split='index', reload=True, **kwargs):

def load_hiv(featurizer='ECFP',
             split='index',
             reload=True,
             data_dir=None,
             save_dir=None,
             **kwargs):
  """Load hiv datasets. Does not do train/test split"""
  # Featurize hiv dataset
  logger.info("About to featurize hiv dataset.")
  data_dir = deepchem.utils.get_data_dir()
  if data_dir is None:
    data_dir = DEFAULT_DIR
  if save_dir is None:
    save_dir = DEFAULT_DIR

  save_folder = os.path.join(save_dir, "hiv-featurized", str(featurizer),
                             str(split))
  if featurizer == "smiles2img":
    img_spec = kwargs.get("img_spec", "std")
    save_folder = os.path.join(save_folder, img_spec)

  if reload:
    save_dir = os.path.join(data_dir, "hiv/" + featurizer + "/" + str(split))
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_folder)
    if loaded:
      return hiv_tasks, all_dataset, transformers

  dataset_file = os.path.join(data_dir, "HIV.csv")
  if not os.path.exists(dataset_file):
    deepchem.utils.download_url(
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv'
    )
    deepchem.utils.download_url(url=HIV_URL, dest_dir=data_dir)

  hiv_tasks = ["HIV_active"]

  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
      return hiv_tasks, all_dataset, transformers

  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
@@ -64,10 +77,20 @@ def load_hiv(featurizer='ECFP', split='index', reload=True, **kwargs):
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'butina': deepchem.splits.ButinaSplitter()
      'butina': deepchem.splits.ButinaSplitter(),
      'stratified': deepchem.splits.RandomStratifiedSplitter()
  }
  splitter = splitters[split]
  logger.info("About to split dataset with {} splitter.".format(split))
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)
  train, valid, test = splitter.train_valid_test_split(dataset)

  transformers = [
@@ -81,6 +104,6 @@ def load_hiv(featurizer='ECFP', split='index', reload=True, **kwargs):
    test = transformer.transform(test)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
    deepchem.utils.save.save_dataset_to_disk(save_folder, train, valid, test,
                                             transformers)
  return hiv_tasks, (train, valid, test), transformers
+33 −12
Original line number Diff line number Diff line
@@ -10,35 +10,47 @@ import deepchem

logger = logging.getLogger(__name__)

SAMPL_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/SAMPL.csv'
DEFAULT_DIR = deepchem.utils.get_data_dir()


def load_sampl(featurizer='ECFP',
               split='index',
               reload=True,
               move_mean=True,
               data_dir=None,
               save_dir=None,
               **kwargs):
  """Load SAMPL datasets."""
  # Featurize SAMPL dataset
  logger.info("About to featurize SAMPL dataset.")
  logger.info("About to load SAMPL dataset.")
  data_dir = deepchem.utils.get_data_dir()
  if reload:

  if data_dir is None:
    data_dir = DEFAULT_DIR
  if save_dir is None:
    save_dir = DEFAULT_DIR

  if move_mean:
      dir_name = "sampl/" + featurizer + "/" + str(split)
    save_folder = os.path.join(data_dir, "sampl-featurized", str(featurizer),
                               str(split))
  else:
      dir_name = "sampl/" + featurizer + "_mean_unmoved/" + str(split)
    save_dir = os.path.join(data_dir, dir_name)
    save_folder = os.path.join(data_dir, "sampl-featurized",
                               str(featurizer) + "_mean_unmoved", str(split))

  if featurizer == "smiles2img":
    img_spec = kwargs.get("img_spec", "std")
    save_folder = os.path.join(save_folder, img_spec)

  dataset_file = os.path.join(data_dir, "SAMPL.csv")
  if not os.path.exists(dataset_file):
    deepchem.utils.download_url(
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/SAMPL.csv'
    )
    deepchem.utils.download_url(url=SAMPL_URL, dest_dir=data_dir)

  SAMPL_tasks = ['expt']

  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
        save_folder)
    if loaded:
      return SAMPL_tasks, all_dataset, transformers

@@ -77,6 +89,15 @@ def load_sampl(featurizer='ECFP',
  }
  splitter = splitters[split]
  logger.info("About to split dataset with {} splitter.".format(split))
  frac_train = kwargs.get("frac_train", 0.8)
  frac_valid = kwargs.get('frac_valid', 0.1)
  frac_test = kwargs.get('frac_test', 0.1)

  train, valid, test = splitter.train_valid_test_split(
      dataset,
      frac_train=frac_train,
      frac_valid=frac_valid,
      frac_test=frac_test)
  train, valid, test = splitter.train_valid_test_split(dataset)

  transformers = [
@@ -91,6 +112,6 @@ def load_sampl(featurizer='ECFP',
    test = transformer.transform(test)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
    deepchem.utils.save.save_dataset_to_disk(save_folder, train, valid, test,
                                             transformers)
  return SAMPL_tasks, (train, valid, test), transformers
+35 −10
Original line number Diff line number Diff line
@@ -10,8 +10,17 @@ import deepchem

logger = logging.getLogger(__name__)

TOX21_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz'
DEFAULT_DIR = deepchem.utils.get_data_dir()

def load_tox21(featurizer='ECFP', split='index', reload=True, K=4, **kwargs):

def load_tox21(featurizer='ECFP',
               split='index',
               reload=True,
               K=4,
               data_dir=None,
               save_dir=None,
               **kwargs):
  """Load Tox21 datasets. Does not do train/test split"""
  # Featurize Tox21 dataset

@@ -20,19 +29,26 @@ def load_tox21(featurizer='ECFP', split='index', reload=True, K=4, **kwargs):
      'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
  ]

  data_dir = deepchem.utils.get_data_dir()
  if data_dir is None:
    data_dir = DEFAULT_DIR
  if save_dir is None:
    save_dir = DEFAULT_DIR

  save_folder = os.path.join(save_dir, "tox21-featurized", str(featurizer),
                             str(split))
  if featurizer == "smiles2img":
    img_spec = kwargs.get("img_spec", "std")
    save_folder = os.path.join(save_folder, img_spec)

  if reload:
    save_dir = os.path.join(data_dir, "tox21/" + featurizer + "/" + str(split))
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
        save_folder)
    if loaded:
      return tox21_tasks, all_dataset, transformers

  dataset_file = os.path.join(data_dir, "tox21.csv.gz")
  if not os.path.exists(dataset_file):
    deepchem.utils.download_url(
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz'
    )
    deepchem.utils.download_url(url=TOX21_URL, dest_dir=data_dir)

  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
@@ -70,16 +86,25 @@ def load_tox21(featurizer='ECFP', split='index', reload=True, K=4, **kwargs):
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'butina': deepchem.splits.ButinaSplitter(),
      'task': deepchem.splits.TaskSplitter()
      'task': deepchem.splits.TaskSplitter(),
      'stratified': deepchem.splits.RandomStratifiedSplitter()
  }
  splitter = splitters[split]
  if split == 'task':
    fold_datasets = splitter.k_fold_split(dataset, K)
    all_dataset = fold_datasets
  else:
    train, valid, test = splitter.train_valid_test_split(dataset)
    frac_train = kwargs.get("frac_train", 0.8)
    frac_valid = kwargs.get('frac_valid', 0.1)
    frac_test = kwargs.get('frac_test', 0.1)

    train, valid, test = splitter.train_valid_test_split(
        dataset,
        frac_train=frac_train,
        frac_valid=frac_valid,
        frac_test=frac_test)
    all_dataset = (train, valid, test)
    if reload:
      deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
      deepchem.utils.save.save_dataset_to_disk(save_folder, train, valid, test,
                                               transformers)
  return tox21_tasks, all_dataset, transformers