Commit 70f273f4 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

more modularisation

parent ff891781
Loading
Loading
Loading
Loading
+108 −139
Original line number Diff line number Diff line
@@ -28,6 +28,52 @@ import warnings
logger = logging.getLogger(__name__)


def create_default_eval_fn(forward_fn, params):
  """
    Calls the function to evaulate the model
    """

  @jax.jit
  def eval_model(batch, rng=None):
    predict = forward_fn(params, rng, batch)

    return predict

  return eval_model


def create_default_update_fn(optimizer, model_loss):
  """
    This function calls the update function, to implement the backpropogation
    """

  @jax.jit
  def update(params, opt_state, batch, target, weights,
             rng) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
    batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target,
                                                       weights, rng)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, batch_loss

  return update


def create_default_gradient_fn(forward_fn, loss_outputs, loss_fn):
  """
    This function calls the gradient function, to implement the backpropogation
    """

  @jax.jit
  def model_loss(params, batch, target, weights, rng):
    predict = forward_fn(params, rng, batch)
    if loss_outputs is not None:
      predict = [predict[i] for i in loss_outputs]
    return loss_fn(predict, target, weights)

  return model_loss


class JaxModel(Model):
  """This is a DeepChem model implemented by a Jax Model
  Here is a simple example of that uses JaxModel to train a
@@ -46,7 +92,7 @@ class JaxModel(Model):
  """

  def __init__(self,
               model: hk.State,
               forward_fn: hk.State,
               params: hk.Params,
               loss: Union[Loss, LossFn],
               output_types: Optional[List[str]] = None,
@@ -54,6 +100,9 @@ class JaxModel(Model):
               learning_rate: float = 0.001,
               optimizer: Union[optax.GradientTransformation,
                                Optimizer] = optax.adam(1e-3),
               grad_fn: Callable = create_default_gradient_fn,
               update_fn: Callable = create_default_update_fn,
               eval_fn: Callable = create_default_eval_fn,
               rng=jax.random.PRNGKey(1),
               log_frequency: int = 100,
               **kwargs):
