Unverified Commit 1d57213d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1447 from christabella/issue/1185-sdf-loading-without-csv

#1185: SDF loader should load tasks from SDF File
parents 2a72d8e1 dfac79f3
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -284,13 +284,14 @@ class SDFLoader(DataLoader):
  def __init__(self, tasks, clean_mols=False, **kwargs):
    super(SDFLoader, self).__init__(tasks, **kwargs)
    self.clean_mols = clean_mols
    self.tasks = tasks
    self.smiles_field = "smiles"
    self.mol_field = "mol"
    self.id_field = "smiles"

  def get_shards(self, input_files, shard_size):
    """Defines a generator which returns data for each shard"""
    return load_sdf_files(input_files, self.clean_mols)
    return load_sdf_files(input_files, self.clean_mols, tasks=self.tasks)

  def featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
+19 −8
Original line number Diff line number Diff line
@@ -70,22 +70,33 @@ def load_data(input_files, shard_size=None, verbose=True):
      yield load_pickle_from_disk(input_file)


def load_sdf_files(input_files, clean_mols):
def load_sdf_files(input_files, clean_mols, tasks=[]):
  """Load SDF file into dataframe."""
  dataframes = []
  for input_file in input_files:
    # Tasks are stored in .sdf.csv file
    raw_df = next(load_csv_files([input_file + ".csv"], shard_size=None))
    # Tasks are either in .sdf.csv file or in the .sdf file itself
    has_csv = os.path.isfile(input_file + ".csv")
    # Structures are stored in .sdf file
    print("Reading structures from %s." % input_file)
    suppl = Chem.SDMolSupplier(str(input_file), clean_mols, False, False)
    df_rows = []
    for ind, mol in enumerate(suppl):
      if mol is not None:
      if mol is None:
        continue
      smiles = Chem.MolToSmiles(mol)
        df_rows.append([ind, smiles, mol])
      df_row = [ind, smiles, mol]
      if not has_csv:  # Get task targets from .sdf file
        for task in tasks:
          df_row.append(mol.GetProp(str(task)))
      df_rows.append(df_row)
    if has_csv:
      mol_df = pd.DataFrame(df_rows, columns=('mol_id', 'smiles', 'mol'))
      raw_df = next(load_csv_files([input_file + ".csv"], shard_size=None))
      dataframes.append(pd.concat([mol_df, raw_df], axis=1, join='inner'))
    else:
      mol_df = pd.DataFrame(
          df_rows, columns=('mol_id', 'smiles', 'mol') + tuple(tasks))
      dataframes.append(mol_df)
  return dataframes