Commit e334db4d authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

eval fn done

parent 4225b87f
Loading
Loading
Loading
Loading
+29 −3
Original line number Diff line number Diff line
@@ -284,7 +284,16 @@ class JaxModel(Model):
      transformers: List[Transformer], uncertainty: bool,
      other_output_types: Optional[OneOrMany[str]]) -> OneOrMany[np.ndarray]:

    pass
    eval_fn = self._create_eval_fn(self.model, self.params)

    for batch in generator:
      inputs, _, _ = self._prepare_batch(batch)

      if isinstance(inputs, list) and len(inputs) == 1:
        inputs = inputs[0]

      output_values = eval_fn(inputs)
      output_values = jax.device_get(output_values)

  def predict_on_generator(
      self,
@@ -292,7 +301,7 @@ class JaxModel(Model):
      transformers: List[Transformer] = [],
      output_types: Optional[OneOrMany[str]] = None) -> OneOrMany[np.ndarray]:

    pass
    return self._predict(generator, transformers, False, output_types)

  def predict_on_batch(self, X: ArrayLike, transformers: List[Transformer] = []
                      ) -> OneOrMany[np.ndarray]:
@@ -310,7 +319,11 @@ class JaxModel(Model):
      transformers: List[Transformer] = [],
      output_types: Optional[List[str]] = None) -> OneOrMany[np.ndarray]:

    pass
    generator = self.default_generator(
        dataset, mode='predict', pad_batches=False)
    return self.predict_on_generator(
        generator, transformers=transformers, output_types=output_types)


  def predict_embedding(self, dataset: Dataset) -> OneOrMany[np.ndarray]:

@@ -361,6 +374,19 @@ class JaxModel(Model):
    self.params = params
    self.opt_state = opt_state

  def _create_eval_fn(self, model, params):
    """
    Calls the function to evaulate the model
    """

    @jax.jit
    def eval_model(batch):
      predict = model.apply(params, batch)

      return predict

    return eval_model
    
  def _create_gradient_fn(self, loss, model, optimizer, loss_outputs):
    """
    This function calls the update function, to implement the backpropogation