diff options
author | CoprDistGit <infra@openeuler.org> | 2023-04-10 08:23:46 +0000 |
---|---|---|
committer | CoprDistGit <infra@openeuler.org> | 2023-04-10 08:23:46 +0000 |
commit | 9428332eb11a19a4c333b3ae479952174f949752 (patch) | |
tree | b689aadcfd73aac979a9c701fb6a4af7fb262ec9 | |
parent | 984782762feb97ecc8052c015c73d0a2e7ea1769 (diff) |
automatic import of python-jax
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | python-jax.spec | 1920 | ||||
-rw-r--r-- | sources | 1 |
3 files changed, 1922 insertions, 0 deletions
@@ -0,0 +1 @@ +/jax-0.4.8.tar.gz diff --git a/python-jax.spec b/python-jax.spec new file mode 100644 index 0000000..d6931cf --- /dev/null +++ b/python-jax.spec @@ -0,0 +1,1920 @@ +%global _empty_manifest_terminate_build 0 +Name: python-jax +Version: 0.4.8 +Release: 1 +Summary: Differentiate, compile, and transform Numpy code. +License: Apache-2.0 +URL: https://github.com/google/jax +Source0: https://mirrors.nju.edu.cn/pypi/web/packages/fe/58/1641614c17fcd7293d250c2cad48011baa1ecef4f109ce2ea027aa8e4898/jax-0.4.8.tar.gz +BuildArch: noarch + + +%description +<div align="center"> +<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img> +</div> + +# JAX: Autograd and XLA + + + + +[**Quickstart**](#quickstart-colab-in-the-cloud) +| [**Transformations**](#transformations) +| [**Install guide**](#installation) +| [**Neural net libraries**](#neural-network-libraries) +| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) +| [**Reference docs**](https://jax.readthedocs.io/en/latest/) + + +## What is JAX? + +JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), +brought together for high-performance machine learning research. + +With its updated version of [Autograd](https://github.com/hips/autograd), +JAX can automatically differentiate native +Python and NumPy functions. It can differentiate through loops, branches, +recursion, and closures, and it can take derivatives of derivatives of +derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) +via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, +and the two can be composed arbitrarily to any order. + +What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla) +to compile and run your NumPy programs on GPUs and TPUs. Compilation happens +under the hood by default, with library calls getting just-in-time compiled and +executed. But JAX also lets you just-in-time compile your own Python functions +into XLA-optimized kernels using a one-function API, +[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be +composed arbitrarily, so you can express sophisticated algorithms and get +maximal performance without leaving Python. You can even program multiple GPUs +or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and +differentiate through the whole thing. + +Dig a little deeper, and you'll see that JAX is really an extensible system for +[composable function transformations](#transformations). Both +[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit) +are instances of such transformations. Others are +[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and +[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) +parallel programming of multiple accelerators, with more to come. + +This is a research project, not an official Google product. Expect bugs and +[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Please help by trying it out, [reporting +bugs](https://github.com/google/jax/issues), and letting us know what you +think! + +```python +import jax.numpy as jnp +from jax import grad, jit, vmap + +def predict(params, inputs): + for W, b in params: + outputs = jnp.dot(inputs, W) + b + inputs = jnp.tanh(outputs) # inputs to the next layer + return outputs # no activation on last layer + +def loss(params, inputs, targets): + preds = predict(params, inputs) + return jnp.sum((preds - targets)**2) + +grad_loss = jit(grad(loss)) # compiled gradient evaluation function +perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads +``` + +### Contents +* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) +* [Transformations](#transformations) +* [Current gotchas](#current-gotchas) +* [Installation](#installation) +* [Neural net libraries](#neural-network-libraries) +* [Citing JAX](#citing-jax) +* [Reference documentation](#reference-documentation) + +## Quickstart: Colab in the Cloud +Jump right in using a notebook in your browser, connected to a Google Cloud GPU. +Here are some starter notebooks: +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) + +**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU +Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs). + +For a deeper dive into JAX: +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- See the [full list of +notebooks](https://github.com/google/jax/tree/main/docs/notebooks). + +You can also take a look at [the mini-libraries in +`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/example_libraries/README.md), +like [`stax` for building neural +networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax) +and [`optimizers` for first-order stochastic +optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization), +or the [examples](https://github.com/google/jax/tree/main/examples). + +## Transformations + +At its core, JAX is an extensible system for transforming numerical functions. +Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and +`pmap`. + +### Automatic differentiation with `grad` + +JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). +The most popular function is +[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +for reverse-mode gradients: + +```python +from jax import grad +import jax.numpy as jnp + +def tanh(x): # Define a function + y = jnp.exp(-2.0 * x) + return (1.0 - y) / (1.0 + y) + +grad_tanh = grad(tanh) # Obtain its gradient function +print(grad_tanh(1.0)) # Evaluate it at x = 1.0 +# prints 0.4199743 +``` + +You can differentiate to any order with `grad`. + +```python +print(grad(grad(grad(tanh)))(1.0)) +# prints 0.62162673 +``` + +For more advanced autodiff, you can use +[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +reverse-mode vector-Jacobian products and +[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for +forward-mode Jacobian-vector products. The two can be composed arbitrarily with +one another, and with other JAX transformations. Here's one way to compose those +to make a function that efficiently computes [full Hessian +matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): + +```python +from jax import jit, jacfwd, jacrev + +def hessian(fun): + return jit(jacfwd(jacrev(fun))) +``` + +As with [Autograd](https://github.com/hips/autograd), you're free to use +differentiation with Python control structures: + +```python +def abs_val(x): + if x > 0: + return x + else: + return -x + +abs_val_grad = grad(abs_val) +print(abs_val_grad(1.0)) # prints 1.0 +print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) +``` + +See the [reference docs on automatic +differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +and the [JAX Autodiff +Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +for more. + +### Compilation with `jit` + +You can use XLA to compile your functions end-to-end with +[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +used either as an `@jit` decorator or as a higher-order function. + +```python +import jax.numpy as jnp +from jax import jit + +def slow_f(x): + # Element-wise ops see a large benefit from fusion + return x * x + x * 2.0 + +x = jnp.ones((5000, 5000)) +fast_f = jit(slow_f) +%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X +%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX) +``` + +You can mix `jit` and `grad` and any other JAX transformation however you like. + +Using `jit` puts constraints on the kind of Python control flow +the function can use; see +the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT) +for more. + +### Auto-vectorization with `vmap` + +[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +the vectorizing map. +It has the familiar semantics of mapping a function along array axes, but +instead of keeping the loop on the outside, it pushes the loop down into a +function’s primitive operations for better performance. + +Using `vmap` can save you from having to carry around batch dimensions in your +code. For example, consider this simple *unbatched* neural network prediction +function: + +```python +def predict(params, input_vec): + assert input_vec.ndim == 1 + activations = input_vec + for W, b in params: + outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side! + activations = jnp.tanh(outputs) # inputs to the next layer + return outputs # no activation on last layer +``` + +We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the +left side of `activations`, but we’ve written this particular prediction function to +apply only to single input vectors. If we wanted to apply this function to a +batch of inputs at once, semantically we could just write + +```python +from functools import partial +predictions = jnp.stack(list(map(partial(predict, params), input_batch))) +``` + +But pushing one example through the network at a time would be slow! It’s better +to vectorize the computation, so that at every layer we’re doing matrix-matrix +multiplication rather than matrix-vector multiplication. + +The `vmap` function does that transformation for us. That is, if we write + +```python +from jax import vmap +predictions = vmap(partial(predict, params))(input_batch) +# or, alternatively +predictions = vmap(predict, in_axes=(None, 0))(params, input_batch) +``` + +then the `vmap` function will push the outer loop inside the function, and our +machine will end up executing matrix-matrix multiplications exactly as if we’d +done the batching by hand. + +It’s easy enough to manually batch a simple neural network without `vmap`, but +in other cases manual vectorization can be impractical or impossible. Take the +problem of efficiently computing per-example gradients: that is, for a fixed set +of parameters, we want to compute the gradient of our loss function evaluated +separately at each example in a batch. With `vmap`, it’s easy: + +```python +per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets) +``` + +Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other +JAX transformation! We use `vmap` with both forward- and reverse-mode automatic +differentiation for fast Jacobian and Hessian matrix calculations in +`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. + +### SPMD programming with `pmap` + +For parallel programming of multiple accelerators, like multiple GPUs, use +[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +With `pmap` you write single-program multiple-data (SPMD) programs, including +fast parallel collective communication operations. Applying `pmap` will mean +that the function you write is compiled by XLA (similarly to `jit`), then +replicated and executed in parallel across devices. + +Here's an example on an 8-GPU machine: + +```python +from jax import random, pmap +import jax.numpy as jnp + +# Create 8 random 5000 x 6000 matrices, one per GPU +keys = random.split(random.PRNGKey(0), 8) +mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) + +# Run a local matmul on each device in parallel (no data transfer) +result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) + +# Compute the mean on each device in parallel and print the result +print(pmap(jnp.mean)(result)) +# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157] +``` + +In addition to expressing pure maps, you can use fast [collective communication +operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +between devices: + +```python +from functools import partial +from jax import lax + +@partial(pmap, axis_name='i') +def normalize(x): + return x / lax.psum(x, 'i') + +print(normalize(jnp.arange(4.))) +# prints [0. 0.16666667 0.33333334 0.5 ] +``` + +You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more +sophisticated communication patterns. + +It all composes, so you're free to differentiate through parallel computations: + +```python +from jax import grad + +@pmap +def f(x): + y = jnp.sin(x) + @pmap + def g(z): + return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum() + return grad(lambda w: jnp.sum(g(w)))(x) + +print(f(x)) +# [[ 0. , -0.7170853 ], +# [-3.1085174 , -0.4824318 ], +# [10.366636 , 13.135289 ], +# [ 0.22163185, -0.52112055]] + +print(grad(lambda x: jnp.sum(f(x)))(x)) +# [[ -3.2369726, -1.6356447], +# [ 4.7572474, 11.606951 ], +# [-98.524414 , 42.76499 ], +# [ -1.6007166, -1.2568436]] +``` + +When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the +backward pass of the computation is parallelized just like the forward pass. + +See the [SPMD +Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +and the [SPMD MNIST classifier from scratch +example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) +for more. + +## Current gotchas + +For a more thorough survey of current gotchas, with examples and explanations, +we highly recommend reading the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Some standouts: + +1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. +1. [In-place mutating updates of + arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. +1. [Random numbers are + different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md). +1. If you're looking for [convolution + operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), + they're in the `jax.lax` package. +1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and + [to enable + double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at + startup (or set the environment variable `JAX_ENABLE_X64=True`). + On TPU, JAX uses 32-bit values by default for everything _except_ internal + temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`. + Those ops have a `precision` parameter which can be used to simulate + true 32-bit, with a cost of possibly slower runtime. +1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars + and NumPy types aren't preserved, namely `np.add(1, np.array([2], + np.float32)).dtype` is `float64` rather than `float32`. +1. Some transformations, like `jit`, [constrain how you can use Python control + flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). + You'll always get loud errors if something goes wrong. You might have to use + [`jit`'s `static_argnums` + parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + [structured control flow + primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + like + [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + or just use `jit` on smaller subfunctions. + +## Installation + +JAX is written in pure Python, but it depends on XLA, which needs to be +installed as the `jaxlib` package. Use the following instructions to install a +binary package with `pip` or `conda`, or to [build JAX from +source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and +macOS (10.12 or later) platforms. + +Windows users can use JAX on CPU and GPU via the [Windows Subsystem for +Linux](https://docs.microsoft.com/en-us/windows/wsl/about). In addition, there +is some initial community-driven native Windows support, but since it is still +somewhat immature, there are no official binary releases and it must be [built +from source for Windows](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows). +For an unofficial discussion of native Windows builds, see also the [Issue #5795 +thread](https://github.com/google/jax/issues/5795). + +### pip installation: CPU + +To install a CPU-only version of JAX, which might be useful for doing local +development on a laptop, you can run + +```bash +pip install --upgrade pip +pip install --upgrade "jax[cpu]" +``` + +On Linux, it is often necessary to first update `pip` to a version that supports +`manylinux2014` wheels. Also note that for Linux, we currently release wheels for `x86_64` architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to `jaxlib` not being installed alongside `jax`, although `jax` may successfully install (but fail at runtime). +**These `pip` installations do not work with Windows, and may fail silently; see +[above](#installation).** + +### pip installation: GPU (CUDA, installed via pip, easier) + +There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN +installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend +installing CUDA and CUDNN using the pip wheels, since it is much easier! + +You must first install the NVIDIA driver. We +recommend installing the newest driver available from NVIDIA, but the driver +must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux. + +```bash +pip install --upgrade pip + +# CUDA 12 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# CUDA 11 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +### pip installation: GPU (CUDA, installed locally, harder) + +If you prefer to use a preinstalled copy of CUDA, you must first +install [CUDA](https://developer.nvidia.com/cuda-downloads) and +[CuDNN](https://developer.nvidia.com/CUDNN). + +JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other +combinations of operating system and architecture are possible, but require +[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +You should use an NVIDIA driver version that is at least as new as your +[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). +If you need to use an newer CUDA toolkit with an older driver, for example +on a cluster where you cannot update the NVIDIA driver easily, you may be +able to use the +[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) +that NVIDIA provides for this purpose. + +JAX currently ships three CUDA wheel variants: +* CUDA 12.0 and CuDNN 8.8. +* CUDA 11.8 and CuDNN 8.6. +* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued + with jax 0.4.8. + +You may use a JAX wheel provided the major version of your CUDA and CuDNN +installation matches, and the minor version is at least as new as the version +JAX expects. For example, you would be able to use the CUDA 12.0 wheel with +CUDA 12.1 and CuDNN 8.9. + +Your CUDA installation must also be new enough to support your GPU. If you have +an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU, +you must use CUDA 11.8 or newer. + + +To install, run + +```bash +pip install --upgrade pip + +# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated). +pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +**These `pip` installations do not work with Windows, and may fail silently; see +[above](#installation).** + +You can find your CUDA version with the command: + +```bash +nvcc --version +``` + +Some GPU functionality expects the CUDA installation to be at +`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number +(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either +create a symlink: + +```bash +sudo ln -s /path/to/cuda /usr/local/cuda-X.X +``` + +Please let us know on [the issue tracker](https://github.com/google/jax/issues) +if you run into any errors or problems with the prebuilt wheels. + +### pip installation: Google Cloud TPU +JAX also provides pre-built wheels for +[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm). +To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run +the following in your cloud TPU VM: +```bash +pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +### pip installation: Colab TPU +Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ. +The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU: +```python +import jax.tools.colab_tpu +jax.tools.colab_tpu.setup_tpu() +``` +Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer. +If you need to re-install JAX on a Colab TPU runtime, you can use the following command: +``` +!pip install jax<=0.3.25 jaxlib<=0.3.25 +``` + +### Conda installation + +There is a community-supported Conda build of `jax`. To install using `conda`, +simply run + +```bash +conda install jax -c conda-forge +``` + +To install on a machine with an NVIDIA GPU, run +```bash +conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +``` + +Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which +JAX requires. You must therefore either install the `cuda-nvcc` package from +the `nvidia` channel, or install CUDA on your machine separately so that `ptxas` +is in your path. The channel order above is important (`conda-forge` before +`nvidia`). + +If you would like to override which release of CUDA is used by JAX, or to +install the CUDA build on a machine without GPUs, follow the instructions in the +[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch) +section of the `conda-forge` website. + +See the `conda-forge` +[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and +[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories +for more details. + +### Building JAX from source +See [Building JAX from +source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +## Neural network libraries + +Multiple Google research groups develop and share libraries for training neural +networks in JAX. If you want a fully featured library for neural network +training with examples and how-to guides, try +[Flax](https://github.com/google/flax). + +In addition, DeepMind has open-sourced an [ecosystem of libraries around +JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) +including [Haiku](https://github.com/deepmind/dm-haiku) for neural network +modules, [Optax](https://github.com/deepmind/optax) for gradient processing and +optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and +[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch +the NeurIPS 2020 JAX Ecosystem at DeepMind talk +[here](https://www.youtube.com/watch?v=iDxJxIyzSiM)) + +## Citing JAX + +To cite this repository: + +``` +@software{jax2018github, + author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, + title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, + url = {http://github.com/google/jax}, + version = {0.3.13}, + year = {2018}, +} +``` + +In the above bibtex entry, names are in alphabetical order, the version number +is intended to be that from [jax/version.py](../main/jax/version.py), and +the year corresponds to the project's open-source release. + +A nascent version of JAX, supporting only automatic differentiation and +compilation to XLA, was described in a [paper that appeared at SysML +2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on +covering JAX's ideas and capabilities in a more comprehensive and up-to-date +paper. + +## Reference documentation + +For details about the JAX API, see the +[reference documentation](https://jax.readthedocs.io/). + +For getting started as a JAX developer, see the +[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). + + +%package -n python3-jax +Summary: Differentiate, compile, and transform Numpy code. +Provides: python-jax +BuildRequires: python3-devel +BuildRequires: python3-setuptools +BuildRequires: python3-pip +%description -n python3-jax +<div align="center"> +<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img> +</div> + +# JAX: Autograd and XLA + + + + +[**Quickstart**](#quickstart-colab-in-the-cloud) +| [**Transformations**](#transformations) +| [**Install guide**](#installation) +| [**Neural net libraries**](#neural-network-libraries) +| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) +| [**Reference docs**](https://jax.readthedocs.io/en/latest/) + + +## What is JAX? + +JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), +brought together for high-performance machine learning research. + +With its updated version of [Autograd](https://github.com/hips/autograd), +JAX can automatically differentiate native +Python and NumPy functions. It can differentiate through loops, branches, +recursion, and closures, and it can take derivatives of derivatives of +derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) +via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, +and the two can be composed arbitrarily to any order. + +What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla) +to compile and run your NumPy programs on GPUs and TPUs. Compilation happens +under the hood by default, with library calls getting just-in-time compiled and +executed. But JAX also lets you just-in-time compile your own Python functions +into XLA-optimized kernels using a one-function API, +[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be +composed arbitrarily, so you can express sophisticated algorithms and get +maximal performance without leaving Python. You can even program multiple GPUs +or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and +differentiate through the whole thing. + +Dig a little deeper, and you'll see that JAX is really an extensible system for +[composable function transformations](#transformations). Both +[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit) +are instances of such transformations. Others are +[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and +[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) +parallel programming of multiple accelerators, with more to come. + +This is a research project, not an official Google product. Expect bugs and +[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Please help by trying it out, [reporting +bugs](https://github.com/google/jax/issues), and letting us know what you +think! + +```python +import jax.numpy as jnp +from jax import grad, jit, vmap + +def predict(params, inputs): + for W, b in params: + outputs = jnp.dot(inputs, W) + b + inputs = jnp.tanh(outputs) # inputs to the next layer + return outputs # no activation on last layer + +def loss(params, inputs, targets): + preds = predict(params, inputs) + return jnp.sum((preds - targets)**2) + +grad_loss = jit(grad(loss)) # compiled gradient evaluation function +perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads +``` + +### Contents +* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) +* [Transformations](#transformations) +* [Current gotchas](#current-gotchas) +* [Installation](#installation) +* [Neural net libraries](#neural-network-libraries) +* [Citing JAX](#citing-jax) +* [Reference documentation](#reference-documentation) + +## Quickstart: Colab in the Cloud +Jump right in using a notebook in your browser, connected to a Google Cloud GPU. +Here are some starter notebooks: +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) + +**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU +Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs). + +For a deeper dive into JAX: +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- See the [full list of +notebooks](https://github.com/google/jax/tree/main/docs/notebooks). + +You can also take a look at [the mini-libraries in +`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/example_libraries/README.md), +like [`stax` for building neural +networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax) +and [`optimizers` for first-order stochastic +optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization), +or the [examples](https://github.com/google/jax/tree/main/examples). + +## Transformations + +At its core, JAX is an extensible system for transforming numerical functions. +Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and +`pmap`. + +### Automatic differentiation with `grad` + +JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). +The most popular function is +[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +for reverse-mode gradients: + +```python +from jax import grad +import jax.numpy as jnp + +def tanh(x): # Define a function + y = jnp.exp(-2.0 * x) + return (1.0 - y) / (1.0 + y) + +grad_tanh = grad(tanh) # Obtain its gradient function +print(grad_tanh(1.0)) # Evaluate it at x = 1.0 +# prints 0.4199743 +``` + +You can differentiate to any order with `grad`. + +```python +print(grad(grad(grad(tanh)))(1.0)) +# prints 0.62162673 +``` + +For more advanced autodiff, you can use +[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +reverse-mode vector-Jacobian products and +[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for +forward-mode Jacobian-vector products. The two can be composed arbitrarily with +one another, and with other JAX transformations. Here's one way to compose those +to make a function that efficiently computes [full Hessian +matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): + +```python +from jax import jit, jacfwd, jacrev + +def hessian(fun): + return jit(jacfwd(jacrev(fun))) +``` + +As with [Autograd](https://github.com/hips/autograd), you're free to use +differentiation with Python control structures: + +```python +def abs_val(x): + if x > 0: + return x + else: + return -x + +abs_val_grad = grad(abs_val) +print(abs_val_grad(1.0)) # prints 1.0 +print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) +``` + +See the [reference docs on automatic +differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +and the [JAX Autodiff +Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +for more. + +### Compilation with `jit` + +You can use XLA to compile your functions end-to-end with +[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +used either as an `@jit` decorator or as a higher-order function. + +```python +import jax.numpy as jnp +from jax import jit + +def slow_f(x): + # Element-wise ops see a large benefit from fusion + return x * x + x * 2.0 + +x = jnp.ones((5000, 5000)) +fast_f = jit(slow_f) +%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X +%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX) +``` + +You can mix `jit` and `grad` and any other JAX transformation however you like. + +Using `jit` puts constraints on the kind of Python control flow +the function can use; see +the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT) +for more. + +### Auto-vectorization with `vmap` + +[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +the vectorizing map. +It has the familiar semantics of mapping a function along array axes, but +instead of keeping the loop on the outside, it pushes the loop down into a +function’s primitive operations for better performance. + +Using `vmap` can save you from having to carry around batch dimensions in your +code. For example, consider this simple *unbatched* neural network prediction +function: + +```python +def predict(params, input_vec): + assert input_vec.ndim == 1 + activations = input_vec + for W, b in params: + outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side! + activations = jnp.tanh(outputs) # inputs to the next layer + return outputs # no activation on last layer +``` + +We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the +left side of `activations`, but we’ve written this particular prediction function to +apply only to single input vectors. If we wanted to apply this function to a +batch of inputs at once, semantically we could just write + +```python +from functools import partial +predictions = jnp.stack(list(map(partial(predict, params), input_batch))) +``` + +But pushing one example through the network at a time would be slow! It’s better +to vectorize the computation, so that at every layer we’re doing matrix-matrix +multiplication rather than matrix-vector multiplication. + +The `vmap` function does that transformation for us. That is, if we write + +```python +from jax import vmap +predictions = vmap(partial(predict, params))(input_batch) +# or, alternatively +predictions = vmap(predict, in_axes=(None, 0))(params, input_batch) +``` + +then the `vmap` function will push the outer loop inside the function, and our +machine will end up executing matrix-matrix multiplications exactly as if we’d +done the batching by hand. + +It’s easy enough to manually batch a simple neural network without `vmap`, but +in other cases manual vectorization can be impractical or impossible. Take the +problem of efficiently computing per-example gradients: that is, for a fixed set +of parameters, we want to compute the gradient of our loss function evaluated +separately at each example in a batch. With `vmap`, it’s easy: + +```python +per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets) +``` + +Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other +JAX transformation! We use `vmap` with both forward- and reverse-mode automatic +differentiation for fast Jacobian and Hessian matrix calculations in +`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. + +### SPMD programming with `pmap` + +For parallel programming of multiple accelerators, like multiple GPUs, use +[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +With `pmap` you write single-program multiple-data (SPMD) programs, including +fast parallel collective communication operations. Applying `pmap` will mean +that the function you write is compiled by XLA (similarly to `jit`), then +replicated and executed in parallel across devices. + +Here's an example on an 8-GPU machine: + +```python +from jax import random, pmap +import jax.numpy as jnp + +# Create 8 random 5000 x 6000 matrices, one per GPU +keys = random.split(random.PRNGKey(0), 8) +mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) + +# Run a local matmul on each device in parallel (no data transfer) +result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) + +# Compute the mean on each device in parallel and print the result +print(pmap(jnp.mean)(result)) +# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157] +``` + +In addition to expressing pure maps, you can use fast [collective communication +operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +between devices: + +```python +from functools import partial +from jax import lax + +@partial(pmap, axis_name='i') +def normalize(x): + return x / lax.psum(x, 'i') + +print(normalize(jnp.arange(4.))) +# prints [0. 0.16666667 0.33333334 0.5 ] +``` + +You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more +sophisticated communication patterns. + +It all composes, so you're free to differentiate through parallel computations: + +```python +from jax import grad + +@pmap +def f(x): + y = jnp.sin(x) + @pmap + def g(z): + return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum() + return grad(lambda w: jnp.sum(g(w)))(x) + +print(f(x)) +# [[ 0. , -0.7170853 ], +# [-3.1085174 , -0.4824318 ], +# [10.366636 , 13.135289 ], +# [ 0.22163185, -0.52112055]] + +print(grad(lambda x: jnp.sum(f(x)))(x)) +# [[ -3.2369726, -1.6356447], +# [ 4.7572474, 11.606951 ], +# [-98.524414 , 42.76499 ], +# [ -1.6007166, -1.2568436]] +``` + +When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the +backward pass of the computation is parallelized just like the forward pass. + +See the [SPMD +Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +and the [SPMD MNIST classifier from scratch +example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) +for more. + +## Current gotchas + +For a more thorough survey of current gotchas, with examples and explanations, +we highly recommend reading the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Some standouts: + +1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. +1. [In-place mutating updates of + arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. +1. [Random numbers are + different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md). +1. If you're looking for [convolution + operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), + they're in the `jax.lax` package. +1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and + [to enable + double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at + startup (or set the environment variable `JAX_ENABLE_X64=True`). + On TPU, JAX uses 32-bit values by default for everything _except_ internal + temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`. + Those ops have a `precision` parameter which can be used to simulate + true 32-bit, with a cost of possibly slower runtime. +1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars + and NumPy types aren't preserved, namely `np.add(1, np.array([2], + np.float32)).dtype` is `float64` rather than `float32`. +1. Some transformations, like `jit`, [constrain how you can use Python control + flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). + You'll always get loud errors if something goes wrong. You might have to use + [`jit`'s `static_argnums` + parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + [structured control flow + primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + like + [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + or just use `jit` on smaller subfunctions. + +## Installation + +JAX is written in pure Python, but it depends on XLA, which needs to be +installed as the `jaxlib` package. Use the following instructions to install a +binary package with `pip` or `conda`, or to [build JAX from +source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and +macOS (10.12 or later) platforms. + +Windows users can use JAX on CPU and GPU via the [Windows Subsystem for +Linux](https://docs.microsoft.com/en-us/windows/wsl/about). In addition, there +is some initial community-driven native Windows support, but since it is still +somewhat immature, there are no official binary releases and it must be [built +from source for Windows](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows). +For an unofficial discussion of native Windows builds, see also the [Issue #5795 +thread](https://github.com/google/jax/issues/5795). + +### pip installation: CPU + +To install a CPU-only version of JAX, which might be useful for doing local +development on a laptop, you can run + +```bash +pip install --upgrade pip +pip install --upgrade "jax[cpu]" +``` + +On Linux, it is often necessary to first update `pip` to a version that supports +`manylinux2014` wheels. Also note that for Linux, we currently release wheels for `x86_64` architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to `jaxlib` not being installed alongside `jax`, although `jax` may successfully install (but fail at runtime). +**These `pip` installations do not work with Windows, and may fail silently; see +[above](#installation).** + +### pip installation: GPU (CUDA, installed via pip, easier) + +There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN +installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend +installing CUDA and CUDNN using the pip wheels, since it is much easier! + +You must first install the NVIDIA driver. We +recommend installing the newest driver available from NVIDIA, but the driver +must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux. + +```bash +pip install --upgrade pip + +# CUDA 12 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# CUDA 11 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +### pip installation: GPU (CUDA, installed locally, harder) + +If you prefer to use a preinstalled copy of CUDA, you must first +install [CUDA](https://developer.nvidia.com/cuda-downloads) and +[CuDNN](https://developer.nvidia.com/CUDNN). + +JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other +combinations of operating system and architecture are possible, but require +[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +You should use an NVIDIA driver version that is at least as new as your +[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). +If you need to use an newer CUDA toolkit with an older driver, for example +on a cluster where you cannot update the NVIDIA driver easily, you may be +able to use the +[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) +that NVIDIA provides for this purpose. + +JAX currently ships three CUDA wheel variants: +* CUDA 12.0 and CuDNN 8.8. +* CUDA 11.8 and CuDNN 8.6. +* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued + with jax 0.4.8. + +You may use a JAX wheel provided the major version of your CUDA and CuDNN +installation matches, and the minor version is at least as new as the version +JAX expects. For example, you would be able to use the CUDA 12.0 wheel with +CUDA 12.1 and CuDNN 8.9. + +Your CUDA installation must also be new enough to support your GPU. If you have +an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU, +you must use CUDA 11.8 or newer. + + +To install, run + +```bash +pip install --upgrade pip + +# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated). +pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +**These `pip` installations do not work with Windows, and may fail silently; see +[above](#installation).** + +You can find your CUDA version with the command: + +```bash +nvcc --version +``` + +Some GPU functionality expects the CUDA installation to be at +`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number +(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either +create a symlink: + +```bash +sudo ln -s /path/to/cuda /usr/local/cuda-X.X +``` + +Please let us know on [the issue tracker](https://github.com/google/jax/issues) +if you run into any errors or problems with the prebuilt wheels. + +### pip installation: Google Cloud TPU +JAX also provides pre-built wheels for +[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm). +To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run +the following in your cloud TPU VM: +```bash +pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +### pip installation: Colab TPU +Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ. +The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU: +```python +import jax.tools.colab_tpu +jax.tools.colab_tpu.setup_tpu() +``` +Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer. +If you need to re-install JAX on a Colab TPU runtime, you can use the following command: +``` +!pip install jax<=0.3.25 jaxlib<=0.3.25 +``` + +### Conda installation + +There is a community-supported Conda build of `jax`. To install using `conda`, +simply run + +```bash +conda install jax -c conda-forge +``` + +To install on a machine with an NVIDIA GPU, run +```bash +conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +``` + +Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which +JAX requires. You must therefore either install the `cuda-nvcc` package from +the `nvidia` channel, or install CUDA on your machine separately so that `ptxas` +is in your path. The channel order above is important (`conda-forge` before +`nvidia`). + +If you would like to override which release of CUDA is used by JAX, or to +install the CUDA build on a machine without GPUs, follow the instructions in the +[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch) +section of the `conda-forge` website. + +See the `conda-forge` +[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and +[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories +for more details. + +### Building JAX from source +See [Building JAX from +source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +## Neural network libraries + +Multiple Google research groups develop and share libraries for training neural +networks in JAX. If you want a fully featured library for neural network +training with examples and how-to guides, try +[Flax](https://github.com/google/flax). + +In addition, DeepMind has open-sourced an [ecosystem of libraries around +JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) +including [Haiku](https://github.com/deepmind/dm-haiku) for neural network +modules, [Optax](https://github.com/deepmind/optax) for gradient processing and +optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and +[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch +the NeurIPS 2020 JAX Ecosystem at DeepMind talk +[here](https://www.youtube.com/watch?v=iDxJxIyzSiM)) + +## Citing JAX + +To cite this repository: + +``` +@software{jax2018github, + author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, + title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, + url = {http://github.com/google/jax}, + version = {0.3.13}, + year = {2018}, +} +``` + +In the above bibtex entry, names are in alphabetical order, the version number +is intended to be that from [jax/version.py](../main/jax/version.py), and +the year corresponds to the project's open-source release. + +A nascent version of JAX, supporting only automatic differentiation and +compilation to XLA, was described in a [paper that appeared at SysML +2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on +covering JAX's ideas and capabilities in a more comprehensive and up-to-date +paper. + +## Reference documentation + +For details about the JAX API, see the +[reference documentation](https://jax.readthedocs.io/). + +For getting started as a JAX developer, see the +[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). + + +%package help +Summary: Development documents and examples for jax +Provides: python3-jax-doc +%description help +<div align="center"> +<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img> +</div> + +# JAX: Autograd and XLA + + + + +[**Quickstart**](#quickstart-colab-in-the-cloud) +| [**Transformations**](#transformations) +| [**Install guide**](#installation) +| [**Neural net libraries**](#neural-network-libraries) +| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) +| [**Reference docs**](https://jax.readthedocs.io/en/latest/) + + +## What is JAX? + +JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), +brought together for high-performance machine learning research. + +With its updated version of [Autograd](https://github.com/hips/autograd), +JAX can automatically differentiate native +Python and NumPy functions. It can differentiate through loops, branches, +recursion, and closures, and it can take derivatives of derivatives of +derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) +via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, +and the two can be composed arbitrarily to any order. + +What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla) +to compile and run your NumPy programs on GPUs and TPUs. Compilation happens +under the hood by default, with library calls getting just-in-time compiled and +executed. But JAX also lets you just-in-time compile your own Python functions +into XLA-optimized kernels using a one-function API, +[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be +composed arbitrarily, so you can express sophisticated algorithms and get +maximal performance without leaving Python. You can even program multiple GPUs +or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and +differentiate through the whole thing. + +Dig a little deeper, and you'll see that JAX is really an extensible system for +[composable function transformations](#transformations). Both +[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit) +are instances of such transformations. Others are +[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and +[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) +parallel programming of multiple accelerators, with more to come. + +This is a research project, not an official Google product. Expect bugs and +[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Please help by trying it out, [reporting +bugs](https://github.com/google/jax/issues), and letting us know what you +think! + +```python +import jax.numpy as jnp +from jax import grad, jit, vmap + +def predict(params, inputs): + for W, b in params: + outputs = jnp.dot(inputs, W) + b + inputs = jnp.tanh(outputs) # inputs to the next layer + return outputs # no activation on last layer + +def loss(params, inputs, targets): + preds = predict(params, inputs) + return jnp.sum((preds - targets)**2) + +grad_loss = jit(grad(loss)) # compiled gradient evaluation function +perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads +``` + +### Contents +* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) +* [Transformations](#transformations) +* [Current gotchas](#current-gotchas) +* [Installation](#installation) +* [Neural net libraries](#neural-network-libraries) +* [Citing JAX](#citing-jax) +* [Reference documentation](#reference-documentation) + +## Quickstart: Colab in the Cloud +Jump right in using a notebook in your browser, connected to a Google Cloud GPU. +Here are some starter notebooks: +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) + +**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU +Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs). + +For a deeper dive into JAX: +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- See the [full list of +notebooks](https://github.com/google/jax/tree/main/docs/notebooks). + +You can also take a look at [the mini-libraries in +`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/example_libraries/README.md), +like [`stax` for building neural +networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax) +and [`optimizers` for first-order stochastic +optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization), +or the [examples](https://github.com/google/jax/tree/main/examples). + +## Transformations + +At its core, JAX is an extensible system for transforming numerical functions. +Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and +`pmap`. + +### Automatic differentiation with `grad` + +JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). +The most popular function is +[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +for reverse-mode gradients: + +```python +from jax import grad +import jax.numpy as jnp + +def tanh(x): # Define a function + y = jnp.exp(-2.0 * x) + return (1.0 - y) / (1.0 + y) + +grad_tanh = grad(tanh) # Obtain its gradient function +print(grad_tanh(1.0)) # Evaluate it at x = 1.0 +# prints 0.4199743 +``` + +You can differentiate to any order with `grad`. + +```python +print(grad(grad(grad(tanh)))(1.0)) +# prints 0.62162673 +``` + +For more advanced autodiff, you can use +[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +reverse-mode vector-Jacobian products and +[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for +forward-mode Jacobian-vector products. The two can be composed arbitrarily with +one another, and with other JAX transformations. Here's one way to compose those +to make a function that efficiently computes [full Hessian +matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): + +```python +from jax import jit, jacfwd, jacrev + +def hessian(fun): + return jit(jacfwd(jacrev(fun))) +``` + +As with [Autograd](https://github.com/hips/autograd), you're free to use +differentiation with Python control structures: + +```python +def abs_val(x): + if x > 0: + return x + else: + return -x + +abs_val_grad = grad(abs_val) +print(abs_val_grad(1.0)) # prints 1.0 +print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) +``` + +See the [reference docs on automatic +differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +and the [JAX Autodiff +Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +for more. + +### Compilation with `jit` + +You can use XLA to compile your functions end-to-end with +[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +used either as an `@jit` decorator or as a higher-order function. + +```python +import jax.numpy as jnp +from jax import jit + +def slow_f(x): + # Element-wise ops see a large benefit from fusion + return x * x + x * 2.0 + +x = jnp.ones((5000, 5000)) +fast_f = jit(slow_f) +%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X +%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX) +``` + +You can mix `jit` and `grad` and any other JAX transformation however you like. + +Using `jit` puts constraints on the kind of Python control flow +the function can use; see +the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT) +for more. + +### Auto-vectorization with `vmap` + +[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +the vectorizing map. +It has the familiar semantics of mapping a function along array axes, but +instead of keeping the loop on the outside, it pushes the loop down into a +function’s primitive operations for better performance. + +Using `vmap` can save you from having to carry around batch dimensions in your +code. For example, consider this simple *unbatched* neural network prediction +function: + +```python +def predict(params, input_vec): + assert input_vec.ndim == 1 + activations = input_vec + for W, b in params: + outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side! + activations = jnp.tanh(outputs) # inputs to the next layer + return outputs # no activation on last layer +``` + +We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the +left side of `activations`, but we’ve written this particular prediction function to +apply only to single input vectors. If we wanted to apply this function to a +batch of inputs at once, semantically we could just write + +```python +from functools import partial +predictions = jnp.stack(list(map(partial(predict, params), input_batch))) +``` + +But pushing one example through the network at a time would be slow! It’s better +to vectorize the computation, so that at every layer we’re doing matrix-matrix +multiplication rather than matrix-vector multiplication. + +The `vmap` function does that transformation for us. That is, if we write + +```python +from jax import vmap +predictions = vmap(partial(predict, params))(input_batch) +# or, alternatively +predictions = vmap(predict, in_axes=(None, 0))(params, input_batch) +``` + +then the `vmap` function will push the outer loop inside the function, and our +machine will end up executing matrix-matrix multiplications exactly as if we’d +done the batching by hand. + +It’s easy enough to manually batch a simple neural network without `vmap`, but +in other cases manual vectorization can be impractical or impossible. Take the +problem of efficiently computing per-example gradients: that is, for a fixed set +of parameters, we want to compute the gradient of our loss function evaluated +separately at each example in a batch. With `vmap`, it’s easy: + +```python +per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets) +``` + +Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other +JAX transformation! We use `vmap` with both forward- and reverse-mode automatic +differentiation for fast Jacobian and Hessian matrix calculations in +`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. + +### SPMD programming with `pmap` + +For parallel programming of multiple accelerators, like multiple GPUs, use +[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +With `pmap` you write single-program multiple-data (SPMD) programs, including +fast parallel collective communication operations. Applying `pmap` will mean +that the function you write is compiled by XLA (similarly to `jit`), then +replicated and executed in parallel across devices. + +Here's an example on an 8-GPU machine: + +```python +from jax import random, pmap +import jax.numpy as jnp + +# Create 8 random 5000 x 6000 matrices, one per GPU +keys = random.split(random.PRNGKey(0), 8) +mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) + +# Run a local matmul on each device in parallel (no data transfer) +result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) + +# Compute the mean on each device in parallel and print the result +print(pmap(jnp.mean)(result)) +# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157] +``` + +In addition to expressing pure maps, you can use fast [collective communication +operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +between devices: + +```python +from functools import partial +from jax import lax + +@partial(pmap, axis_name='i') +def normalize(x): + return x / lax.psum(x, 'i') + +print(normalize(jnp.arange(4.))) +# prints [0. 0.16666667 0.33333334 0.5 ] +``` + +You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more +sophisticated communication patterns. + +It all composes, so you're free to differentiate through parallel computations: + +```python +from jax import grad + +@pmap +def f(x): + y = jnp.sin(x) + @pmap + def g(z): + return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum() + return grad(lambda w: jnp.sum(g(w)))(x) + +print(f(x)) +# [[ 0. , -0.7170853 ], +# [-3.1085174 , -0.4824318 ], +# [10.366636 , 13.135289 ], +# [ 0.22163185, -0.52112055]] + +print(grad(lambda x: jnp.sum(f(x)))(x)) +# [[ -3.2369726, -1.6356447], +# [ 4.7572474, 11.606951 ], +# [-98.524414 , 42.76499 ], +# [ -1.6007166, -1.2568436]] +``` + +When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the +backward pass of the computation is parallelized just like the forward pass. + +See the [SPMD +Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +and the [SPMD MNIST classifier from scratch +example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) +for more. + +## Current gotchas + +For a more thorough survey of current gotchas, with examples and explanations, +we highly recommend reading the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Some standouts: + +1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. +1. [In-place mutating updates of + arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. +1. [Random numbers are + different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md). +1. If you're looking for [convolution + operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), + they're in the `jax.lax` package. +1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and + [to enable + double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at + startup (or set the environment variable `JAX_ENABLE_X64=True`). + On TPU, JAX uses 32-bit values by default for everything _except_ internal + temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`. + Those ops have a `precision` parameter which can be used to simulate + true 32-bit, with a cost of possibly slower runtime. +1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars + and NumPy types aren't preserved, namely `np.add(1, np.array([2], + np.float32)).dtype` is `float64` rather than `float32`. +1. Some transformations, like `jit`, [constrain how you can use Python control + flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). + You'll always get loud errors if something goes wrong. You might have to use + [`jit`'s `static_argnums` + parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + [structured control flow + primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + like + [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + or just use `jit` on smaller subfunctions. + +## Installation + +JAX is written in pure Python, but it depends on XLA, which needs to be +installed as the `jaxlib` package. Use the following instructions to install a +binary package with `pip` or `conda`, or to [build JAX from +source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and +macOS (10.12 or later) platforms. + +Windows users can use JAX on CPU and GPU via the [Windows Subsystem for +Linux](https://docs.microsoft.com/en-us/windows/wsl/about). In addition, there +is some initial community-driven native Windows support, but since it is still +somewhat immature, there are no official binary releases and it must be [built +from source for Windows](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows). +For an unofficial discussion of native Windows builds, see also the [Issue #5795 +thread](https://github.com/google/jax/issues/5795). + +### pip installation: CPU + +To install a CPU-only version of JAX, which might be useful for doing local +development on a laptop, you can run + +```bash +pip install --upgrade pip +pip install --upgrade "jax[cpu]" +``` + +On Linux, it is often necessary to first update `pip` to a version that supports +`manylinux2014` wheels. Also note that for Linux, we currently release wheels for `x86_64` architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to `jaxlib` not being installed alongside `jax`, although `jax` may successfully install (but fail at runtime). +**These `pip` installations do not work with Windows, and may fail silently; see +[above](#installation).** + +### pip installation: GPU (CUDA, installed via pip, easier) + +There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN +installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend +installing CUDA and CUDNN using the pip wheels, since it is much easier! + +You must first install the NVIDIA driver. We +recommend installing the newest driver available from NVIDIA, but the driver +must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux. + +```bash +pip install --upgrade pip + +# CUDA 12 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# CUDA 11 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +### pip installation: GPU (CUDA, installed locally, harder) + +If you prefer to use a preinstalled copy of CUDA, you must first +install [CUDA](https://developer.nvidia.com/cuda-downloads) and +[CuDNN](https://developer.nvidia.com/CUDNN). + +JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other +combinations of operating system and architecture are possible, but require +[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +You should use an NVIDIA driver version that is at least as new as your +[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). +If you need to use an newer CUDA toolkit with an older driver, for example +on a cluster where you cannot update the NVIDIA driver easily, you may be +able to use the +[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) +that NVIDIA provides for this purpose. + +JAX currently ships three CUDA wheel variants: +* CUDA 12.0 and CuDNN 8.8. +* CUDA 11.8 and CuDNN 8.6. +* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued + with jax 0.4.8. + +You may use a JAX wheel provided the major version of your CUDA and CuDNN +installation matches, and the minor version is at least as new as the version +JAX expects. For example, you would be able to use the CUDA 12.0 wheel with +CUDA 12.1 and CuDNN 8.9. + +Your CUDA installation must also be new enough to support your GPU. If you have +an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU, +you must use CUDA 11.8 or newer. + + +To install, run + +```bash +pip install --upgrade pip + +# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated). +pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``` + +**These `pip` installations do not work with Windows, and may fail silently; see +[above](#installation).** + +You can find your CUDA version with the command: + +```bash +nvcc --version +``` + +Some GPU functionality expects the CUDA installation to be at +`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number +(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either +create a symlink: + +```bash +sudo ln -s /path/to/cuda /usr/local/cuda-X.X +``` + +Please let us know on [the issue tracker](https://github.com/google/jax/issues) +if you run into any errors or problems with the prebuilt wheels. + +### pip installation: Google Cloud TPU +JAX also provides pre-built wheels for +[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm). +To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run +the following in your cloud TPU VM: +```bash +pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +### pip installation: Colab TPU +Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ. +The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU: +```python +import jax.tools.colab_tpu +jax.tools.colab_tpu.setup_tpu() +``` +Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer. +If you need to re-install JAX on a Colab TPU runtime, you can use the following command: +``` +!pip install jax<=0.3.25 jaxlib<=0.3.25 +``` + +### Conda installation + +There is a community-supported Conda build of `jax`. To install using `conda`, +simply run + +```bash +conda install jax -c conda-forge +``` + +To install on a machine with an NVIDIA GPU, run +```bash +conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +``` + +Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which +JAX requires. You must therefore either install the `cuda-nvcc` package from +the `nvidia` channel, or install CUDA on your machine separately so that `ptxas` +is in your path. The channel order above is important (`conda-forge` before +`nvidia`). + +If you would like to override which release of CUDA is used by JAX, or to +install the CUDA build on a machine without GPUs, follow the instructions in the +[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch) +section of the `conda-forge` website. + +See the `conda-forge` +[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and +[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories +for more details. + +### Building JAX from source +See [Building JAX from +source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +## Neural network libraries + +Multiple Google research groups develop and share libraries for training neural +networks in JAX. If you want a fully featured library for neural network +training with examples and how-to guides, try +[Flax](https://github.com/google/flax). + +In addition, DeepMind has open-sourced an [ecosystem of libraries around +JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) +including [Haiku](https://github.com/deepmind/dm-haiku) for neural network +modules, [Optax](https://github.com/deepmind/optax) for gradient processing and +optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and +[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch +the NeurIPS 2020 JAX Ecosystem at DeepMind talk +[here](https://www.youtube.com/watch?v=iDxJxIyzSiM)) + +## Citing JAX + +To cite this repository: + +``` +@software{jax2018github, + author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, + title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, + url = {http://github.com/google/jax}, + version = {0.3.13}, + year = {2018}, +} +``` + +In the above bibtex entry, names are in alphabetical order, the version number +is intended to be that from [jax/version.py](../main/jax/version.py), and +the year corresponds to the project's open-source release. + +A nascent version of JAX, supporting only automatic differentiation and +compilation to XLA, was described in a [paper that appeared at SysML +2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on +covering JAX's ideas and capabilities in a more comprehensive and up-to-date +paper. + +## Reference documentation + +For details about the JAX API, see the +[reference documentation](https://jax.readthedocs.io/). + +For getting started as a JAX developer, see the +[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). + + +%prep +%autosetup -n jax-0.4.8 + +%build +%py3_build + +%install +%py3_install +install -d -m755 %{buildroot}/%{_pkgdocdir} +if [ -d doc ]; then cp -arf doc %{buildroot}/%{_pkgdocdir}; fi +if [ -d docs ]; then cp -arf docs %{buildroot}/%{_pkgdocdir}; fi +if [ -d example ]; then cp -arf example %{buildroot}/%{_pkgdocdir}; fi +if [ -d examples ]; then cp -arf examples %{buildroot}/%{_pkgdocdir}; fi +pushd %{buildroot} +if [ -d usr/lib ]; then + find usr/lib -type f -printf "/%h/%f\n" >> filelist.lst +fi +if [ -d usr/lib64 ]; then + find usr/lib64 -type f -printf "/%h/%f\n" >> filelist.lst +fi +if [ -d usr/bin ]; then + find usr/bin -type f -printf "/%h/%f\n" >> filelist.lst +fi +if [ -d usr/sbin ]; then + find usr/sbin -type f -printf "/%h/%f\n" >> filelist.lst +fi +touch doclist.lst +if [ -d usr/share/man ]; then + find usr/share/man -type f -printf "/%h/%f.gz\n" >> doclist.lst +fi +popd +mv %{buildroot}/filelist.lst . +mv %{buildroot}/doclist.lst . + +%files -n python3-jax -f filelist.lst +%dir %{python3_sitelib}/* + +%files help -f doclist.lst +%{_docdir}/* + +%changelog +* Mon Apr 10 2023 Python_Bot <Python_Bot@openeuler.org> - 0.4.8-1 +- Package Spec generated @@ -0,0 +1 @@ +06e5f085d658fd782189ccca5fbf21f5 jax-0.4.8.tar.gz |