Commit 0336e0d8 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Added changes to process scripts.

parent 1ed0ec47
Loading
Loading
Loading
Loading
+37 −9
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import pandas as pd
import openpyxl as px
import numpy as np
import argparse
import csv
from rdkit import Chem
import subprocess
from vs_utils.utils import SmilesGenerator
@@ -15,8 +16,13 @@ from vs_utils.utils import SmilesGenerator
def parse_args(input_args=None):
  """Parse command-line arguments."""
  parser = argparse.ArgumentParser()
  parser.add_argument('--data', required=1,
  parser.add_argument('--input-file', required=1,
                      help='Input file with data.')
  parser.add_argument("--columns", required=1, nargs="+",
                      help = "Names of columns.")
  parser.add_argument('--column-types', required=1, nargs="+",
                      choices=['string', 'float', 'list', 'float-array'],
                      help='Name of dataset to process.')
  parser.add_argument("--name", required=1,
                      help="Name of the dataset.")
  parser.add_argument("--out", required=1,
@@ -75,25 +81,47 @@ def globavir_specs():
  column_types = ["string", "string", "float", "float", "float", "float",
      "float", "float", "float", "float"]

def generate_targets(xlsx_file, columns, column_types, out_pkl, out_sdf):
  """Process input data file."""
  rows, mols = [], []
def gen_xlsx_rows(xlxs_file):
  W = px.load_workbook(xlsx_file, use_iterators=True)
  p = W.get_sheet_by_name(name="Sheet1")
  return p.iter_rows()

def get_xlsx_row_data(row):
  return [cell.internal_value for cell in row]

def gen_csv_rows(csv_file):
  # This is a memory leak...
  f = open(csv_file, "rb")
  return csv.reader(f, delimiter="\t")

def generate_targets(input_file, columns, column_types, out_pkl, out_sdf, type="csv"):
  """Process input data file."""
  rows, mols = [], []
  smiles = SmilesGenerator()
  for row_index, row in enumerate(p.iter_rows()):
  if type == "xlsx":
    row_gen = gen_xlsx_rows(input_file)
  elif type == "csv":
    row_gen = gen_csv_rows(input_file)
  for row_index, raw_row in enumerate(row_gen):
    print row_index
    # Skip row labels.
    if row_index == 0:
      continue
    row_data = [cell.internal_value for cell in row]
    if type == "xlsx":
      row_data = get_xlsx_row_data(raw_row)
    elif type == "csv":
      row_data = raw_row 
      
    row = {}
    for ind, (column, column_type) in enumerate(zip(columns, column_types)):
      if column_type == "string":
        row[column] = row_data[ind]
      elif column_type == "float":
        row[column] = parse_float_input(row_data[ind])
      elif column_type == "float-array" and ind = len(columns) - 1:
        row[column] = np.array(row_data[ind:])
      elif column_type == "list":
        row[column] = row_data[ind].split(",")
      elif column_type == "float-array":
        row[column] = np.array(row_data[ind].split(","))

    mol = Chem.MolFromSmiles(row["smiles"])
    row["smiles"] = smiles.get_smiles(mol)
@@ -113,7 +141,7 @@ def generate_targets(xlsx_file, columns, column_types, out_pkl, out_sdf):
def main():
  args = parse_args()
  out_pkl, out_sdf = generate_directories(args.name, args.out)
  generate_targets(args.data, out_pkl, out_sdf)
  generate_targets(args.input_file, args.columns, args.column_types, out_pkl, out_sdf)
  generate_fingerprints(args.name, args.out)