Commit ee600b31 authored by Vignesh's avatar Vignesh
Browse files

Upgrade to TF 1.14, Fixes to tests, attr(slow) for test_gan

parent 94809d39
Loading
Loading
Loading
Loading
+8 −10
Original line number Diff line number Diff line
@@ -1789,7 +1789,7 @@ class WeaveLayer(tf.keras.layers.Layer):
    """
    super(WeaveLayer, self).__init__(**kwargs)
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.activation = activations.get(activation)  # Get activations
    self.update_pair = update_pair  # last weave layer does not need to update
    self.n_hidden_AA = n_hidden_AA
    self.n_hidden_PA = n_hidden_PA
@@ -1845,8 +1845,6 @@ class WeaveLayer(tf.keras.layers.Layer):

    inputs: [atom_features, pair_features], pair_split, atom_to_pair
    """
    activation = activations.get(self.activation)  # Get activations

    atom_features = inputs[0]
    pair_features = inputs[1]

@@ -1854,30 +1852,30 @@ class WeaveLayer(tf.keras.layers.Layer):
    atom_to_pair = inputs[3]

    AA = tf.matmul(atom_features, self.W_AA) + self.b_AA
    AA = activation(AA)
    AA = self.activation(AA)
    PA = tf.matmul(pair_features, self.W_PA) + self.b_PA
    PA = activation(PA)
    PA = self.activation(PA)
    PA = tf.segment_sum(PA, pair_split)

    A = tf.matmul(tf.concat([AA, PA], 1), self.W_A) + self.b_A
    A = activation(A)
    A = self.activation(A)

    if self.update_pair:
      AP_ij = tf.matmul(
          tf.reshape(
              tf.gather(atom_features, atom_to_pair),
              [-1, 2 * self.n_atom_input_feat]), self.W_AP) + self.b_AP
      AP_ij = activation(AP_ij)
      AP_ij = self.activation(AP_ij)
      AP_ji = tf.matmul(
          tf.reshape(
              tf.gather(atom_features, tf.reverse(atom_to_pair, [1])),
              [-1, 2 * self.n_atom_input_feat]), self.W_AP) + self.b_AP
      AP_ji = activation(AP_ji)
      AP_ji = self.activation(AP_ji)

      PP = tf.matmul(pair_features, self.W_PP) + self.b_PP
      PP = activation(PP)
      PP = self.activation(PP)
      P = tf.matmul(tf.concat([AP_ij + AP_ji, PP], 1), self.W_P) + self.b_P
      P = activation(P)
      P = self.activation(P)
    else:
      P = pair_features

+4 −4
Original line number Diff line number Diff line
@@ -2636,10 +2636,10 @@ class AttnLSTMEmbedding(KerasLayer):
    result = layer(inputs)
    if set_tensors:
      self.out_tensor = result[1]
      self.trainable_variables = layer.trainable_variables + layer.states_init
      self.trainable_variables = layer.trainable_variables
    if tf.executing_eagerly() and not self._built:
      self._built = True
      self.trainable_variables = layer.trainable_variables + layer.states_init
      self.trainable_variables = layer.trainable_variables
    return result


@@ -2691,10 +2691,10 @@ class IterRefLSTMEmbedding(KerasLayer):
    result = layer(inputs)
    if set_tensors:
      self.out_tensor = result[1]
      self.trainable_variables = layer.trainable_variables + layer.support_states_init + layer.test_states_init
      self.trainable_variables = layer.trainable_variables
    if tf.executing_eagerly() and not self._built:
      self._built = True
      self.trainable_variables = layer.trainable_variables + layer.support_states_init + layer.test_states_init
      self.trainable_variables = layer.trainable_variables
    return result


+2 −1
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import tensorflow as tf
import unittest
from tensorflow.keras.layers import Input, Concatenate, Dense
from flaky import flaky
from nose.plugins.attrib import attr


def generate_batch(batch_size):
@@ -70,7 +71,7 @@ class TestGAN(unittest.TestCase):
    assert np.std(deltas) > 1.0
    assert gan.get_global_step() == 500

  @flaky
  @attr("slow")
  def test_mix_gan(self):
    """Test a GAN with multiple generators and discriminators."""

+1 −1
Original line number Diff line number Diff line
@@ -63,4 +63,4 @@ conda install -y -q -c deepchem -c rdkit -c conda-forge -c omnia \
    setuptools=39.0.1 \
    biopython=1.71 \
    numpy=1.14
yes | pip install $tensorflow==1.13.1
yes | pip install $tensorflow==1.14