Commit f80e568d authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

just customary check

parent 403c50b0
Loading
Loading
Loading
Loading
+11 −1
Original line number Diff line number Diff line
@@ -94,6 +94,16 @@ jobs:
      uses: actions/setup-python@v2
      with:
        python-version: ${{ matrix.python-version }}
    - name: Create env.yml
      shell: bash
      run: |
        python -m pip install --upgrade pip;
        pip install conda-merge;
        if [ "$(uname)" == 'Darwin' ]; then
          conda-merge env_jax.yml env.test.yml > env.yml
        else
          conda-merge env_jax.yml env.test.yml > env.yml
        fi;
    - name: Install all dependencies
      uses: conda-incubator/setup-miniconda@v2
      with:
@@ -102,7 +112,7 @@ jobs:
        activate-environment: deepchem
        channels: omnia,conda-forge,defaults
        python-version: ${{ matrix.python-version }}
        environment-file: env_jax.yml
        environment-file: env.yml
    - name: Install DeepChem
      id: install
      shell: bash -l {0}
+24 −0
Original line number Diff line number Diff line
"""
checking jax imports for new CI build
"""
import jax.numpy as jnp
from jax import random
import numpy as np
import deepchem as dc
import pytest


@pytest.mark.jax
def test_jax_import():
  key = random.PRNGKey(0)
  x = random.normal(key, (10, 10), dtype=jnp.float32)
  y = random.normal(key, (10, 10), dtype=jnp.float32)
  assert jnp.all(x == y)

  n_data_points = 10
  n_features = 2
  np.random.seed(1234)
  X = np.random.rand(n_data_points, n_features)
  y = (X[:, 0] > X[:, 1]).astype(np.float32)
  dataset = dc.data.NumpyDataset(X, y)
  assert dataset.X.shape == (10, 2)