Commit bfe276a1 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Adding warnings

parent f5dcc776
Loading
Loading
Loading
Loading
+13 −5
Original line number Diff line number Diff line
@@ -14,10 +14,15 @@ from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from deepchem.utils.typing import LossFn

# JAX depend
try:
  import jax.numpy as jnp
  import jax
  import haiku as hk
  import optax
except:
  raise ImportError("This class requires Jax, haiku and optax installed.")

import warnings

logger = logging.getLogger(__name__)

@@ -26,7 +31,7 @@ class JaxModel(Model):
  """This is a DeepChem model implemented by a Jax Model

  Here is a simple example of that uses JaxModel to train a
  Haiku (JAX Nueral Network Library) based model on deepchem
  Haiku (JAX Neural Network Library) based model on deepchem
  dataset.

  >> def f(x):
@@ -93,6 +98,9 @@ class JaxModel(Model):
    [2] Support for saving & loading the model.
    """
    super(JaxModel, self).__init__(model=model, **kwargs)
    warnings.warn(
        'JaxModel is still under work and hence a lot of the method might not be implemented'
    )
    self._loss_fn = loss  # lambda pred, tar: jnp.mean(optax.l2_loss(pred, tar))
    self.batch_size = batch_size
    self.learning_rate = learning_rate