Unverified Commit ac6ceac0 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #997 from lilleswing/995-docstring

Add docstring for  dc.util.save.load_dataset_from_disk
parents 0ab3a49b b969b19c
Loading
Loading
Loading
Loading
+27 −14
Original line number Diff line number Diff line
@@ -185,11 +185,28 @@ def load_pickle_from_disk(filename):


def load_dataset_from_disk(save_dir):
  """
  Parameters
  ----------
  save_dir: str

  Returns
  -------
  loaded: bool
    Whether the load succeeded
  all_dataset: (dc.data.Dataset, dc.data.Dataset, dc.data.Dataset)
    The train, valid, test datasets
  transformers: list of dc.trans.Transformer
    The transformers used for this dataset

  """

  train_dir = os.path.join(save_dir, "train_dir")
  valid_dir = os.path.join(save_dir, "valid_dir")
  test_dir = os.path.join(save_dir, "test_dir")
  if os.path.exists(train_dir) and os.path.exists(valid_dir) and os.path.exists(
      test_dir):
  if not os.path.exists(train_dir) or not os.path.exists(
      valid_dir) or not os.path.exists(test_dir):
    return False, None, list()
  loaded = True
  train = deepchem.data.DiskDataset(train_dir)
  valid = deepchem.data.DiskDataset(valid_dir)
@@ -197,10 +214,6 @@ def load_dataset_from_disk(save_dir):
  all_dataset = (train, valid, test)
  with open(os.path.join(save_dir, "transformers.pkl"), 'rb') as f:
    transformers = pickle.load(f)
  else:
    loaded = False
    all_dataset = None
    transformers = []
    return loaded, all_dataset, transformers