Commit 7861927d authored by peastman's avatar peastman
Browse files

Transformers don't store a reference to a dataset

parent 57e56474
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -62,7 +62,6 @@ class Transformer(object):
               transform_w=False,
               dataset=None):
    """Initializes transformation based on dataset statistics."""
    self.dataset = dataset
    self.transform_X = transform_X
    self.transform_y = transform_y
    self.transform_w = transform_w
@@ -482,12 +481,12 @@ class BalancingTransformer(Transformer):
    assert transform_w

    # Compute weighting factors from dataset.
    y = self.dataset.y
    w = self.dataset.w
    y = dataset.y
    w = dataset.w
    # Ensure dataset is binary
    np.testing.assert_allclose(sorted(np.unique(y)), np.array([0., 1.]))
    weights = []
    for ind, task in enumerate(self.dataset.get_task_names()):
    for ind, task in enumerate(dataset.get_task_names()):
      task_w = w[:, ind]
      task_y = y[:, ind]
      # Remove labels with zero weights
@@ -505,7 +504,7 @@ class BalancingTransformer(Transformer):
  def transform_array(self, X, y, w):
    """Transform the data in a set of (X, y, w) arrays."""
    w_balanced = np.zeros_like(w)
    for ind, task in enumerate(self.dataset.get_task_names()):
    for ind in range(y.shape[1]):
      task_y = y[:, ind]
      task_w = w[:, ind]
      zero_indices = np.logical_and(task_y == 0, task_w != 0)