Commit 9a1066b4 authored by mufeili's avatar mufeili Committed by Ubuntu
Browse files

Update

parent 93898573
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -146,9 +146,6 @@ class GraphData:

    src = self.edge_index[0]
    dst = self.edge_index[1]
    if self_loop:
      src = np.concatenate([src, np.arange(self.num_nodes)])
      dst = np.concatenate([dst, np.arange(self.num_nodes)])

    g = dgl.graph(
        (torch.from_numpy(src).long(), torch.from_numpy(dst).long()),
@@ -161,6 +158,11 @@ class GraphData:
    if self.edge_features is not None:
      g.edata['edge_attr'] = torch.from_numpy(self.edge_features).float()

    if self_loop:
      # This assumes that the edge features for self loops are full-zero tensors
      # In the future we may want to support featurization for self loops
      g.add_edges(np.arange(self.num_nodes), np.arange(self.num_nodes))

    return g