Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

installation is not working #31

Open
dboonz opened this issue Oct 24, 2023 · 9 comments
Open

installation is not working #31

dboonz opened this issue Oct 24, 2023 · 9 comments

Comments

@dboonz
Copy link

dboonz commented Oct 24, 2023

Hi there,

First of all, thank you for supporting windows! I've used this build before with great success. However, at the moment it's not working, nor can I find a way to get an older version to work.

I'm trying to set up jax on a windows PC with conda, but the provided instructions do not work anymore. I also can't really get any other version to work.

I'm installing on a laptop, this is the output from nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 527.83       Driver Version: 527.83       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   50C    P0     9W /  30W |      0MiB /  2048MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

I tried:

conda create -n jaxtest python
conda activate jax_test
# install it
pip install jax[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
# installs numpy, etc.. 
# raises a warning:
#   WARNING: jax 0.4.19 does not provide the extra 'cuda111'

python -m jax
#  File "C:\ProgramData\Anaconda3\envs\jax_test\Lib\site-packages\jax\_src\lib\__init__.py", line 27, in <module>
#    raise ModuleNotFoundError(
# ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.

This might obviously not work for cuda 12.0. However, If i run it with

pip install jax[pip_cuda12] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
I get the same result.

I also tried this for python==3.11, python==3.10 or python==3.9. Same result.

When I just download a jaxlib it also does not work, sometimes I get a bit further but no computations can be done and I run into 'AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11''.

What should the python version be? And what would be the right command?

@dboonz
Copy link
Author

dboonz commented Oct 24, 2023

Instructions to get halfway (python 3.10):

 pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
# now, when importing we get
# ModuleNotFoundError: No module named 'ml_dtypes._ml_dtypes_ext'
# solve as per https://developer.apple.com/forums/thread/737890
pip install ml_dtypes==0.2.0
# now, numpy is not working
pip install -U numpy --force-reinstallation
# now, it can be run

However, when I now open python, I get:

from jax import numpy as jnp
a = jnp.zeros(5)
# external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.

So still not usable.

@cloudhan
Copy link
Owner

Use -f https://whls.blob.core.windows.net/unstable/index.html may not work because jax changed their extras options handling. You need to open the link and download the whl file manually, and install the compatible jax , not the latest version.

@dboonz
Copy link
Author

dboonz commented Oct 24, 2023

So, I downloaded the wheel, and installed

pip install jax==0.4.13

As far as I can see that should be compatible with jaxlib==0.4.11 (based on the source code)

If I run it, I still get

#jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

@cloudhan
Copy link
Owner

cloudhan commented Oct 24, 2023

Could you please set environment variable TF_CPP_MIN_LOG_LEVEL to 0?

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

import jax
jax.numpy.array([0])

There used to be some useful dll info, not sure how it goes now, tho. Might worth a try.

@dboonz
Copy link
Author

dboonz commented Oct 25, 2023

So, I reinstalled everything from scratch, just to make sure it's not because of some old environment that I tried:

conda env create -n jax
conda activate jax
conda install numpy scipy jupyter
# this should download the same file, just putting it in for reproducability
pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver 
pip install jax==0.4.13  
conda install nvidiatoolkit

And then I ran the script above. I get the following output:

2023-10-25 09:56:55.529123: I external/tsl/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-25 09:56:55.650774: I external/tsl/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-25 09:56:55.653221: I external/tsl/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-25 09:56:55.916782: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:435] TfrtCpuClient created.
2023-10-25 09:56:56.505294: I external/xla/xla/service/service.cc:168] XLA service 0x40f9d20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-25 09:56:56.505474: I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce MX550, Compute Capability 7.5
2023-10-25 09:56:56.506396: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:545] Using BFC allocator.
2023-10-25 09:56:56.508144: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 1610416128 bytes on device 0 for BFCAllocator.
2023-10-25 09:56:56.613734: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:438] TfrtCpuClient destroyed.

It starts with not finding cuda, but then it does seem to find it.

The full traceback is here:

XlaRuntimeError                           Traceback (most recent call last)
Cell In[3], line 1
----> 1 a = jax.numpy.zeros(512)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\numpy\lax_numpy.py:2153, in zeros(shape, dtype)
   2151 dtypes.check_user_dtype_supported(dtype, "zeros")
   2152 shape = canonicalize_shape(shape)
-> 2153 return lax.full(shape, 0, _jnp_dtype(dtype))

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:1206, in full(shape, fill_value, dtype)
   1204 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
   1205 fill_value = _convert_element_type(fill_value, dtype, weak_type)
-> 1206 return broadcast(fill_value, shape)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:768, in broadcast(operand, sizes)
    754 """Broadcasts an array, adding new leading dimensions
    755
    756 Args:
   (...)
    765   jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
    766 """
    767 dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
--> 768 return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:797, in broadcast_in_dim(operand, shape, broadcast_dimensions)
    795 else:
    796   dyn_shape, static_shape = [], shape  # type: ignore
