Commit 7cf104f9 authored by peastman's avatar peastman
Browse files

Began converting to TensorFlow 2

parent 7492d910
Loading
Loading
Loading
Loading
+10 −9
Original line number Diff line number Diff line
@@ -47,11 +47,10 @@ def initializeWeightsBiases(prev_layer_size,
  """

  if weights is None:
    weights = tf.truncated_normal([prev_layer_size, size], stddev=0.01)
    weights = tf.random.truncated_normal([prev_layer_size, size], stddev=0.01)
  if biases is None:
    biases = tf.zeros([size])

  with tf.name_scope(name, 'fully_connected', [weights, biases]):
  w = tf.Variable(weights, name='w')
  b = tf.Variable(biases, name='b')
  return w, b
@@ -84,7 +83,7 @@ class AtomicConvScore(Layer):
        weight, bias = initializeWeightsBiases(
            prev_layer_size=prev_layer_size,
            size=layer_sizes[i],
            weights=tf.truncated_normal(
            weights=tf.random.truncated_normal(
                shape=[prev_layer_size, layer_sizes[i]],
                stddev=weight_init_stddevs[i]),
            biases=tf.constant(
@@ -104,13 +103,15 @@ class AtomicConvScore(Layer):
    def atomnet(current_input, atomtype):
      prev_layer = current_input
      for i in range(num_layers):
        layer = tf.nn.xw_plus_b(prev_layer, self.type_weights[atomtype][i],
        layer = tf.nn.bias_add(
            tf.matmul(prev_layer, self.type_weights[atomtype][i]),
            self.type_biases[atomtype][i])
        layer = tf.nn.relu(layer)
        prev_layer = layer

      output_layer = tf.squeeze(
          tf.nn.xw_plus_b(prev_layer, self.output_weights[atomtype][0],
          tf.nn.bias_add(
              tf.matmul(prev_layer, self.output_weights[atomtype][0]),
              self.output_biases[atomtype][0]))
      return output_layer

+6 −6
Original line number Diff line number Diff line
@@ -125,10 +125,10 @@ class MultitaskClassifier(KerasModel):
        activation_fns):
      layer = prev_layer
      if next_activation is not None:
        layer = Activation(activation_fn)(layer)
        layer = Activation(next_activation)(layer)
      layer = Dense(
          size,
          kernel_initializer=tf.truncated_normal_initializer(
          kernel_initializer=tf.keras.initializers.TruncatedNormal(
              stddev=weight_stddev),
          bias_initializer=tf.constant_initializer(value=bias_const),
          kernel_regularizer=regularizer)(layer)
@@ -275,10 +275,10 @@ class MultitaskRegressor(KerasModel):
        activation_fns):
      layer = prev_layer
      if next_activation is not None:
        layer = Activation(activation_fn)(layer)
        layer = Activation(next_activation)(layer)
      layer = Dense(
          size,
          kernel_initializer=tf.truncated_normal_initializer(
          kernel_initializer=tf.keras.initializers.TruncatedNormal(
              stddev=weight_stddev),
          bias_initializer=tf.constant_initializer(value=bias_const),
          kernel_regularizer=regularizer)(layer)
@@ -295,14 +295,14 @@ class MultitaskRegressor(KerasModel):
    self.neural_fingerprint = prev_layer
    output = Reshape((n_tasks, 1))(Dense(
        n_tasks,
        kernel_initializer=tf.truncated_normal_initializer(
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=weight_init_stddevs[-1]),
        bias_initializer=tf.constant_initializer(
            value=bias_init_consts[-1]))(prev_layer))
    if uncertainty:
      log_var = Reshape((n_tasks, 1))(Dense(
          n_tasks,
          kernel_initializer=tf.truncated_normal_initializer(
          kernel_initializer=tf.keras.initializers.TruncatedNormal(
              stddev=weight_init_stddevs[-1]),
          bias_initializer=tf.constant_initializer(value=0.0))(prev_layer))
      var = Activation(tf.exp)(log_var)
+27 −19
Original line number Diff line number Diff line
@@ -80,12 +80,16 @@ class GAN(KerasModel):
    # Create the inputs.

    self.noise_input = Input(shape=self.get_noise_input_shape())
    self.data_inputs = []
    self.data_input_layers = []
    for shape in self.get_data_input_shapes():
      self.data_inputs.append(Input(shape=shape))
    self.conditional_inputs = []
      self.data_input_layers.append(Input(shape=shape))
    self.data_inputs = [i.experimental_ref() for i in self.data_input_layers]
    self.conditional_input_layers = []
    for shape in self.get_conditional_input_shapes():
      self.conditional_inputs.append(Input(shape=shape))
      self.conditional_input_layers.append(Input(shape=shape))
    self.conditional_inputs = [
        i.experimental_ref() for i in self.conditional_input_layers
    ]

    # Create the generators.

@@ -97,7 +101,8 @@ class GAN(KerasModel):
      self.generators.append(generator)
      generator_outputs.append(
          generator(
              _list_or_tensor([self.noise_input] + self.conditional_inputs)))
              _list_or_tensor([self.noise_input] +
                              self.conditional_input_layers)))
      self.gen_variables += generator.trainable_variables

    # Create the discriminators.
@@ -111,13 +116,14 @@ class GAN(KerasModel):
      self.discriminators.append(discriminator)
      discrim_train_outputs.append(
          discriminator(
              _list_or_tensor(self.data_inputs + self.conditional_inputs)))
              _list_or_tensor(self.data_input_layers +
                              self.conditional_input_layers)))
      for gen_output in generator_outputs:
        if isinstance(gen_output, tf.Tensor):
          gen_output = [gen_output]
        discrim_gen_outputs.append(
            discriminator(
                _list_or_tensor(gen_output + self.conditional_inputs)))
                _list_or_tensor(gen_output + self.conditional_input_layers)))
      self.discrim_variables += discriminator.trainable_variables

    # Compute the loss functions.
@@ -161,14 +167,15 @@ class GAN(KerasModel):

      # Add an entropy term to the loss.

      entropy = Lambda(lambda x: -(tf.reduce_sum(tf.log(x[0]))/n_generators +
          tf.reduce_sum(tf.log(x[1]))/n_discriminators))([gen_weights, discrim_weights])
      entropy = Lambda(lambda x: -(tf.reduce_sum(tf.math.log(x[0]))/n_generators +
          tf.reduce_sum(tf.math.log(x[1]))/n_discriminators))([gen_weights, discrim_weights])
      total_discrim_loss = Lambda(lambda x: x[0] + x[1])(
          [total_discrim_loss, entropy])

    # Create the Keras model.

    inputs = [self.noise_input] + self.data_inputs + self.conditional_inputs
    inputs = [self.noise_input
             ] + self.data_input_layers + self.conditional_input_layers
    outputs = [total_gen_loss, total_discrim_loss]
    self.gen_loss_fn = lambda outputs, labels, weights: outputs[0]
    self.discrim_loss_fn = lambda outputs, labels, weights: outputs[1]
@@ -257,7 +264,8 @@ class GAN(KerasModel):
    -------
    A Tensor equal to the loss function to use for optimizing the generator.
    """
    return Lambda(lambda x: -tf.reduce_mean(tf.log(x + 1e-10)))(discrim_output)
    return Lambda(lambda x: -tf.reduce_mean(tf.math.log(x + 1e-10)))(
        discrim_output)

  def create_discriminator_loss(self, discrim_output_train, discrim_output_gen):
    """Create the loss function for the discriminator.
@@ -278,7 +286,7 @@ class GAN(KerasModel):
    -------
    A Tensor equal to the loss function to use for optimizing the discriminator.
    """
    return Lambda(lambda x: -tf.reduce_mean(tf.log(x[0]+1e-10) + tf.log(1-x[1]+1e-10)))([discrim_output_train, discrim_output_gen])
    return Lambda(lambda x: -tf.reduce_mean(tf.math.log(x[0]+1e-10) + tf.math.log(1-x[1]+1e-10)))([discrim_output_train, discrim_output_gen])

  def fit_gan(self,
              batches,
@@ -312,7 +320,7 @@ class GAN(KerasModel):
    self._ensure_built()
    if not tf.executing_eagerly():
      global_step_placeholder = tf.placeholder(tf.int32, tuple())
      update_global_step = tf.assign(self._global_step, global_step_placeholder)
      update_global_step = self._global_step.assign(global_step_placeholder)
    gen_train_fraction = 0.0
    discrim_error = 0.0
    gen_error = 0.0
@@ -332,10 +340,10 @@ class GAN(KerasModel):
      # Train the discriminator.

      inputs = [self.get_noise_batch(self.batch_size)]
      for input in self.data_inputs:
        inputs.append(feed_dict[input])
      for input in self.conditional_inputs:
        inputs.append(feed_dict[input])
      for input in self.data_input_layers:
        inputs.append(feed_dict[input.experimental_ref()])
      for input in self.conditional_input_layers:
        inputs.append(feed_dict[input.experimental_ref()])
      discrim_error += self.fit_generator(
          [(inputs, [], [])],
          variables=self.discrim_variables,
@@ -358,7 +366,7 @@ class GAN(KerasModel):
          gen_average_steps += 1
          gen_train_fraction -= 1.0
      if tf.executing_eagerly():
        tf.assign(self._global_step, global_step + 1)
        self._global_step.assign(global_step + 1)
      else:
        self.session.run(update_global_step,
                         {global_step_placeholder: global_step + 1})
@@ -513,7 +521,7 @@ class GradientPenaltyLayer(Layer):
    self.gan = gan

  def call(self, inputs):
    gradients = tf.gradients(inputs, self.gan.data_inputs)
    gradients = tf.gradients(inputs, self.gan.data_input_layers)
    norm2 = 0.0
    for g in gradients:
      g2 = tf.square(g)
+5 −4
Original line number Diff line number Diff line
@@ -385,7 +385,7 @@ class KerasModel(Model):
          vars = variables
        grads = tape.gradient(batch_loss, vars)
        self._tf_optimizer.apply_gradients(zip(grads, vars))
        tf.assign_add(self._global_step, 1)
        self._global_step.assign_add(1)
        current_step = self._global_step.numpy()
      else:

@@ -553,10 +553,11 @@ class KerasModel(Model):
          inputs = inputs[0]
        if outputs is not None:
          outputs = tuple(outputs)
          if outputs not in self._output_functions:
            self._output_functions[outputs] = tf.keras.backend.function(
          key = tuple(t.experimental_ref() for t in outputs)
          if key not in self._output_functions:
            self._output_functions[key] = tf.keras.backend.function(
                self.model.inputs, outputs)
          output_values = self._output_functions[outputs](inputs)
          output_values = self._output_functions[key](inputs)
        else:
          output_values = self.model(inputs, training=False)
          if isinstance(output_values, tf.Tensor):
+58 −58
Original line number Diff line number Diff line
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import collections
from deepchem.models.tensorgraph import model_ops, initializations, activations
@@ -127,8 +128,7 @@ class GraphConv(tf.keras.layers.Layer):
    atom_features = tf.concat(axis=0, values=new_rel_atoms_collection)

    if self.activation_fn is not None:
      activation = activations.get(self.activation_fn)
      atom_features = activation(atom_features)
      atom_features = self.activation_fn(atom_features)

    return atom_features

@@ -219,15 +219,14 @@ class GraphGather(tf.keras.layers.Layer):

    assert self.batch_size > 1, "graph_gather requires batches larger than 1"

    sparse_reps = tf.unsorted_segment_sum(atom_features, membership,
    sparse_reps = tf.math.unsorted_segment_sum(atom_features, membership,
                                               self.batch_size)
    max_reps = tf.unsorted_segment_max(atom_features, membership,
    max_reps = tf.math.unsorted_segment_max(atom_features, membership,
                                            self.batch_size)
    mol_features = tf.concat(axis=1, values=[sparse_reps, max_reps])

    if self.activation_fn is not None:
      activation = activations.get(self.activation_fn)
      mol_features = activation(mol_features)
      mol_features = self.activation_fn(mol_features)
    return mol_features


@@ -272,6 +271,8 @@ class LSTMStep(tf.keras.layers.Layer):
    # No other forget biases supported right now.
    self.activation = activation_fn
    self.inner_activation = inner_activation_fn
    self.activation_fn = activations.get(activation_fn)
    self.inner_activation_fn = activations.get(inner_activation_fn)
    self.input_dim = input_dim

  def get_config(self):
@@ -313,8 +314,6 @@ class LSTMStep(tf.keras.layers.Layer):
    list
      Returns h, [h, c]
    """
    activation = activations.get(self.activation)
    inner_activation = activations.get(self.inner_activation)
    x, h_tm1, c_tm1 = inputs

    # Taken from Keras code [citation needed]
@@ -325,12 +324,12 @@ class LSTMStep(tf.keras.layers.Layer):
    z2 = z[:, 2 * self.output_dim:3 * self.output_dim]
    z3 = z[:, 3 * self.output_dim:]

    i = inner_activation(z0)
    f = inner_activation(z1)
    c = f * c_tm1 + i * activation(z2)
    o = inner_activation(z3)
    i = self.inner_activation_fn(z0)
    f = self.inner_activation_fn(z1)
    c = f * c_tm1 + i * self.activation_fn(z2)
    o = self.inner_activation_fn(z3)

    h = o * activation(c)
    h = o * self.activation_fn(c)
    return h, [h, c]


@@ -1183,10 +1182,10 @@ class AtomicConvolution(tf.keras.layers.Layer):
    # M: Maximum number of neighbors
    # d: Number of coordinates/features/filters
    # B: Batch Size
    N = X.get_shape()[-2].value
    d = X.get_shape()[-1].value
    M = Nbrs.get_shape()[-1].value
    B = X.get_shape()[0].value
    N = X.get_shape()[-2]
    d = X.get_shape()[-1]
    M = Nbrs.get_shape()[-1]
    B = X.get_shape()[0]

    # Compute the distances and radial symmetry functions.
    D = self.distance_tensor(X, Nbrs, self.boxsize, B, N, M, d)
@@ -1299,7 +1298,7 @@ class AtomicConvolution(tf.keras.layers.Layer):
      Coordinates/features distance tensor.
    """
    flat_neighbors = tf.reshape(Nbrs, [-1, N * M])
    neighbor_coords = tf.batch_gather(X, flat_neighbors)
    neighbor_coords = tf.gather(X, flat_neighbors, batch_dims=-1)
    neighbor_coords = tf.reshape(neighbor_coords, [-1, N, M, d])
    D = neighbor_coords - tf.expand_dims(X, 2)
    if boxsize is not None:
@@ -1944,6 +1943,7 @@ class WeaveLayer(tf.keras.layers.Layer):
    super(WeaveLayer, self).__init__(**kwargs)
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.activation_fn = activations.get(activation)
    self.update_pair = update_pair  # last weave layer does not need to update
    self.n_hidden_AA = n_hidden_AA
    self.n_hidden_PA = n_hidden_PA
@@ -2020,13 +2020,13 @@ class WeaveLayer(tf.keras.layers.Layer):
    pair_split = inputs[2]
    atom_to_pair = inputs[3]

    activation = activations.get(self.activation)
    activation = self.activation_fn

    AA = tf.matmul(atom_features, self.W_AA) + self.b_AA
    AA = activation(AA)
    PA = tf.matmul(pair_features, self.W_PA) + self.b_PA
    PA = activation(PA)
    PA = tf.segment_sum(PA, pair_split)
    PA = tf.math.segment_sum(PA, pair_split)

    A = tf.matmul(tf.concat([AA, PA], 1), self.W_A) + self.b_A
    A = activation(A)
@@ -2084,6 +2084,7 @@ class WeaveGather(tf.keras.layers.Layer):
    self.gaussian_expand = gaussian_expand
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.activation_fn = activations.get(activation)
    self.epsilon = epsilon
    self.momentum = momentum

@@ -2108,16 +2109,15 @@ class WeaveGather(tf.keras.layers.Layer):
  def call(self, inputs):
    outputs = inputs[0]
    atom_split = inputs[1]
    activation = activations.get(self.activation)

    if self.gaussian_expand:
      outputs = self.gaussian_histogram(outputs)

    output_molecules = tf.segment_sum(outputs, atom_split)
    output_molecules = tf.math.segment_sum(outputs, atom_split)

    if self.gaussian_expand:
      output_molecules = tf.matmul(output_molecules, self.W) + self.b
      output_molecules = activation(output_molecules)
      output_molecules = self.activation_fn(output_molecules)

    return output_molecules

@@ -2126,10 +2126,7 @@ class WeaveGather(tf.keras.layers.Layer):
                            (-0.468, 0.118), (-0.228, 0.114), (0., 0.114),
                            (0.228, 0.114), (0.468, 0.118), (0.739, 0.134),
                            (1.080, 0.170), (1.645, 0.283)]
    dist = [
        tf.contrib.distributions.Normal(p[0], p[1])
        for p in gaussian_memberships
    ]
    dist = [tfp.distributions.Normal(p[0], p[1]) for p in gaussian_memberships]
    dist_max = [dist[i].prob(gaussian_memberships[i][0]) for i in range(11)]
    outputs = [dist[i].prob(x) / dist_max[i] for i in range(11)]
    outputs = tf.stack(outputs, axis=2)
@@ -2209,6 +2206,7 @@ class DTNNStep(tf.keras.layers.Layer):
    self.n_hidden = n_hidden
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.activation_fn = activations.get(activation)

  def get_config(self):
    config = super(DTNNStep, self).get_config()
@@ -2240,7 +2238,6 @@ class DTNNStep(tf.keras.layers.Layer):
    distance = inputs[1]
    distance_membership_i = inputs[2]
    distance_membership_j = inputs[3]
    activation = activations.get(self.activation)
    distance_hidden = tf.matmul(distance, self.W_df) + self.b_df
    atom_features_hidden = tf.matmul(atom_features, self.W_cf) + self.b_cf
    outputs = tf.multiply(
@@ -2249,15 +2246,15 @@ class DTNNStep(tf.keras.layers.Layer):
    # for atom i in a molecule m, this step multiplies together distance info of atom pair(i,j)
    # and embeddings of atom j(both gone through a hidden layer)
    outputs = tf.matmul(outputs, self.W_fc)
    outputs = activation(outputs)
    outputs = self.activation_fn(outputs)

    output_ii = tf.multiply(self.b_df, atom_features_hidden)
    output_ii = tf.matmul(output_ii, self.W_fc)
    output_ii = activation(output_ii)
    output_ii = self.activation_fn(output_ii)

    # for atom i, sum the influence from all other atom j in the molecule
    return tf.segment_sum(outputs,
                          distance_membership_i) - output_ii + atom_features
    return tf.math.segment_sum(
        outputs, distance_membership_i) - output_ii + atom_features


class DTNNGather(tf.keras.layers.Layer):
@@ -2291,6 +2288,7 @@ class DTNNGather(tf.keras.layers.Layer):
    self.output_activation = output_activation
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.activation_fn = activations.get(activation)

  def get_config(self):
    config = super(DTNNGather, self).get_config()
@@ -2325,23 +2323,21 @@ class DTNNGather(tf.keras.layers.Layer):
    """
    output = inputs[0]
    atom_membership = inputs[1]
    activation = activations.get(self.activation)

    for i, W in enumerate(self.W_list[:-1]):
      output = tf.matmul(output, W) + self.b_list[i]
      output = activation(output)
      output = self.activation_fn(output)
    output = tf.matmul(output, self.W_list[-1]) + self.b_list[-1]
    if self.output_activation:
      output = activation(output)
    return tf.segment_sum(output, atom_membership)
      output = self.activation_fn(output)
    return tf.math.segment_sum(output, atom_membership)


def _DAGgraph_step(batch_inputs, W_list, b_list, activation, dropout,
def _DAGgraph_step(batch_inputs, W_list, b_list, activation_fn, dropout,
                   dropout_switch):
  outputs = batch_inputs
  activation_fn = activations.get(activation)
  for idw, W in enumerate(W_list):
    outputs = tf.nn.xw_plus_b(outputs, W, b_list[idw])
    outputs = tf.nn.bias_add(tf.matmul(outputs, W), b_list[idw])
    outputs = activation_fn(outputs)
    if not dropout is None:
      outputs = tf.nn.dropout(outputs, rate=dropout * dropout_switch)
@@ -2385,6 +2381,7 @@ class DAGLayer(tf.keras.layers.Layer):
    super(DAGLayer, self).__init__(**kwargs)
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.activation_fn = activations.get(activation)
    self.layer_sizes = layer_sizes
    self.dropout = dropout
    self.max_atoms = max_atoms
@@ -2439,6 +2436,7 @@ class DAGLayer(tf.keras.layers.Layer):

    n_atoms = tf.squeeze(inputs[4])
    dropout_switch = tf.squeeze(inputs[5])
    with tf.init_scope():
      # initialize graph features for each graph
      graph_features_initial = tf.zeros((self.max_atoms * self.batch_size,
                                         self.max_atoms + 1, self.n_graph_feat))
@@ -2446,7 +2444,6 @@ class DAGLayer(tf.keras.layers.Layer):
      # another row of zeros is generated for padded dummy atoms
      graph_features = tf.Variable(graph_features_initial, trainable=False)

    activation = activations.get(self.activation)
    for count in range(self.max_atoms):
      # `count`-th step
      # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
@@ -2478,14 +2475,15 @@ class DAGLayer(tf.keras.layers.Layer):
      # of shape: (batch_size*max_atoms) * n_graph_features
      # representing the graph features of target atoms in each graph
      batch_outputs = _DAGgraph_step(batch_inputs, self.W_list, self.b_list,
                                     activation, self.dropout, dropout_switch)
                                     self.activation_fn, self.dropout,
                                     dropout_switch)

      # index for targe atoms
      target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
      target_index = tf.boolean_mask(target_index, mask)
      # update the graph features for target atoms
      graph_features = tf.scatter_nd_update(graph_features, target_index,
                                            batch_outputs)
      graph_features = tf.compat.v1.scatter_nd_update(
          graph_features, target_index, batch_outputs)
    return batch_outputs


@@ -2523,6 +2521,7 @@ class DAGGather(tf.keras.layers.Layer):
    super(DAGGather, self).__init__(**kwargs)
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.activation_fn = activations.get(activation)
    self.layer_sizes = layer_sizes
    self.dropout = dropout
    self.max_atoms = max_atoms
@@ -2565,10 +2564,10 @@ class DAGGather(tf.keras.layers.Layer):
    membership = inputs[1]
    dropout_switch = tf.squeeze(inputs[2])
    # Extract atom_features
    graph_features = tf.segment_sum(atom_features, membership)
    graph_features = tf.math.segment_sum(atom_features, membership)
    # sum all graph outputs
    return _DAGgraph_step(graph_features, self.W_list, self.b_list,
                          self.activation, self.dropout, dropout_switch)
                          self.activation_fn, self.dropout, dropout_switch)


class MessagePassing(tf.keras.layers.Layer):
@@ -2664,11 +2663,11 @@ class EdgeNetwork(tf.keras.layers.Layer):

  def call(self, inputs):
    pair_features, atom_features, atom_to_pair = inputs
    A = tf.nn.xw_plus_b(pair_features, self.W, self.b)
    A = tf.nn.bias_add(tf.matmul(pair_features, self.W), self.b)
    A = tf.reshape(A, (-1, self.n_hidden, self.n_hidden))
    out = tf.expand_dims(tf.gather(atom_features, atom_to_pair[:, 1]), 2)
    out = tf.squeeze(tf.matmul(A, out), axis=2)
    return tf.segment_sum(out, atom_to_pair[:, 0])
    return tf.math.segment_sum(out, atom_to_pair[:, 0])


class GatedRecurrentUnit(tf.keras.layers.Layer):
@@ -2764,7 +2763,8 @@ class SetGather(tf.keras.layers.Layer):
          tf.concat([e_mol, tf.constant([-1000.])], 0) for e_mol in e_mols
      ]
      a = tf.concat([tf.nn.softmax(e_mol)[:-1] for e_mol in e_mols], 0)
      r = tf.segment_sum(tf.reshape(a, [-1, 1]) * atom_features, atom_split)
      r = tf.math.segment_sum(
          tf.reshape(a, [-1, 1]) * atom_features, atom_split)
      # Model using this layer must set pad_batches=True
      q_star = tf.concat([h, r], axis=1)
      h, c = self.LSTMStep(q_star, c)
@@ -2772,7 +2772,7 @@ class SetGather(tf.keras.layers.Layer):

  def LSTMStep(self, h, c, x=None):
    # Perform one step of LSTM
    z = tf.nn.xw_plus_b(h, self.U, self.b)
    z = tf.nn.bias_add(tf.matmul(h, self.U), self.b)
    i = tf.nn.sigmoid(z[:, :self.n_hidden])
    f = tf.nn.sigmoid(z[:, self.n_hidden:2 * self.n_hidden])
    o = tf.nn.sigmoid(z[:, 2 * self.n_hidden:3 * self.n_hidden])
Loading