Commit abaa902d authored by peastman's avatar peastman
Browse files

More layers support eager mode

parent 887f94bd
Loading
Loading
Loading
Loading
+151 −55
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import tensorflow as tf
import numpy as np

from deepchem.models.tensorgraph import model_ops, initializations, regularizers, activations
from deepchem.models.tensorgraph.model_ops import create_variable
import tensorflow.contrib.eager as tfe
import math

@@ -88,9 +89,8 @@ class Layer(object):
      return self.clone(in_layers)
    raise ValueError('%s does not implement shared()' % self.__class__.__name__)

  def __call__(self, *in_layers, training=False, **kwargs):
    return self.create_tensor(
        in_layers=in_layers, set_tensors=False, training=training, **kwargs)
  def __call__(self, *in_layers, **kwargs):
    return self.create_tensor(in_layers=in_layers, set_tensors=False, **kwargs)

  @property
  def shape(self):
@@ -508,8 +508,6 @@ class Conv1D(Layer):
    if tfe.in_eager_mode():
      if not self._built:
        self._layer = self._build_layer()
        self.variables = self._layer.variables
        self._built = True
      layer = self._layer
    else:
      layer = self._build_layer()
@@ -517,6 +515,9 @@ class Conv1D(Layer):
    if set_tensors:
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    if tfe.in_eager_mode() and not self._built:
      self._built = True
      self.variables = self._layer.variables
    return out_tensor


@@ -871,7 +872,7 @@ class CombineMeanStd(Layer):
    mean_parent, std_parent = inputs[0], inputs[1]
    sample_noise = tf.random_normal(
        mean_parent.get_shape(), 0, self.noise_epsilon, dtype=tf.float32)
    if self.training_only:
    if self.training_only and 'training' in kwargs:
      sample_noise *= kwargs['training']
    out_tensor = mean_parent + tf.exp(std_parent * 0.5) * sample_noise
    if set_tensors:
@@ -2772,7 +2773,7 @@ class LSTMStep(Layer):
    self.W = init((self.input_dim, 4 * self.output_dim))
    self.U = inner_init((self.output_dim, 4 * self.output_dim))

    self.b = model_ops.create_variable(
    self.b = create_variable(
        np.hstack((np.zeros(self.output_dim), np.ones(self.output_dim),
                   np.zeros(self.output_dim), np.zeros(self.output_dim))),
        dtype=tf.float32)
@@ -3129,20 +3130,41 @@ class IterRefLSTMEmbedding(Layer):

class BatchNorm(Layer):

  def __init__(self, in_layers=None, **kwargs):
  def __init__(self,
               in_layers=None,
               axis=-1,
               momentum=0.99,
               epsilon=1e-3,
               **kwargs):
    super(BatchNorm, self).__init__(in_layers, **kwargs)
    self.axis = axis
    self.momentum = momentum
    self.epsilon = epsilon
    try:
      parent_shape = self.in_layers[0].shape
      self._shape = tuple(self.in_layers[0].shape)
    except:
      pass

  def _build_layer(self):
    return tf.layers.BatchNormalization(
        axis=self.axis, momentum=self.momentum, epsilon=self.epsilon)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    parent_tensor = inputs[0]
    out_tensor = tf.layers.batch_normalization(parent_tensor)
    if tfe.in_eager_mode():
      if not self._built:
        self._layer = self._build_layer()
      layer = self._layer
    else:
      layer = self._build_layer()
    out_tensor = layer(parent_tensor)
    if set_tensors:
      self.out_tensor = out_tensor
    if tfe.in_eager_mode() and not self._built:
      self._built = True
      self.variables = self._layer.variables
    return out_tensor


@@ -3155,6 +3177,9 @@ class BatchNormalization(Layer):
               beta_init='zero',
               gamma_init='one',
               **kwargs):
    warnings.warn(
        'BatchNormalization is deprecated and will be removed in a future release.  Use BatchNorm instead.',
        DeprecationWarning)
    self.beta_init = initializations.get(beta_init)
    self.gamma_init = initializations.get(gamma_init)
    self.epsilon = epsilon
