Commit 1c9474dd authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent d4227698
Loading
Loading
Loading
Loading
+10 −8
Original line number Diff line number Diff line
@@ -29,6 +29,8 @@ class DuplicateBalancingTransformer(Transformer):
  >>> n_features = 3
  >>> n_tasks = 1
  >>> n_classes = 2
  >>> import deepchem as dc
  >>> import numpy as np
  >>> ids = np.arange(n_samples)
  >>> X = np.random.rand(n_samples, n_features)
  >>> y = np.random.randint(n_classes, size=(n_samples, n_tasks))
@@ -98,19 +100,19 @@ class DuplicateBalancingTransformer(Transformer):
    # Remove labels with zero weights
    y = y[w != 0]
    N = len(y)
    class_counts = []
    class_weights = []
    # Note that we may have 0 elements of a given class since we remove those
    # labels with zero weight.
    for c in self.classes:
      # this works because y is 1D
      num_c = len(np.where(y == c)[0])
      class_counts.append(num_c)
    N_largest = max(class_counts)
      c_weight = np.sum(w[y == c])
      class_weights.append(c_weight)
    weight_largest = max(class_weights)
    # This is the right ratio since int(N/num_c) * num_c \approx N
    # for all classes
    duplication_ratio = [
        int(N_largest / float(num_c)) if num_c > 0 else 0
        for num_c in class_counts
        int(weight_largest / float(c_weight)) if c_weight > 0 else 0
        for c_weight in class_weights
    ]
    self.duplication_ratio = duplication_ratio

@@ -141,9 +143,9 @@ class DuplicateBalancingTransformer(Transformer):
    idtrans: np.ndarray
      Transformed array of identifiers
    """
    if not (len(y.shape) == 1 or (len(y.shape) == 2 and y[1] == 1)):
    if not (len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1)):
      raise ValueError("y must be of shape (N,) or (N, 1)")
    if not (len(w.shape) == 1 or (len(w.shape) == 2 and w[1] == 1)):
    if not (len(w.shape) == 1 or (len(w.shape) == 2 and w.shape[1] == 1)):
      raise ValueError("w must be of shape (N,) or (N, 1)")
    # Flattening is safe because of shape check above
    y = y.flatten()
+30 −0
Original line number Diff line number Diff line
@@ -32,6 +32,36 @@ def test_binary_1d():
  assert np.isclose(np.sum(w_t[y_t == 0]), np.sum(w_t[y_t == 1]))


def test_binary_weighted_1d():
  """Test balancing transformer on a weighted single-task dataset without explicit task dimension."""
  n_samples = 6
  n_features = 3
  n_classes = 2
  np.random.seed(123)
  ids = np.arange(n_samples)
  X = np.random.rand(n_samples, n_features)
  # Note that nothing should change in this dataset since weights balance!
  y = np.array([1, 1, 0, 0, 0, 0])
  w = np.array([2, 2, 1, 1, 1, 1])
  dataset = dc.data.NumpyDataset(X, y, w)

  duplicator = dc.trans.DuplicateBalancingTransformer(dataset=dataset)
  dataset = duplicator.transform(dataset)
  # Check that still we have length 6
  assert len(dataset) == 6
  X_t, y_t, w_t, ids_t = (dataset.X, dataset.y, dataset.w, dataset.ids)
  # Check shapes
  assert X_t.shape == (6, n_features)
  assert y_t.shape == (6,)
  assert w_t.shape == (6,)
  assert ids_t.shape == (6,)
  # Check that we have 2 positives and 4 negatives
  assert np.sum(y_t == 0) == 4
  assert np.sum(y_t == 1) == 2
  # Check that sum of 0s equals sum of 1s in transformed for each task
  assert np.isclose(np.sum(w_t[y_t == 0]), np.sum(w_t[y_t == 1]))


def test_binary_singletask():
  """Test duplicate balancing transformer on single-task dataset."""
  n_samples = 6