Unverified Commit 0c5b2b89 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2024 from peastman/transform

Use multiple processes to transform datasets
parents b3bc14ac 4fde2603
Loading
Loading
Loading
Loading
+76 −29
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ import time
import shutil
import json
import warnings
from multiprocessing.dummy import Pool
import multiprocessing
from deepchem.utils.save import save_to_disk, save_metadata
from deepchem.utils.save import load_from_disk

@@ -380,8 +380,7 @@ class Dataset(object):
    """
    raise NotImplementedError()

  def transform(self, fn: Callable[[np.ndarray, np.ndarray, np.ndarray], Tuple[
      np.ndarray, np.ndarray, np.ndarray]], **args) -> "Dataset":
  def transform(self, transformer: "dc.trans.Transformer", **args) -> "Dataset":
    """Construct a new dataset by applying a transformation to every sample in this dataset.

    The argument is a function that can be called as follows:
@@ -394,8 +393,8 @@ class Dataset(object):

    Parameters
    ----------
    fn: function
      A function to apply to each sample in the dataset
    transformer: Transformer
      the transformation to apply to each sample in the dataset

    Returns
    -------
@@ -811,8 +810,8 @@ class NumpyDataset(Dataset):
    return ((self._X[i], self._y[i], self._w[i], self._ids[i])
            for i in range(n_samples))

  def transform(self, fn: Callable[[np.ndarray, np.ndarray, np.ndarray], Tuple[
      np.ndarray, np.ndarray, np.ndarray]], **args) -> "NumpyDataset":
  def transform(self, transformer: "dc.trans.Transformer",
                **args) -> "NumpyDataset":
    """Construct a new dataset by applying a transformation to every sample in this dataset.

    The argument is a function that can be called as follows:
@@ -825,14 +824,14 @@ class NumpyDataset(Dataset):

    Parameters
    ----------
    fn: function
      A function to apply to each sample in the dataset
    transformer: Transformer
      the transformation to apply to each sample in the dataset

    Returns
    -------
    a newly constructed Dataset object
    """
    newx, newy, neww = fn(self._X, self._y, self._w)
    newx, newy, neww = transformer.transform_array(self._X, self._y, self._w)
    return NumpyDataset(newx, newy, neww, self._ids[:])

  def select(self, indices: Sequence[int],
@@ -1218,7 +1217,8 @@ class DiskDataset(Dataset):
      # than process based pools, since process based pools need to pickle/serialize
      # objects as an extra overhead. Also, as hideously as un-thread safe this looks,
      # we're actually protected by the GIL.
      pool = Pool(1)  # mp.dummy aliases ThreadPool to Pool
      pool = multiprocessing.dummy.Pool(
          1)  # mp.dummy aliases ThreadPool to Pool

      if batch_size is None:
        num_global_batches = num_shards
@@ -1336,8 +1336,10 @@ class DiskDataset(Dataset):

    return iterate(self)

  def transform(self, fn: Callable[[np.ndarray, np.ndarray, np.ndarray], Tuple[
      np.ndarray, np.ndarray, np.ndarray]], **args) -> "DiskDataset":
  def transform(self,
                transformer: "dc.trans.Transformer",
                parallel=False,
                **args) -> "DiskDataset":
    """Construct a new dataset by applying a transformation to every sample in this dataset.

    The argument is a function that can be called as follows:
@@ -1350,11 +1352,13 @@ class DiskDataset(Dataset):

    Parameters
    ----------
    fn: function
      A function to apply to each sample in the dataset
    transformer: Transformer
      the transformation to apply to each sample in the dataset
    out_dir: string
      The directory to save the new dataset in.  If this is omitted, a
      temporary directory is created automatically
    parallel: bool
      if True, use multiple processes to transform the dataset in parallel

    Returns
    -------
@@ -1365,18 +1369,61 @@ class DiskDataset(Dataset):
    else:
      out_dir = tempfile.mkdtemp()
    tasks = self.get_task_names()

    n_shards = self.get_number_shards()

    time1 = time.time()
    if parallel:
      results = []
      pool = multiprocessing.Pool()
      for i in range(self.get_number_shards()):
        row = self.metadata_df.iloc[i]
        X_file = os.path.join(self.data_dir, row['X'])
        if row['y'] is not None:
          y_file: Optional[str] = os.path.join(self.data_dir, row['y'])
        else:
          y_file = None
        if row['w'] is not None:
          w_file: Optional[str] = os.path.join(self.data_dir, row['w'])
        else:
          w_file = None
        ids_file = os.path.join(self.data_dir, row['ids'])
        results.append(
            pool.apply_async(DiskDataset._transform_shard,
                             (transformer, i, X_file, y_file, w_file, ids_file,
                              out_dir, tasks)))
      pool.close()
      metadata_rows = [r.get() for r in results]
      metadata_df = DiskDataset._construct_metadata(metadata_rows)
      save_metadata(tasks, metadata_df, out_dir)
      dataset = DiskDataset(out_dir)
    else:

      def generator():
        for shard_num, row in self.metadata_df.iterrows():
          logger.info("Transforming shard %d/%d" % (shard_num, n_shards))
          X, y, w, ids = self.get_shard(shard_num)
        newx, newy, neww = fn(X, y, w)
          newx, newy, neww = transformer.transform_array(X, y, w)
          yield (newx, newy, neww, ids)

    return DiskDataset.create_dataset(
      dataset = DiskDataset.create_dataset(
          generator(), data_dir=out_dir, tasks=tasks)
    time2 = time.time()
    logger.info("TIMING: transforming took %0.3f s" % (time2 - time1))
    return dataset

  @staticmethod
  def _transform_shard(transformer: "dc.trans.Transformer", shard_num: int,
                       X_file: str, y_file: str, w_file: str, ids_file: str,
                       out_dir: str, tasks: np.ndarray):
    """This is called by transform() to transform a single shard."""
    X = None if X_file is None else np.array(load_from_disk(X_file))
    y = None if y_file is None else np.array(load_from_disk(y_file))
    w = None if w_file is None else np.array(load_from_disk(w_file))
    ids = np.array(load_from_disk(ids_file))
    X, y, w = transformer.transform_array(X, y, w)
    basename = "shard-%d" % shard_num
    return DiskDataset.write_data_to_disk(out_dir, basename, tasks, X, y, w,
                                          ids)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.
@@ -2082,8 +2129,8 @@ class ImageDataset(Dataset):
    return ((get_image(self._X, i), get_image(self._y, i), self._w[i],
             self._ids[i]) for i in range(n_samples))

  def transform(self, fn: Callable[[np.ndarray, np.ndarray, np.ndarray], Tuple[
      np.ndarray, np.ndarray, np.ndarray]], **args) -> NumpyDataset:
  def transform(self, transformer: "dc.trans.Transformer",
                **args) -> NumpyDataset:
    """Construct a new dataset by applying a transformation to every sample in this dataset.

    The argument is a function that can be called as follows:
@@ -2096,14 +2143,14 @@ class ImageDataset(Dataset):

    Parameters
    ----------
    fn: function
      A function to apply to each sample in the dataset
    transformer: Transformer
      the transformation to apply to each sample in the dataset

    Returns
    -------
    a newly constructed Dataset object
    """
    newx, newy, neww = fn(self.X, self.y, self.w)
    newx, newy, neww = transformer.transform_array(self.X, self.y, self.w)
    return NumpyDataset(newx, newy, neww, self.ids[:])

  def select(self, indices: Sequence[int],
+19 −15
Original line number Diff line number Diff line
@@ -53,6 +53,12 @@ def load_multitask_data():
  return loader.featurize(input_file)


class TestTransformer(dc.trans.Transformer):

  def transform_array(self, X, y, w):
    return (2 * X, 1.5 * y, w)


class TestDatasets(test_util.TensorFlowTestCase):
  """
  Test basic top-level API for dataset objects.
@@ -386,10 +392,8 @@ class TestDatasets(test_util.TensorFlowTestCase):

    # Transform it

    def fn(x, y, w):
      return (2 * x, 1.5 * y, w)

    transformed = dataset.transform(fn)
    transformer = TestTransformer(transform_X=True, transform_y=True)
    transformed = dataset.transform(transformer)
    np.testing.assert_array_equal(X, dataset.X)
    np.testing.assert_array_equal(y, dataset.y)
    np.testing.assert_array_equal(w, dataset.w)
@@ -408,10 +412,10 @@ class TestDatasets(test_util.TensorFlowTestCase):
    ids = dataset.ids

    # Transform it
    def fn(x, y, w):
      return (2 * x, 1.5 * y, w)

    transformed = dataset.transform(fn)
    transformer = TestTransformer(transform_X=True, transform_y=True)
    for parallel in (True, False):
      transformed = dataset.transform(transformer, parallel=parallel)
      np.testing.assert_array_equal(X, dataset.X)
      np.testing.assert_array_equal(y, dataset.y)
      np.testing.assert_array_equal(w, dataset.w)
+1 −3
Original line number Diff line number Diff line
@@ -56,9 +56,7 @@ class TestReload(unittest.TestCase):
    # TODO(rbharath): Transformers don't play nice with reload! Namely,
    # reloading will cause the transform to be reapplied. This is undesirable in
    # almost all cases. Need to understand a method to fix this.
    transformers = [
        dc.trans.BalancingTransformer(transform_w=True, dataset=train_dataset)
    ]
    transformers = [dc.trans.BalancingTransformer(dataset=train_dataset)]
    logger.info("Transforming datasets")
    for dataset in [train_dataset, valid_dataset, test_dataset]:
      for transformer in transformers:
+2 −6
Original line number Diff line number Diff line
@@ -175,9 +175,7 @@ def load_bace_classification(featurizer='ECFP',

  if split is None:
    # Initialize transformers
    transformers = [
        deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
    ]
    transformers = [deepchem.trans.BalancingTransformer(dataset=dataset)]

    logger.info("Split is None, about to transform data")
    for transformer in transformers:
@@ -204,9 +202,7 @@ def load_bace_classification(featurizer='ECFP',
      frac_valid=frac_valid,
      frac_test=frac_test)

  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
  ]
  transformers = [deepchem.trans.BalancingTransformer(dataset=train)]

  logger.info("About to transform data.")
  for transformer in transformers:
+2 −6
Original line number Diff line number Diff line
@@ -63,9 +63,7 @@ def load_bbbp(featurizer='ECFP',

  if split is None:
    # Initialize transformers
    transformers = [
        deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
    ]
    transformers = [deepchem.trans.BalancingTransformer(dataset=dataset)]

    logger.info("Split is None, about to transform data")
    for transformer in transformers:
@@ -91,9 +89,7 @@ def load_bbbp(featurizer='ECFP',
      frac_test=frac_test)

  # Initialize transformers
  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
  ]
  transformers = [deepchem.trans.BalancingTransformer(dataset=train)]

  for transformer in transformers:
    train = transformer.transform(train)
Loading