Commit f8af3850 authored by Vignesh's avatar Vignesh
Browse files

Added return out_tensor for all layers

parent 9ad359b2
Loading
Loading
Loading
Loading
+44 −9
Original line number Diff line number Diff line
@@ -46,7 +46,11 @@ class DistanceMatrix(Layer):
    # Calculate pairwise distance
    d = tf.sqrt(tf.reduce_sum(tf.square(tensor1 - tensor2), axis=3))
    # Masking for valid atom index
    self.out_tensor = d * tf.to_float(atom_flags)
    out_tensor = d * tf.to_float(atom_flags)
    if set_tensors:
      self.out_tensor = out_tensor

    return out_tensor


class DistanceCutoff(Layer):
@@ -79,7 +83,11 @@ class DistanceCutoff(Layer):
    d = 0.5 * (tf.cos(np.pi * d / self.Rc) + 1)
    out_tensor = d * d_flag
    out_tensor = out_tensor * tf.expand_dims((1 - tf.eye(self.max_atoms)), 0)
    out_tensor = out_tensor

    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class RadialSymmetry(Layer):
@@ -142,9 +150,14 @@ class RadialSymmetry(Layer):
            tf.expand_dims(atom_number_embedded[:, :, atom_type], axis=1),
            axis=3)
        out_tensors.append(tf.reduce_sum(out_tensor * selected_atoms, axis=2))
      self.out_tensor = tf.concat(out_tensors, axis=2)
      out_tensor = tf.concat(out_tensors, axis=2)
    else:
      self.out_tensor = tf.reduce_sum(out_tensor, axis=2)
      out_tensor = tf.reduce_sum(out_tensor, axis=2)

    if set_tensors:
      self.out_tensor = out_tensor

    return out_tensor


class AngularSymmetry(Layer):
@@ -227,9 +240,14 @@ class AngularSymmetry(Layer):
    out_tensor = tf.pow(1 + lambd * tf.cos(theta), zeta) * \
                 tf.exp(-ita * (tf.square(R_ij) + tf.square(R_ik) + tf.square(R_jk))) * \
                 f_R_ij * f_R_ik * f_R_jk
    self.out_tensor = tf.reduce_sum(out_tensor, axis=[2, 3]) * \
    out_tensor = tf.reduce_sum(out_tensor, axis=[2, 3]) * \
                      tf.pow(tf.constant(2.), 1 - tf.reshape(self.zeta, (1, 1, -1)))

    if set_tensors:
      self.out_tensor = out_tensor

    return out_tensor


class AngularSymmetryMod(Layer):
  """ Angular Symmetry Function """
@@ -345,9 +363,14 @@ class AngularSymmetryMod(Layer):
              tf.expand_dims(selected_atoms, axis=1), axis=4)
          out_tensors.append(
              tf.reduce_sum(out_tensor * selected_atoms, axis=[2, 3]))
      self.out_tensor = tf.concat(out_tensors, axis=2)
      out_tensor = tf.concat(out_tensors, axis=2)
    else:
      self.out_tensor = tf.reduce_sum(out_tensor, axis=[2, 3])
      out_tensor = tf.reduce_sum(out_tensor, axis=[2, 3])

    if set_tensors:
      self.out_tensor = out_tensor

    return out_tensor


class BPFeatureMerge(Layer):
@@ -369,7 +392,12 @@ class BPFeatureMerge(Layer):

    out_tensor = tf.concat(
        [atom_embedding, radial_symmetry, angular_symmetry], axis=2)
    self.out_tensor = out_tensor * atom_flags[:, :, 0:1]
    out_tensor = out_tensor * atom_flags[:, :, 0:1]

    if set_tensors:
      self.out_tensor = out_tensor

    return out_tensor


class BPGather(Layer):
@@ -390,6 +418,8 @@ class BPGather(Layer):
    out_tensor = tf.reduce_sum(out_tensor * flags[:, :, 0:1], axis=1)
    self.out_tensor = out_tensor

    return out_tensor


class AtomicDifferentiatedDense(Layer):
  """ Separate Dense module for different atoms """
@@ -453,7 +483,12 @@ class AtomicDifferentiatedDense(Layer):
      output = tf.reshape(output * tf.expand_dims(mask, 2),
                          (-1, self.max_atoms, self.out_channels))
      outputs.append(output)
    self.out_tensor = tf.add_n(outputs)
    out_tensor = tf.add_n(outputs)

    if set_tensors:
      self.out_tensor = out_tensor

    return out_tensor

  def none_tensors(self):
    w, b, out_tensor = self.W, self.b, self.out_tensor