Unverified Commit df466abe authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1026 from lilleswing/guzik-auto-encode

[READY] Variational Autoencoder Asupuru-Guzik
parents 4b8eff91 09cb4c0a
Loading
Loading
Loading
Loading
+115 −56
Original line number Diff line number Diff line
@@ -372,48 +372,93 @@ class Conv1D(Layer):
  """

  def __init__(self,
               width,
               out_channels,
               stride=1,
               padding='SAME',
               activation_fn=tf.nn.relu,
               biases_initializer=tf.random_normal_initializer,
               weights_initializer=tf.random_normal_initializer,
               filters,
               kernel_size,
               strides=1,
               padding='valid',
               dilation_rate=1,
               activation=None,
               use_bias=True,
               kernel_initializer='glorot_uniform',
               bias_initializer='zeros',
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               in_layers=None,
               **kwargs):
    """Create a Conv1D layer.

    Parameters
    ----------
    width: int
      the width of the convolutional kernel
    out_channels: int
      the number of outputs produced by the convolutional kernel
    stride: int
      the stride between applications of the convolutional kernel
    padding: str
      the padding method to use, either 'SAME' or 'VALID'
    activation_fn: object
      the Tensorflow activation function to apply to the output
    biases_initializer: callable object
      the initializer for bias values.  This may be None, in which case the layer
      will not include biases.
    weights_initializer: callable object
      the initializer for weight values
    """1D convolution layer (e.g. temporal convolution).

      This layer creates a convolution kernel that is convolved
      with the layer input over a single spatial (or temporal) dimension
      to produce a tensor of outputs.
      If `use_bias` is True, a bias vector is created and added to the outputs.
      Finally, if `activation` is not `None`,
      it is applied to the outputs as well.

      When using this layer as the first layer in a model,
      provide an `input_shape` argument
      (tuple of integers or `None`, e.g.
      `(10, 128)` for sequences of 10 vectors of 128-dimensional vectors,
      or `(None, 128)` for variable-length sequences of 128-dimensional vectors.

      TODO(LESWING): Calculate output shape at construction time
      Arguments:
          filters: Integer, the dimensionality of the output space
              (i.e. the number output of filters in the convolution).
          kernel_size: An integer or tuple/list of a single integer,
              specifying the length of the 1D convolution window.
          strides: An integer or tuple/list of a single integer,
              specifying the stride length of the convolution.
              Specifying any stride value != 1 is incompatible with specifying
              any `dilation_rate` value != 1.
          padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive).
              `"causal"` results in causal (dilated) convolutions, e.g. output[t]
              does not depend on input[t+1:]. Useful when modeling temporal data
              where the model should not violate the temporal order.
              See [WaveNet: A Generative Model for Raw Audio, section
                2.1](https://arxiv.org/abs/1609.03499).
          dilation_rate: an integer or tuple/list of a single integer, specifying
              the dilation rate to use for dilated convolution.
              Currently, specifying any `dilation_rate` value != 1 is
              incompatible with specifying any `strides` value != 1.
          activation: Activation function to use.
              If you don't specify anything, no activation is applied
              (ie. "linear" activation: `a(x) = x`).
          use_bias: Boolean, whether the layer uses a bias vector.
          kernel_initializer: Initializer for the `kernel` weights matrix.
          bias_initializer: Initializer for the bias vector.
          kernel_regularizer: Regularizer function applied to
              the `kernel` weights matrix.
          bias_regularizer: Regularizer function applied to the bias vector.
          activity_regularizer: Regularizer function applied to
              the output of the layer (its "activation")..
          kernel_constraint: Constraint function applied to the kernel matrix.
          bias_constraint: Constraint function applied to the bias vector.

      Input shape:
          3D tensor with shape: `(batch_size, steps, input_dim)`

      Output shape:
          3D tensor with shape: `(batch_size, new_steps, filters)`
          `steps` value might have changed due to padding or strides.
    """
    self.width = width
    self.out_channels = out_channels
    self.stride = stride
    self.filters = filters
    self.kernel_size = kernel_size
    self.strides = strides
    self.padding = padding
    self.activation_fn = activation_fn
    self.weights_initializer = weights_initializer
    self.biases_initializer = biases_initializer
    self.out_tensor = None
    super(Conv1D, self).__init__(**kwargs)
    try:
      parent_shape = self.in_layers[0].shape
      self._shape = (parent_shape[0], parent_shape[1] // stride, out_channels)
    except:
      pass
    self.dilation_rate = dilation_rate
    self.activation = activation
    self.use_bias = use_bias
    self.kernel_initializer = kernel_initializer
    self.bias_initializer = bias_initializer
    self.kernel_regularizer = kernel_regularizer
    self.bias_regularizer = bias_regularizer
    self.activity_regularizer = activity_regularizer
    self.kernel_constraint = kernel_constraint
    self.bias_constraint = bias_constraint
    super(Conv1D, self).__init__(in_layers, **kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
@@ -424,18 +469,21 @@ class Conv1D(Layer):
      parent = tf.expand_dims(parent, 2)
    elif len(parent.get_shape()) != 3:
      raise ValueError("Parent tensor must be (batch, width, channel)")
    parent_shape = parent.get_shape()
    parent_channel_size = parent_shape[2].value
    f = tf.Variable(self.weights_initializer()(
        [self.width, parent_channel_size, self.out_channels]))
    t = tf.nn.conv1d(parent, f, stride=self.stride, padding=self.padding)
    if self.biases_initializer is not None:
      b = tf.Variable(self.biases_initializer()([self.out_channels]))
      t = tf.nn.bias_add(t, b)
    if self.activation_fn is None:
      out_tensor = t
    else:
      out_tensor = self.activation_fn(t)
    out_tensor = tf.keras.layers.Conv1D(
        filters=self.filters,
        kernel_size=self.kernel_size,
        strides=self.strides,
        padding=self.padding,
        dilation_rate=self.dilation_rate,
        activation=self.activation,
        use_bias=self.use_bias,
        kernel_initializer=self.kernel_initializer,
        bias_initializer=self.bias_initializer,
        kernel_regularizer=self.kernel_regularizer,
        bias_regularizer=self.bias_regularizer,
        activity_regularizer=self.activity_regularizer,
        kernel_constraint=self.kernel_constraint,
        bias_constraint=self.bias_constraint)(parent)
    if set_tensors:
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
@@ -727,7 +775,11 @@ class Transpose(Layer):
class CombineMeanStd(Layer):
  """Generate Gaussian nose."""

  def __init__(self, in_layers=None, training_only=False, **kwargs):
  def __init__(self,
               in_layers=None,
               training_only=False,
               noise_epsilon=0.01,
               **kwargs):
    """Create a CombineMeanStd layer.

    This layer should have two inputs with the same shape, and its output also has the
@@ -744,9 +796,12 @@ class CombineMeanStd(Layer):
      if True, noise is only generated during training.  During prediction, the output
      is simply equal to the first input (that is, the mean of the distribution used
      during training).
    noise_epsilon: float
      The standard deviation of the random noise
    """
    super(CombineMeanStd, self).__init__(in_layers, **kwargs)
    self.training_only = training_only
    self.noise_epsilon = noise_epsilon
    try:
      self._shape = self.in_layers[0].shape
    except:
@@ -758,10 +813,10 @@ class CombineMeanStd(Layer):
      raise ValueError("Must have two in_layers")
    mean_parent, std_parent = inputs[0], inputs[1]
    sample_noise = tf.random_normal(
        mean_parent.get_shape(), 0, 1, dtype=tf.float32)
        mean_parent.get_shape(), 0, self.noise_epsilon, dtype=tf.float32)
    if self.training_only:
      sample_noise *= kwargs['training']
    out_tensor = mean_parent + (std_parent * sample_noise)
    out_tensor = mean_parent + tf.exp(std_parent * 0.5) * sample_noise
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor
@@ -895,21 +950,25 @@ class GRU(Layer):
      self.rnn_initial_states.append(initial_state)
      self.rnn_final_states.append(final_state)
      self.rnn_zero_states.append(np.zeros(zero_state.get_shape(), np.float32))
      self.out_tensors = [
          self.out_tensor, initial_state, final_state, zero_state
      ]
    return out_tensor

  def none_tensors(self):
    saved_tensors = [
        self.out_tensor, self.rnn_initial_states, self.rnn_final_states,
        self.rnn_zero_states
        self.rnn_zero_states, self.out_tensors
    ]
    self.out_tensor = None
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    self.out_tensors = []
    return saved_tensors

  def set_tensors(self, tensor):
    self.out_tensor, self.rnn_initial_states, self.rnn_final_states, self.rnn_zero_states = tensor
    self.out_tensor, self.rnn_initial_states, self.rnn_final_states, self.rnn_zero_states, self.out_tensors = tensor


class LSTM(Layer):
+215 −6
Original line number Diff line number Diff line
@@ -128,7 +128,7 @@ class SeqToSeq(TensorGraph):
    self._embedding_dimension = embedding_dimension
    self._annealing_final_step = annealing_final_step
    self._annealing_start_step = annealing_start_step
    self._features = layers.Feature(shape=(None, None, len(input_tokens)))
    self._features = self._create_features()
    self._labels = layers.Label(shape=(None, None, len(output_tokens)))
    self._gather_indices = layers.Feature(
        shape=(self.batch_size, 2), dtype=tf.int32)
@@ -139,6 +139,9 @@ class SeqToSeq(TensorGraph):
    self.set_loss(self._create_loss())
    self.add_output(self.output)

  def _create_features(self):
    return layers.Feature(shape=(None, None, len(self._input_tokens)))

  def _create_encoder(self, n_layers, dropout):
    """Create the encoder layers."""
    prev_layer = self._features
@@ -176,18 +179,19 @@ class SeqToSeq(TensorGraph):
    prob = layers.ReduceSum(self.output * self._labels, axis=2)
    mask = layers.ReduceSum(self._labels, axis=2)
    log_prob = layers.Log(prob + 1e-20) * mask
    loss = -layers.ReduceMean(layers.ReduceSum(log_prob, axis=1))
    loss = -layers.ReduceMean(
        layers.ReduceSum(log_prob, axis=1), name='cross_entropy_loss')
    if self._variational:
      mean_sq = self._embedding_mean * self._embedding_mean
      stddev_sq = self._embedding_stddev * self._embedding_stddev
      kl = mean_sq + stddev_sq - layers.Log(stddev_sq) - 1
      kl = mean_sq + stddev_sq - layers.Log(stddev_sq + 1e-20) - 1
      anneal_steps = self._annealing_final_step - self._annealing_start_step
      if anneal_steps > 0:
        current_step = tf.to_float(
            self.get_global_step()) - self._annealing_start_step
        anneal_frac = tf.maximum(0.0, current_step) / anneal_steps
        kl_scale = layers.TensorWrapper(
            tf.minimum(1.0, anneal_frac * anneal_frac))
            tf.minimum(1.0, anneal_frac * anneal_frac), name='kl_scale')
      else:
        kl_scale = 1.0
      loss += 0.5 * kl_scale * layers.ReduceMean(layers.ReduceSum(kl, axis=1))
