Commit adc366b5 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Further changes to preprocess.

parent b13e058a
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -58,12 +58,12 @@ def dataset_to_numpy(dataset, mode):
      task_ys[task] = np.reshape(np.array(task_ys[task]), (n_samples, 1))
      task_ws[task] = np.reshape(np.array(task_ws[task]), (n_samples, 1))
    return sorted_ids, X, task_ys, task_ws
  elif mode == multitask:
  elif mode == "multitask":
    y = np.zeros((n_samples, n_tasks))
    w = np.ones((n_samples, n_tasks))
    for task in sorted_tasks:
      y[:,task] = np.array(task_ys[task])
      w[:, task] = np.array(task_ws[task])
    for ind, task in enumerate(sorted_tasks):
      y[:,ind] = np.array(task_ys[task])
      w[:,ind] = np.array(task_ws[task])
    return sorted_ids, X, {"all": y}, {"all": w} 
  else:
    raise ValueError("Unsupported mode for process_datasets.")