Commit ed4a3f7c authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixes to datasets tests

parent 7f1bef0c
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -27,6 +27,8 @@ class TestDatasetAPI(TestAPI):

  def load_solubility_data(self):
    """Loads solubility data from example.csv"""
    if os.path.exists(self.data_dir):
      shutil.rmtree(self.data_dir)
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["log-solubility"]
    task_type = "regression"
@@ -41,6 +43,8 @@ class TestDatasetAPI(TestAPI):

  def load_classification_data(self):
    """Loads classification data from example.csv"""
    if os.path.exists(self.data_dir):
      shutil.rmtree(self.data_dir)
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["outcome"]
    task_type = "classification"
@@ -55,6 +59,8 @@ class TestDatasetAPI(TestAPI):

  def load_multitask_data(self):
    """Load example multitask data."""
    if os.path.exists(self.data_dir):
      shutil.rmtree(self.data_dir)
    featurizers = [CircularFingerprint(size=1024)]
    tasks = ["task0", "task1", "task2", "task3", "task4", "task5", "task6",
             "task7", "task8", "task9", "task10", "task11", "task12",
+4 −0
Original line number Diff line number Diff line
@@ -30,6 +30,10 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    assert solubility_dataset.get_task_names() == ["log-solubility"]

    multitask_dataset = self.load_multitask_data()
    ############################################ DEBUG
    print("multitask_dataset.get_task_names()")
    print(multitask_dataset.get_task_names())
    ############################################ DEBUG
    assert sorted(multitask_dataset.get_task_names()) == sorted(["task0",
        "task1", "task2", "task3", "task4", "task5", "task6", "task7", "task8",
        "task9", "task10", "task11", "task12", "task13", "task14", "task15",
+3 −14
Original line number Diff line number Diff line
@@ -25,9 +25,7 @@ class TestDrop(TestAPI):
    len_full = 25

    current_dir = os.path.dirname(os.path.realpath(__file__))
    feature_dir = os.path.join(self.base_dir, "features")
    samples_dir = os.path.join(self.base_dir, "samples")
    full_dir = os.path.join(self.base_dir, "full_dataset")
    data_dir = os.path.join(self.base_dir, "dataset")
    model_dir = os.path.join(self.base_dir, "model")

    print("About to load emols dataset.")
@@ -41,18 +39,9 @@ class TestDrop(TestAPI):

    featurizer = DataFeaturizer(tasks=emols_tasks,
                                smiles_field="smiles",
                                compound_featurizers=featurizers,
                                featurizers=featurizers,
                                verbosity=verbosity)
    featurized_samples = featurizer.featurize(
        dataset_file, feature_dir,
        samples_dir, reload=reload)
    print("len(featurized_samples)")
    print(len(featurized_samples))

    # Generate datasets
    dataset = Dataset(data_dir=full_dir, samples=featurized_samples, 
                      featurizers=featurizers, tasks=emols_tasks,
                      verbosity=verbosity, reload=reload)
    dataset = featurizer.featurize(dataset_file, data_dir)

    X, y, w, ids = dataset.to_numpy()
    print("ids.shape, X.shape, y.shape, w.shape")
+36 −95
Original line number Diff line number Diff line
@@ -33,29 +33,11 @@ class TestReload(TestAPI):
  """
  Test reload for datasets.
  """
  def setUp(self):
    self.current_dir = os.path.dirname(os.path.abspath(__file__))
    self.smiles_field = "smiles"
    sys_temp = tempfile.gettempdir()
    self.base_dir = os.path.join(sys_temp, "base_dir")
    # Make sure to remove an alternate instance of this dir if it exists.
    if os.path.exists(self.base_dir):
      shutil.rmtree(self.base_dir)
    os.makedirs(self.base_dir)
    self.feature_dir = os.path.join(self.base_dir, "features")
    self.samples_dir = os.path.join(self.base_dir, "samples")
    self.train_dir = os.path.join(self.base_dir, "train_dataset")
    self.valid_dir = os.path.join(self.base_dir, "valid_dataset")
    self.test_dir = os.path.join(self.base_dir, "test_dataset")

  def tearDown(self):
    shutil.rmtree(self.base_dir)

  def _run_muv_experiment(self, dataset_file, reload=False, verbosity=None):
    """Loads or reloads a small version of MUV dataset."""
    # Load MUV dataset
    dataset = load_from_disk(dataset_file)
    print("Number of examples in dataset: %s" % str(dataset.shape[0]))
    raw_dataset = load_from_disk(dataset_file)
    print("Number of examples in dataset: %s" % str(raw_dataset.shape[0]))

    print("About to featurize compounds")
    featurizers = [CircularFingerprint(size=1024)]
@@ -65,114 +47,73 @@ class TestReload(TestAPI):
                 'MUV-466', 'MUV-832']
    featurizer = DataFeaturizer(tasks=MUV_tasks,
                                smiles_field="smiles",
                                compound_featurizers=featurizers,
                                featurizers=featurizers,
                                verbosity=verbosity)
    featurized_samples = featurizer.featurize(
        dataset_file, self.feature_dir,
        self.samples_dir, shard_size=4096,
        reload=reload)
    assert len(featurized_samples) == len(dataset)
    dataset = featurizer.featurize(dataset_file, self.data_dir)
    assert len(dataset) == len(raw_dataset)

    print("About to split compounds into train/valid/test")
    splitter = ScaffoldSplitter(verbosity=verbosity)
    frac_train, frac_valid, frac_test = .8, .1, .1
    train_samples, valid_samples, test_samples = \
    train_dataset, valid_dataset, test_dataset = \
        splitter.train_valid_test_split(
            featurized_samples, self.train_dir, self.valid_dir, self.test_dir,
            log_every_n=1000, reload=reload, frac_train=frac_train,
            dataset, self.train_dir, self.valid_dir, self.test_dir,
            log_every_n=1000, frac_train=frac_train,
            frac_test=frac_test, frac_valid=frac_valid)
    # Do an approximate comparison since splits are sometimes slightly off from
    # the exact fraction.
    assert relative_difference(
        len(train_samples), frac_train * len(featurized_samples)) < 1e-3
        len(train_dataset), frac_train * len(dataset)) < 1e-3
    assert relative_difference(
        len(valid_samples), frac_valid * len(featurized_samples)) < 1e-3
        len(valid_dataset), frac_valid * len(dataset)) < 1e-3
    assert relative_difference(
        len(test_samples), frac_test * len(featurized_samples)) < 1e-3
    len_train_samples, len_valid_samples, len_test_samples = \
      len(train_samples), len(valid_samples), len(test_samples)

    print("Creating train dataset.")
    train_dataset = Dataset(data_dir=self.train_dir, samples=train_samples, 
                            featurizers=featurizers, tasks=MUV_tasks,
                            verbosity=verbosity, reload=reload)
    print("Creating valid dataset.")
    valid_dataset = Dataset(data_dir=self.valid_dir, samples=valid_samples, 
                            featurizers=featurizers, tasks=MUV_tasks,
                            verbosity=verbosity, reload=reload)
    print("Creating test dataset")
    test_dataset = Dataset(data_dir=self.test_dir, samples=test_samples, 
                           featurizers=featurizers, tasks=MUV_tasks,
                           verbosity=verbosity, reload=reload)
    len_train_dataset, len_valid_dataset, len_test_dataset = \
      len(train_dataset), len(valid_dataset), len(test_dataset)

    assert len(train_samples) == len(train_dataset)
    assert len(valid_samples) == len(valid_dataset)
    assert len(test_samples) == len(test_dataset)
        len(test_dataset), frac_test * len(dataset)) < 1e-3

    # TODO(rbharath): Transformers don't play nice with reload! Namely,
    # reloading will cause the transform to be reapplied. This is undesirable in
    # almost all cases. Need to understand a method to fix this.
    input_transformers = []
    output_transformers = []
    weight_transformers = [BalancingTransformer(transform_w=True,
    dataset=train_dataset)]
    transformers = input_transformers + output_transformers + weight_transformers
    print("Transforming train dataset")
    for transformer in transformers:
        transformer.transform(train_dataset)
    print("Transforming valid dataset")
    for transformer in transformers:
        transformer.transform(valid_dataset)
    print("Transforming test dataset")
    transformers = [
        BalancingTransformer(transform_w=True, dataset=train_dataset)]
    print("Transforming datasets")
    for dataset in [train_dataset, valid_dataset, test_dataset]:
      for transformer in transformers:
        transformer.transform(test_dataset)
          transformer.transform(dataset)

    return (len_train_samples, len_valid_samples, len_test_samples,
            len_train_dataset, len_valid_dataset, len_test_dataset)
    return (len(train_dataset), len(valid_dataset), len(test_dataset))
    
  def test_reload_after_gen(self):
    """Check num samples for loaded and reloaded datasets is equal."""
    reload = False 
    verbosity = None
    current_dir = os.path.dirname(os.path.abspath(__file__))
    dataset_file = os.path.join(
        self.current_dir, "../../../datasets/mini_muv.csv.gz")
        current_dir, "../../../datasets/mini_muv.csv.gz")
    print("Running experiment for first time without reload.")
    (len_train_samples, len_valid_samples, len_test_samples,
     len_train_dataset, len_valid_dataset, len_test_dataset) = \
        self._run_muv_experiment(dataset_file, reload, verbosity)
    (len_train, len_valid, len_test) = self._run_muv_experiment(
        dataset_file, reload, verbosity)

    print("Running experiment for second time with reload.")
    reload = True 
    (len_reload_train_samples, len_reload_valid_samples, len_reload_test_samples,
     len_reload_train_dataset, len_reload_valid_dataset, len_reload_test_dataset) = \
        self._run_muv_experiment(dataset_file, reload, verbosity)
    assert len_train_samples == len_reload_train_samples
    assert len_valid_samples == len_reload_valid_samples
    assert len_test_samples == len_reload_valid_samples
    assert len_train_dataset == len_reload_train_dataset
    assert len_valid_dataset == len_reload_valid_dataset
    assert len_test_dataset == len_reload_valid_dataset
    (len_reload_train, len_reload_valid, len_reload_test) = (
        self._run_muv_experiment(dataset_file, reload, verbosity))
    assert len_train == len_reload_train
    assert len_valid == len_reload_valid
    assert len_test == len_reload_valid

  def test_reload_twice(self):
    """Check ability to repeatedly run experiments with reload set True."""
    reload = True 
    verbosity = "high"
    current_dir = os.path.dirname(os.path.abspath(__file__))
    dataset_file = os.path.join(
        self.current_dir, "../../../datasets/mini_muv.csv.gz")
        current_dir, "../../../datasets/mini_muv.csv.gz")
    print("Running experiment for first time with reload.")
    (len_train_samples, len_valid_samples, len_test_samples,
     len_train_dataset, len_valid_dataset, len_test_dataset) = \
        self._run_muv_experiment(dataset_file, reload, verbosity)
    (len_train, len_valid, len_test) = self._run_muv_experiment(
        dataset_file, reload, verbosity)

    print("Running experiment for second time with reload.")
    (len_reload_train_samples, len_reload_valid_samples, len_reload_test_samples,
     len_reload_train_dataset, len_reload_valid_dataset, len_reload_test_dataset) = \
        self._run_muv_experiment(dataset_file, reload, verbosity)
    assert len_train_samples == len_reload_train_samples
    assert len_valid_samples == len_reload_valid_samples
    assert len_test_samples == len_reload_valid_samples
    assert len_train_dataset == len_reload_train_dataset
    assert len_valid_dataset == len_reload_valid_dataset
    assert len_test_dataset == len_reload_valid_dataset
    (len_reload_train, len_reload_valid, len_reload_test) = (
        self._run_muv_experiment(dataset_file, reload, verbosity))
    assert len_train == len_reload_train
    assert len_valid == len_reload_valid
    assert len_test == len_reload_valid