Commit 9290c70e authored by miaecle's avatar miaecle
Browse files

test modification

parent 5e3f001f
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -200,12 +200,12 @@ def run_benchmark(datasets,
      for i in train_score:
        output_line = [
            dataset, str(split), mode, 'train', i,
            train_score[i][train_score[i].keys()[0]], 'valid', i,
            valid_score[i][valid_score[i].keys()[0]]
            train_score[i][list(train_score[i].keys())[0]], 'valid', i,
            valid_score[i][list(valid_score[i].keys())[0]]
        ]
        if test:
          output_line.extend(
              ['test', i, test_score[i][test_score[i].keys()[0]]])
              ['test', i, test_score[i][list(test_score[i].keys(0))[0]]])
        output_line.extend(
            ['time_for_running', time_finish_fitting - time_start_fitting])
        writer.writerow(output_line)
+10 −6
Original line number Diff line number Diff line
@@ -28,14 +28,15 @@ class TestMolnet(unittest.TestCase):
    model = 'graphconvreg'
    split = 'random'
    out_path = self.current_dir
    dc.molnet.run_benchmark(datasets, model, split=split, out_path=out_path)
    dc.molnet.run_benchmark(datasets, str(model), split=split, out_path=out_path)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model
      assert lastrow[-5] == 'valid'
      assert lastrow[-3] > 0.75
      assert float(lastrow[-3]) > 0.75
    os.remove(os.path.join(out_path, 'results.csv'))

  def test_qm7_multitask(self):
    """Tests molnet benchmarking on qm7 with multitask network."""
@@ -43,14 +44,15 @@ class TestMolnet(unittest.TestCase):
    model = 'tf_regression'
    split = 'random'
    out_path = self.current_dir
    dc.molnet.run_benchmark(datasets, model, split=split, out_path=out_path)
    dc.molnet.run_benchmark(datasets, str(model), split=split, out_path=out_path)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model + '_ft'
      assert lastrow[-5] == 'valid'
      assert lastrow[-3] > 0.95
      assert float(lastrow[-3]) > 0.95
    os.remove(os.path.join(out_path, 'results.csv'))

  def test_tox21_multitask(self):
    """Tests molnet benchmarking on tox21 with multitask network."""
@@ -58,11 +60,13 @@ class TestMolnet(unittest.TestCase):
    model = 'tf'
    split = 'random'
    out_path = self.current_dir
    dc.molnet.run_benchmark(datasets, model, split=split, out_path=out_path)
    dc.molnet.run_benchmark(datasets, str(model), split=split, out_path=out_path)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model
      assert lastrow[-5] == 'valid'
      assert lastrow[-3] > 0.75
      assert float(lastrow[-3]) > 0.75
    os.remove(os.path.join(out_path, 'results.csv'))