Unverified Commit 2a0aeacc authored by Karl Leswing's avatar Karl Leswing Committed by GitHub
Browse files

Merge pull request #1424 from VIGS25/enhance-disk-dataset

#1349 - DiskDataset.from_numpy(X)
parents fac87370 a9c2a1ff
Loading
Loading
Loading
Loading
+25 −9
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import time
import shutil
import json
from multiprocessing.dummy import Pool
import warnings

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
@@ -955,7 +956,7 @@ class DiskDataset(Dataset):

  @staticmethod
  def from_numpy(X,
                 y,
                 y=None,
                 w=None,
                 ids=None,
                 tasks=None,
@@ -965,14 +966,29 @@ class DiskDataset(Dataset):
    n_samples = len(X)
    if ids is None:
      ids = np.arange(n_samples)

    if y is not None:
      if w is None:
        w = np.ones_like(y)

      if tasks is None:
        if len(y.shape) > 1:
          n_tasks = y.shape[1]
        else:
          n_tasks = 1
        tasks = np.arange(n_tasks)

    else:
      if w is not None:
        warnings.warn('y is None but w is not None. Setting w to None',
                      UserWarning)
        w = None

      if tasks is not None:
        warnings.warn('y is None but tasks is not None. Setting tasks to None',
                      UserWarning)
        tasks = None

    # raw_data = (X, y, w, ids)
    return DiskDataset.create_dataset(
        [(X, y, w, ids)], data_dir=data_dir, tasks=tasks, verbose=verbose)