Commit 8b004c53 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #370 from CCXD/dev

Fixed transformations on chembl and kaggle datasets
parents 09225083 e7d46be4
Loading
Loading
Loading
Loading
+3 −13
Original line number Diff line number Diff line
@@ -67,8 +67,9 @@ def load_chembl(shard_size=2000, featurizer="ECFP", set="5thresh", split="random
        transformers = [
            dc.trans.NormalizationTransformer(transform_y=True, dataset=train_dataset)]
        for transformer in transformers:
            for dataset in [train_dataset, valid_dataset, test_dataset]:
                transformer.transform(dataset)
            train = transformer.transform(train_dataset)
            valid = transformer.transform(valid_dataset)
            test = transformer.transform(test_dataset)
    else:
        transformers = [
            dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]
@@ -82,17 +83,6 @@ def load_chembl(shard_size=2000, featurizer="ECFP", set="5thresh", split="random
        splitter = splitters[split]
        print("Performing new split.")
        train, valid, test = splitter.train_valid_test_split(dataset)
    elif split == "year":
        print("Featurizing train datasets")
        train = loader.featurize(
            train_files, shard_size=shard_size)

        print("Featurizing valid datasets")
        valid = loader.featurize(
            valid_files, shard_size=shard_size)

        print("Featurizing test datasets")
        test = loader.featurize(
            test_files, shard_size=shard_size)

    return chembl_tasks, (train, valid, test), transformers
+3 −4
Original line number Diff line number Diff line
@@ -78,13 +78,12 @@ def load_kaggle(shard_size=2000, featurizer=None):
      dc.trans.NormalizationTransformer(transform_y=True,
                                        dataset=train_dataset)]

  # TODO(rbharath): Is this a bug in the Kaggle data transformation?
  for transformer in transformers:
    print("Performing transformations with %s"
          % transformer.__class__.__name__)
    for dataset in [train_dataset, valid_dataset, test_dataset]:
      print("Transforming dataset")
      transformer.transform(dataset)
    train_dataset = transformer.transform(train_dataset)
    valid_dataset = transformer.transform(valid_dataset)
    test_dataset = transformer.transform(test_dataset)

  print("Shuffling order of train dataset.")
  train_dataset.sparse_shuffle()