Commit eb0df084 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

formatting

parent 70f273f4
Loading
Loading
Loading
Loading
+15 −9
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)

def create_default_eval_fn(forward_fn, params):
  """
    Calls the function to evaulate the model
  Calls the function to evaluate the model
  """

  @jax.jit
@@ -108,6 +108,7 @@ class JaxModel(Model):
               **kwargs):
    """
    Create a new JaxModel

    Parameters
    ----------
    model: hk.State or Function
@@ -332,6 +333,7 @@ class JaxModel(Model):
    This is the private implementation of prediction.  Do not
    call it directly. Instead call one of the public prediction
    methods.

    Parameters
    ----------
    generator: generator
@@ -346,10 +348,10 @@ class JaxModel(Model):
      returns the values of the uncertainty outputs.
    other_output_types: list, optional
      Provides a list of other output_types (strings) to predict from model.

    Returns
    -------
      a NumPy array of the model produces a single output, or a list of arrays
      if it produces multiple outputs
    A NumpyArray if the model produces a single output, or a list of arrays otherwise.
    """
    results: Optional[List[List[np.ndarray]]] = None
    variances: Optional[List[List[np.ndarray]]] = None
@@ -439,6 +441,7 @@ class JaxModel(Model):
      If specified, all outputs of this type will be retrieved
      from the model. If output_types is specified, outputs must
      be None.

    Returns
    -------
      a NumPy array of the model produces a single output, or a list of arrays
@@ -456,6 +459,7 @@ class JaxModel(Model):
    transformers: List[dc.trans.Transformers]
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays
@@ -476,6 +480,7 @@ class JaxModel(Model):
      output_types: Optional[List[str]] = None) -> OneOrMany[np.ndarray]:
    """
    Uses self to make predictions on provided Dataset object.

    Parameters
    ----------
    dataset: dc.data.Dataset
@@ -487,6 +492,7 @@ class JaxModel(Model):
      If specified, all outputs of this type will be retrieved
      from the model. If output_types is specified, outputs must
      be None.

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays