Install
Install
It is often easiest to install jax and CUDA via conda.
Installing locally can lead to challenges.
https://github.com/google/jax/discussions/6843
Note: On Ubuntu, use the 545 NVIDIA driver to work with CUDA 12.3. Install JAX first using pip and cuda12_pip option.
Then use conda to install CUDA.