Commit ff31d0a7 authored by Christabella Irwanto's avatar Christabella Irwanto
Browse files

Pass tasks into load_sdf_files so we can extract tasks from .sdf

parent cce55b72
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -284,13 +284,14 @@ class SDFLoader(DataLoader):
  def __init__(self, tasks, clean_mols=False, **kwargs):
    super(SDFLoader, self).__init__(tasks, **kwargs)
    self.clean_mols = clean_mols
    self.tasks = tasks
    self.smiles_field = "smiles"
    self.mol_field = "mol"
    self.id_field = "smiles"

  def get_shards(self, input_files, shard_size):
    """Defines a generator which returns data for each shard"""
    return load_sdf_files(input_files, self.clean_mols)
    return load_sdf_files(input_files, self.clean_mols, self.tasks)

  def featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""