Unverified Commit 7ca3a111 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1621 from VIGS25/split-transform-order

Swapping Split-Transform order
parents 4e382ee8 2a33962f
Loading
Loading
Loading
Loading
+42 −17
Original line number Diff line number Diff line
@@ -57,17 +57,17 @@ def load_bace_regression(featurizer='ECFP',
      tasks=bace_tasks, smiles_field="mol", featurizer=featurizer)

  dataset = loader.featurize(dataset_file, shard_size=8192)
  if split is None:
    # Initialize transformers
    transformers = [
        deepchem.trans.NormalizationTransformer(
            transform_y=True, dataset=dataset, move_mean=move_mean)
    ]

  logger.info("About to transform data")
    logger.info("Split is None, about to transform data")
    for transformer in transformers:
      dataset = transformer.transform(dataset)

  if split == None:
    return bace_tasks, (dataset, None, None), transformers

  splitters = {
@@ -76,8 +76,20 @@ def load_bace_regression(featurizer='ECFP',
      'scaffold': deepchem.splits.ScaffoldSplitter()
  }
  splitter = splitters[split]
  logger.info("About to split data using {} splitter".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)

  transformers = [
      deepchem.trans.NormalizationTransformer(
          transform_y=True, dataset=train, move_mean=move_mean)
  ]

  logger.info("About to transform data.")
  for transformer in transformers:
    train = transformer.transform(train)
    valid = transformer.transform(valid)
    test = transformer.transform(test)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
@@ -122,16 +134,17 @@ def load_bace_classification(featurizer='ECFP', split='random', reload=True):
      tasks=bace_tasks, smiles_field="mol", featurizer=featurizer)

  dataset = loader.featurize(dataset_file, shard_size=8192)

  if split is None:
    # Initialize transformers
    transformers = [
        deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
    ]

  logger.info("About to transform data")
    logger.info("Split is None, about to transform data")
    for transformer in transformers:
      dataset = transformer.transform(dataset)

  if split == None:
    return bace_tasks, (dataset, None, None), transformers

  splitters = {
@@ -139,9 +152,21 @@ def load_bace_classification(featurizer='ECFP', split='random', reload=True):
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter()
  }

  splitter = splitters[split]
  logger.info("About to split data using {} splitter".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)

  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
  ]

  logger.info("About to transform data.")
  for transformer in transformers:
    train = transformer.transform(train)
    valid = transformer.transform(valid)
    test = transformer.transform(test)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
+10 −2
Original line number Diff line number Diff line
@@ -57,6 +57,8 @@ def load_bbbc001(split='index', reload=True):
  dataset = deepchem.data.DiskDataset.from_numpy(dataset.X, y)

  if split == None:
    transformers = []
    logger.info("Split is None, no transformers used for the dataset.")
    return bbbc001_tasks, (dataset, None, None), transformers

  splitters = {
@@ -67,7 +69,9 @@ def load_bbbc001(split='index', reload=True):
    raise ValueError("Only index and random splits supported.")
  splitter = splitters[split]

  logger.info("About to split dataset with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  transformers = []
  all_dataset = (train, valid, test)
  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
@@ -121,6 +125,8 @@ def load_bbbc002(split='index', reload=True):
  dataset = deepchem.data.DiskDataset.from_numpy(dataset.X, y, ids=ids)

  if split == None:
    transformers = []
    logger.info("Split is None, no transformers used for the dataset.")
    return bbbc002_tasks, (dataset, None, None), transformers

  splitters = {
@@ -131,8 +137,10 @@ def load_bbbc002(split='index', reload=True):
    raise ValueError("Only index and random splits supported.")
  splitter = splitters[split]

  logger.info("About to split dataset with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  all_dataset = (train, valid, test)
  transformers = []
  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
+20 −8
Original line number Diff line number Diff line
@@ -45,16 +45,17 @@ def load_bbbp(featurizer='ECFP', split='random', reload=True):
  loader = deepchem.data.CSVLoader(
      tasks=bbbp_tasks, smiles_field="smiles", featurizer=featurizer)
  dataset = loader.featurize(dataset_file, shard_size=8192)

  if split is None:
    # Initialize transformers
    transformers = [
        deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
    ]

  logger.info("About to transform data")
    logger.info("Split is None, about to transform data")
    for transformer in transformers:
      dataset = transformer.transform(dataset)

  if split == None:
    return bbbp_tasks, (dataset, None, None), transformers

  splitters = {
@@ -63,8 +64,19 @@ def load_bbbp(featurizer='ECFP', split='random', reload=True):
      'scaffold': deepchem.splits.ScaffoldSplitter()
  }
  splitter = splitters[split]
  logger.info("About to split data with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)

  # Initialize transformers
  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
  ]

  for transformer in transformers:
    train = transformer.transform(train)
    valid = transformer.transform(valid)
    test = transformer.transform(test)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
+4 −1
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ def load_cell_counting(split=None, reload=True):
  transformers = []

  if split == None:
    logger.info("Split is None, no transformers used.")
    return cell_counting_tasks, (dataset, None, None), transformers

  splitters = {
@@ -53,7 +54,9 @@ def load_cell_counting(split=None, reload=True):
    raise ValueError("Only index and random splits supported.")
  splitter = splitters[split]

  logger.info("About to split dataset with {} splitter.".format(split))
  train, valid, test = splitter.train_valid_test_split(dataset)
  transformers = []
  all_dataset = (train, valid, test)
  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
+26 −24
Original line number Diff line number Diff line
@@ -80,35 +80,27 @@ def load_chembl(shard_size=2000,

  if split == "year":
    logger.info("Featurizing train datasets")
    train_dataset = loader.featurize(train_files, shard_size=shard_size)
    train = loader.featurize(train_files, shard_size=shard_size)
    logger.info("Featurizing valid datasets")
    valid_dataset = loader.featurize(valid_files, shard_size=shard_size)
    valid = loader.featurize(valid_files, shard_size=shard_size)
    logger.info("Featurizing test datasets")
    test_dataset = loader.featurize(test_files, shard_size=shard_size)
    test = loader.featurize(test_files, shard_size=shard_size)
  else:
    dataset = loader.featurize(dataset_path, shard_size=shard_size)
  # Initialize transformers
  logger.info("About to transform data")
  if split == "year":
    transformers = [
        deepchem.trans.NormalizationTransformer(
            transform_y=True, dataset=train_dataset)
    ]
    for transformer in transformers:
      train = transformer.transform(train_dataset)
      valid = transformer.transform(valid_dataset)
      test = transformer.transform(test_dataset)
  else:

  if split is None:
    transformers = [
        deepchem.trans.NormalizationTransformer(
            transform_y=True, dataset=dataset)
    ]

    logger.info("Split is None, about to transform data.")
    for transformer in transformers:
      dataset = transformer.transform(dataset)

  if split == None:
    return chembl_tasks, (dataset, None, None), transformers

  if split != "year":
    splitters = {
        'index': deepchem.splits.IndexSplitter(),
        'random': deepchem.splits.RandomSplitter(),
@@ -119,6 +111,16 @@ def load_chembl(shard_size=2000,
    logger.info("Performing new split.")
    train, valid, test = splitter.train_valid_test_split(dataset)

  transformers = [
      deepchem.trans.NormalizationTransformer(transform_y=True, dataset=train)
  ]

  logger.info("About to transform data.")
  for transformer in transformers:
    train = transformer.transform(train)
    valid = transformer.transform(valid)
    test = transformer.transform(test)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
Loading