Commit f5dcc776 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

bug fix

parent 9cb323dc
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -92,7 +92,7 @@ class JaxModel(Model):
    [1] Integerate the optax losses, optimizers, schedulers with Deepchem
    [2] Support for saving & loading the model.
    """

    super(JaxModel, self).__init__(model=model, **kwargs)
    self._loss_fn = loss  # lambda pred, tar: jnp.mean(optax.l2_loss(pred, tar))
    self.batch_size = batch_size
    self.learning_rate = learning_rate
+5 −3
Original line number Diff line number Diff line
@@ -8,14 +8,14 @@ try:
  import jax.numpy as jnp
  import haiku as hk
  import optax
  import deepchem.models as JaxModel
  from deepchem.models import JaxModel
  has_haiku_and_optax = True
except:
  has_haiku_and_optax = False


@unittest.skipIf(not has_haiku_and_optax,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
                 'Jax, Haiku, or Optax are not installed')
def test_jax_model_for_regression():
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer='ECFP')
@@ -52,7 +52,7 @@ def test_jax_model_for_regression():


@unittest.skipIf(not has_haiku_and_optax,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
                 'Jax, Haiku, or Optax are not installed')
def test_jax_model_for_classification():
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer='ECFP')
@@ -73,6 +73,8 @@ def test_jax_model_for_classification():
    return net(x)

  def bce_loss(pred, tar, w):
    tar = jnp.array(
        [x.astype(np.float32) if x.dtype != np.float32 else x for x in tar])
    return jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar))

  # Model Initilisation