Commit 7bb681b9 authored by Bowen Liu's avatar Bowen Liu
Browse files

Added -delimiter flag for csv input. Other fixes

parent 86bae1c4
Loading
Loading
Loading
Loading
+9 −5
Original line number Diff line number Diff line
@@ -33,6 +33,8 @@ def parse_args(input_args=None):
                      help="Name of the dataset.")
  parser.add_argument("--out", required=1,
                      help="Folder to generate processed dataset in.")
  parser.add_argument("--delimiter", default="\t",
                      help="Delimiter in csv file")
  return parser.parse_args(input_args)

def generate_directories(name, out):
@@ -87,7 +89,7 @@ def globavir_specs():
  field_types = ["string", "string", "float", "float", "float", "float",
      "float", "float", "float", "float"]

def get_rows(input_file, input_type):
def get_rows(input_file, input_type, delimiter):
  """Returns an iterator over all rows in input_file"""
  # TODO(rbharath): This function loads into memory, which can be painful. The
  # right option here might be to create a class which internally handles data
@@ -99,7 +101,7 @@ def get_rows(input_file, input_type):
    return p.iter_rows()
  elif input_type == "csv":
    with open(input_file, "rb") as f:
      reader = csv.reader(f, delimiter=",")
      reader = csv.reader(f, delimiter=delimiter)
      return [row for row in reader]
  elif input_type == "pandas":
    with gzip.open(input_file) as f:
@@ -165,10 +167,10 @@ def process_field(data, field_type):
    return data 

def generate_targets(input_file, input_type, fields, field_types, out_pkl,
    out_sdf):
    out_sdf, delimiter):
  """Process input data file."""
  rows, mols, smiles = [], [], SmilesGenerator()
  for row_index, raw_row in enumerate(get_rows(input_file, input_type)):
  for row_index, raw_row in enumerate(get_rows(input_file, input_type, delimiter)):
    print row_index
    print raw_row
    # Skip row labels.
@@ -200,9 +202,11 @@ def generate_targets(input_file, input_type, fields, field_types, out_pkl,

def main():
  args = parse_args()
  if len(args.fields) != len(args.field_types):
    raise ValueError("number of fields does not equal number of field types")
  out_pkl, out_sdf = generate_directories(args.name, args.out)
  generate_targets(args.input_file, args.input_type, args.fields,
      args.field_types, out_pkl, out_sdf)
      args.field_types, out_pkl, out_sdf, args.delimiter)
  generate_fingerprints(args.name, args.out)