Commit aa94473b authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Additional parameter support

parent 8ab22511
Loading
Loading
Loading
Loading
+204 −36
Original line number Diff line number Diff line
@@ -2,10 +2,15 @@ import numpy as np
import time
import logging
import os
try:
  from collections.abc import Sequence as SequenceCollection
except:
  from collections import Sequence as SequenceCollection

from deepchem.data import Dataset, NumpyDataset
from deepchem.models.models import Model

from deepchem.models.losses import Loss
from deepchem.models.optimizers import Optimizer
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from deepchem.utils.typing import ArrayLike, LossFn, OneOrMany

@@ -34,34 +39,99 @@ class JaxModel(Model):
  >> params = model.init(rng, x)
  >> j_m = JaxModel(model, params, 256, 0.001, 100)
  >> j_m.fit(train_dataset)

  All optimizations will be done using the optax library.
  """

  def __init__(self,
               model,
               params: hk.Params,
               loss,
               batch_size: int,
               learning_rate: float,
               optimizer,
               optimizer_state,
               log_frequency: int = 100):
               loss: Union[Loss, LossFn],
               output_types: Optional[List[str]] = None,
               batch_size: int = 100,
               learning_rate: float = 0.001,
               optimizer: Union[optax.GradientTransformation,
                                Optimizer] = optax.adam(1e-3),
               log_frequency: int = 100,
               **kwargs):
    """
    model = hk.without_apply_rng(hk.transform(f))
    rng = jax.random.PRNGKey(500)
    params = model.init(rng, x)

    pass model.apply for model here
    Create a new JaxModel

    Parameters
    ----------
    model: hk.State or Function
      Any Jax based model that has a `apply` method for computing the network.
    params: hk.Params
      The parameter of the Jax based networks
    loss: dc.models.losses.Loss or function
      a Loss or function defining how to compute the training loss for each
      batch, as described above
    output_types: list of strings, optional (default None)
      the type of each output from the model, as described above
    batch_size: int, optional (default 100)
      default batch size for training and evaluating
    learning_rate: float or LearningRateSchedule, optional (default 0.001)
      the learning rate to use for fitting.  If optimizer is specified, this is
      ignored.
    optimizer: optax object
      For the time being, it is optax object
    log_frequency: int, optional (default 100)
      The frequency at which to log data. Data is logged using
      `logging` by default.

    Miscellanous Parameters Yet To Add
    ----------------------------------
    model_dir: str, optional (default None)
      Will be added along with the save & load method
    tensorboard: bool, optional (default False)
      whether to log progress to TensorBoard during training
    wandb: bool, optional (default False)
      whether to log progress to Weights & Biases during training

    Work in Progress
    ----------------
    [1] Integerate the optax losses, optimizers, schedulers with Deepchem
    [2] Support for saving & loading the model.
    """

    self.loss = loss  # lambda pred, tar: jnp.mean(optax.l2_loss(pred, tar))
    self.batch_size = batch_size
    self.optimizer = optimizer
    self.model = model  # this is a function, hk.apply
    self.model = model
    self.params = params
    self._built = False
    self.log_frequency = log_frequency

    if output_types is None:
      self._prediction_outputs = None
      self._loss_outputs = None
      self._variance_outputs = None
      self._other_outputs = None
    else:
      self._prediction_outputs = []
      self._loss_outputs = []
      self._variance_outputs = []
      self._other_outputs = []
      for i, type in enumerate(output_types):
        if type == 'prediction':
          self._prediction_outputs.append(i)
        elif type == 'loss':
          self._loss_outputs.append(i)
        elif type == 'variance':
          self._variance_outputs.append(i)
        else:
          self._other_outputs.append(i)
      if len(self._loss_outputs) == 0:
        self._loss_outputs = self._prediction_outputs

  def _ensure_built(self):
    """The first time this is called, create internal data structures.

    Work in Progress
    ----------------
    [1] Integerate the optax losses, optimizers, schedulers with Deepchem

    """
    if self._built:
      return

@@ -69,29 +139,84 @@ class JaxModel(Model):
    self._global_step = 0
    self.opt_state = self.optimizer.init(self.params)

  def fit(
      self,
  def fit(self,
          dataset: Dataset,
          nb_epochs: int = 10,
          deterministic: bool = False,
  ):
          loss: Optional[LossFn] = None,
          callbacks: Union[Callable, List[Callable]] = [],
          all_losses: Optional[List[float]] = None) -> float:
    """Train this model on a dataset.

    Parameters
    ----------
    dataset: Dataset
      the Dataset to train on
    nb_epoch: int
      the number of epochs to train for
    deterministic: bool
      if True, the samples are processed in order.  If False, a different random
      order is used for each epoch.
    loss: function
      a function of the form f(outputs, labels, weights) that computes the loss
      for each batch.  If None (the default), the model's standard loss function
      is used.
    callbacks: function or list of functions
      one or more functions of the form f(model, step) that will be invoked after
      every step.  This can be used to perform validation, logging, etc.
    all_losses: Optional[List[float]], optional (default None)
      If specified, all logged losses are appended into this list. Note that
      you can call `fit()` repeatedly with the same list and losses will
      continue to be appended.

    Returns
    -------
    The average loss over the most recent checkpoint interval

    Miscellanous Parameters Yet To Add
    ----------------------------------
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    checkpoint_interval: int
      the frequency at which to write checkpoints, measured in training steps.
      Set this to 0 to disable automatic checkpointing.
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    variables: list of torch.nn.Parameter
      the variables to train.  If None (the default), all trainable variables in
      the model are used.

    Work in Progress
    ----------------
    [1] Integerate the optax losses, optimizers, schedulers with Deepchem
    [2] Support for saving & loading the model.
    [3] Adding support for output types (choosing only self._loss_outputs)
   """
    return self.fit_generator(
        self.default_generator(
            dataset, epochs=nb_epochs, deterministic=deterministic))
            dataset, epochs=nb_epochs, deterministic=deterministic), loss,
        callbacks, all_losses)

  def fit_generator(
      self,
  def fit_generator(self,
                    generator: Iterable[Tuple[Any, Any, Any]],
  ):
                    loss: Optional[LossFn] = None,
                    callbacks: Union[Callable, List[Callable]] = [],
                    all_losses: Optional[List[float]] = None) -> float:
    if not isinstance(callbacks, SequenceCollection):
      callbacks = [callbacks]
    self._ensure_built()
    avg_loss = 0.0
    last_avg_loss = 0.0
    averaged_batches = 0

    if loss is None:
      loss = self.loss
    grad_update = self._create_gradient_fn(self.loss, self.model,
                                           self.optimizer)
    params, opt_state = self._get_trainable_params()
    time1 = time.time()

    # Main training loop

    for batch in generator:
      inputs, labels, weights = self._prepare_batch(batch)
@@ -106,7 +231,7 @@ class JaxModel(Model):
        weights = weights[0]

      params, opt_state, batch_loss = grad_update(params, opt_state, inputs,
                                                  labels)
                                                  labels, weights)

      avg_loss += jax.device_get(batch_loss)
      self._global_step += 1
@@ -118,39 +243,56 @@ class JaxModel(Model):
        avg_loss = float(avg_loss) / averaged_batches
        logger.info(
            'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
        if all_losses is not None:
          all_losses.append(avg_loss)
        # Capture the last avg_loss in case of return since we're resetting to 0 now
        last_avg_loss = avg_loss
        avg_loss = 0.0
        averaged_batches = 0
      for c in callbacks:
        c(self, current_step)

    # Report final results.
    if averaged_batches > 0:
      avg_loss = float(avg_loss) / averaged_batches
      logger.info(
          'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
      if all_losses is not None:
        all_losses.append(avg_loss)
      last_avg_loss = avg_loss

    time2 = time.time()
    logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
    self._set_trainable_params(params, opt_state)
    return last_avg_loss

  def _get_trainable_params(self):
    """
    Will be used to seperate freezing parameters while transfer learning
    """
    return self.params, self.opt_state

  def _set_trainable_params(self, params, opt_state):
  def _set_trainable_params(self, params: hk.Params, opt_state: optax.OptState):
    """
    A functional approach to setting the final parameters after training
    """
    self.params = params
    self.opt_state = opt_state

  def _create_gradient_fn(self, loss, model, optimizer, p=None):
  def _create_gradient_fn(self, loss, model, optimizer):
    """
    This function calls the update function, to implement the backpropogation
    """

    def model_loss(params, batch, target):
    @jax.jit
    def model_loss(params, batch, target, weights):
      predict = model.apply(params, batch)
      return loss(predict, target)
      return loss(predict, target, weights)

    @jax.jit
    def update(params, opt_state, batch, target):
      batch_loss, grads = jax.value_and_grad(loss)(params, batch, target)
    def update(params, opt_state, batch, target, weights):
      batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target,
                                                         weights)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optax.apply_updates(params, updates)
      return new_params, opt_state, batch_loss
@@ -185,6 +327,32 @@ class JaxModel(Model):
      mode: str = 'fit',
      deterministic: bool = True,
      pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
    """Create a generator that iterates batches for a dataset.

    Subclasses may override this method to customize how model inputs are
    generated from the data.

    Parameters
    ----------
    dataset: Dataset
      the data to iterate
    epochs: int
      the number of times to iterate over the full dataset
    mode: str
      allowed values are 'fit' (called during training), 'predict' (called
      during prediction), and 'uncertainty' (called during uncertainty
      prediction)
    deterministic: bool
      whether to iterate over the dataset in order, or randomly shuffle the
      data for each epoch
    pad_batches: bool
      whether to pad each batch up to this model's preferred batch size

    Returns
    -------
    a generator that iterates batches, each represented as a tuple of lists:
    ([inputs], [outputs], [weights])
    """

    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(