Commit bd763350 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Splitting seems to work reasonably on NCI dataset now

parent 0e57beeb
Loading
Loading
Loading
Loading
+22 −6
Original line number Diff line number Diff line
@@ -325,20 +325,36 @@ class Dataset(object):

  def select(self, select_dir, indices):
    """Creates a new dataset from a selection of indices from self."""
    indices = np.array(indices).astype(int)
    count = 0
    if not os.path.exists(select_dir):
      os.makedirs(select_dir)
    if not len(indices):
      return Dataset(data_dir=select_dir, metadata_row=[], verbosity=self.verbosity)
    indices = np.array(sorted(indices)).astype(int)
    count, indices_count = 0, 0
    metadata_rows = []
    tasks = self.get_task_names()
    for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
      log("Selecting from shard %d" % shard_num, self.verbosity)
      shard_len = len(X)
      X_sel = X[indices[count:count+shard_len]]
      y_sel = y[indices[count:count+shard_len]]
      w_sel = w[indices[count:count+shard_len]]
      ids_sel = ids[indices[count:count+shard_len]]
      # Find indices which rest in this shard
      num_shard_elts = 0
      while indices[indices_count+num_shard_elts] < count + shard_len:
        num_shard_elts += 1
        if indices_count + num_shard_elts >= len(indices):
          break
      # Need to offset indices to fit within shard_size
      shard_indices = (
          indices[indices_count:indices_count+num_shard_elts] - count)
      X_sel = X[shard_indices]
      y_sel = y[shard_indices]
      w_sel = w[shard_indices]
      ids_sel = ids[shard_indices]
      basename = "dataset-%d" % shard_num
      metadata_rows.append(
          Dataset.write_data_to_disk(select_dir, basename, tasks,
                                     X_sel, y_sel, w_sel, ids_sel))
      # Updating counts
      indices_count += num_shard_elts
      count += shard_len
    return Dataset(data_dir=select_dir,
                   metadata_rows=metadata_rows,
+0 −3
Original line number Diff line number Diff line
@@ -47,9 +47,6 @@ class Splitter(object):
        dataset,
        frac_train=frac_train, frac_test=frac_test,
        frac_valid=frac_valid, log_every_n=log_every_n)
    ########################################################### DEBUG
    print("Computed indices successfully!")
    ########################################################### DEBUG
    train_dataset = dataset.select(train_dir, train_inds)
    if valid_dir is not None:
      valid_dataset = dataset.select(valid_dir, valid_inds)
+2 −2
Original line number Diff line number Diff line
@@ -35,8 +35,8 @@ nci_tasks, nci_dataset, transformers = load_nci(

if os.path.exists(base_dir):
  shutil.rmtree(base_dir)
if not os.path.exists(base_dir):
os.makedirs(base_dir)

train_dir = os.path.join(base_dir, "train_dataset")
valid_dir = os.path.join(base_dir, "valid_dataset")
test_dir = os.path.join(base_dir, "test_dataset")