Commit 02627ce8 authored by Peter Eastman's avatar Peter Eastman
Browse files

Improvements to SeqToSeq

parent 27a3000f
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -194,7 +194,7 @@ class SeqToSeq(TensorGraph):
            tf.minimum(1.0, anneal_frac * anneal_frac), name='kl_scale')
      else:
        kl_scale = 1.0
      loss += 0.5 * kl_scale * layers.ReduceMean(layers.ReduceSum(kl, axis=1))
      loss += 0.5 * kl_scale * layers.ReduceMean(kl)
    return loss

  def fit_sequences(self,
@@ -369,8 +369,8 @@ class SeqToSeq(TensorGraph):
    for i, sequence in enumerate(sequences):
      for j, token in enumerate(sequence):
        labels[i, j, self._output_dict[token]] = 1
      if lengths[i] < self._max_output_length:
        labels[i, lengths[i], end_marker_index] = 1
      for j in range(lengths[i], self._max_output_length):
        labels[i, j, end_marker_index] = 1
    return labels

  def _batch_elements(self, elements):