@@ -3242,13 +3267,17 @@ class VinaFreeEnergy(Layer):
    self.stop = stop
    super(VinaFreeEnergy, self).__init__(**kwargs)

  def _build_layers(self):
    weighted_combo = WeightedLinearCombo()
    w = create_variable(tf.random_normal((1,), stddev=self.stddev))
    return (weighted_combo, w)

  def cutoff(self, d, x):
    out_tensor = tf.where(d < 8, x, tf.zeros_like(x))
    return out_tensor

  def nonlinearity(self, c):
  def nonlinearity(self, c, w):
    """Computes non-linearity used in Vina."""
    w = tf.Variable(tf.random_normal((1,), stddev=self.stddev))
    out_tensor = c / (1 + w * self.Nrot)
    return w, out_tensor

@@ -3298,6 +3327,14 @@ class VinaFreeEnergy(Layer):
    X = inputs[0]
    Z = inputs[1]

    if tfe.in_eager_mode():
      if not self._built:
        self._weighted_combo, self._w = self._build_layers()
      weighted_combo = self._weighted_combo
      w = self._w
    else:
      weighted_combo, w = self._build_layers()

    # TODO(rbharath): This layer shouldn't be neighbor-listing. Make
    # neighbors lists an argument instead of a part of this layer.
    nbr_list = NeighborList(self.N_atoms, self.M_nbrs, self.ndim,
@@ -3314,20 +3351,22 @@ class VinaFreeEnergy(Layer):
    gauss_2 = self.gaussian_second(dists)

    # Shape (N, M)
    weighted_combo = WeightedLinearCombo()
    interactions = weighted_combo(repulsion, hydrophobic, hbond, gauss_1,
                                  gauss_2)

    # Shape (N, M)
    thresholded = self.cutoff(dists, interactions)

    weight, free_energies = self.nonlinearity(thresholded)
    weight, free_energies = self.nonlinearity(thresholded, w)
    free_energy = ReduceSum()(free_energies)

    out_tensor = free_energy
    if set_tensors:
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    if tfe.in_eager_mode() and not self._built:
      self.variables = weighted_combo.variables + [w]
      self._built = True
    return out_tensor


@@ -3342,14 +3381,23 @@ class WeightedLinearCombo(Layer):
    except:
      pass

  def _create_variables(self, inputs):
    return [
        create_variable(tf.random_normal([1], stddev=self.std))
        for i in range(len(inputs))
    ]

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers, True)
    weights = []
    if tfe.in_eager_mode():
      if not self._built:
        self.variables = self._create_variables(inputs)
        self._built = True
      weights = self.variables
    else:
      weights = self._create_variables(inputs)
    out_tensor = None
    for in_tensor in inputs:
      w = tf.Variable(tf.random_normal([
          1,
      ], stddev=self.std))
    for in_tensor, w in zip(inputs, weights):
      if out_tensor is None:
        out_tensor = w * in_tensor
      else:
@@ -3659,7 +3707,8 @@ class Dropout(Layer):
  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    parent_tensor = inputs[0]
    keep_prob = 1.0 - self.dropout_prob * kwargs['training']
    training = kwargs['training'] if 'training' in kwargs else 1.0
    keep_prob = 1.0 - self.dropout_prob * training
    out_tensor = tf.nn.dropout(parent_tensor, keep_prob)
    if set_tensors:
      self.out_tensor = out_tensor
@@ -3763,11 +3812,18 @@ class AtomicConvolution(Layer):
    R = self.distance_matrix(D)
    sym = []
    rsf_zeros = tf.zeros((B, N, M))
    for param in self.radial_params:
    for i, param in enumerate(self.radial_params):

      if tfe.in_eager_mode():
        if not self._built:
          self.variables += self._create_radial_variables(*param)
        param_variables = self.variables[3 * i:3 * i + 3]
      else:
        param_variables = self._create_radial_variables(*param)

      # We apply the radial pooling filter before atom type conv
      # to reduce computation
      param_variables, rsf = self.radial_symmetry_function(R, *param)
      rsf = self.radial_symmetry_function(R, *param_variables)

      if not self.atom_types:
        cond = tf.not_equal(Nbrs_Z, 0.0)
