Commit 3c8f339f authored by nd-02110114's avatar nd-02110114
Browse files

Merge branch 'master' into update-data-2

parents 87bf44e5 6bdba2e0
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import time
import logging
import warnings
from typing import List, Optional, Tuple, Any, Sequence, Union, Iterator

import pandas as pd
import numpy as np

@@ -628,7 +629,7 @@ class JsonLoader(DataLoader):
                    (shard_num, time2 - time1))
        yield X, y, w, ids

    return DiskDataset.create_dataset(shard_generator(), data_dir)
    return DiskDataset.create_dataset(shard_generator(), data_dir, self.tasks)

  def _get_shards(self, input_files: List[str],
                  shard_size: Optional[int]) -> Iterator[pd.DataFrame]:
+229 −143
Original line number Diff line number Diff line
@@ -516,11 +516,14 @@ class Dataset(object):

    return tf.data.Dataset.from_generator(gen_data, dtypes, shapes)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: Optional[int] = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y,
    w, id) for one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
@@ -529,6 +532,9 @@ class Dataset(object):
    deterministic: bool, default False
      If True, the data is produced in order. If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int, optional (default None)
      The number of samples to return in each batch. If None, each returned
      value is a single sample.

    Returns
    -------
@@ -880,19 +886,25 @@ class NumpyDataset(Dataset):
    ids = self.ids[indices]
    return NumpyDataset(X, y, w, ids)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: Optional[int] = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
    epochs: int, default 1
      The number of times to iterate over the Dataset.
      The number of times to iterate over the Dataset
    deterministic: bool, default False
      If True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
      If True, the data is produced in order. If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int, optional (default None)
      The number of samples to return in each batch. If None, each returned
      value is a single sample.

    Returns
    -------
@@ -910,7 +922,10 @@ class NumpyDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = _TorchNumpyDataset(
        numpy_dataset=self, epochs=epochs, deterministic=deterministic)
        numpy_dataset=self,
        epochs=epochs,
        deterministic=deterministic,
        batch_size=batch_size)
    return pytorch_ds

  @staticmethod
@@ -1757,19 +1772,25 @@ class DiskDataset(Dataset):
    return DiskDataset.write_data_to_disk(out_dir, basename, tasks, X, y, w,
                                          ids)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: Optional[int] = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
    epochs: int, default 1
      The number of times to iterate over the Dataset.
      The number of times to iterate over the Dataset
    deterministic: bool, default False
      If True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
      If True, the data is produced in order. If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int, optional (default None)
      The number of samples to return in each batch. If None, each returned
      value is a single sample.

    Returns
    -------
@@ -1787,7 +1808,10 @@ class DiskDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = _TorchDiskDataset(
        disk_dataset=self, epochs=epochs, deterministic=deterministic)
        disk_dataset=self,
        epochs=epochs,
        deterministic=deterministic,
        batch_size=batch_size)
    return pytorch_ds

  @staticmethod
