Commit f24e2fd5 authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Add additional tests for transforming unlabelled datasets.

parent 44c50d41
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -623,7 +623,12 @@ class DiskDataset(Dataset):
        for (X_shard, y_shard, w_shard, ids_shard) in dataset.itershards():
            n_samples = X_shard.shape[0]
            for i in range(n_samples):
                yield (X_shard[i], y_shard[i], w_shard[i], ids_shard[i])
                def sanitize(elem):
                  if elem is None:
                    return None
                  else:
                    return elem[i]
                yield map(sanitize, [X_shard, y_shard, w_shard, ids_shard])
    return iterate(self)

  def transform(self, fn, **args):
+9 −0
Original line number Diff line number Diff line
@@ -91,3 +91,12 @@ def load_gaussian_cdf_data():
  loader = dc.data.UserCSVLoader(
      tasks=tasks, featurizer=featurizer, id_field="id")
  return loader.featurize(input_file)

def load_unlabelled_data():
  current_dir = os.path.dirname(os.path.abspath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=1024)
  tasks = []
  input_file = os.path.join(current_dir, "../../data/tests/no_labels.csv")
  loader = dc.data.CSVLoader(
      tasks=tasks, smiles_field="smiles", featurizer=featurizer)
  return loader.featurize(input_file)
 No newline at end of file
+14 −0
Original line number Diff line number Diff line
@@ -49,6 +49,20 @@ class TestTransformers(unittest.TestCase):
    # Check that untransform does the right thing.
    np.testing.assert_allclose(log_transformer.untransform(y_t), y)

  def test_transform_unlabelled(self):
    ul_dataset = dc.data.tests.load_unlabelled_data()
    # transforming y should raise an exception
    with self.assertRaises(ValueError) as context:
        dc.trans.transformers.Transformer(transform_y=True).transform(ul_dataset)


    # transforming w should raise an exception
    with self.assertRaises(ValueError) as context:
        dc.trans.transformers.Transformer(transform_w=True).transform(ul_dataset)

    # transforming X should be okay
    dc.trans.NormalizationTransformer(transform_X=True, dataset=ul_dataset).transform(ul_dataset)

  def test_X_log_transformer(self):
    """Tests logarithmic data transformer."""
    solubility_dataset = dc.data.tests.load_solubility_data()
+5 −0
Original line number Diff line number Diff line
@@ -79,6 +79,11 @@ class Transformer(object):
    Transforms all internally stored data.
    Adds X-transform, y-transform columns to metadata.
    """
    _, y_shape, w_shape, _ = dataset.get_shape()
    if y_shape == tuple() and self.transform_y:
      raise ValueError("Cannot transform y when y_values are not present")
    if w_shape == tuple() and self.transform_w:
      raise ValueError("Cannot transform w when w_values are not present")
    return dataset.transform(lambda X, y, w: self.transform_array(X, y, w))

  def transform_on_array(self, X, y, w):