Commit 3b9fb905 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Adding extra tests

parent da703c9e
Loading
Loading
Loading
Loading
+0 −25
Original line number Diff line number Diff line
@@ -30,11 +30,9 @@ logger = logging.getLogger(__name__)

class JaxModel(Model):
  """This is a DeepChem model implemented by a Jax Model

  Here is a simple example of that uses JaxModel to train a
  Haiku (JAX Neural Network Library) based model on deepchem
  dataset.

  >> def f(x):
  >>   net = hk.nets.MLP([512, 256, 128, 1])
  >>   return net(x)
@@ -44,7 +42,6 @@ class JaxModel(Model):
  >> params = model.init(rng, x)
  >> j_m = JaxModel(model, params, 256, 0.001, 100)
  >> j_m.fit(train_dataset)

  All optimizations will be done using the optax library.
  """

@@ -62,7 +59,6 @@ class JaxModel(Model):
               **kwargs):
    """
    Create a new JaxModel

    Parameters
    ----------
    model: hk.State or Function
@@ -87,7 +83,6 @@ class JaxModel(Model):
    log_frequency: int, optional (default 100)
      The frequency at which to log data. Data is logged using
      `logging` by default.

    Miscellanous Parameters Yet To Add
    ----------------------------------
    model_dir: str, optional (default None)
@@ -96,7 +91,6 @@ class JaxModel(Model):
      whether to log progress to TensorBoard during training
    wandb: bool, optional (default False)
      whether to log progress to Weights & Biases during training

    Work in Progress
    ----------------
    [1] Integrate the optax losses, optimizers, schedulers with Deepchem
@@ -140,11 +134,9 @@ class JaxModel(Model):

  def _ensure_built(self):
    """The first time this is called, create internal data structures.

    Work in Progress
    ----------------
    [1] Integerate the optax losses, optimizers, schedulers with Deepchem

    """
    if self._built:
      return
@@ -161,7 +153,6 @@ class JaxModel(Model):
          callbacks: Union[Callable, List[Callable]] = [],
          all_losses: Optional[List[float]] = None) -> float:
    """Train this model on a dataset.

    Parameters
    ----------
    dataset: Dataset
@@ -182,11 +173,9 @@ class JaxModel(Model):
      If specified, all logged losses are appended into this list. Note that
      you can call `fit()` repeatedly with the same list and losses will
      continue to be appended.

    Returns
    -------
    The average loss over the most recent checkpoint interval

    Miscellanous Parameters Yet To Add
    ----------------------------------
    max_checkpoints_to_keep: int
@@ -200,7 +189,6 @@ class JaxModel(Model):
    variables: list of hk.Variable
      the variables to train.  If None (the default), all trainable variables in
      the model are used.

    Work in Progress
    ----------------
    [1] Integerate the optax losses, optimizers, schedulers with Deepchem
@@ -287,11 +275,9 @@ class JaxModel(Model):
      other_output_types: Optional[OneOrMany[str]]) -> OneOrMany[np.ndarray]:
    """
    Predict outputs for data provided by a generator.

    This is the private implementation of prediction.  Do not
    call it directly.  Instead call one of the public prediction
    methods.

    Parameters
    ----------
    generator: generator
@@ -306,7 +292,6 @@ class JaxModel(Model):
      returns the values of the uncertainty outputs.
    other_output_types: list, optional
      Provides a list of other output_types (strings) to predict from model.

    Returns
    -------
      a NumPy array of the model produces a single output, or a list of arrays
@@ -400,7 +385,6 @@ class JaxModel(Model):
      If specified, all outputs of this type will be retrieved
      from the model. If output_types is specified, outputs must
      be None.

    Returns
    -------
      a NumPy array of the model produces a single output, or a list of arrays
@@ -411,7 +395,6 @@ class JaxModel(Model):
  def predict_on_batch(self, X: ArrayLike, transformers: List[Transformer] = []
                      ) -> OneOrMany[np.ndarray]:
    """Generates predictions for input samples, processing samples in a batch.

    Parameters
    ----------
    X: ndarray
