Commit 979ba0d5 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

test fixes

parent 4f6d254e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ class Dataset(object):
        metadata_rows = []
        ids, X, y, w = raw_data
        metadata_rows.append(
            Dataset.write_data_to_disk(self.data_dir, "data", X, y, w, ids))
            Dataset.write_data_to_disk(self.data_dir, "data",tasks, X, y, w, ids))
        self.metadata_df = Dataset.construct_metadata(metadata_rows)
        self.save_to_disk()
      else:
+6 −6
Original line number Diff line number Diff line
@@ -23,7 +23,7 @@ class TestTransformerAPI(TestDatasetAPI):

  def test_y_log_transformer(self):
    """Tests logarithmic data transformer."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    log_transformer = LogTransformer(
        transform_y=True, dataset=solubility_dataset)
    X, y, w, ids = solubility_dataset.to_numpy()
@@ -45,7 +45,7 @@ class TestTransformerAPI(TestDatasetAPI):

  def test_X_log_transformer(self):
    """Tests logarithmic data transformer."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    log_transformer = LogTransformer(
        transform_X=True, dataset=solubility_dataset)
    X, y, w, ids = solubility_dataset.to_numpy()
@@ -67,7 +67,7 @@ class TestTransformerAPI(TestDatasetAPI):

  def test_y_normalization_transformer(self):
    """Tests normalization transformer."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    normalization_transformer = NormalizationTransformer(
        transform_y=True, dataset=solubility_dataset)
    X, y, w, ids = solubility_dataset.to_numpy()
@@ -89,7 +89,7 @@ class TestTransformerAPI(TestDatasetAPI):

  def test_X_normalization_transformer(self):
    """Tests normalization transformer."""
    solubility_dataset = self._load_solubility_data()
    solubility_dataset = self.load_solubility_data()
    normalization_transformer = NormalizationTransformer(
        transform_X=True, dataset=solubility_dataset)
    X, y, w, ids = solubility_dataset.to_numpy()
@@ -121,7 +121,7 @@ class TestTransformerAPI(TestDatasetAPI):
  def test_singletask_balancing_transformer(self):
    """Test balancing transformer on single-task dataset."""

    classification_dataset = self._load_classification_data()
    classification_dataset = self.load_classification_data()
    balancing_transformer = BalancingTransformer(
      transform_w=True, dataset=classification_dataset)
    X, y, w, ids = classification_dataset.to_numpy()
@@ -147,7 +147,7 @@ class TestTransformerAPI(TestDatasetAPI):

  def test_multitask_balancing_transformer(self):
    """Test balancing transformer on multitask dataset."""
    multitask_dataset = self._load_multitask_data()
    multitask_dataset = self.load_multitask_data()
    balancing_transformer = BalancingTransformer(
      transform_w=True, dataset=multitask_dataset)
    X, y, w, ids = multitask_dataset.to_numpy()