Commit bc01f5ae authored by nd-02110114's avatar nd-02110114
Browse files

🐛 fix bug

parent 9fcc7a49
Loading
Loading
Loading
Loading
+15 −11
Original line number Diff line number Diff line
@@ -32,9 +32,9 @@ class CGCNNLayer(nn.Module):
  >>> print(type(cgcnn_dgl_graph))
  <class 'dgl.heterograph.DGLHeteroGraph'>
  >>> layer = CGCNNLayer(hidden_node_dim=92, edge_dim=41)
  >>> update_graph = layer(cgcnn_dgl_graph)
  >>> print(type(update_graph))
  <class 'dgl.heterograph.DGLHeteroGraph'>
  >>> node_feats = cgcnn_dgl_graph.ndata.pop('x')
  >>> edge_feats = cgcnn_dgl_graph.edata.pop('edge_attr')
  >>> new_node_feats, new_edge_feats = layer(cgcnn_dgl_graph, node_feats, edge_feats)

  Notes
  -----
@@ -69,23 +69,27 @@ class CGCNNLayer(nn.Module):
    return {'gated_z': gated_z, 'message_z': message_z}

  def reduce_func(self, nodes):
    nbr_sumed = torch.sum(nodes.mailbox['gated_z'] * nodes.mailbox['message_z'], dim=1)
    nbr_sumed = torch.sum(
        nodes.mailbox['gated_z'] * nodes.mailbox['message_z'], dim=1)
    new_x = F.softplus(nodes.data['x'] + nbr_sumed)
    return {'new_x': new_x}

  def forward(self, dgl_graph, node_feats, edge_feats):
    """Update node representaions.
    """Update node representations.

    Parameters
    ----------
    dgl_graph: DGLGraph
      DGLGraph for a batch of graphs. The graph expects that the node features
      are stored in `ndata['x']`, and the edge features are stored in `edata['edge_attr']`.
      DGLGraph for a batch of graphs.
    node_feats: torch.Tensor
      The node features. The shape is `(N, hidden_node_dim)`.
    edge_feats: torch.Tensor
      The edge features. The shape is `(N, hidden_node_dim)`.

    Returns
    -------
    dgl_graph: DGLGraph
      DGLGraph for a batch of updated graphs.
    node_feats: torch.Tensor
      The updated node features. The shape is `(N, hidden_node_dim)`.
    """
    dgl_graph.ndata['x'] = node_feats
    dgl_graph.edata['edge_attr'] = edge_feats
@@ -93,7 +97,7 @@ class CGCNNLayer(nn.Module):
    node_feats = dgl_graph.ndata.pop('new_x')
    if self.batch_norm is not None:
      node_feats = self.batch_norm(node_feats)
    return node_feats, edge_feats
    return node_feats


class CGCNN(nn.Module):
@@ -224,7 +228,7 @@ class CGCNN(nn.Module):

    # convolutional layer
    for conv in self.conv_layers:
      node_feats, edge_feats = conv(graph, node_feats, edge_feats)
      node_feats = conv(graph, node_feats, edge_feats)

    # pooling
    graph.ndata['updated_x'] = node_feats
+7 −5
Original line number Diff line number Diff line
@@ -32,11 +32,11 @@ class GAT(nn.Module):
  >>> pyg_graphs = [graph.to_pyg_graph() for graph in graphs]
  >>> print(type(pyg_graphs[0]))
  <class 'torch_geometric.data.data.Data'>
  >>> model = dc.models.GAT(n_tasks=2)
  >>> out = model(Batch.from_data_list(pyg_graphs))
  >>> print(type(out))
  >>> model = dc.models.GAT(mode='classification', n_tasks=10, n_classes=2)
  >>> preds, logits = model(Batch.from_data_list(pyg_graphs))
  >>> print(type(preds))
  <class 'torch.Tensor'>
  >>> out.shape == (2, 2)
  >>> preds.shape == (2, 10, 2)
  True

  References
@@ -120,7 +120,9 @@ class GAT(nn.Module):
    Returns
    -------
    out: torch.Tensor
      The output value, the shape is `(batch_size, n_out)`.
      If mode == 'regression', the shape is `(batch_size, n_tasks)`.
      If mode == 'classification', the shape is `(batch_size, n_tasks, n_classes)` (n_tasks > 1)
      or `(batch_size, n_classes)` (n_tasks == 1) and the output values are probabilities of each class label.
    """
    node_feat, edge_index = data.x, data.edge_index
    node_feat = self.embedding(node_feat)