Commit bf1810e4 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Resolved merge issues.

parents 819de9f4 cab666f3
Loading
Loading
Loading
Loading
+18 −12
Original line number Diff line number Diff line
@@ -19,10 +19,10 @@ def parse_args(input_args=None):
  parser.add_argument("--input-file", required=1,
                      help="Input file with data.")
  parser.add_argument("--input-type", default="csv",
                      choices=["csv", "pandas", "sdf"],
                      choices=["xlsx", "csv", "pandas", "sdf"],
                      help="Type of input file. If pandas, input must be a pkl.gz\n"
                           "containing a pandas dataframe. If sdf, should be in\n"
                           "gzipped sdf.gz file.")
                           "(perhaps gzipped) sdf file.")
  parser.add_argument("--fields", required=1, nargs="+",
                      help = "Names of fields.")
  parser.add_argument("--field-types", required=1, nargs="+",
@@ -36,6 +36,8 @@ def parse_args(input_args=None):
                      help="Name of measured endpoint to predict.")
  parser.add_argument("--threshold", type=float, default=None,
                      help="Used to turn real-valued data into binary.")
  parser.add_argument("--delimiter", default="\t",
                      help="Delimiter in csv file")
  return parser.parse_args(input_args)

def generate_directories(name, out):
@@ -100,18 +102,19 @@ def generate_descriptors(name, out):
                   sdf, descriptors,
                   "descriptors"])

def get_rows(input_file, input_type):
def get_rows(input_file, input_type, delimiter):
  """Returns an iterator over all rows in input_file"""
  # TODO(rbharath): This function loads into memory, which can be painful. The
  # right option here might be to create a class which internally handles data
  # loading.
  if input_type == "xlsx":
    W = px.load_workbook(xlsx_file, use_iterators=True)
    p = W.get_sheet_by_name(name="Sheet1")
    W = px.load_workbook(input_file, use_iterators=True)
    sheet_names = W.get_sheet_names()
    p = W.get_sheet_by_name(name=sheet_names[0])    # Take first sheet as the active sheet
    return p.iter_rows()
  elif input_type == "csv":
    with open(csv_file, "rb") as f:
      reader = csv.reader(f, delimiter="\t")
    with open(input_file, "rb") as f:
      reader = csv.reader(f, delimiter=delimiter)
      return [row for row in reader]
  elif input_type == "pandas":
    with gzip.open(input_file) as f:
@@ -171,10 +174,10 @@ def process_field(data, field_type):
    return data 

def generate_targets(input_file, input_type, fields, field_types, out_pkl,
    out_sdf, prediction_endpoint, threshold):
    out_sdf, prediction_endpoint, threshold, delimiter):
  """Process input data file."""
  rows, mols, smiles = [], [], SmilesGenerator()
  for row_index, raw_row in enumerate(get_rows(input_file, input_type)):
  for row_index, raw_row in enumerate(get_rows(input_file, input_type, delimiter)):
    print row_index
    # Skip row labels.
    if row_index == 0 or raw_row is None:  
@@ -209,9 +212,12 @@ def generate_targets(input_file, input_type, fields, field_types, out_pkl,

def main():
  args = parse_args()
  if len(args.fields) != len(args.field_types):
    raise ValueError("number of fields does not equal number of field types")
  out_pkl, out_sdf = generate_directories(args.name, args.out)
  generate_targets(args.input_file, args.input_type, args.fields,
      args.field_types, out_pkl, out_sdf, args.prediction_endpoint, args.threshold)
      args.field_types, out_pkl, out_sdf, args.prediction_endpoint,
      args.threshold, args.delimiter)
  generate_fingerprints(args.name, args.out)
  generate_descriptors(args.name, args.out)