Commit 0f5124a7 authored by leswing's avatar leswing
Browse files

docs

parent 2a3d0bc3
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -645,8 +645,18 @@ class Reshape(Layer):


class Cast(Layer):
  """
  Wrapper around tf.cast.  Changes the dtype of a single layer
  """

  def __init__(self, in_layers=None, dtype=None, **kwargs):
    """
    Parameters
    ----------
    dtype: tf.DType
      the dtype to cast the in_layer to
      e.x. tf.int32
    """
    if dtype is None:
      raise ValueError("Must cast to a dtype")
    self.dtype = dtype