Unverified Commit 52649464 authored by Karl Leswing's avatar Karl Leswing Committed by GitHub
Browse files

Merge pull request #1195 from WesleyyC/master

Use tf.matmul for batch matrix multiplication in MPNN's Edge Network
parents 2d7990a9 45d8d85e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -925,7 +925,7 @@ class EdgeNetwork(object):

  def forward(self, atom_features, atom_to_pair):
    out = tf.expand_dims(tf.gather(atom_features, atom_to_pair[:, 1]), 2)
    out = tf.reduce_sum(out * self.A, axis=1)
    out = tf.squeeze(tf.matmul(self.A, out), axis=2)
    out = tf.segment_sum(out, atom_to_pair[:, 0])
    return out