Commit d8aa5c98 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Updating data loader

parent 5cd03aef
Loading
Loading
Loading
Loading
+8 −6
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ def _convert_df_to_numpy(df, tasks):
  return y.astype(float), w.astype(float)


def _featurize_smiles_df(df, featurizer, field, log_every_N=1000):
def _featurize_smiles_df(df, featurizer, field, log_every_n=1000):
  """Featurize individual compounds in dataframe.

  Given a featurizer that operates on individual chemical
@@ -74,7 +74,7 @@ def _featurize_smiles_df(df, featurizer, field, log_every_N=1000):
    if mol:
      new_order = rdmolfiles.CanonicalRankAtoms(mol)
      mol = rdmolops.RenumberAtoms(mol, new_order)
    if ind % log_every_N == 0:
    if ind % log_every_n == 0:
      logger.info("Featurizing sample %d" % ind)
    features.append(featurizer.featurize([mol]))
  valid_inds = np.array(
@@ -109,7 +109,7 @@ def _get_user_specified_features(df, featurizer):
  return X_shard


def _featurize_mol_df(df, featurizer, field, log_every_N=1000):
def _featurize_mol_df(df, featurizer, field, log_every_n=1000):
  """Featurize individual compounds in dataframe.

  Used when processing .sdf files, so the 3-D structure should be
@@ -124,12 +124,14 @@ def _featurize_mol_df(df, featurizer, field, log_every_N=1000):
    Should be created by dc.utils.save.load_sdf_files.
  featurizer: dc.feat.MolecularFeaturizer
    Featurizer for molecules.
  log_every_n: int, optional
    Controls how often logging statements are emitted.
  """
  sample_elems = df[field].tolist()

  features = []
  for ind, mol in enumerate(sample_elems):
    if ind % log_every_N == 0:
    if ind % log_every_n == 0:
      logger.info("Featurizing sample %d" % ind)
    features.append(featurizer.featurize([mol]))
  valid_inds = np.array(
@@ -298,7 +300,7 @@ class CSVLoader(DataLoader):
        shard,
        self.featurizer,
        field=self.smiles_field,
        log_every_N=self.log_every_n)
        log_every_n=self.log_every_n)


class UserCSVLoader(CSVLoader):
@@ -360,7 +362,7 @@ class SDFLoader(DataLoader):
        shard,
        self.featurizer,
        field=self.mol_field,
        log_every_N=self.log_every_n)
        log_every_n=self.log_every_n)


class FASTALoader(DataLoader):