Commit db0af691 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

fixing outypes for classification tasks

parent aa94473b
Loading
Loading
Loading
Loading
+15 −12
Original line number Diff line number Diff line
import numpy as np
import time
import logging
import os
try:
  from collections.abc import Sequence as SequenceCollection
except:
  from collections import Sequence as SequenceCollection

from deepchem.data import Dataset, NumpyDataset
from deepchem.data import Dataset
from deepchem.models.models import Model
from deepchem.models.losses import Loss
from deepchem.models.optimizers import Optimizer
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from deepchem.utils.typing import ArrayLike, LossFn, OneOrMany
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from deepchem.utils.typing import LossFn

# JAX depend
import jax.numpy as jnp
@@ -94,8 +93,9 @@ class JaxModel(Model):
    [2] Support for saving & loading the model.
    """

    self.loss = loss  # lambda pred, tar: jnp.mean(optax.l2_loss(pred, tar))
    self._loss_fn = loss  # lambda pred, tar: jnp.mean(optax.l2_loss(pred, tar))
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.optimizer = optimizer
    self.model = model
    self.params = params
@@ -143,7 +143,7 @@ class JaxModel(Model):
          dataset: Dataset,
          nb_epochs: int = 10,
          deterministic: bool = False,
          loss: Optional[LossFn] = None,
          loss: Union[Loss, LossFn] = None,
          callbacks: Union[Callable, List[Callable]] = [],
          all_losses: Optional[List[float]] = None) -> float:
    """Train this model on a dataset.
@@ -200,7 +200,7 @@ class JaxModel(Model):

  def fit_generator(self,
                    generator: Iterable[Tuple[Any, Any, Any]],
                    loss: Optional[LossFn] = None,
                    loss: Union[Loss, LossFn] = None,
                    callbacks: Union[Callable, List[Callable]] = [],
                    all_losses: Optional[List[float]] = None) -> float:
    if not isinstance(callbacks, SequenceCollection):
@@ -210,9 +210,9 @@ class JaxModel(Model):
    last_avg_loss = 0.0
    averaged_batches = 0
    if loss is None:
      loss = self.loss
    grad_update = self._create_gradient_fn(self.loss, self.model,
                                           self.optimizer)
      loss = self._loss_fn
    grad_update = self._create_gradient_fn(loss, self.model, self.optimizer,
                                           self._loss_outputs)
    params, opt_state = self._get_trainable_params()
    time1 = time.time()

@@ -279,7 +279,7 @@ class JaxModel(Model):
    self.params = params
    self.opt_state = opt_state

  def _create_gradient_fn(self, loss, model, optimizer):
  def _create_gradient_fn(self, loss, model, optimizer, loss_outputs):
    """
    This function calls the update function, to implement the backpropogation
    """
@@ -287,10 +287,13 @@ class JaxModel(Model):
    @jax.jit
    def model_loss(params, batch, target, weights):
      predict = model.apply(params, batch)
      if loss_outputs is not None:
        predict = [predict[i] for i in loss_outputs]
      return loss(predict, target, weights)

    @jax.jit
    def update(params, opt_state, batch, target, weights):
    def update(params, opt_state, batch, target,
               weights) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
      batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target,
                                                         weights)
      updates, opt_state = optimizer.update(grads, opt_state)