Unverified Commit 58c4ecb2 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2012 from nd-02110114/graph-dataset-2

Add new molecule graph data
parents 49b10d1e ef33bd51
Loading
Loading
Loading
Loading
+186 −0
Original line number Diff line number Diff line
from typing import Optional, Sequence
import numpy as np


class MoleculeGraphData:
  """MoleculeGraphData class

  This data class is almost same as `torch_geometric.data.Data 
  <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data>`_.

  Attributes
  ----------
  node_features : np.ndarray
    Node feature matrix with shape [num_nodes, num_node_features]
  edge_index : np.ndarray, dtype int
    Graph connectivity in COO format with shape [2, num_edges]
  targets : np.ndarray
    Graph or node targets with arbitrary shape
  edge_features : np.ndarray, optional (default None)
    Edge feature matrix with shape [num_edges, num_edge_features]
  graph_features : np.ndarray, optional (default None)
    Graph feature vector with shape [num_graph_features,]
  num_nodes : int
    The number of nodes in the graph
  num_node_features : int
    The number of features per node in the graph
  num_edges : int
    The number of edges in the graph
  num_edges_features : int, , optional (default None)
    The number of features per edge in the graph
  """

  def __init__(
      self,
      node_features: np.ndarray,
      edge_index: np.ndarray,
      targets: np.ndarray,
      edge_features: Optional[np.ndarray] = None,
      graph_features: Optional[np.ndarray] = None,
  ):
    """
    Parameters
    ----------
    node_features : np.ndarray
      Node feature matrix with shape [num_nodes, num_node_features]
    edge_index : np.ndarray, dtype int
      Graph connectivity in COO format with shape [2, num_edges]
    targets : np.ndarray
      Graph or node targets with arbitrary shape
    edge_features : np.ndarray, optional (default None)
      Edge feature matrix with shape [num_edges, num_edge_features]
    graph_features : np.ndarray, optional (default None)
      Graph feature vector with shape [num_graph_features,]
    """
    # validate params
    if isinstance(node_features, np.ndarray) is False:
      raise ValueError('node_features must be np.ndarray.')
    if isinstance(edge_index, np.ndarray) is False:
      raise ValueError('edge_index must be np.ndarray.')
    elif edge_index.dtype != np.int:
      raise ValueError('edge_index.dtype must be np.int')
    elif edge_index.shape[0] != 2:
      raise ValueError('The shape of edge_index is [2, num_edges].')
    if isinstance(targets, np.ndarray) is False:
      raise ValueError('y must be np.ndarray.')
    if edge_features is not None:
      if isinstance(edge_features, np.ndarray) is False:
        raise ValueError('edge_features must be np.ndarray or None.')
      elif edge_index.shape[1] != edge_features.shape[0]:
        raise ValueError('The first dimension of edge_features must be the \
                    same as the second dimension of edge_index.')
    if graph_features is not None and isinstance(graph_features,
                                                 np.ndarray) is False:
      raise ValueError('graph_features must be np.ndarray or None.')

    self.node_features = node_features
    self.edge_index = edge_index
    self.edge_features = edge_features
    self.graph_features = graph_features
    self.targets = targets
    self.num_nodes, self.num_node_features = self.node_features.shape
    self.num_edges = edge_index.shape[1]
    if self.node_features is not None:
      self.num_edge_features = self.edge_features.shape[1]

  def to_pyg_data(self):
    """Convert to PyTorch Geometric Data instance

    Returns
    -------
    torch_geometric.data.Data
      Molecule graph data for PyTorch Geometric
    """
    try:
      import torch
      from torch_geometric.data import Data
    except ModuleNotFoundError:
      raise ValueError("This class requires PyTorch Geometric to be installed.")

    return Data(
      x=torch.from_numpy(self.node_features),
      edge_index=torch.from_numpy(self.edge_index),
      edge_attr=None if self.edge_features is None \
        else torch.from_numpy(self.edge_features),
      y=torch.from_numpy(self.targets),
    )


class BatchMoleculeGraphData(MoleculeGraphData):
  """Batch MoleculeGraphData class
  
  Attributes
  ----------
  graph_index : np.ndarray, dtype int
    This vector indicates which graph the node belongs with shape [num_nodes,]
  """

  def __init__(self, molecule_graphs: Sequence[MoleculeGraphData]):
    """
    Parameters
    ----------
    molecule_graphs : Iterable[MoleculeGraphData]
      List of MoleculeGraphData
    """
    # stack features and targets
    batch_node_features = np.vstack(
        [graph.node_features for graph in molecule_graphs])
    batch_targets = np.vstack([graph.targets for graph in molecule_graphs])

    # before stacking edge_features or graph_features,
    # we should check whether these are None or not
    if molecule_graphs[0].edge_features is not None:
      batch_edge_features = np.vstack(
          [graph.edge_features for graph in molecule_graphs])
    else:
      batch_edge_features = None

    if molecule_graphs[0].graph_features is not None:
      batch_graph_features = np.vstack(
          [graph.graph_features for graph in molecule_graphs])
    else:
      batch_graph_features = None

    # create new edge index
    num_nodes_list = [graph.num_nodes for graph in molecule_graphs]
    batch_edge_index = np.hstack(
      [graph.edge_index + prev_num_node for prev_num_node, graph \
        in zip([0] + num_nodes_list[:-1], molecule_graphs)]
    ).astype(int)

    # graph_index indicates which nodes belong to which graph
    graph_index = []
    for i, num_nodes in enumerate(num_nodes_list):
      graph_index.extend([i] * num_nodes)
    self.graph_index = np.array(graph_index, dtype=int)

    super().__init__(
        node_features=batch_node_features,
        edge_index=batch_edge_index,
        targets=batch_targets,
        edge_features=batch_edge_features,
        graph_features=batch_graph_features,
    )

    @staticmethod  # type: ignore
    def to_pyg_data(molecule_graphs: Sequence[MoleculeGraphData]):
      """Convert to PyTorch Geometric Batch instance

      Parameters
      ----------
      molecule_graphs : Iterable[MoleculeGraphData]
        List of MoleculeGraphData

      Returns
      -------
      torch_geometric.data.Batch
        Batch data of molecule graph for PyTorch Geometric
      """
      try:
        from torch_geometric.data import Batch
      except ModuleNotFoundError:
        raise ValueError(
            "This class requires PyTorch Geometric to be installed.")

      data_list = [mol_graph.to_pyg_data() for mol_graph in molecule_graphs]
      return Batch.from_data_list(data_list=data_list)
