Unverified Commit 6681561e authored by Daiki Nishikawa's avatar Daiki Nishikawa Committed by GitHub
Browse files

Merge pull request #2100 from nd-02110114/update-graphdata

Convert graph_feature -> node_pos_features in GraphData
parents 588f8a77 daa51312
Loading
Loading
Loading
Loading
+57 −24
Original line number Diff line number Diff line
@@ -16,8 +16,8 @@ class GraphData:
    Graph connectivity in COO format with shape [2, num_edges]
  edge_features: np.ndarray, optional (default None)
    Edge feature matrix with shape [num_edges, num_edge_features]
  graph_features: np.ndarray, optional (default None)
    Graph feature vector with shape [num_graph_features,]
  node_pos_features: np.ndarray, optional (default None)
    Node position matrix with shape [num_nodes, num_dimensions].
  num_nodes: int
    The number of nodes in the graph
  num_node_features: int
@@ -40,7 +40,7 @@ class GraphData:
      node_features: np.ndarray,
      edge_index: np.ndarray,
      edge_features: Optional[np.ndarray] = None,
      graph_features: Optional[np.ndarray] = None,
      node_pos_features: Optional[np.ndarray] = None,
  ):
    """
    Parameters
@@ -51,8 +51,8 @@ class GraphData:
      Graph connectivity in COO format with shape [2, num_edges]
    edge_features: np.ndarray, optional (default None)
      Edge feature matrix with shape [num_edges, num_edge_features]
    graph_features: np.ndarray, optional (default None)
      Graph feature vector with shape [num_graph_features,]
    node_pos_features: np.ndarray, optional (default None)
      Node position matrix with shape [num_nodes, num_dimensions].
    """
    # validate params
    if isinstance(node_features, np.ndarray) is False:
@@ -74,14 +74,18 @@ class GraphData:
        raise ValueError('The first dimension of edge_features must be the \
                          same as the second dimension of edge_index.')

    if graph_features is not None and isinstance(graph_features,
                                                 np.ndarray) is False:
      raise ValueError('graph_features must be np.ndarray or None.')
    if node_pos_features is not None:
      if isinstance(node_pos_features, np.ndarray) is False:
        raise ValueError('node_pos_features must be np.ndarray or None.')
      elif node_pos_features.shape[0] != node_features.shape[0]:
        raise ValueError(
            'The length of node_pos_features must be the same as the \
                          length of node_features.')

    self.node_features = node_features
    self.edge_index = edge_index
    self.edge_features = edge_features
    self.graph_features = graph_features
    self.node_pos_features = node_pos_features
    self.num_nodes, self.num_node_features = self.node_features.shape
    self.num_edges = edge_index.shape[1]
    if self.edge_features is not None:
@@ -106,12 +110,18 @@ class GraphData:
      raise ValueError(
          "This function requires PyTorch Geometric to be installed.")

    edge_features = self.edge_features
    if edge_features is not None:
      edge_features = torch.from_numpy(self.edge_features).float()
    node_pos_features = self.node_pos_features
    if node_pos_features is not None:
      node_pos_features = torch.from_numpy(self.node_pos_features).float()

    return Data(
      x=torch.from_numpy(self.node_features),
        x=torch.from_numpy(self.node_features).float(),
        edge_index=torch.from_numpy(self.edge_index).long(),
      edge_attr=None if self.edge_features is None \
        else torch.from_numpy(self.edge_features),
    )
        edge_attr=edge_features,
        pos=node_pos_features)

  def to_dgl_graph(self):
    """Convert to DGL graph data instance
@@ -136,10 +146,13 @@ class GraphData:
    g.add_edges(
        torch.from_numpy(self.edge_index[0]).long(),
        torch.from_numpy(self.edge_index[1]).long())
    g.ndata['x'] = torch.from_numpy(self.node_features)
    g.ndata['x'] = torch.from_numpy(self.node_features).float()

    if self.node_pos_features is not None:
      g.ndata['pos'] = torch.from_numpy(self.node_pos_features).float()

    if self.edge_features is not None:
      g.edata['edge_attr'] = torch.from_numpy(self.edge_features)
      g.edata['edge_attr'] = torch.from_numpy(self.edge_features).float()

    return g

@@ -149,8 +162,28 @@ class BatchGraphData(GraphData):

  Attributes
  ----------
  node_features: np.ndarray
    Concatenated node feature matrix with shape [num_nodes, num_node_features].
    `num_nodes` is total number of nodes in the batch graph.
  edge_index: np.ndarray, dtype int
    Concatenated graph connectivity in COO format with shape [2, num_edges].
    `num_edges` is total number of edges in the batch graph.
  edge_features: np.ndarray, optional (default None)
    Concatenated edge feature matrix with shape [num_edges, num_edge_features].
    `num_edges` is total number of edges in the batch graph.
  node_pos_features: np.ndarray, optional (default None)
    Concatenated node position matrix with shape [num_nodes, num_dimensions].
    `num_nodes` is total number of edges in the batch graph.
  num_nodes: int
    The number of nodes in the batch graph.
  num_node_features: int
    The number of features per node in the graph.
  num_edges: int
    The number of edges in the batch graph.
  num_edges_features: int, optional (default None)
    The number of features per edge in the graph.
  graph_index: np.ndarray, dtype int
    This vector indicates which graph the node belongs with shape [num_nodes,]
    This vector indicates which graph the node belongs with shape [num_nodes,].

  Examples
  --------
@@ -177,7 +210,7 @@ class BatchGraphData(GraphData):
    batch_node_features = np.vstack(
        [graph.node_features for graph in graph_list])

    # before stacking edge_features or graph_features,
    # before stacking edge_features or node_pos_features,
    # we should check whether these are None or not
    if graph_list[0].edge_features is not None:
      batch_edge_features = np.vstack(
@@ -185,11 +218,11 @@ class BatchGraphData(GraphData):
    else:
      batch_edge_features = None

    if graph_list[0].graph_features is not None:
      batch_graph_features = np.vstack(
          [graph.graph_features for graph in graph_list])
    if graph_list[0].node_pos_features is not None:
      batch_node_pos_features = np.vstack(
          [graph.node_pos_features for graph in graph_list])
    else:
      batch_graph_features = None
      batch_node_pos_features = None

    # create new edge index
    num_nodes_list = [graph.num_nodes for graph in graph_list]
@@ -208,5 +241,5 @@ class BatchGraphData(GraphData):
        node_features=batch_node_features,
        edge_index=batch_edge_index,
        edge_features=batch_edge_features,
        graph_features=batch_graph_features,
        node_pos_features=batch_node_pos_features,
    )
+3 −3
Original line number Diff line number Diff line
@@ -15,13 +15,13 @@ class TestGraph(unittest.TestCase):
        [0, 1, 2, 2, 3, 4],
        [1, 2, 0, 3, 4, 0],
    ])
    graph_features = None
    node_pos_features = None

    graph = GraphData(
        node_features=node_features,
        edge_index=edge_index,
        edge_features=edge_features,
        graph_features=graph_features)
        node_pos_features=node_pos_features)

    assert graph.num_nodes == num_nodes
    assert graph.num_node_features == num_node_features
@@ -92,7 +92,7 @@ class TestGraph(unittest.TestCase):
            edge_index=edge_index_list[i],
            edge_features=np.random.random_sample((num_edge_list[i],
                                                   num_edge_features)),
            graph_features=None) for i in range(len(num_edge_list))
            node_pos_features=None) for i in range(len(num_edge_list))
    ]
    batch = BatchGraphData(graph_list)