Commit 0a881330 authored by nd-02110114's avatar nd-02110114
Browse files

add to_pyg_data method

parent bdcd9a9c
Loading
Loading
Loading
Loading
+22 −4
Original line number Diff line number Diff line
@@ -7,13 +7,13 @@ class MoleculeGraphData(object):

  This data class is almost same as `torch_geometric.data.Data
  <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data>`
  in Pytorch Geometric.
  in PyTorch Geometric.

  Attributes
  ----------
  node_features : np.ndarray
    Node feature matrix with shape [num_nodes, num_node_features]
  edge_index : np.ndarray
  edge_index : np.ndarray, dtype int
    Graph connectivity in COO format with shape [2, num_edges]
  targets : np.ndarray
    Graph or node targets with arbitrary shape
@@ -85,7 +85,7 @@ class MoleculeGraphData(object):
      self.num_edge_features = self.edge_features.shape[1]

  def to_pyg_data(self):
    """"Convert to Pytorch Geometric data class"""
    """"Convert to PyTorch Geometric data class"""
    try:
      import torch
      from torch_geometric.data import Data
@@ -107,7 +107,7 @@ class BatchMoleculeGraphData(MoleculeGraphData):
  
  Attributes
  ----------
  graph_index : np.ndarray
  graph_index : np.ndarray, dtype int
    This vector indicates which graph the node belongs with shape [num_nodes,]
  """

@@ -157,3 +157,21 @@ class BatchMoleculeGraphData(MoleculeGraphData):
        edge_features=batch_edge_features,
        graph_features=batch_graph_features,
    )

    @staticmethod
    def to_pyg_data(molecule_graphs: Iterable[MoleculeGraphData]):
      """"Convert to PyTorch Geometric Batch class

      Parameters
      ----------
      molecule_graphs : Iterable[MoleculeGraphData]
        List of MoleculeGraphData
      """
      try:
        from torch_geometric.data import Batch
      except ModuleNotFoundError:
        raise ValueError(
            "This class requires PyTorch Geometric to be installed.")

      data_list = [mol_graph.to_pyg_data() for mol_graph in molecule_graphs]
      return Batch.from_data_list(data_list=data_list)