+93 −0
Original line number Diff line number Diff line
import unittest
import pytest
import numpy as np
from deepchem.utils.molecule_graph import MoleculeGraphData, BatchMoleculeGraphData


class TestMoleculeGraph(unittest.TestCase):

  def test_molecule_graph_data(self):
    num_nodes, num_node_features = 4, 32
    num_edges, num_edge_features = 6, 32
    node_features = np.random.random_sample((num_nodes, num_node_features))
    edge_features = np.random.random_sample((num_edges, num_edge_features))
    targets = np.random.random_sample(5)
    edge_index = np.array([
        [0, 1, 2, 2, 3, 4],
        [1, 2, 0, 3, 4, 0],
    ])
    graph_features = None

    mol_graph = MoleculeGraphData(
        node_features=node_features,
        edge_index=edge_index,
        targets=targets,
        edge_features=edge_features,
        graph_features=graph_features)

    assert mol_graph.num_nodes == num_nodes
    assert mol_graph.num_node_features == num_node_features
    assert mol_graph.num_edges == num_edges
    assert mol_graph.num_edge_features == num_edge_features
    assert mol_graph.targets.shape == (5,)

  def test_invalid_molecule_graph_data(self):
    with pytest.raises(ValueError):
      invalid_node_features_type = list(np.random.random_sample((5, 5)))
      edge_index = np.array([
          [0, 1, 2, 2, 3, 4],
          [1, 2, 0, 3, 4, 0],
      ])
      targets = np.random.random_sample(5)
      mol_graph = MoleculeGraphData(
          node_features=invalid_node_features_type,
          edge_index=edge_index,
          targets=targets,
      )

    with pytest.raises(ValueError):
      node_features = np.random.random_sample((5, 5))
      invalid_edge_index_shape = np.array([
          [0, 1, 2, 2, 3, 4],
          [1, 2, 0, 3, 4, 0],
          [2, 2, 1, 4, 0, 3],
      ])
      targets = np.random.random_sample(5)
      mol_graph = MoleculeGraphData(
          node_features=node_features,
          edge_index=invalid_edge_index_shape,
          targets=targets,
      )

    with pytest.raises(TypeError):
      node_features = np.random.random_sample((5, 5))
      mol_graph = MoleculeGraphData(node_features=node_features)

  def test_batch_molecule_graph_data(self):
    num_nodes_list, num_edge_list = [3, 4, 5], [2, 4, 5]
    num_node_features, num_edge_features = 32, 32
    edge_index_list = [
        np.array([[0, 1], [1, 2]]),
        np.array([[0, 1, 2, 3], [1, 2, 0, 2]]),
        np.array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]])
    ]
    targets = np.random.random_sample(5)

    molecule_graphs = [
        MoleculeGraphData(
            node_features=np.random.random_sample((num_nodes_list[i],
                                                   num_node_features)),
            edge_index=edge_index_list[i],
            targets=targets,
            edge_features=np.random.random_sample((num_edge_list[i],
                                                   num_edge_features)),
            graph_features=None) for i in range(len(num_edge_list))
    ]
    batch = BatchMoleculeGraphData(molecule_graphs)

    assert batch.num_nodes == sum(num_nodes_list)
    assert batch.num_node_features == num_node_features
    assert batch.num_edges == sum(num_edge_list)
    assert batch.num_edge_features == num_edge_features
    assert batch.targets.shape == (3, 5)
    assert batch.graph_index.shape == (sum(num_nodes_list),)
+6 −0
Original line number Diff line number Diff line
@@ -18,3 +18,9 @@ These classes document the data classes for graph convolutions. We plan to simpl

.. autoclass:: deepchem.feat.mol_graphs.WeaveMol
  :members:

.. autoclass:: deepchem.utils.molecule_graph.MoleculeGraphData
  :members:

.. autoclass:: deepchem.utils.molecule_graph.BatchMoleculeGraphData
  :members:
+5 −0
Original line number Diff line number Diff line
@@ -70,6 +70,10 @@ DeepChem has a number of "soft" requirements.
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
| `PyTorch Geometric`_           | Not Testing   | :code:`dc.utils.molecule_graph`                   |
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
| `RDKit`_                       | 2020.03.4     | Many modules                                      |
|                                |               | (we recommend you to instal)                      |
|                                |               |                                                   |
@@ -108,6 +112,7 @@ DeepChem has a number of "soft" requirements.
.. _`pyGPGO`: https://pygpgo.readthedocs.io/en/latest/
.. _`Pymatgen`: https://pymatgen.org/
.. _`PyTorch`: https://pytorch.org/
.. _`PyTorch Geometric`: https://pytorch-geometric.readthedocs.io/en/latest/
.. _`RDKit`: http://www.rdkit.org/ocs/Install.html
.. _`simdna`: https://github.com/kundajelab/simdna
.. _`Tensorflow Probability`: https://www.tensorflow.org/probability