Commit 207f7de2 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Merge branch 'jax' into jax2

parents e334db4d f2460e46
Loading
Loading
Loading
Loading
+4 −7
Original line number Diff line number Diff line
@@ -18,13 +18,10 @@ from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, Sequen
from deepchem.utils.typing import LossFn, OneOrMany, ArrayLike

# JAX depend
try:
import jax.numpy as jnp
import jax
import haiku as hk
import optax
except:
  raise ImportError("This class requires Jax, haiku and optax to be installed.")

import warnings

+0 −1
Original line number Diff line number Diff line
@@ -2,7 +2,6 @@ dependencies:
  - pip:
    - -f https://download.pytorch.org/whl/torch_stable.html
    - -f https://pytorch-geometric.com/whl/torch-1.8.1+cu111.html
    - -f https://storage.googleapis.com/jax-releases/jax_releases.html
    - dgl-cu110==0.6.*
    - torch==1.8.1+cu111
    - torchvision==0.9.1+cu111