@@ -394,8 +398,213 @@ class SeqToSeq(TensorGraph):
      feed_dict = {}
      feed_dict[self._features] = self._create_input_array(inputs)
      feed_dict[self._labels] = self._create_output_array(outputs)
      feed_dict[self._gather_indices] = [(i, len(x))
                                         for i, x in enumerate(inputs)]
      feed_dict[self._gather_indices] = [
          (i, len(x)) for i, x in enumerate(inputs)
      ]
      for initial, zero in zip(self.rnn_initial_states, self.rnn_zero_states):
        feed_dict[initial] = zero
      yield feed_dict


class AspuruGuzikAutoEncoder(SeqToSeq):
  """
  This is an implementation of Automatic Chemical Design Using a Continuous Representation of Molecules
  http://pubs.acs.org/doi/full/10.1021/acscentsci.7b00572

  Abstract
  --------
  We report a method to convert discrete representations of molecules to and
  from a multidimensional continuous representation. This model allows us to
  generate new molecules for efficient exploration and optimization through
  open-ended spaces of chemical compounds. A deep neural network was trained on
  hundreds of thousands of existing chemical structures to construct three
  coupled functions: an encoder, a decoder, and a predictor. The encoder
  converts the discrete representation of a molecule into a real-valued
  continuous vector, and the decoder converts these continuous vectors back to
  discrete molecular representations. The predictor estimates chemical
  properties from the latent continuous vector representation of the molecule.
  Continuous representations of molecules allow us to automatically generate
  novel chemical structures by performing simple operations in the latent space,
  such as decoding random vectors, perturbing known chemical structures, or
  interpolating between molecules. Continuous representations also allow the use
  of powerful gradient-based optimization to efficiently guide the search for
  optimized functional compounds. We demonstrate our method in the domain of
  drug-like molecules and also in a set of molecules with fewer that nine heavy
  atoms.

  Notes
  -------
  This is currently an imperfect reproduction of the paper.  One difference is
  that teacher forcing in the decoder is not implemented.  The paper also
  discusses co-learning molecular properties at the same time as training the
  encoder/decoder.  This is not done here.  The hyperparameters chosen are from
  ZINC dataset.

  This network also currently suffers from exploding gradients.  Care has to be taken when training.

  NOTE(LESWING): Will need to play around with annealing schedule to not have exploding gradients
  TODO(LESWING): Teacher Forcing
  TODO(LESWING): Sigmoid variational loss annealing schedule
  The output GRU layer had one
  additional input, corresponding to the character sampled from the softmax output of the
  previous time step and was trained using teacher forcing. 48 This increased the accuracy
  of generated SMILES strings, which resulted in higher fractions of valid SMILES strings
  for latent points outside the training data, but also made training more difficult, since the
  decoder showed a tendency to ignore the (variational) encoding and rely solely on the input
  sequence. The variational loss was annealed according to sigmoid schedule after 29 epochs,
  running for a total 120 epochs

  I also added a BatchNorm before the mean and std embedding layers.  This has empiracally
  made training more stable, and is discussed in Ladder Variational Autoencoders.
  https://arxiv.org/pdf/1602.02282.pdf
  Maybe if Teacher Forcing and Sigmoid variational loss annealing schedule are used the
  BatchNorm will no longer be neccessary.
  """

  def __init__(self,
               num_tokens,
               max_output_length,
               embedding_dimension=196,
               filter_sizes=[9, 9, 10],
               kernel_sizes=[9, 9, 11],
               decoder_dimension=488,
               **kwargs):
    """
    Parameters
    ----------
    filter_sizes: list of int
      Number of filters for each 1D convolution in the encoder
    kernel_sizes: list of int
      Kernel size for each 1D convolution in the encoder
    decoder_dimension: int
      Number of channels for the GRU Decoder
    """
    if len(filter_sizes) != len(kernel_sizes):
      raise ValueError("Must have same number of layers and kernels")
    self._filter_sizes = filter_sizes
    self._kernel_sizes = kernel_sizes
    self._decoder_dimension = decoder_dimension
    super(AspuruGuzikAutoEncoder, self).__init__(
        input_tokens=num_tokens,
        output_tokens=num_tokens,
        max_output_length=max_output_length,
        embedding_dimension=embedding_dimension,
        variational=True,
        reverse_input=False,
        **kwargs)

  def _create_features(self):
    return layers.Feature(
        shape=(self.batch_size, self._max_output_length,
               len(self._input_tokens)))

  def _create_encoder(self, n_layers, dropout):
    """Create the encoder layers."""
    prev_layer = self._features
    for i in range(len(self._filter_sizes)):
      filter_size = self._filter_sizes[i]
      kernel_size = self._kernel_sizes[i]
      if dropout > 0.0:
        prev_layer = layers.Dropout(dropout, in_layers=prev_layer)
      prev_layer = layers.Conv1D(
          filters=filter_size,
          kernel_size=kernel_size,
          in_layers=prev_layer,
          activation_fn=tf.nn.relu)
    prev_layer = layers.Flatten(prev_layer)
    prev_layer = layers.Dense(
        self._decoder_dimension, in_layers=prev_layer, activation_fn=tf.nn.relu)
    prev_layer = layers.BatchNorm(prev_layer)
    if self._variational:
      self._embedding_mean = layers.Dense(
          self._embedding_dimension,
          in_layers=prev_layer,
          name='embedding_mean')
      self._embedding_stddev = layers.Dense(
          self._embedding_dimension, in_layers=prev_layer, name='embedding_std')
      prev_layer = layers.CombineMeanStd(
          [self._embedding_mean, self._embedding_stddev], training_only=True)
    return prev_layer

  def _create_decoder(self, n_layers, dropout):
    """Create the decoder layers."""
    prev_layer = layers.Dense(
        self._embedding_dimension,
        in_layers=self.embedding,
        activation_fn=tf.nn.relu)
    prev_layer = layers.Repeat(self._max_output_length, in_layers=prev_layer)
    for i in range(3):
      if dropout > 0.0:
        prev_layer = layers.Dropout(dropout, in_layers=prev_layer)
      prev_layer = layers.GRU(
          self._decoder_dimension, self.batch_size, in_layers=prev_layer)
    retval = layers.Dense(
        len(self._output_tokens),
        in_layers=prev_layer,
        activation_fn=tf.nn.softmax,
        name='output')
    return retval

  def _generate_batches(self, sequences):
    """Create feed_dicts for fitting."""
    for batch in self._batch_elements(sequences):
      inputs = []
      outputs = []
      for input, output in batch:
        inputs.append(input)
        outputs.append(output)
      for i in range(len(inputs), self.batch_size):
        inputs.append([])
        outputs.append([])
      feed_dict = {}
      feed_dict[self._features] = self._create_output_array(inputs)
      feed_dict[self._labels] = self._create_output_array(outputs)
      for initial, zero in zip(self.rnn_initial_states, self.rnn_zero_states):
        feed_dict[initial] = zero
      yield feed_dict

  def predict_from_sequences(self, sequences, beam_width=5):
    """Given a set of input sequences, predict the output sequences.

    The prediction is done using a beam search with length normalization.

    Parameters
    ----------
    sequences: iterable
      the input sequences to generate a prediction for
    beam_width: int
      the beam width to use for searching.  Set to 1 to use a simple greedy search.
    """
    result = []
    with self._get_tf("Graph").as_default():
      for batch in self._batch_elements(sequences):
        feed_dict = {}
        feed_dict[self._features] = self._create_output_array(batch)
        feed_dict[self._training_placeholder] = 0.0
        for initial, zero in zip(self.rnn_initial_states, self.rnn_zero_states):
          feed_dict[initial] = zero
        probs = self.session.run(self.output, feed_dict=feed_dict)
        for i in range(len(batch)):
          result.append(self._beam_search(probs[i], beam_width))
    return result

  def predict_embeddings(self, sequences):
    """Given a set of input sequences, compute the embedding vectors.

    Parameters
    ----------
    sequences: iterable
      the input sequences to generate an embedding vector for
    """
    result = []
    with self._get_tf("Graph").as_default():
      for batch in self._batch_elements(sequences):
        feed_dict = {}
        feed_dict[self._features] = self._create_output_array(batch)
        feed_dict[self._training_placeholder] = 0.0
        for initial, zero in zip(self.rnn_initial_states, self.rnn_zero_states):
          feed_dict[initial] = zero
        embeddings = self.session.run(self.embedding, feed_dict=feed_dict)
        for i in range(len(batch)):
          result.append(embeddings[i])
    return np.array(result, dtype=np.float32)
