Commit 2b01efa5 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Cleanup

parent bd763350
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -328,7 +328,8 @@ class Dataset(object):
    if not os.path.exists(select_dir):
      os.makedirs(select_dir)
    if not len(indices):
      return Dataset(data_dir=select_dir, metadata_row=[], verbosity=self.verbosity)
      return Dataset(
          data_dir=select_dir, metadata_row=[], verbosity=self.verbosity)
    indices = np.array(sorted(indices)).astype(int)
    count, indices_count = 0, 0
    metadata_rows = []
+4 −4
Original line number Diff line number Diff line
@@ -90,7 +90,6 @@ class SingletaskToMultitask(Model):
        y_pred[:, ind] = task_model.predict_on_batch(X)
      else:
        raise ValueError("Invalid task_type")
      ############################################### DEBUG
    return y_pred

  def predict_proba_on_batch(self, X, n_classes=2):
@@ -101,7 +100,8 @@ class SingletaskToMultitask(Model):
    n_samples = X.shape[0]
    y_pred = np.zeros((n_samples, n_tasks, n_classes))
    for ind, task in enumerate(self.tasks):
      task_model = self.model_builder([task], {task: self.task_types[task]}, self.model_params,
      task_model = self.model_builder(
          [task], {task: self.task_types[task]}, self.model_params,
          self.task_model_dirs[task],
          verbosity=self.verbosity)
      task_model.reload()
+0 −13
Original line number Diff line number Diff line
@@ -117,23 +117,10 @@ class RandomSplitter(Splitter):
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    np.random.seed(seed)
    ########################################################### DEBUG
    print("About to compute len!")
    ########################################################### DEBUG
    num_datapoints = len(dataset)
    train_cutoff = int(frac_train * num_datapoints)
    ########################################################### DEBUG
    print("Successfully computed len!")
    ########################################################### DEBUG
    valid_cutoff = int((frac_train+frac_valid) * num_datapoints )
    ########################################################### DEBUG
    print("num_datapoints, train_cutoff, valid_cutoff")
    print(num_datapoints, train_cutoff, valid_cutoff)
    ########################################################### DEBUG
    shuffled = np.random.permutation(range(num_datapoints))
    ########################################################### DEBUG
    print("Successfully computed shuffled.")
    ########################################################### DEBUG
    return (shuffled[:train_cutoff], shuffled[train_cutoff:valid_cutoff],
            shuffled[valid_cutoff:])