@@ -419,7 +402,6 @@ class JaxModel(Model):
    transformers: List[dc.trans.Transformers]
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays
@@ -440,7 +422,6 @@ class JaxModel(Model):
      output_types: Optional[List[str]] = None) -> OneOrMany[np.ndarray]:
    """
    Uses self to make predictions on provided Dataset object.

    Parameters
    ----------
    dataset: dc.data.Dataset
@@ -452,7 +433,6 @@ class JaxModel(Model):
      If specified, all outputs of this type will be retrieved
      from the model. If output_types is specified, outputs must
      be None.

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays
@@ -523,7 +503,6 @@ class JaxModel(Model):
                         transformers: List[Transformer] = [],
                         per_task_metrics: bool = False):
    """Evaluate the performance of this model on the data produced by a generator.

    Parameters
    ----------
    generator: generator
@@ -536,7 +515,6 @@ class JaxModel(Model):
      is passed through these transformers to undo the transformations.
    per_task_metrics: bool
      If True, return per-task scores.

    Returns
    -------
    dict
@@ -623,10 +601,8 @@ class JaxModel(Model):
      deterministic: bool = True,
      pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
    """Create a generator that iterates batches for a dataset.

    Subclasses may override this method to customize how model inputs are
    generated from the data.

    Parameters
    ----------
    dataset: Dataset
@@ -642,7 +618,6 @@ class JaxModel(Model):
      data for each epoch
    pad_batches: bool
      whether to pad each batch up to this model's preferred batch size

    Returns
    -------
    a generator that iterates batches, each represented as a tuple of lists:
+123 −21
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ import numpy as np
try:
  import jax
  import jax.numpy as jnp
  from jax import random
  import haiku as hk
  import optax
  from deepchem.models import JaxModel
@@ -14,13 +15,73 @@ except:
  has_haiku_and_optax = False


@pytest.mark.jax
def test_pure_jax_model():
  """
  Here we train a fully NN model made purely in Jax.
  The model is taken from Jax Tutorial https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
  """
  n_data_points = 50
  n_features = 1
  np.random.seed(1234)
  X = np.random.rand(n_data_points, n_features)
  y = X * X + X + 1
  dataset = dc.data.NumpyDataset(X, y)

  # Initialize the weights with random values
  def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key,
                                 (m, n)), scale * random.normal(b_key, (n,))

  def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [
        random_layer_params(m, n, k)
        for m, n, k in zip(sizes[:-1], sizes[1:], keys)
    ]

  layer_sizes = [1, 256, 128, 1]
  params = init_network_params(layer_sizes, random.PRNGKey(0))

  # Forward function which takes the params
  def forward_fn(params, rng, x):
    for i, weights in enumerate(params[:-1]):
      w, b = weights
      x = jnp.dot(x, w) + b
      x = jax.nn.relu(x)

    final_w, final_b = params[-1]
    output = jnp.dot(x, final_w) + final_b
    return output

  def rms_loss(pred, tar, w):
    return jnp.mean(optax.l2_loss(pred, tar))

  # Loss Function
  criterion = rms_loss

  # JaxModel Working
  j_m = JaxModel(
      forward_fn,
      params,
      criterion,
      batch_size=100,
      learning_rate=0.001,
      log_frequency=2)
  j_m.fit(dataset, nb_epochs=1000)
  metric = dc.metrics.Metric(dc.metrics.mean_absolute_error, mode="regression")
  scores = j_m.evaluate(dataset, [metric])
  assert scores[metric.name] < 0.5


@pytest.mark.jax
def test_jax_model_for_regression():
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer='ECFP')

  # sample network
  def f(x):
  def forward_model(x):
    net = hk.nets.MLP([512, 256, 128, 2])
    return net(x)

@@ -28,19 +89,19 @@ def test_jax_model_for_regression():
    return jnp.mean(optax.l2_loss(pred, tar))

  # Model Initialization
  model = hk.transform(f)
  params_init, forward_fn = hk.transform(forward_model)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=256)))
  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = model.init(rng, modified_inputs)
  params = params_init(rng, modified_inputs)

  # Loss Function
  criterion = rms_loss

  # JaxModel Working
  j_m = JaxModel(
      model.apply,
      forward_fn,
      params,
      criterion,
      batch_size=256,
@@ -59,7 +120,7 @@ def test_jax_model_for_classification():
  # sample network
  class Encoder(hk.Module):

    def __init__(self, output_size: int = 1):
    def __init__(self, output_size: int = 2):
      super().__init__()
      self._network = hk.nets.MLP([512, 256, 128, output_size])

@@ -67,29 +128,25 @@ def test_jax_model_for_classification():
      x = self._network(x)
      return x, jax.nn.softmax(x)

  def f(x):
    net = Encoder(2)
    return net(x)

  def bce_loss(pred, tar, w):
    tar = jnp.array(
        [x.astype(np.float32) if x.dtype != np.float32 else x for x in tar])
    return jnp.mean(optax.softmax_cross_entropy(pred[0], tar))

  # Model Initilisation
  model = hk.transform(f)
  params_init, forward_fn = hk.transform(lambda x: Encoder()(x))  # noqa
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=256)))
  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = model.init(rng, modified_inputs)
  params = params_init(rng, modified_inputs)

  # Loss Function
  criterion = bce_loss

  # JaxModel Working
  j_m = JaxModel(
      model.apply,
      forward_fn,
      params,
      criterion,
      output_types=['loss', 'prediction'],
@@ -121,25 +178,21 @@ def test_overfit_subclass_model():
      x = self._network(x)
      return x, jax.nn.sigmoid(x)

  def f(x):
    net = Encoder(1)
    return net(x)

  # Model Initilisation
  model = hk.transform(f)
  params_init, forward_fn = hk.transform(lambda x: Encoder()(x))  # noqa
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))

  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = model.init(rng, modified_inputs)
  params = params_init(rng, modified_inputs)

  # Loss Function
  criterion = lambda pred, tar, w: jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar))
  criterion = lambda pred, tar, w: jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar)) #noqa

  # JaxModel Working
  j_m = JaxModel(
      model.apply,
      forward_fn,
      params,
      criterion,
      output_types=['loss', 'prediction'],
@@ -154,6 +207,55 @@ def test_overfit_subclass_model():
  assert scores[metric.name] > 0.9


@pytest.mark.jax
def test_overfit_sequential_model():
  """Test fitting a JaxModel defined by subclassing Module."""
  n_data_points = 10
  n_features = 1
  np.random.seed(1234)
  X = np.random.rand(n_data_points, n_features)
  y = X * X + X + 1
  dataset = dc.data.NumpyDataset(X, y)

  def forward_fn(x):
    mlp = hk.Sequential([
        hk.Linear(300),
        jax.nn.relu,
        hk.Linear(100),
        jax.nn.relu,
        hk.Linear(1),
    ])
    return mlp(x)

  def rms_loss(pred, tar, w):
    return jnp.mean(optax.l2_loss(pred, tar))

  # Model Initilisation
  params_init, forward_fn = hk.transform(forward_fn)  # noqa
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))

  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = params_init(rng, modified_inputs)

  # Loss Function
  criterion = rms_loss

  # JaxModel Working
  j_m = JaxModel(
      forward_fn,
      params,
      criterion,
      batch_size=100,
      learning_rate=0.001,
      log_frequency=2)
  j_m.fit(dataset, nb_epochs=1000)
  metric = dc.metrics.Metric(dc.metrics.mean_absolute_error, mode="regression")
  scores = j_m.evaluate(dataset, [metric])
  assert scores[metric.name] < 0.5


@pytest.mark.jax
def test_fit_use_all_losses():
  """Test fitting a TorchModel defined by subclassing Module."""
@@ -188,7 +290,7 @@ def test_fit_use_all_losses():
  params = model.init(rng, modified_inputs)

  # Loss Function
  criterion = lambda pred, tar, w: jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar))
  criterion = lambda pred, tar, w: jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar)) #noqa

  # JaxModel Working
  j_m = JaxModel(