Commit 896cc659 authored by Christabella Irwanto's avatar Christabella Irwanto
Browse files

Handle backwards compatibility where .sdf.csv exists storing tasks

parent 130e98dd
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -291,7 +291,7 @@ class SDFLoader(DataLoader):

  def get_shards(self, input_files, shard_size):
    """Defines a generator which returns data for each shard"""
    return load_sdf_files(input_files, self.clean_mols, self.tasks)
    return load_sdf_files(input_files, self.clean_mols, tasks=self.tasks)

  def featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
+9 −7
Original line number Diff line number Diff line
@@ -81,19 +81,21 @@ def load_sdf_files(input_files, clean_mols, tasks=[]):
    suppl = Chem.SDMolSupplier(str(input_file), clean_mols, False, False)
    df_rows = []
    for ind, mol in enumerate(suppl):
      if mol is not None:
      if mol is None:
        continue
      smiles = Chem.MolToSmiles(mol)
      df_row = [ind, smiles, mol]
      if not has_csv:  # Get task targets from .sdf file
        for task in tasks:
          df_row.append(mol.GetProp(str(task)))
      df_rows.append(df_row)
    mol_df = pd.DataFrame(
        df_rows, columns=('mol_id', 'smiles', 'mol') + tuple(tasks))
    # Tasks are stored either in .sdf.csv file
    if has_csv:
      mol_df = pd.DataFrame(df_rows, columns=('mol_id', 'smiles', 'mol'))
      raw_df = next(load_csv_files([input_file + ".csv"], shard_size=None))
      dataframes.append(pd.concat([mol_df, raw_df], axis=1, join='inner'))
    else:
      mol_df = pd.DataFrame(
        df_rows, columns=('mol_id', 'smiles', 'mol') + tuple(tasks))
      dataframes.append(mol_df)
  return dataframes