Commit 58737439 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixing move behavior to match expected.

parent 0590c365
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -1290,7 +1290,11 @@ class DiskDataset(Dataset):
    if delete_if_exists and os.path.isdir(new_data_dir):
      shutil.rmtree(new_data_dir)
    shutil.move(self.data_dir, new_data_dir)
    self.data_dir = os.path.join(new_data_dir, os.path.basename(self.data_dir))
    if delete_if_exists:
      self.data_dir = new_data_dir
    else:
      self.data_dir = os.path.join(new_data_dir,
                                   os.path.basename(self.data_dir))

  def copy(self, new_data_dir: str) -> "DiskDataset":
    """Copies dataset to new directory.
+51 −1
Original line number Diff line number Diff line
@@ -347,7 +347,21 @@ def load_pickle_from_disk(filename):


def load_dataset_from_disk(save_dir):
  """
  """Loads MoleculeNet train/valid/test/transformers from disk.

  Expects that data was saved using `save_dataset_to_disk` below. Expects the
  following directory structure for `save_dir`:
  
  save_dir/
    |
    ---> train_dir/
    |
    ---> valid_dir/
    |
    ---> test_dir/
    |
    ---> transformers.pkl

  Parameters
  ----------
  save_dir: str
@@ -361,6 +375,9 @@ def load_dataset_from_disk(save_dir):
  transformers: list of dc.trans.Transformer
    The transformers used for this dataset

  See Also
  --------
  save_dataset_to_disk
  """

  train_dir = os.path.join(save_dir, "train_dir")
@@ -381,6 +398,39 @@ def load_dataset_from_disk(save_dir):


def save_dataset_to_disk(save_dir, train, valid, test, transformers):
  """Utility used by MoleculeNet to save train/valid/test datasets.

  This utility function saves a train/valid/test split of a dataset along
  with transformers in the same directory. The saved datasets will take the
  following structure:
  
  save_dir/
    |
    ---> train_dir/
    |
    ---> valid_dir/
    |
    ---> test_dir/
    |
    ---> transformers.pkl

  Parameters
  ----------
  save_dir: str
    Filename of directory to save datasets to.
  train: DiskDataset
    Training dataset to save.
  valid: DiskDataset
    Validation dataset to save.
  test: DiskDataset
    Test dataset to save.
  transformers: List
    List of transformers to save to disk.

  See Also
  --------
  load_dataset_from_disk 
  """
  train_dir = os.path.join(save_dir, "train_dir")
  valid_dir = os.path.join(save_dir, "valid_dir")
  test_dir = os.path.join(save_dir, "test_dir")