Commit 1a7b25c1 authored by nd-02110114's avatar nd-02110114
Browse files

♻️ refactor test

parent ff26a992
Loading
Loading
Loading
Loading
+40 −24
Original line number Diff line number Diff line
@@ -7,13 +7,16 @@ from deepchem.feat.molecule_graph import MoleculeGraphData, BatchMoleculeGraphDa
class TestMoleculeGraph(unittest.TestCase):

  def test_molecule_graph_data(self):
    num_nodes, num_node_features = 10, 32
    num_edges, num_edge_features = 13, 32
    num_nodes, num_node_features = 4, 32
    num_edges, num_edge_features = 6, 32
    node_features = np.random.random_sample((num_nodes, num_node_features))
    edge_features = np.random.random_sample((num_edges, num_edge_features))
    targets = np.random.random_sample(5)
    edge_index = np.array([
        [0, 1, 2, 2, 3, 4],
        [1, 2, 0, 3, 4, 0],
    ])
    graph_features = None
    node_features = np.ones((num_nodes, num_node_features))
    edge_index = np.ones((2, num_edges))
    edge_features = np.ones((num_edges, num_edge_features))
    targets = np.ones(5)

    mol_graph = MoleculeGraphData(
        node_features=node_features,
@@ -30,19 +33,26 @@ class TestMoleculeGraph(unittest.TestCase):

  def test_invalid_molecule_graph_data(self):
    with pytest.raises(ValueError):
      invalid_node_features = [[0, 1, 2, 3, 4], [5, 6, 7, 8]]
      edge_index = np.ones((2, 5))
      targets = np.ones(5)
      invalid_node_features_type = list(np.random.random_sample((5, 5)))
      edge_index = np.array([
          [0, 1, 2, 2, 3, 4],
          [1, 2, 0, 3, 4, 0],
      ])
      targets = np.random.random_sample(5)
      mol_graph = MoleculeGraphData(
          node_features=invalid_node_features,
          node_features=invalid_node_features_type,
          edge_index=edge_index,
          targets=targets,
      )

    with pytest.raises(ValueError):
      node_features = np.ones((5, 5))
      invalid_edge_index_shape = np.ones((3, 10))
      targets = np.ones(5)
      node_features = np.random.random_sample((5, 5))
      invalid_edge_index_shape = np.array([
          [0, 1, 2, 2, 3, 4],
          [1, 2, 0, 3, 4, 0],
          [2, 2, 1, 4, 0, 3],
      ])
      targets = np.random.random_sample(5)
      mol_graph = MoleculeGraphData(
          node_features=node_features,
          edge_index=invalid_edge_index_shape,
@@ -50,25 +60,31 @@ class TestMoleculeGraph(unittest.TestCase):
      )

    with pytest.raises(TypeError):
      node_features = np.ones((5, 5))
      mol_graph = MoleculeGraphData(node_features=node_features,)
      node_features = np.random.random_sample((5, 5))
      mol_graph = MoleculeGraphData(node_features=node_features)

  def test_batch_molecule_graph_data(self):

    num_nodes_list, num_edge_list = [5, 7, 10], [6, 10, 20]
    num_nodes_list, num_edge_list = [3, 4, 5], [2, 4, 5]
    num_node_features, num_edge_features = 32, 32
    targets = np.ones(5)
    edge_index_list = [
        np.array([[0, 1], [1, 2]]),
        np.array([[0, 1, 2, 3], [1, 2, 0, 2]]),
        np.array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]])
    ]
    targets = np.random.random_sample(5)

    molecule_graph_list = [
        MoleculeGraphData(
            node_features=np.ones((num_nodes, num_node_features)),
            edge_index=np.ones((2, num_edges)),
            node_features=np.random.random_sample((num_nodes_list[i],
                                                   num_node_features)),
            edge_index=edge_index_list[i],
            targets=targets,
            edge_features=np.ones((num_edges, num_edge_features)),
            graph_features=None)
        for num_nodes, num_edges in zip(num_nodes_list, num_edge_list)
            edge_features=np.random.random_sample((num_edge_list[i],
                                                   num_edge_features)),
            graph_features=None) for i in range(len(num_edge_list))
    ]

    batch = BatchMoleculeGraphData(molecule_graph_list)

    assert batch.num_nodes == sum(num_nodes_list)
    assert batch.num_node_features == num_node_features
    assert batch.num_edges == sum(num_edge_list)