Commit 329969f5 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

adding evals

parent 6ce6881f
Loading
Loading
Loading
Loading
+137 −6
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ try:
except:
  from collections import Sequence as SequenceCollection

from deepchem.data import Dataset
from deepchem.data import Dataset, NumpyDataset
from deepchem.metrics import Metric
from deepchem.models.models import Model
from deepchem.models.losses import Loss
@@ -280,7 +280,49 @@ class JaxModel(Model):
      self, generator: Iterable[Tuple[Any, Any, Any]],
      transformers: List[Transformer], uncertainty: bool,
      other_output_types: Optional[OneOrMany[str]]) -> OneOrMany[np.ndarray]:
    """
    Predict outputs for data provided by a generator.

    This is the private implementation of prediction.  Do not
    call it directly.  Instead call one of the public prediction
    methods.

    Parameters
    ----------
    generator: generator
      this should generate batches, each represented as a tuple of the form
      (inputs, labels, weights).
    transformers: list of dc.trans.Transformers
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.
    uncertainty: bool
      specifies whether this is being called as part of estimating uncertainty.
      If True, it sets the training flag so that dropout will be enabled, and
      returns the values of the uncertainty outputs.
    other_output_types: list, optional
      Provides a list of other output_types (strings) to predict from model.
    Returns:
      a NumPy array of the model produces a single output, or a list of arrays
      if it produces multiple outputs
    """
    results: Optional[List[List[np.ndarray]]] = None
    variances: Optional[List[List[np.ndarray]]] = None
    if uncertainty and (other_output_types is not None):
      raise ValueError(
          'This model cannot compute uncertainties and other output types simultaneously. Please invoke one at a time.'
      )
    if uncertainty:
      if self._variance_outputs is None or len(self._variance_outputs) == 0:
        raise ValueError('This model cannot compute uncertainties')
      if len(self._variance_outputs) != len(self._prediction_outputs):
        raise ValueError(
            'The number of variances must exactly match the number of outputs')
    if other_output_types:
      if self._other_outputs is None or len(self._other_outputs) == 0:
        raise ValueError(
            'This model cannot compute other outputs since no other output_types were specified.'
        )
    self._ensure_built()
    eval_fn = self._create_eval_fn(self.model, self.params)

    for batch in generator:
@@ -290,20 +332,91 @@ class JaxModel(Model):
        inputs = inputs[0]

      output_values = eval_fn(inputs)
      output_values = jax.device_get(output_values)
      if isinstance(output_values, jnp.ndarray):
        output_values = [output_values]
      output_values = [jax.device_get(t) for t in output_values]

      # Apply tranformers and record results.
      if uncertainty:
        var = [output_values[i] for i in self._variance_outputs]
        if variances is None:
          variances = [var]
        else:
          for i, t in enumerate(var):
            variances[i].append(t)

      access_values = []
      if other_output_types:
        access_values += self._other_outputs
      elif self._prediction_outputs is not None:
        access_values += self._prediction_outputs

      if len(access_values) > 0:
        output_values = [output_values[i] for i in access_values]

      if results is None:
        results = [[] for i in range(len(output_values))]
      for i, t in enumerate(output_values):
        results[i].append(t)

    # Concatenate arrays to create the final results.
    final_results = []
    final_variances = []
    if results is not None:
      for r in results:
        final_results.append(np.concatenate(r, axis=0))
    if uncertainty and variances is not None:
      for v in variances:
        final_variances.append(np.concatenate(v, axis=0))
      return zip(final_results, final_variances)
    if len(final_results) == 1:
      return final_results[0]
    else:
      return final_results

  def predict_on_generator(
      self,
      generator: Iterable[Tuple[Any, Any, Any]],
      transformers: List[Transformer] = [],
      output_types: Optional[OneOrMany[str]] = None) -> OneOrMany[np.ndarray]:

    """
    Parameters
    ----------
    generator: generator
      this should generate batches, each represented as a tuple of the form
      (inputs, labels, weights).
    transformers: list of dc.trans.Transformers
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.
    output_types: String or list of Strings
      If specified, all outputs of this type will be retrieved
      from the model. If output_types is specified, outputs must
      be None.
    Returns:
      a NumPy array of the model produces a single output, or a list of arrays
      if it produces multiple outputs
    """
    return self._predict(generator, transformers, False, output_types)

  def predict_on_batch(self, X: ArrayLike, transformers: List[Transformer] = []
                      ) -> OneOrMany[np.ndarray]:
    """Generates predictions for input samples, processing samples in a batch.

    pass
    Parameters
    ----------
    X: ndarray
      the input data, as a Numpy array.
    transformers: list of dc.trans.Transformers
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays
    if it produces multiple outputs
    """
    dataset = NumpyDataset(X=X, y=None)
    return self.predict(dataset, transformers)

  def predict_uncertainty_on_batch(self, X: Sequence, masks: int = 50
                                  ) -> OneOrMany[Tuple[np.ndarray, np.ndarray]]:
@@ -315,13 +428,31 @@ class JaxModel(Model):
      dataset: Dataset,
      transformers: List[Transformer] = [],
      output_types: Optional[List[str]] = None) -> OneOrMany[np.ndarray]:
    """
    Uses self to make predictions on provided Dataset object.

    Parameters
    ----------
    dataset: dc.data.Dataset
      Dataset to make prediction on
    transformers: list of dc.trans.Transformers
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.
    output_types: String or list of Strings
      If specified, all outputs of this type will be retrieved
      from the model. If output_types is specified, outputs must
      be None.

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays
    if it produces multiple outputs
    """
    generator = self.default_generator(
        dataset, mode='predict', pad_batches=False)
    return self.predict_on_generator(
        generator, transformers=transformers, output_types=output_types)


  def predict_embedding(self, dataset: Dataset) -> OneOrMany[np.ndarray]:

    pass