Commit fdf25145 authored by Peter Eastman's avatar Peter Eastman
Browse files

Reenabled tests in test_datasets.py

parent d5fd17a5
Loading
Loading
Loading
Loading
+14 −14
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
  Test basic top-level API for dataset objects.
  """

  def notest_sparsify_and_densify(self):
  def test_sparsify_and_densify(self):
    """Test that sparsify and densify work as inverses."""
    # Test on identity matrix
    num_samples = 10
@@ -53,7 +53,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    X_reconstructed = densify_features(X_sparse, num_features)
    np.testing.assert_array_equal(X, X_reconstructed)

  def notest_pad_features(self):
  def test_pad_features(self):
    """Test that pad_features pads features correctly."""
    batch_size = 100
    num_features = 10
@@ -99,7 +99,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    assert len(X_out) == batch_size
  

  def notest_pad_batches(self):
  def test_pad_batches(self):
    """Test that pad_batch pads batches correctly."""
    batch_size = 100
    num_features = 10
@@ -171,7 +171,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
        batch_size, X_b, y_b, w_b, ids_b)
    assert len(X_out) == len(y_out) == len(w_out) == len(ids_out) == batch_size
    
  def notest_get_task_names(self):
  def test_get_task_names(self):
    """Test that get_task_names returns correct task_names"""
    solubility_dataset = self.load_solubility_data()
    assert solubility_dataset.get_task_names() == ["log-solubility"]
@@ -182,7 +182,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
        "task9", "task10", "task11", "task12", "task13", "task14", "task15",
        "task16"])

  def notest_get_data_shape(self):
  def test_get_data_shape(self):
    """Test that get_data_shape returns currect data shape"""
    solubility_dataset = self.load_solubility_data()
    assert solubility_dataset.get_data_shape() == (1024,) 
@@ -190,12 +190,12 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    multitask_dataset = self.load_multitask_data()
    assert multitask_dataset.get_data_shape() == (1024,)

  def notest_len(self):
  def test_len(self):
    """Test that len(dataset) works."""
    solubility_dataset = self.load_solubility_data()
    assert len(solubility_dataset) == 10

  def notest_reshard(self):
  def test_reshard(self):
    """Test that resharding the dataset works."""
    solubility_dataset = self.load_solubility_data()
    X, y, w, ids = solubility_dataset.to_numpy()
@@ -220,7 +220,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    np.testing.assert_array_equal(w, w_rr)
    np.testing.assert_array_equal(ids, ids_rr)

  def notest_select(self):
  def test_select(self):
    """Test that dataset select works."""
    num_datapoints = 10
    num_features = 10
@@ -241,7 +241,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    np.testing.assert_array_equal(ids[indices], ids_sel)
    shutil.rmtree(select_dir)

  def notest_get_shape(self):
  def test_get_shape(self):
    """Test that get_shape works."""
    num_datapoints = 100
    num_features = 10
@@ -261,7 +261,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    assert ids_shape == ids.shape


  def notest_to_singletask(self):
  def test_to_singletask(self):
    """Test that to_singletask works."""
    num_datapoints = 100
    num_features = 10
@@ -292,7 +292,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
      for task_dir in task_dirs:
        shutil.rmtree(task_dir)
  
  def notest_iterbatches(self):
  def test_iterbatches(self):
    """Test that iterating over batches of data works."""
    solubility_dataset = self.load_solubility_data()
    batch_size = 2
@@ -304,7 +304,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
      assert w_b.shape == (batch_size,) + (len(tasks),)
      assert ids_b.shape == (batch_size,)

  def notest_to_numpy(self):
  def test_to_numpy(self):
    """Test that transformation to numpy arrays is sensible."""
    solubility_dataset = self.load_solubility_data()
    data_shape = solubility_dataset.get_data_shape()
@@ -318,7 +318,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    assert w.shape == (N_samples, N_tasks)
    assert ids.shape == (N_samples,)

  def notest_consistent_ordering(self):
  def test_consistent_ordering(self):
    """Test that ordering of labels is consistent over time."""
    solubility_dataset = self.load_solubility_data()

@@ -327,7 +327,7 @@ class TestBasicDatasetAPI(TestDatasetAPI):

    assert np.array_equal(ids1, ids2)

  def notest_get_statistics(self):
  def test_get_statistics(self):
    """Test statistics computation of this dataset."""
    solubility_dataset = self.load_solubility_data()
    X, y, _, _ = solubility_dataset.to_numpy()