@@ -1967,7 +1991,7 @@ class DiskDataset(Dataset):
    time2 = time.time()
    logger.info("TIMING: sparse_shuffle took %0.3f s" % (time2 - time1))

  def complete_shuffle(self, data_dir: Optional[str] = None) -> "DiskDataset":
  def complete_shuffle(self, data_dir: Optional[str] = None) -> Dataset:
    """Completely shuffle across all data, across all shards.

    Notes
@@ -1994,55 +2018,7 @@ class DiskDataset(Dataset):
    N = len(self)
    perm = np.random.permutation(N)
    shard_size = self.get_shard_size()

    def generator():
      start = 0
      shard_num = 0
      while start < N:
        logger.info("Constructing shard %d" % shard_num)
        if start + shard_size < N:
          end = start + shard_size
        else:
          end = N
        shard_indices = perm[start:end]
        # Note that this is in sorted order which doesn't respect the random
        # permutation.
        shard_dataset = self.select(shard_indices)
        # One bit of trickiness here is that select() will return in sorted
        # order. For example, suppose we'd like these elements in our permuted
        # shard:
        #
        # [12, 234, 1, 4]
        #
        # Then select would return elements in order
        #
        # [1, 4, 12, 234]
        #
        # We need to recover the original ordering. We can do this by using
        # np.where to find the locatios of the original indices in the sorted
        # indices.
        sorted_indices = np.array(sorted(shard_indices))
        reverted_indices = np.array(
            # We know there's only one match for np.where since this is a
            # permutation, so the [0][0] pulls out the exact match location.
            [
                np.where(sorted_indices == orig_index)[0][0]
                for orig_index in shard_indices
            ])
        # Let's pull out shard elements
        shard_X, shard_y, shard_w, shard_ids = (shard_dataset.X,
                                                shard_dataset.y,
                                                shard_dataset.w,
                                                shard_dataset.ids)

        yield (shard_X[reverted_indices], shard_y[reverted_indices],
               shard_w[reverted_indices], shard_ids[reverted_indices])

        start = end
        shard_num += 1

    return DiskDataset.create_dataset(
        generator(), data_dir=data_dir, tasks=self.get_task_names())
    return self.select(perm, data_dir, self.get_shard_size())

  def shuffle_each_shard(self,
                         shard_basenames: Optional[List[str]] = None) -> None:
@@ -2262,56 +2238,116 @@ class DiskDataset(Dataset):
    DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w, ids)
    self._cached_shards = None

  def select(self, indices: Sequence[int],
             select_dir: Optional[str] = None) -> "DiskDataset":
  def select(self,
             indices: Sequence[int],
             select_dir: Optional[str] = None,
             select_shard_size: Optional[int] = None,
             output_numpy_dataset: Optional[bool] = False) -> Dataset:
    """Creates a new dataset from a selection of indices from self.

    Note
    ----
    The specified indices will be returned in sorted order. That is, if you
    request that indices `[3, 1, 2]` are returned, you will get a
    `DiskDataset` which contains elements in order `[1, 2, 3]`.
    Examples
    --------
    >>> import numpy as np
    >>> X = np.random.rand(10, 10)
    >>> dataset = dc.data.DiskDataset.from_numpy(X)
    >>> selected = dataset.select([1, 3, 4])
    >>> len(selected)
    3

    Parameters
    ----------
    indices: Sequence
      List of indices to select.
    select_dir: str, optional (default None)
      Path to new directory that the selected indices will be copied
      to.
      Path to new directory that the selected indices will be copied to.
    select_shard_size: Optional[int], (default None)
      If specified, the shard-size to use for output selected `DiskDataset`.
      If not output_numpy_dataset, then this is set to this current dataset's
      shard size if not manually specified. 
    output_numpy_dataset: Optional[bool], (default False)
      If True, output an in-memory `NumpyDataset` instead of a `DiskDataset`.
      Note that `select_dir` and `select_shard_size` must be `None` if this
      is `True`

    Returns
    -------
    DiskDataset
      A selected DiskDataset object
      A DiskDataset contains selected samples.
    """
    if output_numpy_dataset and (select_dir is not None or
                                 select_shard_size is not None):
      raise ValueError(
          "If output_numpy_dataset is set, then select_dir and select_shard_size must both be None"
      )
    if output_numpy_dataset:
      # When outputting a NumpyDataset, we have 1 in-memory shard
      select_shard_size = len(indices)
    else:
      if select_dir is not None:
        if not os.path.exists(select_dir):
          os.makedirs(select_dir)
      else:
        select_dir = tempfile.mkdtemp()
      if select_shard_size is None:
        select_shard_size = self.get_shard_size()
    # Handle edge case with empty indices
    if not len(indices):
      if not output_numpy_dataset:
        return DiskDataset.create_dataset([], data_dir=select_dir)
    indices = np.array(sorted(indices)).astype(int)
    tasks = self.get_task_names()
      else:
        return NumpyDataset(
            np.array([]), np.array([]), np.array([]), np.array([]))

    N = len(indices)
    indices = np.array(indices).astype(int)
    tasks = self.get_task_names()
    n_shards = self.get_number_shards()

    # We use two loops here. The outer while loop walks over selection shards
    # (the chunks of the indices to select that should go into separate
    # output shards), while the inner for loop walks over the shards in the
    # source datasets to select out the shard indices from that  source shard
    def generator():
      start = 0
      select_shard_num = 0
      while start < N:
        logger.info(
            "Constructing selection output shard %d" % (select_shard_num + 1))
        end = min(start + select_shard_size, N)
        select_shard_indices = indices[start:end]
        sorted_indices = np.array(sorted(select_shard_indices)).astype(int)

        Xs, ys, ws, ids_s = [], [], [], []
        count, indices_count = 0, 0
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        logger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
        shard_len = len(X)
        for shard_num in range(self.get_number_shards()):
          logger.info(
              "Selecting from input shard %d/%d for selection output shard %d" %
              (shard_num + 1, n_shards, select_shard_num + 1))
          if self.legacy_metadata:
            ids = self.get_shard_ids(shard_num)
            shard_len = len(ids)
          else:
            shard_X_shape, _, _, _ = self._get_shard_shape(shard_num)
            if len(shard_X_shape) > 0:
              shard_len = shard_X_shape[0]
            else:
              shard_len = 0
          # Find indices which rest in this shard
          num_shard_elts = 0
        while indices[indices_count + num_shard_elts] < count + shard_len:
          while sorted_indices[indices_count +
                               num_shard_elts] < count + shard_len:
            num_shard_elts += 1
          if indices_count + num_shard_elts >= len(indices):
            if (indices_count + num_shard_elts) >= len(sorted_indices):
              break
          if num_shard_elts == 0:
            count += shard_len
            continue
          else:
            X, y, w, ids = self.get_shard(shard_num)
          # Need to offset indices to fit within shard_size
        shard_inds = indices[indices_count:indices_count +
          shard_inds = sorted_indices[indices_count:indices_count +
                                      num_shard_elts] - count
          # Handle empty case where no data from this shard needed
          X_sel = X[shard_inds]
          # Handle the case of datasets with y/w missing
          if y is not None:
@@ -2323,16 +2359,42 @@ class DiskDataset(Dataset):
          else:
            w_sel = None
          ids_sel = ids[shard_inds]
        yield (X_sel, y_sel, w_sel, ids_sel)
        # Updating counts
          Xs.append(X_sel)
          ys.append(y_sel)
          ws.append(w_sel)
          ids_s.append(ids_sel)
          indices_count += num_shard_elts
          count += shard_len
        # Break when all indices have been used up already
        if indices_count >= len(indices):
          return
          # Break if all indices have been used up already
          if indices_count >= len(sorted_indices):
            break
        # Note these will be in the sorted order
        X = np.concatenate(Xs, axis=0)
        y = np.concatenate(ys, axis=0)
        w = np.concatenate(ws, axis=0)
        ids = np.concatenate(ids_s, axis=0)
        # We need to recover the original ordering. We can do this by using
        # np.where to find the locatios of the original indices in the sorted
        # indices.
        reverted_indices = np.array(
            # We know there's only one match for np.where since this is a
            # permutation, so the [0][0] pulls out the exact match location.
            [
                np.where(sorted_indices == orig_index)[0][0]
                for orig_index in select_shard_indices
            ])
        X, y, w, ids = X[reverted_indices], y[reverted_indices], w[
            reverted_indices], ids[reverted_indices]
        yield (X, y, w, ids)
        start = end
        select_shard_num += 1

    if not output_numpy_dataset:
      return DiskDataset.create_dataset(
          generator(), data_dir=select_dir, tasks=tasks)
    else:
      X, y, w, ids = next(generator())
      return NumpyDataset(X, y, w, ids)

  @property
  def ids(self) -> np.ndarray:
@@ -2410,14 +2472,14 @@ class DiskDataset(Dataset):
      total += len(y)
    return total

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
  def _get_shard_shape(self,
                       shard_num: int) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds the shape of the specified shard."""
    if self.legacy_metadata:
      raise ValueError(
          "This function requires the new metadata format to be called. Please reshard this dataset by calling the reshard() method."
      )
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    # If shape metadata is available use it to directly compute shape from
    # metadata
    if not self.legacy_metadata:
      for shard_num in range(n_rows):
    row = self.metadata_df.iloc[shard_num]
    if row['X_shape'] is not None:
      shard_X_shape = make_tuple(str(row['X_shape']))
