Commit 5faad7ed authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

adding extra tests

parent d4493ff9
Loading
Loading
Loading
Loading
+119 −18
Original line number Diff line number Diff line
@@ -57,6 +57,7 @@ class JaxModel(Model):
               learning_rate: float = 0.001,
               optimizer: Union[optax.GradientTransformation,
                                Optimizer] = optax.adam(1e-3),
               rng=jax.random.PRNGKey(1),
               log_frequency: int = 100,
               **kwargs):
    """
@@ -81,6 +82,8 @@ class JaxModel(Model):
      ignored.
    optimizer: optax object
      For the time being, it is optax object
    rng: jax.random.PRNGKey, optional (default 1)
      A default global PRNG key to use for drawing random numbers.
    log_frequency: int, optional (default 100)
      The frequency at which to log data. Data is logged using
      `logging` by default.
@@ -107,10 +110,11 @@ class JaxModel(Model):
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.optimizer = optimizer
    self.model = model
    self.forward_fn = model
    self.params = params
    self._built = False
    self.log_frequency = log_frequency
    self.rng = rng

    if output_types is None:
      self._prediction_outputs = None
@@ -221,9 +225,10 @@ class JaxModel(Model):
    averaged_batches = 0
    if loss is None:
      loss = self._loss_fn
    grad_update = self._create_gradient_fn(loss, self.model, self.optimizer,
                                           self._loss_outputs)
    grad_update = self._create_gradient_fn(loss, self.forward_fn,
                                           self.optimizer, self._loss_outputs)
    params, opt_state = self._get_trainable_params()
    rng = self.rng
    time1 = time.time()

    # Main training loop
@@ -240,9 +245,9 @@ class JaxModel(Model):
      if isinstance(weights, list) and len(weights) == 1:
        weights = weights[0]

      params, opt_state, batch_loss = grad_update(params, opt_state, inputs,
                                                  labels, weights)

      params, opt_state, batch_loss = grad_update(
          params, opt_state, inputs, labels, weights, rng=rng)
      rng, _ = jax.random.split(rng)
      avg_loss += jax.device_get(batch_loss)
      self._global_step += 1
      current_step = self._global_step
@@ -323,7 +328,8 @@ class JaxModel(Model):
            'This model cannot compute other outputs since no other output_types were specified.'
        )
    self._ensure_built()
    eval_fn = self._create_eval_fn(self.model, self.params)
    eval_fn = self._create_eval_fn(self.forward_fn, self.params)
    rng = self.rng

    for batch in generator:
      inputs, _, _ = self._prepare_batch(batch)
@@ -331,7 +337,7 @@ class JaxModel(Model):
      if isinstance(inputs, list) and len(inputs) == 1:
        inputs = inputs[0]

      output_values = eval_fn(inputs)
      output_values = eval_fn(inputs, rng)
      if isinstance(output_values, jnp.ndarray):
        output_values = [output_values]
      output_values = [jax.device_get(t) for t in output_values]
@@ -459,7 +465,53 @@ class JaxModel(Model):

  def predict_uncertainty(self, dataset: Dataset, masks: int = 50
                         ) -> OneOrMany[Tuple[np.ndarray, np.ndarray]]:
    pass
    """
    Predict the model's outputs, along with the uncertainty in each one.
    The uncertainty is computed as described in https://arxiv.org/abs/1703.04977.
    It involves repeating the prediction many times with different dropout masks.
    The prediction is computed as the average over all the predictions.  The
    uncertainty includes both the variation among the predicted values (epistemic
    uncertainty) and the model's own estimates for how well it fits the data
    (aleatoric uncertainty).  Not all models support uncertainty prediction.
    Parameters
    ----------
    dataset: dc.data.Dataset
      Dataset to make prediction on
    masks: int
      the number of dropout masks to average over
    Returns
    -------
    for each output, a tuple (y_pred, y_std) where y_pred is the predicted
    value of the output, and each element of y_std estimates the standard
    deviation of the corresponding element of y_pred
    """
    sum_pred: List[np.ndarray] = []
    sum_sq_pred: List[np.ndarray] = []
    sum_var: List[np.ndarray] = []
    for i in range(masks):
      generator = self.default_generator(
          dataset, mode='uncertainty', pad_batches=False)
      results = self._predict(generator, [], True, None)
      if len(sum_pred) == 0:
        for p, v in results:
          sum_pred.append(p)
          sum_sq_pred.append(p * p)
          sum_var.append(v)
      else:
        for j, (p, v) in enumerate(results):
          sum_pred[j] += p
          sum_sq_pred[j] += p * p
          sum_var[j] += v
    output = []
    std = []
    for i in range(len(sum_pred)):
      p = sum_pred[i] / masks
      output.append(p)
      std.append(np.sqrt(sum_sq_pred[i] / masks - p * p + sum_var[i] / masks))
    if len(output) == 1:
      return (output[0], std[0])
    else:
      return list(zip(output, std))

  def evaluate_generator(self,
                         generator: Iterable[Tuple[Any, Any, Any]],
@@ -502,36 +554,36 @@ class JaxModel(Model):
    self.params = params
    self.opt_state = opt_state

  def _create_eval_fn(self, model, params):
  def _create_eval_fn(self, forward_fn, params):
    """
    Calls the function to evaulate the model
    """

    @jax.jit
    def eval_model(batch):
      predict = model.apply(params, batch)
    def eval_model(batch, rng=None):
      predict = forward_fn(params, rng, batch)

      return predict

    return eval_model

  def _create_gradient_fn(self, loss, model, optimizer, loss_outputs):
  def _create_gradient_fn(self, loss, forward_fn, optimizer, loss_outputs):
    """
    This function calls the update function, to implement the backpropogation
    """

    @jax.jit
    def model_loss(params, batch, target, weights):
      predict = model.apply(params, batch)
    def model_loss(params, batch, target, weights, rng):
      predict = forward_fn(params, rng, batch)
      if loss_outputs is not None:
        predict = [predict[i] for i in loss_outputs]
      return loss(predict, target, weights)

    @jax.jit
    def update(params, opt_state, batch, target,
               weights) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
    def update(params, opt_state, batch, target, weights,
               rng) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
      batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target,
                                                         weights)
                                                         weights, rng)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optax.apply_updates(params, updates)
      return new_params, opt_state, batch_loss
@@ -599,3 +651,52 @@ class JaxModel(Model):
          deterministic=deterministic,
          pad_batches=pad_batches):
        yield ([X_b], [y_b], [w_b])


# def create_default_forward_fn(haiku_model):
#   """
#   This function is used to create the forward function for the model.
#   """
#   def forward_fn(
#     data: Union[jnp.ndarray, Mapping[]],
#     is_training: bool = True
#   ) -> jnp.ndarray:
#     """
#     Forward pass
#     """
#     return haiku_model(data, is_training)

