Commit 92d262ee authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixed insidious typo in train_test_split causing eval issues. Some NaN issues now though.

parent f664d079
Loading
Loading
Loading
Loading
+0 −7
Original line number Diff line number Diff line
@@ -277,9 +277,6 @@ def train_test_split(paths, input_transforms, output_transforms,
  dataset_files = []
  for path in paths:
    dataset_files += glob.glob(os.path.join(path, "*.joblib"))
  print("paths")
  print(paths)

  print("Loading featurized data.")
  samples_dir = os.path.join(data_dir, "samples")
  samples = FeaturizedSamples(samples_dir, dataset_files, reload=False)
@@ -321,10 +318,6 @@ def eval_trained_model(model_type, model_dir, data_dir,
                       csv_out, stats_out, split="test"):
  """Evaluates a trained model on specified data."""
  model = Model.load(model_type, model_dir)
  print("eval_trained_model()")
  print("data_dir")
  print(data_dir)
  
  data = Dataset(data_dir)

  evaluator = Evaluator(model, data, verbose=True)
+8 −2
Original line number Diff line number Diff line
@@ -38,12 +38,18 @@ class Dataset(object):
          write_dataset_single, data_dir=self.data_dir,
          feature_types=feature_types)
      print("Dataset()")
      print("samples.compounds_df")
      print(samples.compounds_df)
      print("data_dir")
      print(data_dir)
      print("len(samples.compounds_df)")
      print(len(samples.compounds_df))

      metadata_rows = []
      # TODO(rbharath): Still a bit of information leakage.
      for df_file, df in zip(samples.dataset_files, samples.itersamples()):
        print("df_file")
        print(df_file)
        print("len(df)")
        print(len(df))
        retval = write_dataset_single_partial((df_file, df))
        if retval is not None:
          metadata_rows.append(retval)
+11 −11
Original line number Diff line number Diff line
@@ -247,18 +247,13 @@ class FeaturizedSamples(object):
    if not os.path.exists(feature_dir):
      os.makedirs(feature_dir)
    self.feature_dir = feature_dir
    print("FeaturizedSamples()")
    if os.path.exists(self._get_compounds_filename()) and reload:
      print("compounds loaded from disk")
      compounds_df = load_from_disk(self._get_compounds_filename())
    else:
      print("compounds recomputed")
      compounds_df = self._get_compounds()
      # compounds_df is not altered by any method after initialization, so it's
      # safe to keep a copy in memory and on disk.
      save_to_disk(compounds_df, self._get_compounds_filename())
    print("len(compounds_df)")
    print(len(compounds_df))
    self._check_validity(compounds_df)
    self.compounds_df = compounds_df
    
@@ -307,7 +302,7 @@ class FeaturizedSamples(object):
    """Internal method used to replace compounds_df."""
    self._check_validity(df)
    save_to_disk(df, self._get_compounds_filename())
    self.compounsd_df = df
    self.compounds_df = df

  # TODO(rbharath): Might this be inefficient?
  def itersamples(self):
@@ -338,12 +333,17 @@ class FeaturizedSamples(object):
      train_inds, test_inds = self.train_test_specified_split()
    else:
      raise ValueError("improper splittype.")
    train_dataset = FeaturizedSamples(train_dir, self.dataset_files)
    train_dataset._set_compound_df(self.compounds_df.iloc[train_inds])
    test_dataset = FeaturizedSamples(test_dir, self.dataset_files)
    test_dataset._set_compound_df(self.compounds_df.iloc[test_inds])
    print("train_test_split()")
    train_samples = FeaturizedSamples(train_dir, self.dataset_files)
    train_samples._set_compound_df(self.compounds_df.iloc[train_inds])
    print("len(train_inds)")
    print(len(train_inds))
    test_samples = FeaturizedSamples(test_dir, self.dataset_files)
    test_samples._set_compound_df(self.compounds_df.iloc[test_inds])
    print("len(test_inds)")
    print(len(test_inds))

    return train_dataset, test_dataset
    return train_samples, test_samples

  def _train_test_random_split(self, seed=None, frac_train=.8):
    """