alpha state...
A community supported Windows build for jax.
Currently, only CPU and CUDA 11.1 are supported. For CUDA 11.x, please install the cuda
/cuda11_cudnn82
package.
jax
pinned a jaxlib
package version in its setup.py
, to install unstable
build, you must first ensure the required jaxlib
package exists in the pacakge
index. Check it out at https://whls.blob.core.windows.net/unstable/index.html
# See https://peps.python.org/pep-0440/#arbitrary-equality for triple `=`
pip install jaxlib===0.3.5 -f https://whls.blob.core.windows.net/unstable/index.html
pip install jax[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install -e .[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
You just manually select a version of jaxlib
that you want to install. And
then install jax
manually.
# download jaxlib from https://whls.blob.core.windows.net/unstable/index.html
pip install <jaxlib_whl>
pip install jax
For --use-deprecated legacy-resolver
, refers to
pip #9186 and
pip #9203.