Commit 6a2f7d0b authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixes to parallel featurization

parent 69931760
Loading
Loading
Loading
Loading
+21 −15
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ import dill
import multiprocessing as mp
from functools import partial
from rdkit import Chem
import itertools as it
from deepchem.utils.save import log
from deepchem.utils.save import save_to_disk
from deepchem.utils.save import load_pickle_from_disk
@@ -23,6 +24,23 @@ from deepchem.datasets import Dataset
from deepchem.utils.save import load_data
from deepchem.utils.save import get_input_type

def _process_helper(row, loader, fields, input_type):
  return loader._process_raw_sample(input_type, row, fields)

def featurize_map_function(args):
  ((loader, shard_size, input_type, data_dir), (shard_num, raw_df_shard)) = args
  log("Loading shard %d of size %s from file." % (shard_num+1, str(shard_size)),
      loader.verbosity)
  log("About to featurize shard.", loader.verbosity)
  write_fn = partial(
      Dataset.write_dataframe, data_dir=data_dir,
      featurizers=loader.featurizers, tasks=loader.tasks)
  process_fn = partial(_process_helper, loader=loader,
                       fields=raw_df_shard.keys(),
                       input_type=input_type)
  return loader._featurize_shard(
      raw_df_shard, process_fn, write_fn, shard_num, input_type)

def _process_field(val):
  """Parse data in a field."""
  if (isinstance(val, numbers.Number) or isinstance(val, np.ndarray)):
@@ -91,25 +109,13 @@ class DataFeaturizer(object):
    if not len(input_files):
      return None
    input_type = get_input_type(input_files[0])
    write_fn = partial(
        Dataset.write_dataframe, data_dir=data_dir,
        featurizers=self.featurizers, tasks=self.tasks)
    def map_function(args):
      (shard_num, raw_df_shard) = args
      log("Loading shard %d of size %s from file." % (shard_num+1, str(shard_size)),
          self.verbosity)
      log("About to featurize shard.", self.verbosity)
      def process_helper(row, fields, input_type):
        return self._process_raw_sample(input_type, row, fields)
      process_fn = partial(process_helper, fields=raw_df_shard.keys(),
                           input_type=input_type)
      return self._featurize_shard(
          raw_df_shard, process_fn, write_fn, shard_num, input_type)

    if worker_pool is None:
      worker_pool = mp.Pool(processes=1)
    metadata_rows = worker_pool.map(
        map_function, enumerate(load_data(input_files, shard_size)))
        featurize_map_function,
        it.izip(it.repeat((self, shard_size, input_type, data_dir)),
                enumerate(load_data(input_files, shard_size))))

    # TODO(rbharath): This whole bit with metadata_rows is an awkward way of
    # creating a Dataset. Is there a more elegant solutions?