Commit 416c577c authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

SDF processing.

parent aee5a160
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
# Usage ./process_bace.sh INPUT_SDF_FILE
python -m deep_chem.scripts.process_dataset --input_file $1 --input-type sdf --fields Name smiles pIC50 --field-types string string concentration --name BACE --out /tmp/
+47 −29
Original line number Diff line number Diff line
@@ -16,17 +16,19 @@ from vs_utils.utils import SmilesGenerator
def parse_args(input_args=None):
  """Parse command-line arguments."""
  parser = argparse.ArgumentParser()
  parser.add_argument('--input-file', required=1,
                      help='Input file with data.')
  parser.add_argument("--input-file", required=1,
                      help="Input file with data.")
  parser.add_argument("--input-type", default="csv",
                      choices=["csv", "pandas"],
                      help="Type of input file. The pkl.gz must contain a pandas dataframe.")
  parser.add_argument("--columns", required=1, nargs="+",
                      help = "Names of columns.")
  parser.add_argument('--column-types', required=1, nargs="+",
                      choices=['string', 'float', 'list-string', 'list-float',
                               'ndarray'],
                      help='Type of data in columns.')
                      choices=["csv", "pandas", "sdf"],
                      help="Type of input file. If pandas, input must be a pkl.gz\n"
                           "containing a pandas dataframe. If sdf, should be in\n"
                           "gzipped sdf.gz file.")
  parser.add_argument("--fields", required=1, nargs="+",
                      help = "Names of fields.")
  parser.add_argument("--field-types", required=1, nargs="+",
                      choices=["string", "float", "list-string", "list-float",
                               "ndarray", "concentration"],
                      help="Type of data in fields. Concentration is for molar concentrations.")
  parser.add_argument("--name", required=1,
                      help="Name of the dataset.")
  parser.add_argument("--out", required=1,
@@ -79,10 +81,10 @@ def generate_fingerprints(name, out):
                   "circular", "--size", "1024"])

def globavir_specs():
  columns = ["compound_name", "isomeric_smiles", "tdo_ic50_nm", "tdo_Ki_nm",
  fields = ["compound_name", "isomeric_smiles", "tdo_ic50_nm", "tdo_Ki_nm",
    "tdo_percent_activity_10_um", "tdo_percent_activity_1_um", "ido_ic50_nm",
    "ido_Ki_nm", "ido_percent_activity_10_um", "ido_percent_activity_1_um"]
  column_types = ["string", "string", "float", "float", "float", "float",
  field_types = ["string", "string", "float", "float", "float", "float",
      "float", "float", "float", "float"]

def get_rows(input_file, input_type):
@@ -102,8 +104,11 @@ def get_rows(input_file, input_type):
    with gzip.open(input_file) as f:
      df = pickle.load(f)
    return df.iterrows()
  elif input_type == "sdf":
    with gzip.open(input_file) as f:
      return Chem.ForwardSDMolSupplier(f)

def get_row_data(row, input_type, columns):
def get_row_data(row, input_type, fields, field_types):
  """Extract information from row data."""
  if input_type == "xlsx":
    return [cell.internal_value for cell in row]
@@ -112,37 +117,50 @@ def get_row_data(row, input_type, columns):
  elif input_type == "pandas":
    # pandas rows are tuples (row_num, row_info)
    row, row_data = row[1], {}
    # pandas rows are keyed by column-name. Change to key by index to match
    # pandas rows are keyed by field-name. Change to key by index to match
    # csv/xlsx handling
    for ind, column in enumerate(columns):
      row_data[ind] = row[column]
    for ind, field in enumerate(fields):
      row_data[ind] = row[field]
    return row_data
  elif input_type == "sdf":
    row_data, mol = {}, row
    for ind, (field, field_type) in enumerate(zip(fields, field_types)):
      # TODO(rbharath): This is kludgey...
      if field == "smiles":
        row_data[ind] = Chem.MolToSmiles(mol)
        continue
      if not mol.HasProp(field):
        row_data[ind] = None
      else:
        row_data[ind] = mol.GetProp(field)

def process_field(data, column_type):
def process_field(data, field_type):
  """Parse data in a field."""
  if column_type == "string":
  if field_type == "string":
    return data 
  elif column_type == "float":
  elif field_type == "float":
    return parse_float_input(data)
  elif column_type == "list-string":
  elif field_type == "concentration":
    return parse_float_input(data) / 1e-6
  elif field_type == "list-string":
    return data.split(",")
  elif column_type == "list-float":
  elif field_type == "list-float":
    return np.array(data.split(","))
  elif column_type == "ndarray":
  elif field_type == "ndarray":
    return data 

def generate_targets(input_file, input_type, columns, column_types, out_pkl,
def generate_targets(input_file, input_type, fields, field_types, out_pkl,
    out_sdf):
  """Process input data file."""
  rows, mols, smiles = [], [], SmilesGenerator()
  for row_index, raw_row in enumerate(get_rows(input_file, input_type)):
    print row_index
    # Skip row labels.
    if row_index == 0:
    if row_index == 0 or raw_row is None:
      continue
    row, row_data = {}, get_row_data(raw_row, input_type, columns)
    for ind, (column, column_type) in enumerate(zip(columns, column_types)):
      row[column] = process_field(row_data[ind], column_type)
    row, row_data = {}, get_row_data(raw_row, input_type, fields, field_types)
    for ind, (field, field_type) in enumerate(zip(fields, field_types)):
      row[field] = process_field(row_data[ind], field_type)
    # TODO(rbharath): This patch is only in place until the smiles/sequence
    # support is fixed.
    if row["smiles"] is None:
@@ -167,8 +185,8 @@ def generate_targets(input_file, input_type, columns, column_types, out_pkl,
def main():
  args = parse_args()
  out_pkl, out_sdf = generate_directories(args.name, args.out)
  generate_targets(args.input_file, args.input_type, args.columns,
      args.column_types, out_pkl, out_sdf)
  generate_targets(args.input_file, args.input_type, args.fields,
      args.field_types, out_pkl, out_sdf)
  generate_fingerprints(args.name, args.out)