Commit 673e8801 authored by Peter Eastman's avatar Peter Eastman
Browse files

Convert dtypes for L1 and L2 loss

parent 7ea67ce4
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ class L1Loss(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.abs(output - labels)


@@ -39,6 +40,7 @@ class L2Loss(Loss):

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    output, labels = _ensure_float(output, labels)
    return tf.square(output - labels)