Commit d064e857 authored by nd-02110114's avatar nd-02110114
Browse files

🐛 fix cgcnn bug

parent 06b2d7d5
Loading
Loading
Loading
Loading
+11 −8
Original line number Diff line number Diff line
@@ -57,20 +57,23 @@ class CGCNNLayer(nn.Module):
    """
    super(CGCNNLayer, self).__init__()
    z_dim = 2 * hidden_node_dim + edge_dim
    self.linear_with_sigmoid = nn.Linear(z_dim, hidden_node_dim)
    self.linear_with_softplus = nn.Linear(z_dim, hidden_node_dim)
    self.batch_norm = nn.BatchNorm1d(hidden_node_dim) if batch_norm else None
    liner_out_dim = 2 * hidden_node_dim
    self.linear = nn.Linear(z_dim, liner_out_dim)
    self.batch_norm = nn.BatchNorm1d(liner_out_dim) if batch_norm else None

  def message_func(self, edges):
    z = torch.cat(
        [edges.src['x'], edges.dst['x'], edges.data['edge_attr']], dim=1)
    gated_z = torch.sigmoid(self.linear_with_sigmoid(z))
    message_z = F.softplus(self.linear_with_softplus(z))
    return {'gated_z': gated_z, 'message_z': message_z}
    z = self.linear(z)
    if self.batch_norm is not None:
      z = self.batch_norm(z)
    gated_z, message_z = z.chunk(2, dim=1)
    gated_z = torch.sigmoid(gated_z)
    message_z = F.softplus(message_z)
    return {'message': gated_z * message_z}

  def reduce_func(self, nodes):
    nbr_sumed = torch.sum(
        nodes.mailbox['gated_z'] * nodes.mailbox['message_z'], dim=1)
    nbr_sumed = torch.sum(nodes.mailbox['message'], dim=1)
    new_x = F.softplus(nodes.data['x'] + nbr_sumed)
    return {'new_x': new_x}