Commit 91e53855 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes for external data eval.

parent f7801c00
Loading
Loading
Loading
Loading
+58 −46
Original line number Diff line number Diff line
@@ -133,6 +133,9 @@ def add_model_command(subparsers):
  model_cmd.add_argument(
      "--featurize", action="store_true",
      help="Perform the featurization step.")
  model_cmd.add_argument(
      "--generate-dataset", action="store_true",
      help="Generate dataset from featurized data.")
  model_cmd.add_argument(
      "--train-test-split", action="store_true",
      help="Perform the train-test-split step.")
@@ -142,6 +145,9 @@ def add_model_command(subparsers):
  model_cmd.add_argument(
      "--eval", action="store_true",
      help="Perform model eval step.")
  model_cmd.add_argument(
      "--eval-full", action="store_true",
      help="Evaluate model on full dataset.")
  model_cmd.add_argument(
      "--base-dir", type=str, default=None,
      help="The base directory for the model.")
@@ -154,15 +160,6 @@ def add_model_command(subparsers):
  model_cmd.add_argument(
      "--model-dir", type=str, default=None,
      help="The model storage directory for the model.")
  model_cmd.add_argument(
      "--eval-train", type=bool, default=True,
      help="Evaluate model on train dataset.")
  model_cmd.add_argument(
      "--eval-test", type=bool, default=True,
      help="Evaluate model on test dataset.")
  model_cmd.add_argument(
      "--eval-full", type=bool, default=False,
      help="Evaluate model on full dataset.")
  add_featurize_group(model_cmd)

  add_transforms_group(model_cmd)
@@ -205,32 +202,45 @@ def create_model(args):
    ensure_exists([feature_dir, data_dir, model_dir])
                

  if args.featurize:
    print("+++++++++++++++++++++++++++++++++")
    print("Perform featurization")
  if args.featurize:
    featurize_inputs(
        feature_dir, data_dir, args.input_files, args.user_specified_features,
        args.tasks, args.smiles_field, args.split_field, args.id_field,
        args.threshold, args.parallel)

  if args.generate_dataset:
    print("+++++++++++++++++++++++++++++++++")
    print("Generate dataset for featurized samples")
    samples_dir = os.path.join(data_dir, "samples")
    samples = FeaturizedSamples(samples_dir, reload=True)

    print("Generating dataset.")
    full_data_dir = os.path.join(data_dir, "full-data")
    full_dataset = Dataset(full_data_dir, samples, args.feature_types)

    print("Transform data.")
    full_dataset.transform(args.input_transforms, args.output_transforms)
  

  if args.train_test_split:
    print("+++++++++++++++++++++++++++++++++")
    print("Perform train-test split")
    paths = [feature_dir]
  if args.train_test_split:
    train_test_split(
        paths, args.input_transforms, args.output_transforms, args.feature_types,
        args.splittype, args.mode, data_dir)

  if args.fit:
    print("+++++++++++++++++++++++++++++++++")
    print("Fit model")
  if args.fit:
    model_params = extract_model_params(args)
    fit_model(
        model_name, model_params, model_dir, data_dir)

  print("+++++++++++++++++++++++++++++++++")
  if args.eval:
    if args.eval_train:
    print("+++++++++++++++++++++++++++++++++")
    print("Eval Model on Train")
    print("-------------------")
    train_dir = os.path.join(data_dir, "train-data")
@@ -240,7 +250,6 @@ def create_model(args):
        model_name, model_dir, train_dir, csv_out_train,
        stats_out_train)

    if args.eval_test:
    print("Eval Model on Test")
    print("------------------")
    test_dir = os.path.join(data_dir, "test-data")
@@ -251,12 +260,14 @@ def create_model(args):
        stats_out_test)

  if args.eval_full:
    print("+++++++++++++++++++++++++++++++++")
    print("Eval Model on Full Dataset")
    print("--------------------------")
    full_data_dir = os.path.join(data_dir, "full-data")
    csv_out_full = os.path.join(data_dir, "full.csv")
    stats_out_full = os.path.join(data_dir, "full-stats.txt")
    eval_trained_model(
          model_name, model_dir, data_dir, csv_out_full,
        model_name, model_dir, full_data_dir, csv_out_full,
        stats_out_full)

def parse_args(input_args=None):
@@ -288,7 +299,8 @@ def featurize_inputs(feature_dir, data_dir, input_files,
      featurize_input_partial(input_file)

  dataset_files = glob.glob(os.path.join(feature_dir, "*.joblib"))
  print("Loading featurized data.")

  print("Writing samples to disk.")
  samples_dir = os.path.join(data_dir, "samples")
  samples = FeaturizedSamples(samples_dir, dataset_files)