Unverified Commit 537c3347 authored by Karl Leswing's avatar Karl Leswing Committed by GitHub
Browse files

Merge pull request #1141 from nitinprakash96/load-functions

[WIP]Apply split==None to all the load functions
parents d5282b37 1a072bcf
Loading
Loading
Loading
Loading
+13 −4
Original line number Diff line number Diff line
@@ -5,14 +5,17 @@ from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem
from deepchem.molnet.load_function.bace_features import bace_user_specified_features

logger = logging.getLogger(__name__)


def load_bace_regression(featurizer='ECFP', split='random', reload=True):
  """Load bace datasets."""
  # Featurize bace dataset
  print("About to featurize bace dataset.")
  logger.info("About to featurize bace dataset.")
  data_dir = deepchem.utils.get_data_dir()
  if reload:
    save_dir = os.path.join(data_dir, "bace_r/" + featurizer + "/" + split)
@@ -53,10 +56,13 @@ def load_bace_regression(featurizer='ECFP', split='random', reload=True):
          transform_y=True, dataset=dataset)
  ]

  print("About to transform data")
  logger.info("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)

  if split == None:
    return bace_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
@@ -74,7 +80,7 @@ def load_bace_regression(featurizer='ECFP', split='random', reload=True):
def load_bace_classification(featurizer='ECFP', split='random', reload=True):
  """Load bace datasets."""
  # Featurize bace dataset
  print("About to featurize bace dataset.")
  logger.info("About to featurize bace dataset.")
  data_dir = deepchem.utils.get_data_dir()
  if reload:
    save_dir = os.path.join(data_dir, "bace_c/" + featurizer + "/" + split)
@@ -114,10 +120,13 @@ def load_bace_classification(featurizer='ECFP', split='random', reload=True):
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
  ]

  print("About to transform data")
  logger.info("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)

  if split == None:
    return bace_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
+8 −2
Original line number Diff line number Diff line
@@ -5,13 +5,16 @@ from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem

logger = logging.getLogger(__name__)


def load_bbbp(featurizer='ECFP', split='random', reload=True):
  """Load blood-brain barrier penetration datasets """
  # Featurize bbb dataset
  print("About to featurize bbbp dataset.")
  logger.info("About to featurize bbbp dataset.")
  data_dir = deepchem.utils.get_data_dir()
  if reload:
    save_dir = os.path.join(data_dir, "bbbp/" + featurizer + "/" + split)
@@ -47,10 +50,13 @@ def load_bbbp(featurizer='ECFP', split='random', reload=True):
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
  ]

  print("About to transform data")
  logger.info("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)

  if split == None:
    return bbbp_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
+15 −10
Original line number Diff line number Diff line
@@ -5,9 +5,12 @@ from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem
from deepchem.molnet.load_function.chembl_tasks import chembl_tasks

logger = logging.getLogger(__name__)


def load_chembl(shard_size=2000,
                featurizer="ECFP",
@@ -46,7 +49,7 @@ def load_chembl(shard_size=2000,
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/chembl_year_sets/chembl_sparse_ts_valid.csv.gz'
    )

  print("About to load ChEMBL dataset.")
  logger.info("About to load ChEMBL dataset.")
  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
@@ -62,7 +65,7 @@ def load_chembl(shard_size=2000,
        data_dir, "./chembl_year_sets/chembl_%s_ts_test.csv.gz" % set)

  # Featurize ChEMBL dataset
  print("About to featurize ChEMBL dataset.")
  logger.info("About to featurize ChEMBL dataset.")
  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
@@ -76,16 +79,16 @@ def load_chembl(shard_size=2000,
      tasks=chembl_tasks, smiles_field="smiles", featurizer=featurizer)

  if split == "year":
    print("Featurizing train datasets")
    logger.info("Featurizing train datasets")
    train_dataset = loader.featurize(train_files, shard_size=shard_size)
    print("Featurizing valid datasets")
    logger.info("Featurizing valid datasets")
    valid_dataset = loader.featurize(valid_files, shard_size=shard_size)
    print("Featurizing test datasets")
    logger.info("Featurizing test datasets")
    test_dataset = loader.featurize(test_files, shard_size=shard_size)
  else:
    dataset = loader.featurize(dataset_path, shard_size=shard_size)
  # Initialize transformers
  print("About to transform data")
  logger.info("About to transform data")
  if split == "year":
    transformers = [
        deepchem.trans.NormalizationTransformer(
@@ -103,15 +106,17 @@ def load_chembl(shard_size=2000,
    for transformer in transformers:
      dataset = transformer.transform(dataset)

  if spit == None:
    return chembl_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter()
  }

  if split in splitters:
  splitter = splitters[split]
    print("Performing new split.")
  logger.info("Performing new split.")
  train, valid, test = splitter.train_valid_test_split(dataset)

  if reload:
+9 −3
Original line number Diff line number Diff line
@@ -5,14 +5,17 @@ from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem

logger = logging.getLogger(__name__)


def load_clearance(featurizer='ECFP', split='random', reload=True):
  """Load clearance datasets."""
  # Featurize clearance dataset
  print("About to featurize clearance dataset.")
  print("About to load clearance dataset.")
  logger.info("About to featurize clearance dataset.")
  logger.info("About to load clearance dataset.")
  data_dir = deepchem.utils.get_data_dir()
  if reload:
    save_dir = os.path.join(data_dir, "clearance/" + featurizer + "/" + split)
@@ -50,10 +53,13 @@ def load_clearance(featurizer='ECFP', split='random', reload=True):
          transform_y=True, dataset=dataset)
  ]

  print("About to transform data")
  logger.info("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)

  if split == None:
    return clearance_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
+14 −7
Original line number Diff line number Diff line
@@ -7,8 +7,11 @@ from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem

logger = logging.getLogger(__name__)


def load_clintox(featurizer='ECFP', split='index', reload=True):
  """Load clintox datasets."""
@@ -23,19 +26,19 @@ def load_clintox(featurizer='ECFP', split='index', reload=True):
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz'
    )

  print("About to load clintox dataset.")
  logger.info("About to load clintox dataset.")
  dataset = deepchem.utils.save.load_from_disk(dataset_file)
  clintox_tasks = dataset.columns.values[1:].tolist()
  print("Tasks in dataset: %s" % (clintox_tasks))
  print("Number of tasks in dataset: %s" % str(len(clintox_tasks)))
  print("Number of examples in dataset: %s" % str(dataset.shape[0]))
  logger.info("Tasks in dataset: %s" % (clintox_tasks))
  logger.info("Number of tasks in dataset: %s" % str(len(clintox_tasks)))
  logger.info("Number of examples in dataset: %s" % str(dataset.shape[0]))
  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
      return clintox_tasks, all_dataset, transformers
  # Featurize clintox dataset
  print("About to featurize clintox dataset.")
  logger.info("About to featurize clintox dataset.")
  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
@@ -50,7 +53,7 @@ def load_clintox(featurizer='ECFP', split='index', reload=True):
  dataset = loader.featurize(dataset_file, shard_size=8192)

  # Transform clintox dataset
  print("About to transform clintox dataset.")
  logger.info("About to transform clintox dataset.")
  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
  ]
@@ -58,7 +61,11 @@ def load_clintox(featurizer='ECFP', split='index', reload=True):
    dataset = transformer.transform(dataset)

  # Split clintox dataset
  print("About to split clintox dataset.")
  logger.info("About to split clintox dataset.")

  if split == None:
    return clintox_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
Loading