Commit 95816eb8 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

adding base

parent 3d478f01
Loading
Loading
Loading
Loading
+52 −2
Original line number Diff line number Diff line
from deepchem.trans.transformers import Transformer
import numpy as np
import time
import logging
@@ -7,11 +8,12 @@ except:
  from collections import Sequence as SequenceCollection

from deepchem.data import Dataset
from deepchem.metrics import Metric
from deepchem.models.models import Model
from deepchem.models.losses import Loss
from deepchem.models.optimizers import Optimizer
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from deepchem.utils.typing import LossFn
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, Sequence
from deepchem.utils.typing import LossFn, OneOrMany, ArrayLike

# JAX depend
try:
@@ -274,6 +276,54 @@ class JaxModel(Model):
    self._set_trainable_params(params, opt_state)
    return last_avg_loss

  def _predict(
      self, generator: Iterable[Tuple[Any, Any, Any]],
      transformers: List[Transformer], uncertainty: bool,
      other_output_types: Optional[OneOrMany[str]]) -> OneOrMany[np.ndarray]:

    pass

  def predict_on_generator(
      self,
      generator: Iterable[Tuple[Any, Any, Any]],
      transformers: List[Transformer] = [],
      output_types: Optional[OneOrMany[str]] = None) -> OneOrMany[np.ndarray]:

    pass

  def predict_on_batch(self, X: ArrayLike, transformers: List[Transformer] = []
                      ) -> OneOrMany[np.ndarray]:

    pass

  def predict_uncertainty_on_batch(self, X: Sequence, masks: int = 50
                                  ) -> OneOrMany[Tuple[np.ndarray, np.ndarray]]:

    pass

  def predict(
      self,
      dataset: Dataset,
      transformers: List[Transformer] = [],
      output_types: Optional[List[str]] = None) -> OneOrMany[np.ndarray]:

    pass

  def predict_embedding(self, dataset: Dataset) -> OneOrMany[np.ndarray]:

    pass

  def predict_uncertainty(self, dataset: Dataset, masks: int = 50
                         ) -> OneOrMany[Tuple[np.ndarray, np.ndarray]]:
    pass

  def evaluate_generator(self,
                         generator: Iterable[Tuple[Any, Any, Any]],
                         metrics: List[Metric],
                         transformers: List[Transformer] = [],
                         per_task_metrics: bool = False):
    pass

  def _get_trainable_params(self):
    """
    Will be used to seperate freezing parameters while transfer learning