-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
installation is not working #31
Comments
Instructions to get halfway (python 3.10): pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
# now, when importing we get
# ModuleNotFoundError: No module named 'ml_dtypes._ml_dtypes_ext'
# solve as per https://developer.apple.com/forums/thread/737890
pip install ml_dtypes==0.2.0
# now, numpy is not working
pip install -U numpy --force-reinstallation
# now, it can be run However, when I now open python, I get: from jax import numpy as jnp
a = jnp.zeros(5)
# external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found. So still not usable. |
Use |
So, I downloaded the wheel, and installed
As far as I can see that should be compatible with jaxlib==0.4.11 (based on the source code) If I run it, I still get
|
Could you please set environment variable import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
import jax
jax.numpy.array([0]) There used to be some useful dll info, not sure how it goes now, tho. Might worth a try. |
So, I reinstalled everything from scratch, just to make sure it's not because of some old environment that I tried: conda env create -n jax
conda activate jax
conda install numpy scipy jupyter
# this should download the same file, just putting it in for reproducability
pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install jax==0.4.13
conda install nvidiatoolkit And then I ran the script above. I get the following output:
It starts with not finding cuda, but then it does seem to find it. The full traceback is here: XlaRuntimeError Traceback (most recent call last)
Cell In[3], line 1
----> 1 a = jax.numpy.zeros(512)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\numpy\lax_numpy.py:2153, in zeros(shape, dtype)
2151 dtypes.check_user_dtype_supported(dtype, "zeros")
2152 shape = canonicalize_shape(shape)
-> 2153 return lax.full(shape, 0, _jnp_dtype(dtype))
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:1206, in full(shape, fill_value, dtype)
1204 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
1205 fill_value = _convert_element_type(fill_value, dtype, weak_type)
-> 1206 return broadcast(fill_value, shape)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:768, in broadcast(operand, sizes)
754 """Broadcasts an array, adding new leading dimensions
755
756 Args:
(...)
765 jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
766 """
767 dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
--> 768 return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:797, in broadcast_in_dim(operand, shape, broadcast_dimensions)
795 else:
796 dyn_shape, static_shape = [], shape # type: ignore
--> 797 return broadcast_in_dim_p.bind(
798 operand, *dyn_shape, shape=tuple(static_shape),
799 broadcast_dimensions=tuple(broadcast_dimensions))
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
378 assert (not config.jax_enable_checks or
379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380 return self.bind_with_trace(find_top_trace(args), args, params)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
382 def bind_with_trace(self, trace, args, params):
--> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params)
384 return map(full_lower, out) if self.multiple_results else full_lower(out)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:815, in EvalTrace.process_primitive(self, primitive, tracers, params)
814 def process_primitive(self, primitive, tracers, params):
--> 815 return primitive.impl(*tracers, **params)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:132, in apply_primitive(prim, *args, **params)
130 try:
131 in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 132 compiled_fun = xla_primitive_callable(
133 prim, in_avals, OrigShardings(in_shardings), **params)
134 except pxla.DeviceAssignmentMismatchError as e:
135 fails, = e.args
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:284, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
282 return f(*args, **kwargs)
283 else:
--> 284 return cached(config._trace_context(), *args, **kwargs)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:277, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
275 @functools.lru_cache(max_size)
276 def cached(_, *args, **kwargs):
--> 277 return f(*args, **kwargs)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:223, in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
221 return out,
222 donated_invars = (False,) * len(in_avals)
--> 223 compiled = _xla_callable_uncached(
224 lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
225 orig_in_shardings)
226 if not prim.multiple_results:
227 return lambda *args, **kw: compiled(*args, **kw)[0]
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:253, in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
248 def _xla_callable_uncached(fun: lu.WrappedFun, name, donated_invars,
249 keep_unused, in_avals, orig_in_shardings):
250 computation = sharded_lowering(
251 fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
252 lowering_platform=None)
--> 253 return computation.compile().unsafe_call
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2323, in MeshComputation.compile(self, compiler_options)
2320 executable = MeshExecutable.from_trivial_jaxpr(
2321 **self.compile_args)
2322 else:
-> 2323 executable = UnloadedMeshExecutable.from_hlo(
2324 self._name,
2325 self._hlo,
2326 **self.compile_args,
2327 compiler_options=compiler_options)
2328 if compiler_options is None:
2329 self._executable = executable
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2645, in UnloadedMeshExecutable.from_hlo(***failed resolving arguments***)
2642 mesh = i.mesh # type: ignore
2643 break
-> 2645 xla_executable, compile_options = _cached_compilation(
2646 hlo, name, mesh, spmd_lowering,
2647 tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
2648 tuple(host_callbacks), backend, da, pmap_nreps,
2649 compiler_options_keys, compiler_options_values)
2651 if hasattr(backend, "compile_replicated"):
2652 semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2555, in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
2550 return None, compile_options
2552 with dispatch.log_elapsed_time(
2553 "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
2554 fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2555 xla_executable = dispatch.compile_or_get_cached(
2556 backend, computation, dev, compile_options, host_callbacks)
2557 return xla_executable, compile_options
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:497, in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
493 use_compilation_cache = (compilation_cache.is_initialized() and
494 backend.platform in supported_platforms)
496 if not use_compilation_cache:
--> 497 return backend_compile(backend, computation, compile_options,
498 host_callbacks)
500 cache_key = compilation_cache.get_cache_key(
501 computation, devices, compile_options, backend)
503 cached_executable = _cache_read(module_name, cache_key, compile_options,
504 backend)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\profiler.py:314, in annotate_function.<locals>.wrapper(*args, **kwargs)
311 @wraps(func)
312 def wrapper(*args, **kwargs):
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
460 return backend.compile(built_c, compile_options=options,
461 host_callbacks=host_callbacks)
462 # Some backends don't have `host_callbacks` option yet
463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. |
Once we get it to work, I could make a conda environment file that hopefully works without having to go through the same options. Would you be interested in including that? |
|
Yes, it does. I also just checked, looking with
Gives, among others I also installed cupy, that works without problems. |
Then does cudnn*.dll exists under that dir? |
Hi there,
First of all, thank you for supporting windows! I've used this build before with great success. However, at the moment it's not working, nor can I find a way to get an older version to work.
I'm trying to set up jax on a windows PC with conda, but the provided instructions do not work anymore. I also can't really get any other version to work.
I'm installing on a laptop, this is the output from nvidia-smi:
I tried:
This might obviously not work for cuda 12.0. However, If i run it with
pip install jax[pip_cuda12] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
I get the same result.
I also tried this for
python==3.11, python==3.10
orpython==3.9
. Same result.When I just download a jaxlib it also does not work, sometimes I get a bit further but no computations can be done and I run into 'AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11''.
What should the python version be? And what would be the right command?
The text was updated successfully, but these errors were encountered: