Commit 72ebe32e authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent d8aa5c98
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -324,15 +324,14 @@ class SDFLoader(DataLoader):
  Handles loading of SDF files.
  """

  def __init__(self, tasks, clean_mols=False, featurizer=None,
               log_every_n=1000):
  def __init__(self, tasks, sanitize=False, featurizer=None, log_every_n=1000):
    """Initialize SDF Loader

    Parameters
    ----------
    tasks: list[str]
      List of tasknames. These will be loaded from the SDF file.
    clean_mols: bool, optional
    sanitize: bool, optional
      Whether to sanitize molecules.
    featurizer: dc.feat.Featurizer, optional
      Featurizer to use to process data
@@ -340,7 +339,7 @@ class SDFLoader(DataLoader):
      Writes a logging statement this often.
    """
    self.featurizer = featurizer
    self.clean_mols = clean_mols
    self.sanitize = sanitize
    self.tasks = tasks
    # The field in which dc.utils.save.load_sdf_files stores
    # RDKit mol objects
@@ -352,7 +351,7 @@ class SDFLoader(DataLoader):

  def _get_shards(self, input_files, shard_size):
    """Defines a generator which returns data for each shard"""
    return load_sdf_files(input_files, self.clean_mols, tasks=self.tasks)
    return load_sdf_files(input_files, self.sanitize, tasks=self.tasks)

  def _featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
+1 −5
Original line number Diff line number Diff line
@@ -30,11 +30,7 @@ class TestFeaturizedSamples(unittest.TestCase):
    input_file = os.path.join(current_dir, "data/water.sdf")

    featurizer = dc.feat.CoulombMatrixEig(6, remove_hydrogens=False)
    loader = dc.data.SDFLoader(
        tasks=tasks,
        smiles_field="smiles",
        mol_field="mol",
        featurizer=featurizer)
    loader = dc.data.SDFLoader(tasks=tasks, featurizer=featurizer)

    dataset = loader.featurize(input_file)