@@ -2439,6 +2501,21 @@ class DiskDataset(Dataset):
      shard_ids_shape = make_tuple(str(row['ids_shape']))
    else:
      shard_ids_shape = tuple()
    X_shape, y_shape, w_shape, ids_shape = tuple(
        np.array(shard_X_shape)), tuple(np.array(shard_y_shape)), tuple(
            np.array(shard_w_shape)), tuple(np.array(shard_ids_shape))
    return X_shape, y_shape, w_shape, ids_shape

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    # If shape metadata is available use it to directly compute shape from
    # metadata
    if not self.legacy_metadata:
      for shard_num in range(n_rows):
        shard_X_shape, shard_y_shape, shard_w_shape, shard_ids_shape = self._get_shard_shape(
            shard_num)
        if shard_num == 0:
          X_shape, y_shape, w_shape, ids_shape = np.array(
              shard_X_shape), np.array(shard_y_shape), np.array(
@@ -2728,11 +2805,14 @@ class ImageDataset(Dataset):
    ids = self._ids[indices]
    return ImageDataset(X, y, w, ids)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: Optional[int] = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y,
    w, id) for one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
@@ -2741,6 +2821,9 @@ class ImageDataset(Dataset):
    deterministic: bool, default False
      If True, the data is produced in order. If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int, optional (default None)
      The number of samples to return in each batch. If None, each returned
      value is a single sample.

    Returns
    -------
