Commit 36b8b1f1 authored by nd-02110114's avatar nd-02110114
Browse files

♻️ refactor

parent 1a7b25c1
Loading
Loading
Loading
Loading
+51 −20
Original line number Diff line number Diff line
from typing import Optional, List
from typing import Optional, Iterable
import numpy as np


class MoleculeGraphData(object):
  """Molecule Graph Data class for sparse pattern"""
  """MoleculeGraphData class

  This data class is almost same as `torch_geometric.data.Data
  <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data>`
  in Pytorch Geometric.

  Attributes
  ----------
  node_features : np.ndarray
    Node feature matrix with shape [num_nodes, num_node_features]
  edge_index : np.ndarray
    Graph connectivity in COO format with shape [2, num_edges]
  targets : np.ndarray
    Graph or node targets with arbitrary shape
  edge_features : np.ndarray, optional (default None)
    Edge feature matrix with shape [num_edges, num_edge_features]
  graph_features : np.ndarray, optional (default None)
    Graph feature vector with shape [num_graph_features,]
  num_nodes : int
    The number of nodes in the graph
  num_node_features : int
    The number of features per node in the graph
  num_edges : int
    The number of edges in the graph
  num_edges_features : int, , optional (default None)
    The number of features per edge in the graph
  """

  def __init__(
      self,
@@ -14,7 +40,6 @@ class MoleculeGraphData(object):
      graph_features: Optional[np.ndarray] = None,
  ):
    """

    Parameters
    ----------
    node_features : np.ndarray
@@ -53,52 +78,58 @@ class MoleculeGraphData(object):
    self.graph_features = graph_features
    self.targets = targets
    self.num_nodes, self.num_node_features = self.node_features.shape
    self.num_edges, self.num_edge_features = None, None
    self.num_edges = edge_index.shape[1]
    if self.node_features is not None:
      self.num_edges, self.num_edge_features = self.edge_features.shape
      self.num_edge_features = self.edge_features.shape[1]


class BatchMoleculeGraphData(MoleculeGraphData):
  """Batch Data class for sparse pattern"""
  """Batch MoleculeGraphData class
  
  Attributes
  ----------
  graph_index : np.ndarray
    This vector indicates which graph the node belongs with shape [num_nodes,]
  """

  def __init__(self, molecule_graph_list: List[MoleculeGraphData]):
  def __init__(self, molecule_graphs: Iterable[MoleculeGraphData]):
    """
    Parameters
    ----------
    molecule_graph_list : List[MoleculeGraphData]
    molecule_graphs : Iterable[MoleculeGraphData]
      List of MoleculeGraphData
    """
    # stack features and targets
    batch_node_features = np.vstack(
        [graph.node_features for graph in molecule_graph_list])
    batch_targets = np.vstack([graph.targets for graph in molecule_graph_list])
        [graph.node_features for graph in molecule_graphs])
    batch_targets = np.vstack([graph.targets for graph in molecule_graphs])

    # before stacking edge_features or graph_features,
    # we should check whether these are None or not
    if molecule_graph_list[0].edge_features is not None:
    if molecule_graphs[0].edge_features is not None:
      batch_edge_features = np.vstack(
          [graph.edge_features for graph in molecule_graph_list])
          [graph.edge_features for graph in molecule_graphs])
    else:
      batch_edge_features = None

    if molecule_graph_list[0].graph_features is not None:
    if molecule_graphs[0].graph_features is not None:
      batch_graph_features = np.vstack(
          [graph.graph_features for graph in molecule_graph_list])
          [graph.graph_features for graph in molecule_graphs])
    else:
      batch_graph_features = None

    # create new edge index
    num_nodes_list = [graph.num_nodes for graph in molecule_graph_list]
    num_nodes_list = [graph.num_nodes for graph in molecule_graphs]
    batch_edge_index = np.hstack(
      [graph.edge_index + prev_num_node for prev_num_node, graph \
        in zip([0] + num_nodes_list[:-1], molecule_graph_list)]
        in zip([0] + num_nodes_list[:-1], molecule_graphs)]
    ).astype(int)

    # graph idx indicates which nodes belong to which graph
    graph_idx = []
    # graph_index indicates which nodes belong to which graph
    graph_index = []
    for i, num_nodes in enumerate(num_nodes_list):
      graph_idx.extend([i] * num_nodes)
    self.graph_idx = np.array(graph_idx, dtype=int)
      graph_index.extend([i] * num_nodes)
    self.graph_index = np.array(graph_index, dtype=int)

    super().__init__(
        node_features=batch_node_features,
+4 −4
Original line number Diff line number Diff line
import unittest
import pytest
import numpy as np
from deepchem.feat.molecule_graph import MoleculeGraphData, BatchMoleculeGraphData
from deepchem.utils.molecule_graph import MoleculeGraphData, BatchMoleculeGraphData


class TestMoleculeGraph(unittest.TestCase):
@@ -73,7 +73,7 @@ class TestMoleculeGraph(unittest.TestCase):
    ]
    targets = np.random.random_sample(5)

    molecule_graph_list = [
    molecule_graphs = [
        MoleculeGraphData(
            node_features=np.random.random_sample((num_nodes_list[i],
                                                   num_node_features)),
@@ -83,11 +83,11 @@ class TestMoleculeGraph(unittest.TestCase):
                                                   num_edge_features)),
            graph_features=None) for i in range(len(num_edge_list))
    ]
    batch = BatchMoleculeGraphData(molecule_graph_list)
    batch = BatchMoleculeGraphData(molecule_graphs)

    assert batch.num_nodes == sum(num_nodes_list)
    assert batch.num_node_features == num_node_features
    assert batch.num_edges == sum(num_edge_list)
    assert batch.num_edge_features == num_edge_features
    assert batch.targets.shape == (3, 5)
    assert batch.graph_idx.shape == (sum(num_nodes_list),)
    assert batch.graph_index.shape == (sum(num_nodes_list),)