Unverified Commit 8566ed94 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1786 from peastman/loss

Be more tolerant of dtypes when computing loss
parents f310f0a8 673e8801
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ class L1Loss(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.abs(output - labels)


@@ -39,6 +40,7 @@ class L2Loss(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.square(output - labels)


@@ -63,6 +65,7 @@ class BinaryCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.keras.losses.binary_crossentropy(labels, output)


@@ -76,6 +79,7 @@ class CategoricalCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.keras.losses.categorical_crossentropy(labels, output)


@@ -89,6 +93,7 @@ class SigmoidCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.nn.sigmoid_cross_entropy_with_logits(labels, output)


@@ -103,6 +108,7 @@ class SoftmaxCrossEntropy(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.nn.softmax_cross_entropy_with_logits(labels, output)


@@ -142,3 +148,11 @@ def _make_shapes_consistent(output, labels):
    return (output, labels)
  raise ValueError("Incompatible shapes for outputs and labels: %s versus %s" %
                   (str(shape1), str(shape2)))

def _ensure_float(output, labels):
  """Make sure the outputs and labels are both floating point types."""
  if output.dtype not in (tf.float32, tf.float64):
    output = tf.cast(output, tf.float32)
  if labels.dtype not in (tf.float32, tf.float64):
    labels = tf.cast(labels, tf.float32)
  return (output, labels)