Unverified Commit 3ba55f47 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1034 from lilleswing/fix-merge

NumpyDataset Merge
parents 9c2e036b fc42a142
Loading
Loading
Loading
Loading
+24 −0
Original line number Diff line number Diff line
@@ -450,6 +450,30 @@ class NumpyDataset(Dataset):
      d = json.load(fin)
      return NumpyDataset(d['X'], d['y'], d['w'], d['ids'])

  @staticmethod
  def merge(datasets):
    """
    Parameters
    ----------
    datasets: list of deepchem.data.NumpyDataset
      list of datasets to merge

    Returns
    -------
    Single deepchem.data.NumpyDataset with data concatenated over axis 0
    """
    X, y, w, ids = datasets[0].X, datasets[0].y, datasets[0].w, datasets[0].ids
    for dataset in datasets[1:]:
      X = np.concatenate([X, dataset.X], axis=0)
      y = np.concatenate([y, dataset.y], axis=0)
      w = np.concatenate([w, dataset.w], axis=0)
      ids = np.concatenate(
          [ids, dataset.ids],
          axis=0,
      )

    return NumpyDataset(X, y, w, ids, n_tasks=y.shape[1])


class DiskDataset(Dataset):
  """
+1 −1
Original line number Diff line number Diff line
@@ -132,7 +132,7 @@ class TestSplitters(unittest.TestCase):
    assert len(valid_data) == 1
    assert len(test_data) == 1

    merged_dataset = dc.data.DiskDataset.merge(
    merged_dataset = dc.data.NumpyDataset.merge(
        [train_data, valid_data, test_data])
    assert sorted(merged_dataset.ids) == (sorted(solubility_dataset.ids))