Commit 45d8d85e authored by wesleyyc's avatar wesleyyc
Browse files

Use tf.matmul for batch matrix multiplication instead of tf.multiply with...

Use tf.matmul for batch matrix multiplication instead of tf.multiply with matrix and vector broadcast in MPNN's Edge Network Implementation #1188
parent 47cfb41a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -925,7 +925,7 @@ class EdgeNetwork(object):

  def forward(self, atom_features, atom_to_pair):
    out = tf.expand_dims(tf.gather(atom_features, atom_to_pair[:, 1]), 2)
    out = tf.reduce_sum(out * self.A, axis=1)
    out = tf.squeeze(tf.matmul(self.A, out), axis=2)
    out = tf.segment_sum(out, atom_to_pair[:, 0])
    return out