Commit bdcd9a9c authored by nd-02110114's avatar nd-02110114
Browse files

add to_pyg_data function

parent 36b8b1f1
Loading
Loading
Loading
Loading
+20 −1
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ class MoleculeGraphData(object):
    ----------
    node_features : np.ndarray
      Node feature matrix with shape [num_nodes, num_node_features]
    edge_index : np.ndarray
    edge_index : np.ndarray, dtype int
      Graph connectivity in COO format with shape [2, num_edges]
    targets : np.ndarray
      Graph or node targets with arbitrary shape
@@ -58,6 +58,8 @@ class MoleculeGraphData(object):
      raise ValueError('node_features must be np.ndarray.')
    if isinstance(edge_index, np.ndarray) is False:
      raise ValueError('edge_index must be np.ndarray.')
    elif edge_index.dtype != np.int:
      raise ValueError('edge_index.dtype must be np.int')
    elif edge_index.shape[0] != 2:
      raise ValueError('The shape of edge_index is [2, num_edges].')
    if isinstance(targets, np.ndarray) is False:
@@ -82,6 +84,23 @@ class MoleculeGraphData(object):
    if self.node_features is not None:
      self.num_edge_features = self.edge_features.shape[1]

  def to_pyg_data(self):
    """"Convert to Pytorch Geometric data class"""
    try:
      import torch
      from torch_geometric.data import Data
    except ModuleNotFoundError:
      raise ValueError(
          "This class requires Pytorch and PyTorch Geometric to be installed.")

    return Data(
      x=torch.from_numpy(self.node_features),
      edge_index=torch.from_numpy(self.edge_index),
      edge_attr=None if self.edge_features is None \
        else torch.from_numpy(self.edge_features),
      y=torch.from_numpy(self.targets),
    )


class BatchMoleculeGraphData(MoleculeGraphData):
  """Batch MoleculeGraphData class