Commit 4a1f2e3c authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Initial data loader

parent e2fd409b
Loading
Loading
Loading
Loading
+21 −18
Original line number Diff line number Diff line
@@ -131,7 +131,8 @@ def _get_user_specified_features(df, featurizer):
      pd.to_numeric)
  X_shard = df[featurizer.feature_fields].to_numpy()
  time2 = time.time()
  logger.info("TIMING: user specified processing took %0.3f s" % (time2 - time1))
  logger.info(
      "TIMING: user specified processing took %0.3f s" % (time2 - time1))
  return X_shard


@@ -184,11 +185,7 @@ class DataLoader(object):
  `featurize`.
  """

  def __init__(self,
               tasks,
               id_field=None,
               featurizer=None,
               log_every_n=1000):
  def __init__(self, tasks, id_field=None, featurizer=None, log_every_n=1000):
    """Construct a DataLoader object.

    This constructor is provided as a template mainly. You
@@ -256,13 +253,11 @@ class DataLoader(object):
          assert len(X) == len(ids)

        time2 = time.time()
        logger.info(
            "TIMING: featurizing shard %d took %0.3f s" %
        logger.info("TIMING: featurizing shard %d took %0.3f s" %
                    (shard_num, time2 - time1))
        yield X, y, w, ids

    return DiskDataset.create_dataset(
        shard_generator(), data_dir, self.tasks)
    return DiskDataset.create_dataset(shard_generator(), data_dir, self.tasks)

  def _get_shards(self, input_files, shard_size):
    """Stub for children classes."""
@@ -326,7 +321,11 @@ class CSVLoader(DataLoader):

  def _featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
    return _featurize_smiles_df(shard, self.featurizer, field=self.smiles_field, log_every_N=self.log_every_n)
    return _featurize_smiles_df(
        shard,
        self.featurizer,
        field=self.smiles_field,
        log_every_N=self.log_every_n)


class UserCSVLoader(CSVLoader):
@@ -350,7 +349,8 @@ class SDFLoader(DataLoader):
  Handles loading of SDF files.
  """

  def __init__(self, tasks, clean_mols=False, featurizer=None, log_every_n=1000):
  def __init__(self, tasks, clean_mols=False, featurizer=None,
               log_every_n=1000):
    """Initialize SDF Loader

    Parameters
@@ -381,10 +381,13 @@ class SDFLoader(DataLoader):

  def _featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
    logger.info(
        "Currently featurizing feature_type: %s" %
    logger.info("Currently featurizing feature_type: %s" %
                self.featurizer.__class__.__name__)
    return _featurize_mol_df(shard, self.featurizer, field=self.mol_field, log_every_N=self.log_every_n)
    return _featurize_mol_df(
        shard,
        self.featurizer,
        field=self.mol_field,
        log_every_N=self.log_every_n)


class FASTALoader(DataLoader):
@@ -519,6 +522,7 @@ class ImageLoader(DataLoader):
        raise ValueError("Unsupported image filetype for %s" % image_file)
    return np.array(images)


class MolecularComplexLoader(DataLoader):
  """Handles Loading of Molecular Complex Data

@@ -563,4 +567,3 @@ class MolecularComplexLoader(DataLoader):
      If provided, a numpy ndarray of image weights
    """
    raise NotImplementedError