@@ -96,7 +145,7 @@ class JaxModel(Model):
    [1] Integrate the optax losses, optimizers, schedulers with Deepchem
    [2] Support for saving & loading the model.
    """
    super(JaxModel, self).__init__(model=model, **kwargs)
    super(JaxModel, self).__init__(model=(forward_fn, params), **kwargs)
    warnings.warn(
        'JaxModel is still in active development and all features may not yet be implemented'
    )
@@ -104,11 +153,14 @@ class JaxModel(Model):
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.optimizer = optimizer
    self.forward_fn = model
    self.forward_fn = forward_fn
    self.params = params
    self._built = False
    self.log_frequency = log_frequency
    self.rng = rng
    self._create_gradient_fn = grad_fn
    self._create_update_fn = update_fn
    self._create_eval_fn = eval_fn

    if output_types is None:
      self._prediction_outputs = None
@@ -213,8 +265,10 @@ class JaxModel(Model):
    averaged_batches = 0
    if loss is None:
      loss = self._loss_fn
    grad_update = self._create_gradient_fn(loss, self.forward_fn,
                                           self.optimizer, self._loss_outputs)
    model_loss_fn = self._create_gradient_fn(self.forward_fn,
                                             self._loss_outputs, loss)
    grad_update = self._create_update_fn(self.optimizer, model_loss_fn)

    params, opt_state = self._get_trainable_params()
    rng = self.rng
    time1 = time.time()
@@ -447,55 +501,55 @@ class JaxModel(Model):

    pass

  def predict_uncertainty(self, dataset: Dataset, masks: int = 50
                         ) -> OneOrMany[Tuple[np.ndarray, np.ndarray]]:
    """
    Predict the model's outputs, along with the uncertainty in each one.
    The uncertainty is computed as described in https://arxiv.org/abs/1703.04977.
    It involves repeating the prediction many times with different dropout masks.
    The prediction is computed as the average over all the predictions.  The
    uncertainty includes both the variation among the predicted values (epistemic
    uncertainty) and the model's own estimates for how well it fits the data
    (aleatoric uncertainty).  Not all models support uncertainty prediction.
    Parameters
    ----------
    dataset: dc.data.Dataset
      Dataset to make prediction on
    masks: int
      the number of dropout masks to average over
    Returns
    -------
    for each output, a tuple (y_pred, y_std) where y_pred is the predicted
    value of the output, and each element of y_std estimates the standard
    deviation of the corresponding element of y_pred
    """
    sum_pred: List[np.ndarray] = []
    sum_sq_pred: List[np.ndarray] = []
    sum_var: List[np.ndarray] = []
    for i in range(masks):
      generator = self.default_generator(
          dataset, mode='uncertainty', pad_batches=False)
      results = self._predict(generator, [], True, None)
      if len(sum_pred) == 0:
        for p, v in results:
          sum_pred.append(p)
          sum_sq_pred.append(p * p)
          sum_var.append(v)
      else:
        for j, (p, v) in enumerate(results):
          sum_pred[j] += p
          sum_sq_pred[j] += p * p
          sum_var[j] += v
    output = []
    std = []
    for i in range(len(sum_pred)):
      p = sum_pred[i] / masks
      output.append(p)
      std.append(np.sqrt(sum_sq_pred[i] / masks - p * p + sum_var[i] / masks))
    if len(output) == 1:
      return (output[0], std[0])
    else:
      return list(zip(output, std))
  # def predict_uncertainty(self, dataset: Dataset, masks: int = 50
  #                        ) -> OneOrMany[Tuple[np.ndarray, np.ndarray]]:
  #   """
  #   Predict the model's outputs, along with the uncertainty in each one.
  #   The uncertainty is computed as described in https://arxiv.org/abs/1703.04977.
  #   It involves repeating the prediction many times with different dropout masks.
  #   The prediction is computed as the average over all the predictions.  The
  #   uncertainty includes both the variation among the predicted values (epistemic
  #   uncertainty) and the model's own estimates for how well it fits the data
  #   (aleatoric uncertainty).  Not all models support uncertainty prediction.
  #   Parameters
  #   ----------
  #   dataset: dc.data.Dataset
  #     Dataset to make prediction on
  #   masks: int
  #     the number of dropout masks to average over
  #   Returns
  #   -------
  #   for each output, a tuple (y_pred, y_std) where y_pred is the predicted
  #   value of the output, and each element of y_std estimates the standard
  #   deviation of the corresponding element of y_pred
  #   """
  #   sum_pred: List[np.ndarray] = []
  #   sum_sq_pred: List[np.ndarray] = []
  #   sum_var: List[np.ndarray] = []
  #   for i in range(masks):
  #     generator = self.default_generator(
  #         dataset, mode='uncertainty', pad_batches=False)
  #     results = self._predict(generator, [], True, None)
  #     if len(sum_pred) == 0:
  #       for p, v in results:
  #         sum_pred.append(p)
  #         sum_sq_pred.append(p * p)
  #         sum_var.append(v)
  #     else:
  #       for j, (p, v) in enumerate(results):
  #         sum_pred[j] += p
  #         sum_sq_pred[j] += p * p
  #         sum_var[j] += v
  #   output = []
  #   std = []
  #   for i in range(len(sum_pred)):
  #     p = sum_pred[i] / masks
  #     output.append(p)
  #     std.append(np.sqrt(sum_sq_pred[i] / masks - p * p + sum_var[i] / masks))
  #   if len(output) == 1:
  #     return (output[0], std[0])
  #   else:
  #     return list(zip(output, std))

  def evaluate_generator(self,
                         generator: Iterable[Tuple[Any, Any, Any]],
@@ -536,42 +590,6 @@ class JaxModel(Model):
    self.params = params
    self.opt_state = opt_state

  def _create_eval_fn(self, forward_fn, params):
    """
    Calls the function to evaulate the model
    """

    @jax.jit
    def eval_model(batch, rng=None):
      predict = forward_fn(params, rng, batch)

      return predict

    return eval_model

  def _create_gradient_fn(self, loss, forward_fn, optimizer, loss_outputs):
    """
    This function calls the update function, to implement the backpropogation
    """

    @jax.jit
    def model_loss(params, batch, target, weights, rng):
      predict = forward_fn(params, rng, batch)
      if loss_outputs is not None:
        predict = [predict[i] for i in loss_outputs]
      return loss(predict, target, weights)

    @jax.jit
    def update(params, opt_state, batch, target, weights,
               rng) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
      batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target,
                                                         weights, rng)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optax.apply_updates(params, updates)
      return new_params, opt_state, batch_loss

    return update

  def _prepare_batch(self, batch):
    inputs, labels, weights = batch
    inputs = [
@@ -630,52 +648,3 @@ class JaxModel(Model):
          deterministic=deterministic,
          pad_batches=pad_batches):
        yield ([X_b], [y_b], [w_b])


# def create_default_forward_fn(haiku_model):
#   """
#   This function is used to create the forward function for the model.
#   """
#   def forward_fn(
#     data: Union[jnp.ndarray, Mapping[]],
#     is_training: bool = True
#   ) -> jnp.ndarray:
#     """
#     Forward pass
#     """
#     return haiku_model(data, is_training)

#   return forward_fn

# def create_default_eval_fn(forward_fn, params):
#   """
#   Calls the function to evaulate the model
#   """
#   def eval_model(batch):
#     predict = forward_fn(params, batch)

#     return predict

#   return eval_model

# def create_default_update_fn(loss_fn, forward_fn, optimizer, loss_outputs):
#   """
#   This function calls the update function, to implement the backpropogation
#   """
#   @jax.jit
#   def model_loss(params, batch, target, weights):
#     predict = forward_fn(params, batch)
#     if loss_outputs is not None:
#       predict = [predict[i] for i in loss_outputs]
#     return loss_fn(predict, target, weights)

#   @jax.jit
#   def update(params, opt_state, batch, target,
#               weights) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
#     batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target,
#                                                         weights)
#     updates, opt_state = optimizer.update(grads, opt_state)
#     new_params = optax.apply_updates(params, updates)
#     return new_params, opt_state, batch_loss

