Unverified Commit c84e7409 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1858 from peastman/npy

DiskDataset stores shards in npy format
parents 51f426b3 420ee9d3
Loading
Loading
Loading
Loading
+11 −11
Original line number Diff line number Diff line
@@ -1021,25 +1021,25 @@ class DiskDataset(Dataset):
                         w=None,
                         ids=None):
    if X is not None:
      out_X = "%s-X.joblib" % basename
      out_X = "%s-X.npy" % basename
      save_to_disk(X, os.path.join(data_dir, out_X))
    else:
      out_X = None

    if y is not None:
      out_y = "%s-y.joblib" % basename
      out_y = "%s-y.npy" % basename
      save_to_disk(y, os.path.join(data_dir, out_y))
    else:
      out_y = None

    if w is not None:
      out_w = "%s-w.joblib" % basename
      out_w = "%s-w.npy" % basename
      save_to_disk(w, os.path.join(data_dir, out_w))
    else:
      out_w = None

    if ids is not None:
      out_ids = "%s-ids.joblib" % basename
      out_ids = "%s-ids.npy" % basename
      save_to_disk(ids, os.path.join(data_dir, out_ids))
    else:
      out_ids = None
+12 −4
Original line number Diff line number Diff line
@@ -21,7 +21,12 @@ def log(string, verbose=True):

def save_to_disk(dataset, filename, compress=3):
  """Save a dataset to file."""
  if filename.endswith('.joblib'):
    joblib.dump(dataset, filename, compress=compress)
  elif filename.endswith('.npy'):
    np.save(filename, dataset)
  else:
    raise ValueError("Filename with unsupported extension: %s" % filename)


def get_input_type(input_file):
@@ -210,15 +215,18 @@ def load_from_disk(filename):
  name = filename
  if os.path.splitext(name)[1] == ".gz":
    name = os.path.splitext(name)[0]
  if os.path.splitext(name)[1] == ".pkl":
  extension = os.path.splitext(name)[1]
  if extension == ".pkl":
    return load_pickle_from_disk(filename)
  elif os.path.splitext(name)[1] == ".joblib":
  elif extension == ".joblib":
    return joblib.load(filename)
  elif os.path.splitext(name)[1] == ".csv":
  elif extension == ".csv":
    # First line of user-specified CSV *must* be header.
    df = pd.read_csv(filename, header=0)
    df = df.replace(np.nan, str(""), regex=True)
    return df
  elif extension == ".npy":
    return np.load(filename, allow_pickle=True)
  else:
    raise ValueError("Unrecognized filetype for %s" % filename)