Commit d4493ff9 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

adding few tests

parent d024726f
Loading
Loading
Loading
Loading
+115 −5
Original line number Diff line number Diff line
import pytest
from deepchem.models.tests.test_graph_models import get_dataset
import deepchem as dc
import numpy as np

try:
@@ -20,7 +21,7 @@ def test_jax_model_for_regression():

  # sample network
  def f(x):
    net = hk.nets.MLP([512, 256, 128, 1])
    net = hk.nets.MLP([512, 256, 128, 2])
    return net(x)

  def rms_loss(pred, tar, w):
@@ -38,6 +39,7 @@ def test_jax_model_for_regression():
  criterion = rms_loss

  # JaxModel Working
  n_tasks = len(tasks)
  j_m = JaxModel(
      model,
      params,
@@ -45,8 +47,9 @@ def test_jax_model_for_regression():
      batch_size=256,
      learning_rate=0.001,
      log_frequency=2)
  results = j_m.fit(dataset, deterministic=True)
  assert results < 0.5
  results = j_m.fit(dataset, nb_epochs=25, deterministic=True)
  scores = j_m.evaluate(dataset, [metric])
  assert scores[metric.name] < 0.5


@pytest.mark.jax
@@ -94,5 +97,112 @@ def test_jax_model_for_classification():
      batch_size=256,
      learning_rate=0.001,
      log_frequency=2)
  results = j_m.fit(dataset, nb_epochs=50, deterministic=True)
  assert results < 1.0
  _ = j_m.fit(dataset, nb_epochs=25, deterministic=True)
  scores = j_m.evaluate(dataset, [metric])
  assert scores[metric.name] > 0.8


@pytest.mark.jax
def test_overfit_subclass_model():
  """Test fitting a JaxModel defined by subclassing Module."""
  n_data_points = 10
  n_features = 2
  np.random.seed(1234)
  X = np.random.rand(n_data_points, n_features)
  y = (X[:, 0] > X[:, 1]).astype(np.float32)
  dataset = dc.data.NumpyDataset(X, np.expand_dims(y, axis=1))

  class Encoder(hk.Module):

    def __init__(self, output_size: int = 1):
      super().__init__()
      self._network = hk.nets.MLP([512, 256, 128, output_size])

    def __call__(self, x: jnp.ndarray):
      x = self._network(x)
      return x, jax.nn.sigmoid(x)

  def f(x):
    net = Encoder(1)
    return net(x)

  # Model Initilisation
  model = hk.without_apply_rng(hk.transform(f))
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))
  #inputs = np.random.rand(100, n_features)
  #print(inputs.shape)

  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = model.init(rng, modified_inputs)

  # Loss Function
  criterion = lambda pred, tar, w: jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar))

  # JaxModel Working
  j_m = JaxModel(
      model,
      params,
      criterion,
      output_types=['loss', 'prediction'],
      batch_size=100,
      learning_rate=0.001,
      log_frequency=2)
  j_m.fit(dataset, nb_epochs=1000)
  prediction = np.squeeze(j_m.predict_on_batch(X))
  assert np.array_equal(y, np.round(prediction))
  metric = dc.metrics.Metric(dc.metrics.roc_auc_score)
  scores = j_m.evaluate(dataset, [metric])
  assert scores[metric.name] > 0.9


def test_fit_use_all_losses():
  """Test fitting a TorchModel defined by subclassing Module."""
  n_data_points = 10
  n_features = 2
  np.random.seed(1234)
  X = np.random.rand(n_data_points, n_features)
  y = (X[:, 0] > X[:, 1]).astype(np.float32)
  dataset = dc.data.NumpyDataset(X, np.expand_dims(y, axis=1))

  class Encoder(hk.Module):

    def __init__(self, output_size: int = 1):
      super().__init__()
      self._network = hk.nets.MLP([512, 256, 128, output_size])

    def __call__(self, x: jnp.ndarray):
      x = self._network(x)
      return x, jax.nn.sigmoid(x)

  def f(x):
    net = Encoder(1)
    return net(x)

  # Model Initilisation
  model = hk.without_apply_rng(hk.transform(f))
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=100)))

  modified_inputs = jnp.array(
      [x.astype(np.float32) if x.dtype == np.float64 else x for x in inputs])
  params = model.init(rng, modified_inputs)

  # Loss Function
  criterion = lambda pred, tar, w: jnp.mean(optax.sigmoid_binary_cross_entropy(pred[0], tar))

  # JaxModel Working
  j_m = JaxModel(
      model,
      params,
      criterion,
      output_types=['loss', 'prediction'],
      learning_rate=0.005,
      log_frequency=10)

  losses = []
  j_m.fit(dataset, nb_epochs=1000, all_losses=losses)
  # Each epoch is a single step for this model
  assert len(losses) == 100
  assert np.count_nonzero(np.array(losses)) == 100