--> 797 return broadcast_in_dim_p.bind(
    798     operand, *dyn_shape, shape=tuple(static_shape),
    799     broadcast_dimensions=tuple(broadcast_dimensions))

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:380, in Primitive.bind(self, *args, **params)
    377 def bind(self, *args, **params):
    378   assert (not config.jax_enable_checks or
    379           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380   return self.bind_with_trace(find_top_trace(args), args, params)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
    382 def bind_with_trace(self, trace, args, params):
--> 383   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    384   return map(full_lower, out) if self.multiple_results else full_lower(out)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:815, in EvalTrace.process_primitive(self, primitive, tracers, params)
    814 def process_primitive(self, primitive, tracers, params):
--> 815   return primitive.impl(*tracers, **params)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:132, in apply_primitive(prim, *args, **params)
    130 try:
    131   in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 132   compiled_fun = xla_primitive_callable(
    133       prim, in_avals, OrigShardings(in_shardings), **params)
    134 except pxla.DeviceAssignmentMismatchError as e:
    135   fails, = e.args

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:284, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    282   return f(*args, **kwargs)
    283 else:
--> 284   return cached(config._trace_context(), *args, **kwargs)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:277, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    275 @functools.lru_cache(max_size)
    276 def cached(_, *args, **kwargs):
--> 277   return f(*args, **kwargs)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:223, in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
    221     return out,
    222 donated_invars = (False,) * len(in_avals)
--> 223 compiled = _xla_callable_uncached(
    224     lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
    225     orig_in_shardings)
    226 if not prim.multiple_results:
    227   return lambda *args, **kw: compiled(*args, **kw)[0]

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:253, in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
    248 def _xla_callable_uncached(fun: lu.WrappedFun, name, donated_invars,
    249                            keep_unused, in_avals, orig_in_shardings):
    250   computation = sharded_lowering(
    251       fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
    252       lowering_platform=None)
--> 253   return computation.compile().unsafe_call

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2323, in MeshComputation.compile(self, compiler_options)
   2320   executable = MeshExecutable.from_trivial_jaxpr(
   2321       **self.compile_args)
   2322 else:
-> 2323   executable = UnloadedMeshExecutable.from_hlo(
   2324       self._name,
   2325       self._hlo,
   2326       **self.compile_args,
   2327       compiler_options=compiler_options)
   2328 if compiler_options is None:
   2329   self._executable = executable

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2645, in UnloadedMeshExecutable.from_hlo(***failed resolving arguments***)
   2642       mesh = i.mesh  # type: ignore
   2643       break
-> 2645 xla_executable, compile_options = _cached_compilation(
   2646     hlo, name, mesh, spmd_lowering,
   2647     tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
   2648     tuple(host_callbacks), backend, da, pmap_nreps,
   2649     compiler_options_keys, compiler_options_values)
   2651 if hasattr(backend, "compile_replicated"):
   2652   semantics_in_shardings = SemanticallyEqualShardings(in_shardings)  # type: ignore

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2555, in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
   2550   return None, compile_options
   2552 with dispatch.log_elapsed_time(
   2553     "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
   2554     fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2555   xla_executable = dispatch.compile_or_get_cached(
   2556       backend, computation, dev, compile_options, host_callbacks)
   2557 return xla_executable, compile_options

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:497, in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
    493 use_compilation_cache = (compilation_cache.is_initialized() and
    494                          backend.platform in supported_platforms)
    496 if not use_compilation_cache:
--> 497   return backend_compile(backend, computation, compile_options,
    498                          host_callbacks)
    500 cache_key = compilation_cache.get_cache_key(
    501     computation, devices, compile_options, backend)
    503 cached_executable = _cache_read(module_name, cache_key, compile_options,
    504                                 backend)

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\profiler.py:314, in annotate_function.<locals>.wrapper(*args, **kwargs)
    311 @wraps(func)
    312 def wrapper(*args, **kwargs):
    313   with TraceAnnotation(name, **decorator_kwargs):
--> 314     return func(*args, **kwargs)
    315   return wrapper

File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
    460   return backend.compile(built_c, compile_options=options,
    461                          host_callbacks=host_callbacks)
    462 # Some backends don't have `host_callbacks` option yet
    463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

@dboonz
Copy link
Author

dboonz commented Oct 25, 2023

Once we get it to work, I could make a conda environment file that hopefully works without having to go through the same options. Would you be interested in including that?

@cloudhan
Copy link
Owner

conda install nvidiatoolkit does it config the PATH for you? If not, you might need to manually config the PATH to include the dir of cuda libraries.

@dboonz
Copy link
Author

dboonz commented Oct 25, 2023

Yes, it does. I also just checked, looking with

import os
os.environ

Gives, among others CUDA_PATH': 'C:\\ProgramData\\Anaconda3\\envs\\jax'

I also installed cupy, that works without problems.

@cloudhan
Copy link
Owner

Then does cudnn*.dll exists under that dir?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants