Commit bfd2d15e authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

review fixes

parent 5faad7ed
Loading
Loading
Loading
Loading
+11 −7
Original line number Diff line number Diff line
@@ -297,7 +297,7 @@ class JaxModel(Model):
    generator: generator
      this should generate batches, each represented as a tuple of the form
      (inputs, labels, weights).
    transformers: list of dc.trans.Transformers
    transformers: List[dc.trans.Transformers]
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.
    uncertainty: bool
@@ -306,7 +306,9 @@ class JaxModel(Model):
      returns the values of the uncertainty outputs.
    other_output_types: list, optional
      Provides a list of other output_types (strings) to predict from model.
    Returns:

    Returns
    -------
      a NumPy array of the model produces a single output, or a list of arrays
      if it produces multiple outputs
    """
@@ -391,14 +393,16 @@ class JaxModel(Model):
    generator: generator
      this should generate batches, each represented as a tuple of the form
      (inputs, labels, weights).
    transformers: list of dc.trans.Transformers
    transformers: List[dc.trans.Transformers]
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.
    output_types: String or list of Strings
      If specified, all outputs of this type will be retrieved
      from the model. If output_types is specified, outputs must
      be None.
    Returns:

    Returns
    -------
      a NumPy array of the model produces a single output, or a list of arrays
      if it produces multiple outputs
    """
@@ -412,7 +416,7 @@ class JaxModel(Model):
    ----------
    X: ndarray
      the input data, as a Numpy array.
    transformers: list of dc.trans.Transformers
    transformers: List[dc.trans.Transformers]
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.

@@ -441,7 +445,7 @@ class JaxModel(Model):
    ----------
    dataset: dc.data.Dataset
      Dataset to make prediction on
    transformers: list of dc.trans.Transformers
    transformers: List[dc.trans.Transformers]
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.
    output_types: String or list of Strings
@@ -527,7 +531,7 @@ class JaxModel(Model):
      (inputs, labels, weights).
    metric: list of deepchem.metrics.Metric
      Evaluation metric
    transformers: list of dc.trans.Transformers
    transformers: List[dc.trans.Transformers]
      Transformers that the input data has been transformed by.  The output
      is passed through these transformers to undo the transformations.
    per_task_metrics: bool
+2 −1
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ def test_jax_model_for_regression():
  def rms_loss(pred, tar, w):
    return jnp.mean(optax.l2_loss(pred, tar))

  # Model Initilisation
  # Model Initialization
  model = hk.transform(f)
  rng = jax.random.PRNGKey(500)
  inputs, _, _, _ = next(iter(dataset.iterbatches(batch_size=256)))
@@ -207,6 +207,7 @@ def test_fit_use_all_losses():


@pytest.mark.jax
@pytest.mark.slow
def test_uncertainty():
  """Test estimating uncertainty a TorchModel."""
  n_samples = 30