+6 −6
Original line number Diff line number Diff line
@@ -84,7 +84,7 @@ class TextCNNTensorGraph(TensorGraph):
      char_dict,
      seq_length,
      n_embedding=75,
      filter_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20],
      kernel_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20],
      num_filters=[100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160],
      dropout=0.25,
      mode="classification",
@@ -111,7 +111,7 @@ class TextCNNTensorGraph(TensorGraph):
    self.char_dict = char_dict
    self.seq_length = seq_length
    self.n_embedding = n_embedding
    self.filter_sizes = filter_sizes
    self.kernel_sizes = kernel_sizes
    self.num_filters = num_filters
    self.dropout = dropout
    self.mode = mode
@@ -164,13 +164,13 @@ class TextCNNTensorGraph(TensorGraph):
        in_layers=[self.smiles_seqs])
    self.pooled_outputs = []
    self.conv_layers = []
    for filter_size, num_filter in zip(self.filter_sizes, self.num_filters):
    for filter_size, num_filter in zip(self.kernel_sizes, self.num_filters):
      # Multiple convolutional layers with different filter widths
      self.conv_layers.append(
          Conv1D(
              filter_size,
              num_filter,
              padding='VALID',
              kernel_size=filter_size,
              filters=num_filter,
              padding='valid',
              in_layers=[self.Embedding]))
      # Max-over-time pooling
      self.pooled_outputs.append(
+5 −3
Original line number Diff line number Diff line
@@ -62,16 +62,18 @@ class TestLayers(test_util.TensorFlowTestCase):
    """Test that Conv1D can be invoked."""
    width = 5
    in_channels = 2
    out_channels = 3
    filters = 3
    kernel_size = 2
    batch_size = 10
    in_tensor = np.random.rand(batch_size, width, in_channels)
    with self.test_session() as sess:
      in_tensor = tf.convert_to_tensor(in_tensor, dtype=tf.float32)
      out_tensor = Conv1D(width, out_channels)(in_tensor)
      out_tensor = Conv1D(filters, kernel_size)(in_tensor)
      sess.run(tf.global_variables_initializer())
      out_tensor = out_tensor.eval()

      assert out_tensor.shape == (batch_size, width, out_channels)
      self.assertEqual(out_tensor.shape[0], batch_size)
      self.assertEqual(out_tensor.shape[2], filters)

  def test_dense(self):
    """Test that Dense can be invoked."""
+40 −0
Original line number Diff line number Diff line
@@ -58,6 +58,46 @@ class TestSeqToSeq(unittest.TestCase):
    assert count1 >= 12
    assert count4 >= 12

  def test_aspuru_guzik(self):
    """Test that the aspuru_guzik encoder doesn't hard error.
    This model takes too long to fit to do an overfit test
    """
    train_smiles = [
        'Cc1cccc(N2CCN(C(=O)C34CC5CC(CC(C5)C3)C4)CC2)c1C',
        'Cn1ccnc1SCC(=O)Nc1ccc(Oc2ccccc2)cc1',
        'COc1cc2c(cc1NC(=O)CN1C(=O)NC3(CCc4ccccc43)C1=O)oc1ccccc12',
        'O=C1/C(=C/NC2CCS(=O)(=O)C2)c2ccccc2C(=O)N1c1ccccc1',
        'NC(=O)NC(Cc1ccccc1)C(=O)O', 'CCn1c(CSc2nccn2C)nc2cc(C(=O)O)ccc21',
        'CCc1cccc2c1NC(=O)C21C2C(=O)N(Cc3ccccc3)C(=O)C2C2CCCN21',
        'COc1ccc(C2C(C(=O)NCc3ccccc3)=C(C)N=C3N=CNN32)cc1OC',
        'CCCc1cc(=O)nc(SCC(=O)N(CC(C)C)C2CCS(=O)(=O)C2)[nH]1',
        'CCn1cnc2c1c(=O)n(CC(=O)Nc1cc(C)on1)c(=O)n2Cc1ccccc1'
    ]
    tokens = set()
    for s in train_smiles:
      tokens = tokens.union(set(c for c in s))
    tokens = sorted(list(tokens))
    max_length = max(len(s) for s in train_smiles) + 1
    s = dc.models.tensorgraph.models.seqtoseq.AspuruGuzikAutoEncoder(
        tokens, max_length)

    def generate_sequences(smiles, epochs):
      for i in range(epochs):
        for s in smiles:
          yield (s, s)

    s.fit_sequences(generate_sequences(train_smiles, 100))

    # Test it out.
    pred1 = s.predict_from_sequences(train_smiles, beam_width=1)
    pred4 = s.predict_from_sequences(train_smiles, beam_width=4)
    embeddings = s.predict_embeddings(train_smiles)
    pred1e = s.predict_from_embeddings(embeddings, beam_width=1)
    pred4e = s.predict_from_embeddings(embeddings, beam_width=4)
    for i in range(len(train_smiles)):
      assert pred1[i] == pred1e[i]
      assert pred4[i] == pred4e[i]

  def test_variational(self):
    """Test using a SeqToSeq model as a variational autoenconder."""