Commit 6d791a50 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

adding jax markers

parent 6eda0c48
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ try:
  import haiku as hk
  import optax
except:
  raise ImportError("This class requires Jax, haiku and optax installed.")
  raise ImportError("This class requires Jax, haiku and optax to be installed.")

import warnings

@@ -48,7 +48,7 @@ class JaxModel(Model):
  """

  def __init__(self,
               model,
               model: hk.State,
               params: hk.Params,
               loss: Union[Loss, LossFn],
               output_types: Optional[List[str]] = None,
@@ -64,7 +64,8 @@ class JaxModel(Model):
    Parameters
    ----------
    model: hk.State or Function
      Any Jax based model that has a `apply` method for computing the network.
      Any Jax based model that has a `apply` method for computing the network. Currently
      only haiku models are supported.
    params: hk.Params
      The parameter of the Jax based networks
    loss: dc.models.losses.Loss or function
@@ -99,7 +100,7 @@ class JaxModel(Model):
    """
    super(JaxModel, self).__init__(model=model, **kwargs)
    warnings.warn(
        'JaxModel is still under work and hence a lot of the method might not be implemented'
        'JaxModel is still in active development and all features may not yet be implemented'
    )
    self._loss_fn = loss  # lambda pred, tar: jnp.mean(optax.l2_loss(pred, tar))
    self.batch_size = batch_size
+3 −6
Original line number Diff line number Diff line
import unittest

import pytest
from deepchem.models.tests.test_graph_models import get_dataset
import numpy as np

@@ -14,8 +13,7 @@ except:
  has_haiku_and_optax = False


@unittest.skipIf(not has_haiku_and_optax,
                 'Jax, Haiku, or Optax are not installed')
@pytest.mark.jax
def test_jax_model_for_regression():
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer='ECFP')
@@ -51,8 +49,7 @@ def test_jax_model_for_regression():
  assert results < 0.5


@unittest.skipIf(not has_haiku_and_optax,
                 'Jax, Haiku, or Optax are not installed')
@pytest.mark.jax
def test_jax_model_for_classification():
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer='ECFP')