@@ -2758,7 +2841,10 @@ class ImageDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = _TorchImageDataset(
        image_dataset=self, epochs=epochs, deterministic=deterministic)
        image_dataset=self,
        epochs=epochs,
        deterministic=deterministic,
        batch_size=batch_size)
    return pytorch_ds


+63 −18
Original line number Diff line number Diff line
@@ -6,8 +6,11 @@ from deepchem.data.datasets import NumpyDataset, DiskDataset, ImageDataset

class _TorchNumpyDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, numpy_dataset: NumpyDataset, epochs: int,
               deterministic: bool):
  def __init__(self,
               numpy_dataset: NumpyDataset,
               epochs: int,
               deterministic: bool,
               batch_size: int = None):
    """
    Parameters
    ----------
@@ -18,10 +21,14 @@ class _TorchNumpyDataset(torch.utils.data.IterableDataset): # type: ignore
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.
    """
    self.numpy_dataset = numpy_dataset
    self.epochs = epochs
    self.deterministic = deterministic
    self.batch_size = batch_size

  def __iter__(self):
    n_samples = self.numpy_dataset._X.shape[0]
@@ -36,16 +43,28 @@ class _TorchNumpyDataset(torch.utils.data.IterableDataset): # type: ignore
      if self.deterministic:
        order = first_sample + np.arange(last_sample - first_sample)
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
        # Ensure that every worker will pick the same random order for each epoch.
        random = np.random.RandomState(epoch)
        order = random.permutation(n_samples)[first_sample:last_sample]
      if self.batch_size is None:
        for i in order:
          yield (self.numpy_dataset._X[i], self.numpy_dataset._y[i],
                 self.numpy_dataset._w[i], self.numpy_dataset._ids[i])
      else:
        for i in range(0, len(order), self.batch_size):
          indices = order[i:i + self.batch_size]
          yield (self.numpy_dataset._X[indices], self.numpy_dataset._y[indices],
                 self.numpy_dataset._w[indices],
                 self.numpy_dataset._ids[indices])


class _TorchDiskDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, disk_dataset: DiskDataset, epochs: int,
               deterministic: bool):
  def __init__(self,
               disk_dataset: DiskDataset,
               epochs: int,
               deterministic: bool,
               batch_size: int = None):
    """
    Parameters
    ----------
@@ -56,10 +75,14 @@ class _TorchDiskDataset(torch.utils.data.IterableDataset): # type: ignore
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.
    """
    self.disk_dataset = disk_dataset
    self.epochs = epochs
    self.deterministic = deterministic
    self.batch_size = batch_size

  def __iter__(self):
    worker_info = torch.utils.data.get_worker_info()