#   return forward_fn

# def create_default_eval_fn(forward_fn, params):
#   """
#   Calls the function to evaulate the model
#   """
#   def eval_model(batch):
#     predict = forward_fn(params, batch)

#     return predict

#   return eval_model

# def create_default_update_fn(loss_fn, forward_fn, optimizer, loss_outputs):
#   """
#   This function calls the update function, to implement the backpropogation
#   """
#   @jax.jit
#   def model_loss(params, batch, target, weights):
#     predict = forward_fn(params, batch)
#     if loss_outputs is not None:
#       predict = [predict[i] for i in loss_outputs]
#     return loss_fn(predict, target, weights)

#   @jax.jit
#   def update(params, opt_state, batch, target,
#               weights) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
#     batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target,
#                                                         weights)
#     updates, opt_state = optimizer.update(grads, opt_state)
#     new_params = optax.apply_updates(params, updates)
#     return new_params, opt_state, batch_loss

#   return update
+81 −12
Original line number Diff line number Diff line
@@ -28,7 +28,7 @@ def test_jax_model_for_regression():
    return jnp.mean(optax.l2_loss(pred, tar))

  # Model Initilisation
  model = hk.without_apply_rng(hk.transform(f))
  model = hk.transform(f)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=256)))
  modified_inputs = jnp.array(
@@ -39,15 +39,14 @@ def test_jax_model_for_regression():
  criterion = rms_loss

  # JaxModel Working
  n_tasks = len(tasks)
  j_m = JaxModel(
      model,
      model.apply,
      params,
      criterion,
      batch_size=256,
      learning_rate=0.001,
      log_frequency=2)
  results = j_m.fit(dataset, nb_epochs=25, deterministic=True)
  _ = j_m.fit(dataset, nb_epochs=25, deterministic=True)
  scores = j_m.evaluate(dataset, [metric])
  assert scores[metric.name] < 0.5

@@ -78,7 +77,7 @@ def test_jax_model_for_classification():
    return jnp.mean(optax.softmax_cross_entropy(pred[0], tar))

  # Model Initilisation
  model = hk.without_apply_rng(hk.transform(f))
  model = hk.transform(f)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=256)))
  modified_inputs = jnp.array(
@@ -90,7 +89,7 @@ def test_jax_model_for_classification():

  # JaxModel Working
  j_m = JaxModel(
      model,
      model.apply,
      params,
      criterion,
      output_types=['loss', 'prediction'],
@@ -127,11 +126,9 @@ def test_overfit_subclass_model():
    return net(x)

  # Model Initilisation
  model = hk.without_apply_rng(hk.transform(f))
  model = hk.transform(f)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))
  #inputs = np.random.rand(100, n_features)
  #print(inputs.shape)

  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