@@ -3784,8 +3840,17 @@ class AtomicConvolution(Layer):
    if set_tensors:
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    if tfe.in_eager_mode() and not self._built:
      self._built = True
    return out_tensor

  def _create_radial_variables(self, rc, rs, e):
    with tf.name_scope(None, "NbrRadialSymmetryFunction", [rc, rs, e]):
      rc = create_variable(rc)
      rs = create_variable(rs)
      e = create_variable(e)
    return (rc, rs, e)

  def radial_symmetry_function(self, R, rc, rs, e):
    """Calculates radial symmetry function.

@@ -3809,13 +3874,9 @@ class AtomicConvolution(Layer):

    """

    with tf.name_scope(None, "NbrRadialSymmetryFunction", [rc, rs, e]):
      rc = tf.Variable(rc)
      rs = tf.Variable(rs)
      e = tf.Variable(e)
    K = self.gaussian_distance_matrix(R, rs, e)
    FC = self.radial_cutoff(R, rc)
    return [rc, rs, e], tf.multiply(K, FC)
    return tf.multiply(K, FC)

  def radial_cutoff(self, R, rc):
    """Calculates radial cutoff matrix.
@@ -4021,7 +4082,17 @@ class AlphaShareLayer(Layer):
    subspaces = tf.reshape(tf.stack(subspaces), [n_alphas, -1])

    # create the alpha learnable parameters
    alphas = tf.Variable(tf.random_normal([n_alphas, n_alphas]), name='alphas')
    if tfe.in_eager_mode():
      if not self._built:
        self.variables = [
            create_variable(
                tf.random_normal([n_alphas, n_alphas]), name='alphas')
        ]
        self._built = True
      alphas = self.variables[0]
    else:
      alphas = create_variable(
          tf.random_normal([n_alphas, n_alphas]), name='alphas')

    subspaces = tf.matmul(alphas, subspaces)

@@ -4113,7 +4184,15 @@ class BetaShare(Layer):
    n_betas = len(inputs)
    subspaces = tf.reshape(tf.stack(subspaces), [n_betas, -1])

    betas = tf.Variable(tf.random_normal([1, n_betas]), name='betas')
    if tfe.in_eager_mode():
      if not self._built:
        self.variables = [
            create_variable(tf.random_normal([1, n_betas]), name='betas')
        ]
        self._built = True
      betas = self.variables[0]
    else:
      betas = create_variable(tf.random_normal([1, n_betas]), name='betas')
    out_tensor = tf.matmul(betas, subspaces)
    self.betas = betas
    self.out_tensor = tf.reshape(out_tensor, [-1, original_cols])
@@ -4134,7 +4213,7 @@ class ANIFeat(Layer):
  """

  def __init__(self,
               in_layers,
               in_layers=None,
               max_atoms=23,
               radial_cutoff=4.6,
               angular_cutoff=3.1,
@@ -4405,17 +4484,25 @@ class GraphEmbedPoolLayer(Layer):
    self.out_tensors = [result, result_A]
    return result, result_A

  def _create_variables(self, no_features, no_filters, name):
    W = create_variable(
        tf.truncated_normal(
            [no_features, no_filters], stddev=1.0 / math.sqrt(no_features)),
        name='%s_weights' % name,
        dtype=tf.float32)
    b = create_variable(
        tf.constant(0.1), name='%s_bias' % self.name, dtype=tf.float32)
    return [W, b]

  def embedding_factors(self, V, no_filters, name="default"):
    no_features = V.get_shape()[-1].value
    W = tf.get_variable(
        '%s_weights' % name, [no_features, no_filters],
        initializer=tf.truncated_normal_initializer(
            stddev=1.0 / math.sqrt(no_features)),
        dtype=tf.float32)
    b = tf.get_variable(
        '%s_bias' % self.name, [no_filters],
        initializer=tf.constant_initializer(0.1),
        dtype=tf.float32)
    if tfe.in_eager_mode():
      if not self._built:
        self.variables = self._create_variables(no_features, no_filters, name)
        self._built = True
      W, b = self.variables
    else:
      W, b = self._create_variables(no_features, no_filters, name)
    V_reshape = tf.reshape(V, (-1, no_features))
    s = tf.slice(tf.shape(V), [0], [len(V.get_shape()) - 1])
    s = tf.concat([s, tf.stack([no_filters])], 0)
@@ -4487,6 +4574,23 @@ class GraphCNN(Layer):
    self.num_filters = num_filters
    super(GraphCNN, self).__init__(**kwargs)

  def _create_variables(self, no_features, no_A):
    W = create_variable(
        tf.truncated_normal(
            [no_features * no_A, self.num_filters],
            stddev=math.sqrt(1.0 / (no_features * (no_A + 1) * 1.0))),
        name='%s_weights' % self.name,
        dtype=tf.float32)
    W_I = create_variable(
        tf.truncated_normal(
            [no_features, self.num_filters],
            stddev=math.sqrt(1.0 / (no_features * (no_A + 1) * 1.0))),
        name='%s_weights_I' % self.name,
        dtype=tf.float32)
    b = create_variable(
        tf.constant(0.1), name='%s_bias' % self.name, dtype=tf.float32)
    return [W, W_I, b]

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    if len(inputs) == 3:
@@ -4495,21 +4599,13 @@ class GraphCNN(Layer):
      V, A = inputs
    no_A = A.get_shape()[2].value
    no_features = V.get_shape()[2].value
    W = tf.get_variable(
        '%s_weights' % self.name, [no_features * no_A, self.num_filters],
        initializer=tf.truncated_normal_initializer(
            stddev=math.sqrt(1.0 / (no_features * (no_A + 1) * 1.0))),
        dtype=tf.float32)
    W_I = tf.get_variable(
        '%s_weights_I' % self.name, [no_features, self.num_filters],
        initializer=tf.truncated_normal_initializer(
            stddev=math.sqrt(1.0 / (no_features * (no_A + 1) * 1.0))),
        dtype=tf.float32)

    b = tf.get_variable(
        '%s_bias' % self.name, [self.num_filters],
        initializer=tf.constant_initializer(0.1),
        dtype=tf.float32)
    if tfe.in_eager_mode():
      if not self._built:
        self.variables = self._create_variables(no_features, no_A)
        self._built = True
      W, W_I, b = self.variables
    else:
      W, W_I, b = self._create_variables(no_features, no_A)

    n = self.graphConvolution(V, A)
    A_shape = tf.shape(A)
+3 −3
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ from deepchem.models.tensorgraph.graph_layers import WeaveGather, \
from deepchem.models.tensorgraph.graph_layers import WeaveLayerFactory
from deepchem.models.tensorgraph.layers import Dense, SoftMax, \
    SoftMaxCrossEntropy, GraphConv, BatchNorm, \
    GraphPool, GraphGather, WeightedError, Dropout, BatchNormalization, Stack, Flatten, GraphCNN, GraphCNNPool
    GraphPool, GraphGather, WeightedError, Dropout, BatchNorm, Stack, Flatten, GraphCNN, GraphCNNPool
from deepchem.models.tensorgraph.layers import L2Loss, Label, Weights, Feature
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.trans import undo_transforms
@@ -86,7 +86,7 @@ class WeaveTensorGraph(TensorGraph):
        out_channels=self.n_graph_feat,
        activation_fn=tf.nn.tanh,
        in_layers=weave_layer2A)
    batch_norm1 = BatchNormalization(epsilon=1e-5, mode=1, in_layers=[dense1])
    batch_norm1 = BatchNorm(epsilon=1e-5, in_layers=[dense1])
    weave_gather = WeaveGather(
        self.batch_size,
        n_input=self.n_graph_feat,
+1 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ import copy

from deepchem.metrics import to_one_hot, from_one_hot
from deepchem.models.tensorgraph.layers import Dense, Concat, SoftMax, \
  SoftMaxCrossEntropy, BatchNorm, WeightedError, Dropout, BatchNormalization, \
  SoftMaxCrossEntropy, BatchNorm, WeightedError, Dropout, \
  Conv1D, ReduceMax, Squeeze, Stack, Highway
from deepchem.models.tensorgraph.graph_layers import DTNNEmbedding

+249 −1
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ class TestLayersEager(test_util.TensorFlowTestCase):
        result = layer(input)
        self.assertEqual(result.shape[0], batch_size)
        self.assertEqual(result.shape[2], filters)
        assert len(layer.variables) == 2

        # Creating a second layer should produce different results, since it has
        # different random weights.
@@ -137,7 +138,7 @@ class TestLayersEager(test_util.TensorFlowTestCase):
        mean = np.random.rand(5, 3).astype(np.float32)
        std = np.random.rand(5, 3).astype(np.float32)
        layer = layers.CombineMeanStd(training_only=True, noise_epsilon=0.01)
        result1 = layer(mean, std)
        result1 = layer(mean, std, training=False)
        assert np.array_equal(result1, mean)  # No noise in test mode
        result2 = layer(mean, std, training=True)
        assert not np.array_equal(result2, mean)
@@ -730,3 +731,250 @@ class TestLayersEager(test_util.TensorFlowTestCase):
        assert test_out.shape == (n_test, n_feat)
        assert support_out.shape == (n_support, n_feat)
        assert len(layer.variables) == 8

  def test_batch_norm(self):
    """Test invoking BatchNorm in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        n_features = 5
        input = np.random.rand(batch_size, n_features).astype(np.float32)
        layer = layers.BatchNorm()
        result = layer(input)
        assert result.shape == (batch_size, n_features)
        assert len(layer.variables) == 4

  def test_weighted_error(self):
    """Test invoking WeightedError in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input1 = np.random.rand(5, 10).astype(np.float32)
        input2 = np.random.rand(5, 10).astype(np.float32)
        result = layers.WeightedError()(input1, input2)
        expected = np.sum(input1 * input2)
        assert np.allclose(result, expected)

  def test_vina_free_energy(self):
    """Test invoking VinaFreeEnergy in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        n_atoms = 5
        m_nbrs = 1
        ndim = 3
        nbr_cutoff = 1
        start = 0
        stop = 4
        X = np.random.rand(n_atoms, ndim).astype(np.float32)
        Z = np.random.randint(0, 2, (n_atoms)).astype(np.float32)
        layer = layers.VinaFreeEnergy(n_atoms, m_nbrs, ndim, nbr_cutoff, start,
                                      stop)
        result = layer(X, Z)
        assert len(layer.variables) == 6
        assert result.shape == tuple()

        # Creating a second layer should produce different results, since it has
        # different random weights.

        layer2 = layers.VinaFreeEnergy(n_atoms, m_nbrs, ndim, nbr_cutoff, start,
                                       stop)
        result2 = layer2(X, Z)
        assert not np.allclose(result, result2)

        # But evaluating the first layer again should produce the same result as before.

        result3 = layer(X, Z)
        assert np.allclose(result, result3)

  def test_weighted_linear_combo(self):
    """Test invoking WeightedLinearCombo in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input1 = np.random.rand(5, 10).astype(np.float32)
        input2 = np.random.rand(5, 10).astype(np.float32)
        layer = layers.WeightedLinearCombo()
        result = layer(input1, input2)
        assert len(layer.variables) == 2
        expected = input1 * layer.variables[0] + input2 * layer.variables[1]
        assert np.allclose(result, expected)

  def test_neighbor_list(self):
    """Test invoking NeighborList in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        N_atoms = 5
        start = 0
        stop = 12
        nbr_cutoff = 3
        ndim = 3
        M_nbrs = 2
        coords = start + np.random.rand(N_atoms, ndim) * (stop - start)
        coords = tf.to_float(tf.stack(coords))
        layer = layers.NeighborList(N_atoms, M_nbrs, ndim, nbr_cutoff, start,
                                    stop)
        result = layer(coords)
        assert result.shape == (N_atoms, M_nbrs)

  def test_dropout(self):
    """Test invoking Dropout in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        rate = 0.5
        input = np.random.rand(5, 10).astype(np.float32)
        layer = layers.Dropout(rate)
        result1 = layer(input, training=False)
        assert np.allclose(result1, input)
        result2 = layer(input, training=True)
        assert not np.allclose(result2, input)
        nonzero = result2.numpy() != 0
        assert np.allclose(result2.numpy()[nonzero], input[nonzero] / rate)

  def test_atomic_convolution(self):
    """Test invoking AtomicConvolution in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 4
        max_atoms = 5
        max_neighbors = 2
        dimensions = 3
        params = [[5.0, 2.0, 0.5], [10.0, 2.0, 0.5]]
        input1 = np.random.rand(batch_size, max_atoms, dimensions).astype(
            np.float32)
        input2 = np.random.randint(
            max_atoms, size=(batch_size, max_atoms, max_neighbors))
        input3 = np.random.randint(
            1, 10, size=(batch_size, max_atoms, max_neighbors))
        layer = layers.AtomicConvolution(radial_params=params)
        result = layer(input1, input2, input3)
        assert result.shape == (batch_size, max_atoms, len(params))
        assert len(layer.variables) == 3 * len(params)

  def test_alpha_share_layer(self):
    """Test invoking AlphaShareLayer in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        length = 6
        input1 = np.random.rand(batch_size, length).astype(np.float32)
        input2 = np.random.rand(batch_size, length).astype(np.float32)
        layer = layers.AlphaShareLayer()
        result = layer(input1, input2)
        assert input1.shape == result[0].shape
        assert input2.shape == result[1].shape

        # Creating a second layer should produce different results, since it has
        # different random weights.

        layer2 = layers.AlphaShareLayer()
        result2 = layer2(input1, input2)
        assert not np.allclose(result[0], result2[0])
        assert not np.allclose(result[1], result2[1])

        # But evaluating the first layer again should produce the same result as before.

        result3 = layer(input1, input2)
        assert np.allclose(result[0], result3[0])
        assert np.allclose(result[1], result3[1])

  def test_sluice_loss(self):
    """Test invoking SluiceLoss in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input1 = np.ones((3, 4)).astype(np.float32)
        input2 = np.ones((2, 2)).astype(np.float32)
        result = layers.SluiceLoss()(input1, input2)
        assert np.allclose(result, 40.0)

  def test_beta_share(self):
    """Test invoking BetaShare in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        length = 6
        input1 = np.random.rand(batch_size, length).astype(np.float32)
        input2 = np.random.rand(batch_size, length).astype(np.float32)
        layer = layers.BetaShare()
        result = layer(input1, input2)
        assert input1.shape == result.shape
        assert input2.shape == result.shape

        # Creating a second layer should produce different results, since it has
        # different random weights.

        layer2 = layers.BetaShare()
        result2 = layer2(input1, input2)
        assert not np.allclose(result, result2)

        # But evaluating the first layer again should produce the same result as before.

        result3 = layer(input1, input2)
        assert np.allclose(result, result3)

  def test_ani_feat(self):
    """Test invoking ANIFeat in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        max_atoms = 5
        input = np.random.rand(batch_size, max_atoms, 4).astype(np.float32)
        layer = layers.ANIFeat(max_atoms=max_atoms)
        result = layer(input)
        # TODO What should the output shape be?  It's not documented, and there
        # are no other test cases for it.

  def test_graph_embed_pool_layer(self):
    """Test invoking GraphEmbedPoolLayer in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        V = np.random.uniform(size=(10, 100, 50)).astype(np.float32)
        adjs = np.random.uniform(size=(10, 100, 5, 100)).astype(np.float32)
        layer = layers.GraphEmbedPoolLayer(num_vertices=6)
        result = layer(V, adjs)
        assert result[0].shape == (10, 6, 50)
        assert result[1].shape == (10, 6, 5, 6)

        # Creating a second layer should produce different results, since it has
        # different random weights.

        layer2 = layers.GraphEmbedPoolLayer(num_vertices=6)
        result2 = layer2(V, adjs)
        assert not np.allclose(result[0], result2[0])
        assert not np.allclose(result[1], result2[1])

        # But evaluating the first layer again should produce the same result as before.

        result3 = layer(V, adjs)
        assert np.allclose(result[0], result3[0])
        assert np.allclose(result[1], result3[1])

  def test_graph_cnn(self):
    """Test invoking GraphCNN in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        V = np.random.uniform(size=(10, 100, 50)).astype(np.float32)
        adjs = np.random.uniform(size=(10, 100, 5, 100)).astype(np.float32)
        layer = layers.GraphCNN(num_filters=6)
        result = layer(V, adjs)
        assert result.shape == (10, 100, 6)

        # Creating a second layer should produce different results, since it has
        # different random weights.

        layer2 = layers.GraphCNN(num_filters=6)
        result2 = layer2(V, adjs)
        assert not np.allclose(result, result2)

        # But evaluating the first layer again should produce the same result as before.

        result3 = layer(V, adjs)
        assert np.allclose(result, result3)

  def test_hinge_loss(self):
    """Test invoking HingeLoss in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        n_labels = 1
        n_logits = 1
        logits = np.random.rand(n_logits).astype(np.float32)
        labels = np.random.rand(n_labels).astype(np.float32)
        result = layers.Hingeloss()(labels, logits)
        assert result.shape == (n_labels,)