Commit 3544a95f authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

small fixes

parent a2720110
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -92,7 +92,7 @@ class JaxModel(Model):

    Work in Progress
    ----------------
    [1] Integerate the optax losses, optimizers, schedulers with Deepchem
    [1] Integrate the optax losses, optimizers, schedulers with Deepchem
    [2] Support for saving & loading the model.
    """
    super(JaxModel, self).__init__(model=model, **kwargs)
@@ -364,7 +364,7 @@ class JaxModel(Model):
    """

    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
      for (X_b, y_b, w_b, _) in dataset.iterbatches(
          batch_size=self.batch_size,
          deterministic=deterministic,
          pad_batches=pad_batches):