Commit 7ea67ce4 authored by peastman's avatar peastman
Browse files

Be more tolerant of dtypes when computing loss

parent f310f0a8
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -63,6 +63,7 @@ class BinaryCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.keras.losses.binary_crossentropy(labels, output)


@@ -76,6 +77,7 @@ class CategoricalCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.keras.losses.categorical_crossentropy(labels, output)


@@ -89,6 +91,7 @@ class SigmoidCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.nn.sigmoid_cross_entropy_with_logits(labels, output)


@@ -103,6 +106,7 @@ class SoftmaxCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.nn.softmax_cross_entropy_with_logits(labels, output)


@@ -142,3 +146,11 @@ def _make_shapes_consistent(output, labels):
    return (output, labels)
  raise ValueError("Incompatible shapes for outputs and labels: %s versus %s" %
                   (str(shape1), str(shape2)))

def _ensure_float(output, labels):
  """Make sure the outputs and labels are both floating point types."""
  if output.dtype not in (tf.float32, tf.float64):
    output = tf.cast(output, tf.float32)
  if labels.dtype not in (tf.float32, tf.float64):
    labels = tf.cast(labels, tf.float32)
  return (output, labels)