Commit 2b8191d9 authored by miaecle's avatar miaecle
Browse files

unit test for molnet

parent 2f83f641
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

from deepchem.molnet.load_function.chembl_datasets import load_chembl
from deepchem.molnet.load_function.clintox_datasets import load_clintox
from deepchem.molnet.load_function.delaney_datasets import load_delaney
+0 −0

Empty file added.

+67 −0
Original line number Diff line number Diff line
"""
Tests for molnet function 
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import unittest
import numpy as np
import pandas as pd
import deepchem as dc
import csv

class TestMolnet(unittest.TestCase):
  """
  Test basic function of molnet
  """

  def setUp(self):
    super(TestMolnet, self).setUp()
    self.current_dir = os.path.dirname(os.path.abspath(__file__))

  def test_delaney_graphconvreg(self):
    """Tests molnet benchmarking on delaney with graphconvreg."""
    datasets = ['delaney']
    model = 'graphconvreg'
    split = 'random'
    out_path = self.current_dir
    dc.molnet.run_benchmark(datasets, model, split=split, out_path=out_path)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model
      assert lastrow[-5] == 'valid'
      assert lastrow[-3] > 0.75

  def test_qm7_multitask(self):
    """Tests molnet benchmarking on qm7 with multitask network."""
    datasets = ['qm7']
    model = 'tf_regression'
    split = 'random'
    out_path = self.current_dir
    dc.molnet.run_benchmark(datasets, model, split=split, out_path=out_path)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model + '_ft'
      assert lastrow[-5] == 'valid'
      assert lastrow[-3] > 0.95

  def test_tox21_multitask(self):
    """Tests molnet benchmarking on tox21 with multitask network."""
    datasets = ['tox21']
    model = 'tf'
    split = 'random'
    out_path = self.current_dir
    dc.molnet.run_benchmark(datasets, model, split=split, out_path=out_path)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model
      assert lastrow[-5] == 'valid'
      assert lastrow[-3] > 0.75