#   return update
+73 −73
Original line number Diff line number Diff line
@@ -308,74 +308,74 @@ def test_fit_use_all_losses():
  assert np.count_nonzero(np.array(losses)) == 100


@pytest.mark.jax
@pytest.mark.slow
def test_uncertainty():
  """Test estimating uncertainty a TorchModel."""
  n_samples = 30
  n_features = 1
  noise = 0.1
  X = np.random.rand(n_samples, n_features)
  y = (10 * X + np.random.normal(scale=noise, size=(n_samples, n_features)))
  dataset = dc.data.NumpyDataset(X, y)

  class Net(hk.Module):

    def __init__(self, output_size: int = 1):
      super().__init__()
      self._network1 = hk.Sequential([hk.Linear(200), jax.nn.relu])
      self._network2 = hk.Sequential([hk.Linear(200), jax.nn.relu])
      self.output = hk.Linear(output_size)
      self.log_var = hk.Linear(output_size)

    def __call__(self, x):
      x = self._network1(x)
      x = hk.dropout(hk.next_rng_key(), 0.1, x)
      x = self._network2(x)
      x = hk.dropout(hk.next_rng_key(), 0.1, x)
      output = self.output(x)
      log_var = self.log_var(x)
      var = jnp.exp(log_var)
      return output, var, output, log_var

  def f(x):
    net = Net(1)
    return net(x)

  def loss(outputs, labels, weights):
    diff = labels[0] - outputs[0]
    log_var = outputs[1]
    var = jnp.exp(log_var)
    return jnp.mean(diff * diff / var + log_var)

  class UncertaintyModel(JaxModel):

    def default_generator(self,
                          dataset,
                          epochs=1,
                          mode='fit',
                          deterministic=True,
                          pad_batches=True):
      for epoch in range(epochs):
        for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
            batch_size=self.batch_size,
            deterministic=deterministic,
            pad_batches=pad_batches):
          yield ([X_b], [y_b], [w_b])

  jm_model = hk.transform(f)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))
  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = jm_model.init(rng, modified_inputs)
  model = UncertaintyModel(
      jm_model.apply,
      params,
      loss,
      output_types=['prediction', 'variance', 'loss', 'loss'],
      learning_rate=0.003)
  model.fit(dataset, nb_epochs=2500)
  pred, std = model.predict_uncertainty(dataset)
  assert np.mean(np.abs(y - pred)) < 2.0
  assert noise < np.mean(std) < 1.0
# @pytest.mark.jax
# @pytest.mark.slow
# def test_uncertainty():
#   """Test estimating uncertainty a TorchModel."""
#   n_samples = 30
#   n_features = 1
#   noise = 0.1
#   X = np.random.rand(n_samples, n_features)
#   y = (10 * X + np.random.normal(scale=noise, size=(n_samples, n_features)))
#   dataset = dc.data.NumpyDataset(X, y)

#   class Net(hk.Module):

#     def __init__(self, output_size: int = 1):
#       super().__init__()
#       self._network1 = hk.Sequential([hk.Linear(200), jax.nn.relu])
#       self._network2 = hk.Sequential([hk.Linear(200), jax.nn.relu])
#       self.output = hk.Linear(output_size)
#       self.log_var = hk.Linear(output_size)

#     def __call__(self, x):
#       x = self._network1(x)
#       x = hk.dropout(hk.next_rng_key(), 0.1, x)
#       x = self._network2(x)
#       x = hk.dropout(hk.next_rng_key(), 0.1, x)
#       output = self.output(x)
#       log_var = self.log_var(x)
#       var = jnp.exp(log_var)
#       return output, var, output, log_var

#   def f(x):
#     net = Net(1)
#     return net(x)

#   def loss(outputs, labels, weights):
#     diff = labels[0] - outputs[0]
#     log_var = outputs[1]
#     var = jnp.exp(log_var)
#     return jnp.mean(diff * diff / var + log_var)

#   class UncertaintyModel(JaxModel):

#     def default_generator(self,
#                           dataset,
#                           epochs=1,
#                           mode='fit',
#                           deterministic=True,
#                           pad_batches=True):
#       for epoch in range(epochs):
#         for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
#             batch_size=self.batch_size,
#             deterministic=deterministic,
#             pad_batches=pad_batches):
#           yield ([X_b], [y_b], [w_b])

#   jm_model = hk.transform(f)
#   rng = jax.random.PRNGKey(500)
#   inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))
#   modified_inputs = jnp.array(
#       [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
#   params = jm_model.init(rng, modified_inputs)
#   model = UncertaintyModel(
#       jm_model.apply,
#       params,
#       loss,
#       output_types=['prediction', 'variance', 'loss', 'loss'],
#       learning_rate=0.003)
#   model.fit(dataset, nb_epochs=2500)
#   pred, std = model.predict_uncertainty(dataset)
#   assert np.mean(np.abs(y - pred)) < 2.0
#   assert noise < np.mean(std) < 1.0