Commit cab666f3 authored by Bowen Liu's avatar Bowen Liu
Browse files

Merge pull request #12 from bowenliu16/testing

Fixes csv and xlsx input errors in deep_chem.scripts.process_datasets
parents 9bed9e66 2496eef0
Loading
Loading
Loading
Loading
+16 −11
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ def parse_args(input_args=None):
  parser.add_argument("--input-file", required=1,
                      help="Input file with data.")
  parser.add_argument("--input-type", default="csv",
                      choices=["csv", "pandas", "sdf"],
                      choices=["xlsx", "csv", "pandas", "sdf"],
                      help="Type of input file. If pandas, input must be a pkl.gz\n"
                           "containing a pandas dataframe. If sdf, should be in\n"
                           "(perhaps gzipped) sdf file.")
@@ -33,6 +33,8 @@ def parse_args(input_args=None):
                      help="Name of the dataset.")
  parser.add_argument("--out", required=1,
                      help="Folder to generate processed dataset in.")
  parser.add_argument("--delimiter", default="\t",
                      help="Delimiter in csv file")
  return parser.parse_args(input_args)

def generate_directories(name, out):
@@ -87,18 +89,19 @@ def globavir_specs():
  field_types = ["string", "string", "float", "float", "float", "float",
      "float", "float", "float", "float"]

def get_rows(input_file, input_type):
def get_rows(input_file, input_type, delimiter):
  """Returns an iterator over all rows in input_file"""
  # TODO(rbharath): This function loads into memory, which can be painful. The
  # right option here might be to create a class which internally handles data
  # loading.
  if input_type == "xlsx":
    W = px.load_workbook(xlsx_file, use_iterators=True)
    p = W.get_sheet_by_name(name="Sheet1")
    W = px.load_workbook(input_file, use_iterators=True)
    sheet_names = W.get_sheet_names()
    p = W.get_sheet_by_name(name=sheet_names[0])    # Take first sheet as the active sheet
    return p.iter_rows()
  elif input_type == "csv":
    with open(csv_file, "rb") as f:
      reader = csv.reader(f, delimiter="\t")
    with open(input_file, "rb") as f:
      reader = csv.reader(f, delimiter=delimiter)
      return [row for row in reader]
  elif input_type == "pandas":
    with gzip.open(input_file) as f:
@@ -164,10 +167,10 @@ def process_field(data, field_type):
    return data 

def generate_targets(input_file, input_type, fields, field_types, out_pkl,
    out_sdf):
    out_sdf, delimiter):
  """Process input data file."""
  rows, mols, smiles = [], [], SmilesGenerator()
  for row_index, raw_row in enumerate(get_rows(input_file, input_type)):
  for row_index, raw_row in enumerate(get_rows(input_file, input_type, delimiter)):
    print row_index
    print raw_row
    # Skip row labels.
@@ -199,9 +202,11 @@ def generate_targets(input_file, input_type, fields, field_types, out_pkl,

def main():
  args = parse_args()
  if len(args.fields) != len(args.field_types):
    raise ValueError("number of fields does not equal number of field types")
  out_pkl, out_sdf = generate_directories(args.name, args.out)
  generate_targets(args.input_file, args.input_type, args.fields,
      args.field_types, out_pkl, out_sdf)
      args.field_types, out_pkl, out_sdf, args.delimiter)
  generate_fingerprints(args.name, args.out)