Commit e98965e2 authored by leswing's avatar leswing
Browse files

Portable

parent 979f707e
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -481,6 +481,7 @@ class DiskDataset(Dataset):
    metadata_filename = os.path.join(self.data_dir, "metadata.joblib")
    if os.path.exists(metadata_filename):
      tasks, metadata_df = load_from_disk(metadata_filename)
      del metadata_df['task_names']
      save_metadata(tasks, metadata_df, self.data_dir)
      return tasks, metadata_df
    raise ValueError("No Metadata Found On Disk")
@@ -492,7 +493,7 @@ class DiskDataset(Dataset):
    metadata_entries should have elements returned by write_data_to_disk
    above.
    """
    columns = ('basename', 'task_names', 'ids', 'X', 'y', 'w')
    columns = ('basename', 'ids', 'X', 'y', 'w')
    metadata_df = pd.DataFrame(metadata_entries, columns=columns)
    return metadata_df

@@ -529,7 +530,7 @@ class DiskDataset(Dataset):
      out_ids = None

    # note that this corresponds to the _construct_metadata column order
    return [basename, tasks, out_ids, out_X, out_y, out_w]
    return [basename, out_ids, out_X, out_y, out_w]

  def save_to_disk(self):
    """Save dataset to disk."""
+2 −1
Original line number Diff line number Diff line
@@ -821,6 +821,7 @@ class GraphConvTensorGraph(TensorGraph):
        self.default_generator(dataset, predict=True),
        metrics,
        labels=self.my_labels,
        transformers=transformers,
        weights=[self.my_task_weights],
        per_task_metrics=per_task_metrics)

+1 −2
Original line number Diff line number Diff line
@@ -117,7 +117,6 @@ def save_metadata(tasks, metadata_df, data_dir):
  Returns
  -------
  """

  if isinstance(tasks, np.ndarray):
    tasks = tasks.tolist()
  metadata_filename = os.path.join(data_dir, "metadata.hd5")
@@ -125,7 +124,7 @@ def save_metadata(tasks, metadata_df, data_dir):
  with open(tasks_filename, 'w') as fout:
    json.dump(tasks, fout)
  hdf = pd.HDFStore(metadata_filename)
  hdf.put('metadata', metadata_df)
  hdf.put('metadata', metadata_df, format='table')
  hdf.close()