Commit aa0403c9 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Making it easier to process sharded datasets.

parent 61bc5bf6
Loading
Loading
Loading
Loading
+16 −68
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ from deepchem.utils.save import load_pickle_from_disk
from deepchem.featurizers import Featurizer, ComplexFeaturizer
from deepchem.featurizers import UserDefinedFeaturizer
from deepchem.datasets import Dataset
from deepchem.utils.save import load_data
from deepchem.utils.save import get_input_type

def _process_field(val):
  """Parse data in a field."""
@@ -37,64 +39,6 @@ def _process_field(val):
  else:
    raise ValueError("Field of unrecognized type: %s" % str(val))

def load_data(input_file, shard_size=None):
  """Loads data from disk.
     
  For CSV files, supports sharded loading for large files.
  """
  input_type = _get_input_type(input_file)
  if input_type == "sdf":
    if shard_size is not None:
      raise ValueError("shard_size must be None for sdf input.")
    return _load_sdf_file(input_file)
  elif input_type == "csv":
    return _load_csv_file(input_file, shard_size)
  elif input_type == "pandas-pickle":
    return [load_pickle_from_disk(input_file)]

def _load_sdf_file(input_file):
  """Load SDF file into dataframe."""
  # Tasks are stored in .sdf.csv file
  raw_df = _load_csv_file(input_file+".csv", shard_size=None).next()
  # Structures are stored in .sdf file
  print("Reading structures from %s." % input_file)
  suppl = Chem.SDMolSupplier(str(input_file), removeHs=False)
  df_rows = []
  for ind, mol in enumerate(suppl):
    if mol is not None:
      smiles = Chem.MolToSmiles(mol)
      df_rows.append([ind,smiles,mol])
  mol_df = pd.DataFrame(df_rows, columns=('mol_id', 'smiles', 'mol'))
  raw_df = pd.concat([mol_df, raw_df], axis=1, join='inner')
  return [raw_df]

def _load_csv_file(filename, shard_size=None):
  """Load data as pandas dataframe."""
  # First line of user-specified CSV *must* be header.
  if shard_size is None:
    yield pd.read_csv(filename)
  else:
    for df in pd.read_csv(filename, chunksize=shard_size):
      df = df.replace(np.nan, str(""), regex=True)
      yield df

def _get_input_type(input_file):
  """Get type of input file. Must be csv/pkl.gz/sdf file."""
  filename, file_extension = os.path.splitext(input_file)
  # If gzipped, need to compute extension again
  if file_extension == ".gz":
    filename, file_extension = os.path.splitext(filename)
  if file_extension == ".csv":
    return "csv"
  elif file_extension == ".pkl":
    return "pandas-pickle"
  elif file_extension == ".joblib":
    return "pandas-joblib"
  elif file_extension == ".sdf":
    return "sdf"
  else:
    raise ValueError("Unrecognized extension %s" % file_extension)

class DataFeaturizer(object):
  """
  Handles loading/featurizing of chemical samples (datapoints).
@@ -132,32 +76,36 @@ class DataFeaturizer(object):
    self.featurizers = featurizers
    self.log_every_n = log_every_n

  def featurize(self, input_file, data_dir, shard_size=8192, worker_pool=None):
    """Featurize provided file and write to specified location."""
  def featurize(self, input_files, data_dir, shard_size=8192, worker_pool=None):
    """Featurize provided files and write to specified location."""
    log("Loading raw samples now.", self.verbosity)

    if not os.path.exists(data_dir):
      os.makedirs(data_dir)

    # Construct partial function to write datasets.
    if not len(input_files):
      return None
    input_type = get_input_type(input_files[0])
    write_fn = partial(
        Dataset.write_dataframe, data_dir=data_dir,
        featurizers=self.featurizers, tasks=self.tasks)
    input_type = _get_input_type(input_file)

    metadata_rows = []
    for shard_num, raw_df_shard in enumerate(load_data(input_file, shard_size)):
      log("Loaded shard %d of size %s from file." % (shard_num+1, str(shard_size)),
    def map_function(args):
      (shard_num, raw_df_shard) = args
      log("Loading shard %d of size %s from file." % (shard_num+1, str(shard_size)),
          self.verbosity)
      log("About to featurize shard.", self.verbosity)

      def process_helper(row, fields, input_type):
        return self._process_raw_sample(input_type, row, fields)
      process_fn = partial(process_helper, fields=raw_df_shard.keys(),
                           input_type=input_type)
      return self._featurize_shard(
          raw_df_shard, process_fn, write_fn, shard_num, input_type)

      metadata_rows.append(self._featurize_shard(
          raw_df_shard, process_fn, write_fn, shard_num, input_type))
    if worker_pool is None:
      worker_pool = mp.Pool(processes=1)
    metadata_rows = worker_pool.map(
        map_function, enumerate(load_data(input_files, shard_size)))

    # TODO(rbharath): This whole bit with metadata_rows is an awkward way of
    # creating a Dataset. Is there a more elegant solutions?
+63 −0
Original line number Diff line number Diff line
@@ -25,6 +25,69 @@ def save_to_disk(dataset, filename, compress=3):
  """Save a dataset to file."""
  joblib.dump(dataset, filename, compress=compress)

def get_input_type(input_file):
  """Get type of input file. Must be csv/pkl.gz/sdf file."""
  filename, file_extension = os.path.splitext(input_file)
  # If gzipped, need to compute extension again
  if file_extension == ".gz":
    filename, file_extension = os.path.splitext(filename)
  if file_extension == ".csv":
    return "csv"
  elif file_extension == ".pkl":
    return "pandas-pickle"
  elif file_extension == ".joblib":
    return "pandas-joblib"
  elif file_extension == ".sdf":
    return "sdf"
  else:
    raise ValueError("Unrecognized extension %s" % file_extension)

def load_data(input_files, shard_size=None):
  """Loads data from disk.
     
  For CSV files, supports sharded loading for large files.
  """
  if not len(input_files):
    return []
  input_type = get_input_type(input_files[0])
  if input_type == "sdf":
    if shard_size is not None:
      raise ValueError("shard_size must be None for sdf input.")
    return load_sdf_files(input_files)
  elif input_type == "csv":
    return load_csv_files(input_files, shard_size)
  elif input_type == "pandas-pickle":
    return [load_pickle_from_disk(input_file) for input_file in input_files]

def load_sdf_files(input_files):
  """Load SDF file into dataframe."""
  dataframes = []
  for input_file in input_files:
    # Tasks are stored in .sdf.csv file
    raw_df = load_csv_file(input_file+".csv", shard_size=None).next()
    # Structures are stored in .sdf file
    print("Reading structures from %s." % input_file)
    suppl = Chem.SDMolSupplier(str(input_file), removeHs=False)
    df_rows = []
    for ind, mol in enumerate(suppl):
      if mol is not None:
        smiles = Chem.MolToSmiles(mol)
        df_rows.append([ind,smiles,mol])
    mol_df = pd.DataFrame(df_rows, columns=('mol_id', 'smiles', 'mol'))
    dataframes.append(pd.concat([mol_df, raw_df], axis=1, join='inner'))
  return dataframes

def load_csv_file(filenames, shard_size=None):
  """Load data as pandas dataframe."""
  # First line of user-specified CSV *must* be header.
  for filename in filenames:
    if shard_size is None:
      yield pd.read_csv(filename)
    else:
      for df in pd.read_csv(filename, chunksize=shard_size):
        df = df.replace(np.nan, str(""), regex=True)
        yield df

def load_from_disk(filename):
  """Load a dataset from file."""
  name = filename