@@ -74,17 +97,25 @@ class _TorchDiskDataset(torch.utils.data.IterableDataset): # type: ignore
      return

    shard_indices = list(range(first_shard, last_shard))
    for epoch in range(self.epochs):
    for X, y, w, ids in self.disk_dataset._iterbatches_from_shards(
          shard_indices, deterministic=self.deterministic):
        shard_indices,
        batch_size=self.batch_size,
        epochs=self.epochs,
        deterministic=self.deterministic):
      if self.batch_size is None:
        for i in range(X.shape[0]):
          yield (X[i], y[i], w[i], ids[i])
      else:
        yield (X, y, w, ids)


class _TorchImageDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, image_dataset: ImageDataset, epochs: int,
               deterministic: bool):
  def __init__(self,
               image_dataset: ImageDataset,
               epochs: int,
               deterministic: bool,
               batch_size: int = None):
    """
    Parameters
    ----------
@@ -95,10 +126,14 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.
    """
    self.image_dataset = image_dataset
    self.epochs = epochs
    self.deterministic = deterministic
    self.batch_size = batch_size

  def __iter__(self):
    n_samples = self.image_dataset._X_shape[0]
@@ -113,8 +148,18 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
      if self.deterministic:
        order = first_sample + np.arange(last_sample - first_sample)
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
        # Ensure that every worker will pick the same random order for each epoch.
        random = np.random.RandomState(epoch)
        order = random.permutation(n_samples)[first_sample:last_sample]
      if self.batch_size is None:
        for i in order:
          yield (self.image_dataset._get_image(self.image_dataset._X, i),
                 self.image_dataset._get_image(self.image_dataset._y, i),
                 self.image_dataset._w[i], self.image_dataset._ids[i])
      else:
        for i in range(0, len(order), self.batch_size):
          indices = order[i:i + self.batch_size]
          yield (self.image_dataset._get_image(self.image_dataset._X, indices),
                 self.image_dataset._get_image(self.image_dataset._y,
                                 indices), self.image_dataset._w[indices],
                 self.image_dataset._ids[indices])
+17 −21
Original line number Diff line number Diff line
@@ -272,27 +272,6 @@ def test_reshard():
  np.testing.assert_array_equal(ids, ids_rr)


def test_select():
  """Test that dataset select works."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [0, 4, 5, 8]
  select_dataset = dataset.select(indices)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_complete_shuffle():
  shard_sizes = [1, 2, 3, 4, 5]

@@ -742,9 +721,26 @@ def _validate_pytorch_dataset(dataset):
    id_count[iter_id] += 1
  assert all(id_count[id] == 2 for id in ids)

  # Test iterating in batches.

  ds = dataset.make_pytorch_dataset(epochs=2, deterministic=False, batch_size=7)
  id_to_index = dict((id, i) for i, id in enumerate(ids))
  id_count = dict((id, 0) for id in ids)
  for iter_X, iter_y, iter_w, iter_id in ds:
    size = len(iter_id)
    assert size <= 7
    for i in range(size):
      j = id_to_index[iter_id[i]]
      np.testing.assert_array_equal(X[j, :], iter_X[i])
      np.testing.assert_array_equal(y[j, :], iter_y[i])
      np.testing.assert_array_equal(w[j, :], iter_w[i])
      id_count[iter_id[i]] += 1
  assert all(id_count[id] == 2 for id in ids)

  # Test iterating with multiple workers.

  import torch  # noqa
  ds = dataset.make_pytorch_dataset(epochs=2, deterministic=False)
  loader = torch.utils.data.DataLoader(ds, num_workers=3)
  id_count = dict((id, 0) for id in ids)
  for iter_X, iter_y, iter_w, iter_id in loader:
+0 −4

File changed.

Preview size limit exceeded, changes collapsed.

Loading