Commit b02cb094 authored by peastman's avatar peastman
Browse files

Improved logic for generating weights for datasets

parent 53a8b7b7
Loading
Loading
Loading
Loading
+23 −8
Original line number Diff line number Diff line
@@ -317,16 +317,19 @@ class NumpyDataset(Dataset):
    if n_samples > 0:
      if y is None:
        # Set labels to be zero, with zero weights
        y = np.zeros((n_samples, n_tasks))
        w = np.zeros_like(y)
        y = np.zeros((n_samples, n_tasks), np.float32)
        w = np.zeros((n_samples, 1), np.float32)
    if ids is None:
      ids = np.arange(n_samples)
    if w is None:
      w = np.ones_like(y)
    if not isinstance(X, np.ndarray):
      X = np.array(X)
    if not isinstance(y, np.ndarray):
      y = np.array(y)
    if w is None:
      if len(y.shape) == 1:
        w = np.ones(y.shape[0], np.float32)
      else:
        w = np.ones((y.shape[0], 1), np.float32)
    if not isinstance(w, np.ndarray):
      w = np.array(w)
    self._X = X
@@ -751,7 +754,10 @@ class DiskDataset(Dataset):
          if os.path.exists(w_filename):
            w = np.array(load_from_disk(w_filename))
          else:
            w = np.ones(y.shape)
            if len(y.shape) == 1:
              w = np.ones(y.shape[0], np.float32)
            else:
              w = np.ones((y.shape[0], 1), np.float32)
        else:
          w = None
        yield (X, y, w, ids)
@@ -975,7 +981,10 @@ class DiskDataset(Dataset):

    if y is not None:
      if w is None:
        w = np.ones_like(y)
        if len(y.shape) == 1:
          w = np.ones(y.shape[0], np.float32)
        else:
          w = np.ones((y.shape[0], 1), np.float32)

      if tasks is None:
        if len(y.shape) > 1:
@@ -1170,7 +1179,10 @@ class DiskDataset(Dataset):
      if os.path.exists(w_filename):
        w = np.array(load_from_disk(w_filename))
      else:
        w = np.ones(y.shape)
        if len(y.shape) == 1:
          w = np.ones(y.shape[0], np.float32)
        else:
          w = np.ones((y.shape[0], 1), np.float32)
    else:
      w = None

@@ -1370,7 +1382,10 @@ class ImageDataset(Dataset):
    self._X_shape = self._find_array_shape(X)
    self._y_shape = self._find_array_shape(y)
    if w is None:
      w = np.ones(self._y_shape[:2])
      if len(self._y_shape) == 1:
        w = np.ones(self._y_shape[0], np.float32)
      else:
        w = np.ones((self._y_shape[0], 1), np.float32)
    if ids is None:
      if not isinstance(X, np.ndarray):
        ids = X
+1 −1
Original line number Diff line number Diff line
@@ -45,7 +45,7 @@ class TestImageDataset(test_util.TensorFlowTestCase):
    x_shape, y_shape, w_shape, ids_shape = ds2.get_shape()
    np.testing.assert_array_equal([10], x_shape)
    np.testing.assert_array_equal([10, 28, 28], y_shape)
    np.testing.assert_array_equal([10, 28], w_shape)
    np.testing.assert_array_equal([10, 1], w_shape)
    np.testing.assert_array_equal([10], ids_shape)
    np.testing.assert_array_equal(ds2.X.shape, x_shape)
    np.testing.assert_array_equal(ds2.y.shape, y_shape)
+2 −0
Original line number Diff line number Diff line
@@ -304,6 +304,8 @@ class Metric(object):
      y_pred_task = y_pred[:, task]
      if len(w.shape) == 1:
        w_task = w
      elif w.shape[1] == 1:
        w_task = w[:, 0]
      else:
        w_task = w[:, task]

+8 −2
Original line number Diff line number Diff line
@@ -47,7 +47,8 @@ class SingletaskToMultitask(Model):
      task_data_dirs.append(task_data_dir)
    task_datasets = self._to_singletask(dataset, task_data_dirs)
    for task, task_dataset in zip(self.tasks, task_datasets):
      log("Dataset for task %s has shape %s" % (task,
      log(
          "Dataset for task %s has shape %s" % (task,
                                                str(task_dataset.get_shape())),
          self.verbose)
    return task_datasets
@@ -68,6 +69,11 @@ class SingletaskToMultitask(Model):
      basename = "dataset-%d" % shard_num
      for task_num, task in enumerate(tasks):
        log("\tTask %s" % task, dataset.verbose)
        if len(w.shape) == 1:
          w_task = w
        elif w.shape[1] == 1:
          w_task = w[:, 0]
        else:
          w_task = w[:, task_num]
        y_task = y[:, task_num]