Commit ecfdcc5d authored by joegomes's avatar joegomes
Browse files

Fix ClippingTransformer and add unit tests

parent cef2632f
Loading
Loading
Loading
Loading
+50 −0
Original line number Diff line number Diff line
@@ -243,6 +243,56 @@ class TestTransformers(unittest.TestCase):
    # Check that untransform does the right thing.
    np.testing.assert_allclose(cdf_transformer.untransform(y_t), y)

  def test_clipping_X_transformer(self):
    """Test clipping transformer on X of singletask dataset."""
    n_samples = 10
    n_features = 3
    n_tasks = 1
    ids = np.arange(n_samples)
    X = np.ones((n_samples, n_features))
    target = 5.*X
    X *= 6.
    y = np.zeros((n_samples, n_tasks))
    w = np.ones((n_samples, n_tasks))
    dataset = dc.data.NumpyDataset(X, y, w, ids)
    transformer = dc.trans.ClippingTransformer(transform_X=True, x_max=5.)
    clipped_dataset = transformer.transform(dataset)
    X_t, y_t, w_t, ids_t = (clipped_dataset.X, clipped_dataset.y, clipped_dataset.w, clipped_dataset.ids)
    # Check ids are unchanged.
    for id_elt, id_t_elt in zip(ids, ids_t):
      assert id_elt == id_t_elt
    # Check y is unchanged since this is an X transformer
    np.testing.assert_allclose(y, y_t)
    # Check w is unchanged since this is an X transformer
    np.testing.assert_allclose(w, w_t)
    # Check X is now holding the proper values when sorted.
    np.testing.assert_allclose(X_t, target)
 
  def test_clipping_y_transformer(self):
    """Test clipping transformer on y of singletask dataset."""
    n_samples = 10
    n_features = 3
    n_tasks = 1
    ids = np.arange(n_samples)
    X = np.zeros((n_samples, n_features))
    y = np.ones((n_samples, n_tasks))
    target = 5.*y
    y *= 6.
    w = np.ones((n_samples, n_tasks))
    dataset = dc.data.NumpyDataset(X, y, w, ids)
    transformer = dc.trans.ClippingTransformer(transform_y=True, y_max=5.)
    clipped_dataset = transformer.transform(dataset)
    X_t, y_t, w_t, ids_t = (clipped_dataset.X, clipped_dataset.y, clipped_dataset.w, clipped_dataset.ids)
    # Check ids are unchanged.
    for id_elt, id_t_elt in zip(ids, ids_t):
      assert id_elt == id_t_elt
    # Check X is unchanged since this is a y transformer
    np.testing.assert_allclose(X, X_t)
    # Check w is unchanged since this is a y transformer
    np.testing.assert_allclose(w, w_t)
    # Check y is now holding the proper values when sorted.
    np.testing.assert_allclose(y_t, target)
  
  def test_power_X_transformer(self):
    """Test Power transformer on Gaussian normal dataset."""
    gaussian_dataset = dc.data.tests.load_gaussian_cdf_data()
+9 −8
Original line number Diff line number Diff line
@@ -159,27 +159,28 @@ class NormalizationTransformer(Transformer):
class ClippingTransformer(Transformer):

  def __init__(self, transform_X=False, transform_y=False,
               transform_w=False, dataset=None, max_val=5.):
               transform_w=False, dataset=None, x_max=5., y_max=500.):
    """Initialize clipping transformation."""
    super(ClippingTransformer, self).__init__(transform_X=transform_X,
                                              transform_y=transform_y,
                                              transform_w=transform_w,
                                              dataset=dataset)
    self.max_val = max_val
    self.x_max = x_max
    self.y_max = y_max

  def transform_array(self, X, y, w):
    """Transform the data in a set of (X, y, w) arrays."""
    if self.transform_X:
      X[X > self.max_val] = self.max_val
      X[X < (-1.0*self.max_val)] = -1.0 * self.max_val
      X[X > self.x_max] = self.x_max
      X[X < (-1.0*self.x_max)] = -1.0 * self.x_max
    if self.transform_y:
      y[y > trunc] = trunc
      y[y < (-1.0*trunc)] = -1.0 * trunc
      y[y > self.y_max] = self.y_max
      y[y < (-1.0*self.y_max)] = -1.0 * self.y_max
    return (X, y, w)

  def untransform(self, z):
    warnings.warn("Clipping cannot be undone.")
    return z
    raise NotImplementedError(
      "Cannot untransform datasets with ClippingTransformer.")

class LogTransformer(Transformer):