@@ -142,7 +139,7 @@ def test_overfit_subclass_model():

  # JaxModel Working
  j_m = JaxModel(
      model,
      model.apply,
      params,
      criterion,
      output_types=['loss', 'prediction'],
@@ -157,6 +154,7 @@ def test_overfit_subclass_model():
  assert scores[metric.name] > 0.9


@pytest.mark.jax
def test_fit_use_all_losses():
  """Test fitting a TorchModel defined by subclassing Module."""
  n_data_points = 10
@@ -181,7 +179,7 @@ def test_fit_use_all_losses():
    return net(x)

  # Model Initilisation
  model = hk.without_apply_rng(hk.transform(f))
  model = hk.transform(f)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))

@@ -194,7 +192,7 @@ def test_fit_use_all_losses():

  # JaxModel Working
  j_m = JaxModel(
      model,
      model.apply,
      params,
      criterion,
      output_types=['loss', 'prediction'],
@@ -206,3 +204,74 @@ def test_fit_use_all_losses():
  # Each epoch is a single step for this model
  assert len(losses) == 100
  assert np.count_nonzero(np.array(losses)) == 100


@pytest.mark.jax
def test_uncertainty():
  """Test estimating uncertainty a TorchModel."""
  n_samples = 30
  n_features = 1
  noise = 0.1
  X = np.random.rand(n_samples, n_features)
  y = (10 * X + np.random.normal(scale=noise, size=(n_samples, n_features)))
  dataset = dc.data.NumpyDataset(X, y)

  class Net(hk.Module):

    def __init__(self, output_size: int = 1):
      super().__init__()
      self._network = hk.Sequential([hk.Linear(200), jax.nn.relu])
      self.output = hk.Linear(output_size)
      self.log_var = hk.Linear(output_size)

    def __call__(self, x):
      # x, dropout_rate = x
      x = self._network(x)
      # if x is not None:
      x = hk.dropout(hk.next_rng_key(), 0.1, x)
      output = self.output(x)
      log_var = self.log_var(x)
      var = jnp.exp(log_var)
      return output, var, output, log_var

  def f(x):
    net = Net(1)
    return net(x)

  def loss(outputs, labels, weights):
    diff = labels[0] - outputs[0]
    log_var = outputs[1]
    var = jnp.exp(log_var)
    return jnp.mean(diff * diff / var + log_var)

  class UncertaintyModel(JaxModel):

    def default_generator(self,
                          dataset,
                          epochs=1,
                          mode='fit',
                          deterministic=True,
                          pad_batches=True):
      for epoch in range(epochs):
        for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
            batch_size=self.batch_size,
            deterministic=deterministic,
            pad_batches=pad_batches):
          yield ([X_b], [y_b], [w_b])

  jm_model = hk.transform(f)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))
  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = jm_model.init(rng, modified_inputs)
  model = UncertaintyModel(
      jm_model.apply,
      params,
      loss,
      output_types=['prediction', 'variance', 'loss', 'loss'],
      learning_rate=0.003)
  model.fit(dataset, nb_epochs=2500)
  pred, std = model.predict_uncertainty(dataset)
  assert np.mean(np.abs(y - pred)) < 1.0
  assert noise < np.mean(std) < 1.0