Commit b361d250 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

bug fix

parent bfe276a1
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -66,16 +66,16 @@ def test_jax_model_for_classification():

    def __call__(self, x: jnp.ndarray):
      x = self._network(x)
      return x, jax.nn.sigmoid(x)
      return x, jax.nn.softmax(x)

  def f(x):
    net = Encoder(1)
    net = Encoder(2)
    return net(x)

  def bce_loss(pred, tar, w):
    tar = jnp.array(
        [x.astype(np.float32) if x.dtype != np.float32 else x for x in tar])
    return jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar))
    return jnp.mean(optax.softmax_cross_entropy(pred[0], tar))

  # Model Initilisation
  model = hk.without_apply_rng(hk.transform(f))
@@ -97,5 +97,5 @@ def test_jax_model_for_classification():
      batch_size=256,
      learning_rate=0.001,
      log_frequency=2)
  results = j_m.fit(dataset, deterministic=True)
  results = j_m.fit(dataset, nb_epochs=50, deterministic=True)
  assert results < 0.5