summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCoprDistGit <infra@openeuler.org>2023-04-10 08:23:46 +0000
committerCoprDistGit <infra@openeuler.org>2023-04-10 08:23:46 +0000
commit9428332eb11a19a4c333b3ae479952174f949752 (patch)
treeb689aadcfd73aac979a9c701fb6a4af7fb262ec9
parent984782762feb97ecc8052c015c73d0a2e7ea1769 (diff)
automatic import of python-jax
-rw-r--r--.gitignore1
-rw-r--r--python-jax.spec1920
-rw-r--r--sources1
3 files changed, 1922 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
index e69de29..cfa2764 100644
--- a/.gitignore
+++ b/.gitignore
@@ -0,0 +1 @@
+/jax-0.4.8.tar.gz
diff --git a/python-jax.spec b/python-jax.spec
new file mode 100644
index 0000000..d6931cf
--- /dev/null
+++ b/python-jax.spec
@@ -0,0 +1,1920 @@
+%global _empty_manifest_terminate_build 0
+Name: python-jax
+Version: 0.4.8
+Release: 1
+Summary: Differentiate, compile, and transform Numpy code.
+License: Apache-2.0
+URL: https://github.com/google/jax
+Source0: https://mirrors.nju.edu.cn/pypi/web/packages/fe/58/1641614c17fcd7293d250c2cad48011baa1ecef4f109ce2ea027aa8e4898/jax-0.4.8.tar.gz
+BuildArch: noarch
+
+
+%description
+<div align="center">
+<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img>
+</div>
+
+# JAX: Autograd and XLA
+
+![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg)
+![PyPI version](https://img.shields.io/pypi/v/jax)
+
+[**Quickstart**](#quickstart-colab-in-the-cloud)
+| [**Transformations**](#transformations)
+| [**Install guide**](#installation)
+| [**Neural net libraries**](#neural-network-libraries)
+| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html)
+| [**Reference docs**](https://jax.readthedocs.io/en/latest/)
+
+
+## What is JAX?
+
+JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla),
+brought together for high-performance machine learning research.
+
+With its updated version of [Autograd](https://github.com/hips/autograd),
+JAX can automatically differentiate native
+Python and NumPy functions. It can differentiate through loops, branches,
+recursion, and closures, and it can take derivatives of derivatives of
+derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
+via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
+and the two can be composed arbitrarily to any order.
+
+What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla)
+to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
+under the hood by default, with library calls getting just-in-time compiled and
+executed. But JAX also lets you just-in-time compile your own Python functions
+into XLA-optimized kernels using a one-function API,
+[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
+composed arbitrarily, so you can express sophisticated algorithms and get
+maximal performance without leaving Python. You can even program multiple GPUs
+or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and
+differentiate through the whole thing.
+
+Dig a little deeper, and you'll see that JAX is really an extensible system for
+[composable function transformations](#transformations). Both
+[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
+are instances of such transformations. Others are
+[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
+[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
+parallel programming of multiple accelerators, with more to come.
+
+This is a research project, not an official Google product. Expect bugs and
+[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
+Please help by trying it out, [reporting
+bugs](https://github.com/google/jax/issues), and letting us know what you
+think!
+
+```python
+import jax.numpy as jnp
+from jax import grad, jit, vmap
+
+def predict(params, inputs):
+ for W, b in params:
+ outputs = jnp.dot(inputs, W) + b
+ inputs = jnp.tanh(outputs) # inputs to the next layer
+ return outputs # no activation on last layer
+
+def loss(params, inputs, targets):
+ preds = predict(params, inputs)
+ return jnp.sum((preds - targets)**2)
+
+grad_loss = jit(grad(loss)) # compiled gradient evaluation function
+perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
+```
+
+### Contents
+* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
+* [Transformations](#transformations)
+* [Current gotchas](#current-gotchas)
+* [Installation](#installation)
+* [Neural net libraries](#neural-network-libraries)
+* [Citing JAX](#citing-jax)
+* [Reference documentation](#reference-documentation)
+
+## Quickstart: Colab in the Cloud
+Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
+Here are some starter notebooks:
+- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
+- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
+
+**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
+Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs).
+
+For a deeper dive into JAX:
+- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
+- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
+- See the [full list of
+notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
+
+You can also take a look at [the mini-libraries in
+`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/example_libraries/README.md),
+like [`stax` for building neural
+networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax)
+and [`optimizers` for first-order stochastic
+optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization),
+or the [examples](https://github.com/google/jax/tree/main/examples).
+
+## Transformations
+
+At its core, JAX is an extensible system for transforming numerical functions.
+Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and
+`pmap`.
+
+### Automatic differentiation with `grad`
+
+JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
+The most popular function is
+[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
+for reverse-mode gradients:
+
+```python
+from jax import grad
+import jax.numpy as jnp
+
+def tanh(x): # Define a function
+ y = jnp.exp(-2.0 * x)
+ return (1.0 - y) / (1.0 + y)
+
+grad_tanh = grad(tanh) # Obtain its gradient function
+print(grad_tanh(1.0)) # Evaluate it at x = 1.0
+# prints 0.4199743
+```
+
+You can differentiate to any order with `grad`.
+
+```python
+print(grad(grad(grad(tanh)))(1.0))
+# prints 0.62162673
+```
+
+For more advanced autodiff, you can use
+[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
+reverse-mode vector-Jacobian products and
+[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
+forward-mode Jacobian-vector products. The two can be composed arbitrarily with
+one another, and with other JAX transformations. Here's one way to compose those
+to make a function that efficiently computes [full Hessian
+matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian):
+
+```python
+from jax import jit, jacfwd, jacrev
+
+def hessian(fun):
+ return jit(jacfwd(jacrev(fun)))
+```
+
+As with [Autograd](https://github.com/hips/autograd), you're free to use
+differentiation with Python control structures:
+
+```python
+def abs_val(x):
+ if x > 0:
+ return x
+ else:
+ return -x
+
+abs_val_grad = grad(abs_val)
+print(abs_val_grad(1.0)) # prints 1.0
+print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
+```
+
+See the [reference docs on automatic
+differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+and the [JAX Autodiff
+Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
+for more.
+
+### Compilation with `jit`
+
+You can use XLA to compile your functions end-to-end with
+[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
+used either as an `@jit` decorator or as a higher-order function.
+
+```python
+import jax.numpy as jnp
+from jax import jit
+
+def slow_f(x):
+ # Element-wise ops see a large benefit from fusion
+ return x * x + x * 2.0
+
+x = jnp.ones((5000, 5000))
+fast_f = jit(slow_f)
+%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
+%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
+```
+
+You can mix `jit` and `grad` and any other JAX transformation however you like.
+
+Using `jit` puts constraints on the kind of Python control flow
+the function can use; see
+the [Gotchas
+Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
+for more.
+
+### Auto-vectorization with `vmap`
+
+[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
+the vectorizing map.
+It has the familiar semantics of mapping a function along array axes, but
+instead of keeping the loop on the outside, it pushes the loop down into a
+function’s primitive operations for better performance.
+
+Using `vmap` can save you from having to carry around batch dimensions in your
+code. For example, consider this simple *unbatched* neural network prediction
+function:
+
+```python
+def predict(params, input_vec):
+ assert input_vec.ndim == 1
+ activations = input_vec
+ for W, b in params:
+ outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side!
+ activations = jnp.tanh(outputs) # inputs to the next layer
+ return outputs # no activation on last layer
+```
+
+We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the
+left side of `activations`, but we’ve written this particular prediction function to
+apply only to single input vectors. If we wanted to apply this function to a
+batch of inputs at once, semantically we could just write
+
+```python
+from functools import partial
+predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
+```
+
+But pushing one example through the network at a time would be slow! It’s better
+to vectorize the computation, so that at every layer we’re doing matrix-matrix
+multiplication rather than matrix-vector multiplication.
+
+The `vmap` function does that transformation for us. That is, if we write
+
+```python
+from jax import vmap
+predictions = vmap(partial(predict, params))(input_batch)
+# or, alternatively
+predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
+```
+
+then the `vmap` function will push the outer loop inside the function, and our
+machine will end up executing matrix-matrix multiplications exactly as if we’d
+done the batching by hand.
+
+It’s easy enough to manually batch a simple neural network without `vmap`, but
+in other cases manual vectorization can be impractical or impossible. Take the
+problem of efficiently computing per-example gradients: that is, for a fixed set
+of parameters, we want to compute the gradient of our loss function evaluated
+separately at each example in a batch. With `vmap`, it’s easy:
+
+```python
+per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
+```
+
+Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other
+JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
+differentiation for fast Jacobian and Hessian matrix calculations in
+`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.
+
+### SPMD programming with `pmap`
+
+For parallel programming of multiple accelerators, like multiple GPUs, use
+[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
+With `pmap` you write single-program multiple-data (SPMD) programs, including
+fast parallel collective communication operations. Applying `pmap` will mean
+that the function you write is compiled by XLA (similarly to `jit`), then
+replicated and executed in parallel across devices.
+
+Here's an example on an 8-GPU machine:
+
+```python
+from jax import random, pmap
+import jax.numpy as jnp
+
+# Create 8 random 5000 x 6000 matrices, one per GPU
+keys = random.split(random.PRNGKey(0), 8)
+mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
+
+# Run a local matmul on each device in parallel (no data transfer)
+result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
+
+# Compute the mean on each device in parallel and print the result
+print(pmap(jnp.mean)(result))
+# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
+```
+
+In addition to expressing pure maps, you can use fast [collective communication
+operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
+between devices:
+
+```python
+from functools import partial
+from jax import lax
+
+@partial(pmap, axis_name='i')
+def normalize(x):
+ return x / lax.psum(x, 'i')
+
+print(normalize(jnp.arange(4.)))
+# prints [0. 0.16666667 0.33333334 0.5 ]
+```
+
+You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
+sophisticated communication patterns.
+
+It all composes, so you're free to differentiate through parallel computations:
+
+```python
+from jax import grad
+
+@pmap
+def f(x):
+ y = jnp.sin(x)
+ @pmap
+ def g(z):
+ return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
+ return grad(lambda w: jnp.sum(g(w)))(x)
+
+print(f(x))
+# [[ 0. , -0.7170853 ],
+# [-3.1085174 , -0.4824318 ],
+# [10.366636 , 13.135289 ],
+# [ 0.22163185, -0.52112055]]
+
+print(grad(lambda x: jnp.sum(f(x)))(x))
+# [[ -3.2369726, -1.6356447],
+# [ 4.7572474, 11.606951 ],
+# [-98.524414 , 42.76499 ],
+# [ -1.6007166, -1.2568436]]
+```
+
+When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
+backward pass of the computation is parallelized just like the forward pass.
+
+See the [SPMD
+Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
+and the [SPMD MNIST classifier from scratch
+example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
+for more.
+
+## Current gotchas
+
+For a more thorough survey of current gotchas, with examples and explanations,
+we highly recommend reading the [Gotchas
+Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
+Some standouts:
+
+1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`.
+1. [In-place mutating updates of
+ arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
+1. [Random numbers are
+ different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md).
+1. If you're looking for [convolution
+ operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
+ they're in the `jax.lax` package.
+1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
+ [to enable
+ double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
+ (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
+ startup (or set the environment variable `JAX_ENABLE_X64=True`).
+ On TPU, JAX uses 32-bit values by default for everything _except_ internal
+ temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`.
+ Those ops have a `precision` parameter which can be used to simulate
+ true 32-bit, with a cost of possibly slower runtime.
+1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
+ and NumPy types aren't preserved, namely `np.add(1, np.array([2],
+ np.float32)).dtype` is `float64` rather than `float32`.
+1. Some transformations, like `jit`, [constrain how you can use Python control
+ flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
+ You'll always get loud errors if something goes wrong. You might have to use
+ [`jit`'s `static_argnums`
+ parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
+ [structured control flow
+ primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)
+ like
+ [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
+ or just use `jit` on smaller subfunctions.
+
+## Installation
+
+JAX is written in pure Python, but it depends on XLA, which needs to be
+installed as the `jaxlib` package. Use the following instructions to install a
+binary package with `pip` or `conda`, or to [build JAX from
+source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and
+macOS (10.12 or later) platforms.
+
+Windows users can use JAX on CPU and GPU via the [Windows Subsystem for
+Linux](https://docs.microsoft.com/en-us/windows/wsl/about). In addition, there
+is some initial community-driven native Windows support, but since it is still
+somewhat immature, there are no official binary releases and it must be [built
+from source for Windows](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows).
+For an unofficial discussion of native Windows builds, see also the [Issue #5795
+thread](https://github.com/google/jax/issues/5795).
+
+### pip installation: CPU
+
+To install a CPU-only version of JAX, which might be useful for doing local
+development on a laptop, you can run
+
+```bash
+pip install --upgrade pip
+pip install --upgrade "jax[cpu]"
+```
+
+On Linux, it is often necessary to first update `pip` to a version that supports
+`manylinux2014` wheels. Also note that for Linux, we currently release wheels for `x86_64` architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to `jaxlib` not being installed alongside `jax`, although `jax` may successfully install (but fail at runtime).
+**These `pip` installations do not work with Windows, and may fail silently; see
+[above](#installation).**
+
+### pip installation: GPU (CUDA, installed via pip, easier)
+
+There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
+installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
+installing CUDA and CUDNN using the pip wheels, since it is much easier!
+
+You must first install the NVIDIA driver. We
+recommend installing the newest driver available from NVIDIA, but the driver
+must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
+
+```bash
+pip install --upgrade pip
+
+# CUDA 12 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# CUDA 11 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+```
+
+### pip installation: GPU (CUDA, installed locally, harder)
+
+If you prefer to use a preinstalled copy of CUDA, you must first
+install [CUDA](https://developer.nvidia.com/cuda-downloads) and
+[CuDNN](https://developer.nvidia.com/CUDNN).
+
+JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
+combinations of operating system and architecture are possible, but require
+[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+You should use an NVIDIA driver version that is at least as new as your
+[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
+If you need to use an newer CUDA toolkit with an older driver, for example
+on a cluster where you cannot update the NVIDIA driver easily, you may be
+able to use the
+[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
+that NVIDIA provides for this purpose.
+
+JAX currently ships three CUDA wheel variants:
+* CUDA 12.0 and CuDNN 8.8.
+* CUDA 11.8 and CuDNN 8.6.
+* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued
+ with jax 0.4.8.
+
+You may use a JAX wheel provided the major version of your CUDA and CuDNN
+installation matches, and the minor version is at least as new as the version
+JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
+CUDA 12.1 and CuDNN 8.9.
+
+Your CUDA installation must also be new enough to support your GPU. If you have
+an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
+you must use CUDA 11.8 or newer.
+
+
+To install, run
+
+```bash
+pip install --upgrade pip
+
+# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated).
+pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+```
+
+**These `pip` installations do not work with Windows, and may fail silently; see
+[above](#installation).**
+
+You can find your CUDA version with the command:
+
+```bash
+nvcc --version
+```
+
+Some GPU functionality expects the CUDA installation to be at
+`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number
+(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either
+create a symlink:
+
+```bash
+sudo ln -s /path/to/cuda /usr/local/cuda-X.X
+```
+
+Please let us know on [the issue tracker](https://github.com/google/jax/issues)
+if you run into any errors or problems with the prebuilt wheels.
+
+### pip installation: Google Cloud TPU
+JAX also provides pre-built wheels for
+[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
+To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run
+the following in your cloud TPU VM:
+```bash
+pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+### pip installation: Colab TPU
+Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ.
+The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
+```python
+import jax.tools.colab_tpu
+jax.tools.colab_tpu.setup_tpu()
+```
+Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer.
+If you need to re-install JAX on a Colab TPU runtime, you can use the following command:
+```
+!pip install jax<=0.3.25 jaxlib<=0.3.25
+```
+
+### Conda installation
+
+There is a community-supported Conda build of `jax`. To install using `conda`,
+simply run
+
+```bash
+conda install jax -c conda-forge
+```
+
+To install on a machine with an NVIDIA GPU, run
+```bash
+conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
+```
+
+Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
+JAX requires. You must therefore either install the `cuda-nvcc` package from
+the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
+is in your path. The channel order above is important (`conda-forge` before
+`nvidia`).
+
+If you would like to override which release of CUDA is used by JAX, or to
+install the CUDA build on a machine without GPUs, follow the instructions in the
+[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)
+section of the `conda-forge` website.
+
+See the `conda-forge`
+[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and
+[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories
+for more details.
+
+### Building JAX from source
+See [Building JAX from
+source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+## Neural network libraries
+
+Multiple Google research groups develop and share libraries for training neural
+networks in JAX. If you want a fully featured library for neural network
+training with examples and how-to guides, try
+[Flax](https://github.com/google/flax).
+
+In addition, DeepMind has open-sourced an [ecosystem of libraries around
+JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research)
+including [Haiku](https://github.com/deepmind/dm-haiku) for neural network
+modules, [Optax](https://github.com/deepmind/optax) for gradient processing and
+optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and
+[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch
+the NeurIPS 2020 JAX Ecosystem at DeepMind talk
+[here](https://www.youtube.com/watch?v=iDxJxIyzSiM))
+
+## Citing JAX
+
+To cite this repository:
+
+```
+@software{jax2018github,
+ author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
+ title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
+ url = {http://github.com/google/jax},
+ version = {0.3.13},
+ year = {2018},
+}
+```
+
+In the above bibtex entry, names are in alphabetical order, the version number
+is intended to be that from [jax/version.py](../main/jax/version.py), and
+the year corresponds to the project's open-source release.
+
+A nascent version of JAX, supporting only automatic differentiation and
+compilation to XLA, was described in a [paper that appeared at SysML
+2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on
+covering JAX's ideas and capabilities in a more comprehensive and up-to-date
+paper.
+
+## Reference documentation
+
+For details about the JAX API, see the
+[reference documentation](https://jax.readthedocs.io/).
+
+For getting started as a JAX developer, see the
+[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
+
+
+%package -n python3-jax
+Summary: Differentiate, compile, and transform Numpy code.
+Provides: python-jax
+BuildRequires: python3-devel
+BuildRequires: python3-setuptools
+BuildRequires: python3-pip
+%description -n python3-jax
+<div align="center">
+<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img>
+</div>
+
+# JAX: Autograd and XLA
+
+![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg)
+![PyPI version](https://img.shields.io/pypi/v/jax)
+
+[**Quickstart**](#quickstart-colab-in-the-cloud)
+| [**Transformations**](#transformations)
+| [**Install guide**](#installation)
+| [**Neural net libraries**](#neural-network-libraries)
+| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html)
+| [**Reference docs**](https://jax.readthedocs.io/en/latest/)
+
+
+## What is JAX?
+
+JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla),
+brought together for high-performance machine learning research.
+
+With its updated version of [Autograd](https://github.com/hips/autograd),
+JAX can automatically differentiate native
+Python and NumPy functions. It can differentiate through loops, branches,
+recursion, and closures, and it can take derivatives of derivatives of
+derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
+via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
+and the two can be composed arbitrarily to any order.
+
+What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla)
+to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
+under the hood by default, with library calls getting just-in-time compiled and
+executed. But JAX also lets you just-in-time compile your own Python functions
+into XLA-optimized kernels using a one-function API,
+[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
+composed arbitrarily, so you can express sophisticated algorithms and get
+maximal performance without leaving Python. You can even program multiple GPUs
+or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and
+differentiate through the whole thing.
+
+Dig a little deeper, and you'll see that JAX is really an extensible system for
+[composable function transformations](#transformations). Both
+[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
+are instances of such transformations. Others are
+[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
+[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
+parallel programming of multiple accelerators, with more to come.
+
+This is a research project, not an official Google product. Expect bugs and
+[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
+Please help by trying it out, [reporting
+bugs](https://github.com/google/jax/issues), and letting us know what you
+think!
+
+```python
+import jax.numpy as jnp
+from jax import grad, jit, vmap
+
+def predict(params, inputs):
+ for W, b in params:
+ outputs = jnp.dot(inputs, W) + b
+ inputs = jnp.tanh(outputs) # inputs to the next layer
+ return outputs # no activation on last layer
+
+def loss(params, inputs, targets):
+ preds = predict(params, inputs)
+ return jnp.sum((preds - targets)**2)
+
+grad_loss = jit(grad(loss)) # compiled gradient evaluation function
+perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
+```
+
+### Contents
+* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
+* [Transformations](#transformations)
+* [Current gotchas](#current-gotchas)
+* [Installation](#installation)
+* [Neural net libraries](#neural-network-libraries)
+* [Citing JAX](#citing-jax)
+* [Reference documentation](#reference-documentation)
+
+## Quickstart: Colab in the Cloud
+Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
+Here are some starter notebooks:
+- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
+- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
+
+**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
+Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs).
+
+For a deeper dive into JAX:
+- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
+- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
+- See the [full list of
+notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
+
+You can also take a look at [the mini-libraries in
+`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/example_libraries/README.md),
+like [`stax` for building neural
+networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax)
+and [`optimizers` for first-order stochastic
+optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization),
+or the [examples](https://github.com/google/jax/tree/main/examples).
+
+## Transformations
+
+At its core, JAX is an extensible system for transforming numerical functions.
+Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and
+`pmap`.
+
+### Automatic differentiation with `grad`
+
+JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
+The most popular function is
+[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
+for reverse-mode gradients:
+
+```python
+from jax import grad
+import jax.numpy as jnp
+
+def tanh(x): # Define a function
+ y = jnp.exp(-2.0 * x)
+ return (1.0 - y) / (1.0 + y)
+
+grad_tanh = grad(tanh) # Obtain its gradient function
+print(grad_tanh(1.0)) # Evaluate it at x = 1.0
+# prints 0.4199743
+```
+
+You can differentiate to any order with `grad`.
+
+```python
+print(grad(grad(grad(tanh)))(1.0))
+# prints 0.62162673
+```
+
+For more advanced autodiff, you can use
+[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
+reverse-mode vector-Jacobian products and
+[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
+forward-mode Jacobian-vector products. The two can be composed arbitrarily with
+one another, and with other JAX transformations. Here's one way to compose those
+to make a function that efficiently computes [full Hessian
+matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian):
+
+```python
+from jax import jit, jacfwd, jacrev
+
+def hessian(fun):
+ return jit(jacfwd(jacrev(fun)))
+```
+
+As with [Autograd](https://github.com/hips/autograd), you're free to use
+differentiation with Python control structures:
+
+```python
+def abs_val(x):
+ if x > 0:
+ return x
+ else:
+ return -x
+
+abs_val_grad = grad(abs_val)
+print(abs_val_grad(1.0)) # prints 1.0
+print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
+```
+
+See the [reference docs on automatic
+differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+and the [JAX Autodiff
+Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
+for more.
+
+### Compilation with `jit`
+
+You can use XLA to compile your functions end-to-end with
+[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
+used either as an `@jit` decorator or as a higher-order function.
+
+```python
+import jax.numpy as jnp
+from jax import jit
+
+def slow_f(x):
+ # Element-wise ops see a large benefit from fusion
+ return x * x + x * 2.0
+
+x = jnp.ones((5000, 5000))
+fast_f = jit(slow_f)
+%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
+%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
+```
+
+You can mix `jit` and `grad` and any other JAX transformation however you like.
+
+Using `jit` puts constraints on the kind of Python control flow
+the function can use; see
+the [Gotchas
+Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
+for more.
+
+### Auto-vectorization with `vmap`
+
+[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
+the vectorizing map.
+It has the familiar semantics of mapping a function along array axes, but
+instead of keeping the loop on the outside, it pushes the loop down into a
+function’s primitive operations for better performance.
+
+Using `vmap` can save you from having to carry around batch dimensions in your
+code. For example, consider this simple *unbatched* neural network prediction
+function:
+
+```python
+def predict(params, input_vec):
+ assert input_vec.ndim == 1
+ activations = input_vec
+ for W, b in params:
+ outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side!
+ activations = jnp.tanh(outputs) # inputs to the next layer
+ return outputs # no activation on last layer
+```
+
+We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the
+left side of `activations`, but we’ve written this particular prediction function to
+apply only to single input vectors. If we wanted to apply this function to a
+batch of inputs at once, semantically we could just write
+
+```python
+from functools import partial
+predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
+```
+
+But pushing one example through the network at a time would be slow! It’s better
+to vectorize the computation, so that at every layer we’re doing matrix-matrix
+multiplication rather than matrix-vector multiplication.
+
+The `vmap` function does that transformation for us. That is, if we write
+
+```python
+from jax import vmap
+predictions = vmap(partial(predict, params))(input_batch)
+# or, alternatively
+predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
+```
+
+then the `vmap` function will push the outer loop inside the function, and our
+machine will end up executing matrix-matrix multiplications exactly as if we’d
+done the batching by hand.
+
+It’s easy enough to manually batch a simple neural network without `vmap`, but
+in other cases manual vectorization can be impractical or impossible. Take the
+problem of efficiently computing per-example gradients: that is, for a fixed set
+of parameters, we want to compute the gradient of our loss function evaluated
+separately at each example in a batch. With `vmap`, it’s easy:
+
+```python
+per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
+```
+
+Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other
+JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
+differentiation for fast Jacobian and Hessian matrix calculations in
+`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.
+
+### SPMD programming with `pmap`
+
+For parallel programming of multiple accelerators, like multiple GPUs, use
+[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
+With `pmap` you write single-program multiple-data (SPMD) programs, including
+fast parallel collective communication operations. Applying `pmap` will mean
+that the function you write is compiled by XLA (similarly to `jit`), then
+replicated and executed in parallel across devices.
+
+Here's an example on an 8-GPU machine:
+
+```python
+from jax import random, pmap
+import jax.numpy as jnp
+
+# Create 8 random 5000 x 6000 matrices, one per GPU
+keys = random.split(random.PRNGKey(0), 8)
+mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
+
+# Run a local matmul on each device in parallel (no data transfer)
+result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
+
+# Compute the mean on each device in parallel and print the result
+print(pmap(jnp.mean)(result))
+# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
+```
+
+In addition to expressing pure maps, you can use fast [collective communication
+operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
+between devices:
+
+```python
+from functools import partial
+from jax import lax
+
+@partial(pmap, axis_name='i')
+def normalize(x):
+ return x / lax.psum(x, 'i')
+
+print(normalize(jnp.arange(4.)))
+# prints [0. 0.16666667 0.33333334 0.5 ]
+```
+
+You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
+sophisticated communication patterns.
+
+It all composes, so you're free to differentiate through parallel computations:
+
+```python
+from jax import grad
+
+@pmap
+def f(x):
+ y = jnp.sin(x)
+ @pmap
+ def g(z):
+ return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
+ return grad(lambda w: jnp.sum(g(w)))(x)
+
+print(f(x))
+# [[ 0. , -0.7170853 ],
+# [-3.1085174 , -0.4824318 ],
+# [10.366636 , 13.135289 ],
+# [ 0.22163185, -0.52112055]]
+
+print(grad(lambda x: jnp.sum(f(x)))(x))
+# [[ -3.2369726, -1.6356447],
+# [ 4.7572474, 11.606951 ],
+# [-98.524414 , 42.76499 ],
+# [ -1.6007166, -1.2568436]]
+```
+
+When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
+backward pass of the computation is parallelized just like the forward pass.
+
+See the [SPMD
+Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
+and the [SPMD MNIST classifier from scratch
+example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
+for more.
+
+## Current gotchas
+
+For a more thorough survey of current gotchas, with examples and explanations,
+we highly recommend reading the [Gotchas
+Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
+Some standouts:
+
+1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`.
+1. [In-place mutating updates of
+ arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
+1. [Random numbers are
+ different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md).
+1. If you're looking for [convolution
+ operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
+ they're in the `jax.lax` package.
+1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
+ [to enable
+ double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
+ (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
+ startup (or set the environment variable `JAX_ENABLE_X64=True`).
+ On TPU, JAX uses 32-bit values by default for everything _except_ internal
+ temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`.
+ Those ops have a `precision` parameter which can be used to simulate
+ true 32-bit, with a cost of possibly slower runtime.
+1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
+ and NumPy types aren't preserved, namely `np.add(1, np.array([2],
+ np.float32)).dtype` is `float64` rather than `float32`.
+1. Some transformations, like `jit`, [constrain how you can use Python control
+ flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
+ You'll always get loud errors if something goes wrong. You might have to use
+ [`jit`'s `static_argnums`
+ parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
+ [structured control flow
+ primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)
+ like
+ [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
+ or just use `jit` on smaller subfunctions.
+
+## Installation
+
+JAX is written in pure Python, but it depends on XLA, which needs to be
+installed as the `jaxlib` package. Use the following instructions to install a
+binary package with `pip` or `conda`, or to [build JAX from
+source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and
+macOS (10.12 or later) platforms.
+
+Windows users can use JAX on CPU and GPU via the [Windows Subsystem for
+Linux](https://docs.microsoft.com/en-us/windows/wsl/about). In addition, there
+is some initial community-driven native Windows support, but since it is still
+somewhat immature, there are no official binary releases and it must be [built
+from source for Windows](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows).
+For an unofficial discussion of native Windows builds, see also the [Issue #5795
+thread](https://github.com/google/jax/issues/5795).
+
+### pip installation: CPU
+
+To install a CPU-only version of JAX, which might be useful for doing local
+development on a laptop, you can run
+
+```bash
+pip install --upgrade pip
+pip install --upgrade "jax[cpu]"
+```
+
+On Linux, it is often necessary to first update `pip` to a version that supports
+`manylinux2014` wheels. Also note that for Linux, we currently release wheels for `x86_64` architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to `jaxlib` not being installed alongside `jax`, although `jax` may successfully install (but fail at runtime).
+**These `pip` installations do not work with Windows, and may fail silently; see
+[above](#installation).**
+
+### pip installation: GPU (CUDA, installed via pip, easier)
+
+There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
+installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
+installing CUDA and CUDNN using the pip wheels, since it is much easier!
+
+You must first install the NVIDIA driver. We
+recommend installing the newest driver available from NVIDIA, but the driver
+must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
+
+```bash
+pip install --upgrade pip
+
+# CUDA 12 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# CUDA 11 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+```
+
+### pip installation: GPU (CUDA, installed locally, harder)
+
+If you prefer to use a preinstalled copy of CUDA, you must first
+install [CUDA](https://developer.nvidia.com/cuda-downloads) and
+[CuDNN](https://developer.nvidia.com/CUDNN).
+
+JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
+combinations of operating system and architecture are possible, but require
+[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+You should use an NVIDIA driver version that is at least as new as your
+[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
+If you need to use an newer CUDA toolkit with an older driver, for example
+on a cluster where you cannot update the NVIDIA driver easily, you may be
+able to use the
+[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
+that NVIDIA provides for this purpose.
+
+JAX currently ships three CUDA wheel variants:
+* CUDA 12.0 and CuDNN 8.8.
+* CUDA 11.8 and CuDNN 8.6.
+* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued
+ with jax 0.4.8.
+
+You may use a JAX wheel provided the major version of your CUDA and CuDNN
+installation matches, and the minor version is at least as new as the version
+JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
+CUDA 12.1 and CuDNN 8.9.
+
+Your CUDA installation must also be new enough to support your GPU. If you have
+an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
+you must use CUDA 11.8 or newer.
+
+
+To install, run
+
+```bash
+pip install --upgrade pip
+
+# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated).
+pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+```
+
+**These `pip` installations do not work with Windows, and may fail silently; see
+[above](#installation).**
+
+You can find your CUDA version with the command:
+
+```bash
+nvcc --version
+```
+
+Some GPU functionality expects the CUDA installation to be at
+`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number
+(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either
+create a symlink:
+
+```bash
+sudo ln -s /path/to/cuda /usr/local/cuda-X.X
+```
+
+Please let us know on [the issue tracker](https://github.com/google/jax/issues)
+if you run into any errors or problems with the prebuilt wheels.
+
+### pip installation: Google Cloud TPU
+JAX also provides pre-built wheels for
+[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
+To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run
+the following in your cloud TPU VM:
+```bash
+pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+### pip installation: Colab TPU
+Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ.
+The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
+```python
+import jax.tools.colab_tpu
+jax.tools.colab_tpu.setup_tpu()
+```
+Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer.
+If you need to re-install JAX on a Colab TPU runtime, you can use the following command:
+```
+!pip install jax<=0.3.25 jaxlib<=0.3.25
+```
+
+### Conda installation
+
+There is a community-supported Conda build of `jax`. To install using `conda`,
+simply run
+
+```bash
+conda install jax -c conda-forge
+```
+
+To install on a machine with an NVIDIA GPU, run
+```bash
+conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
+```
+
+Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
+JAX requires. You must therefore either install the `cuda-nvcc` package from
+the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
+is in your path. The channel order above is important (`conda-forge` before
+`nvidia`).
+
+If you would like to override which release of CUDA is used by JAX, or to
+install the CUDA build on a machine without GPUs, follow the instructions in the
+[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)
+section of the `conda-forge` website.
+
+See the `conda-forge`
+[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and
+[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories
+for more details.
+
+### Building JAX from source
+See [Building JAX from
+source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+## Neural network libraries
+
+Multiple Google research groups develop and share libraries for training neural
+networks in JAX. If you want a fully featured library for neural network
+training with examples and how-to guides, try
+[Flax](https://github.com/google/flax).
+
+In addition, DeepMind has open-sourced an [ecosystem of libraries around
+JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research)
+including [Haiku](https://github.com/deepmind/dm-haiku) for neural network
+modules, [Optax](https://github.com/deepmind/optax) for gradient processing and
+optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and
+[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch
+the NeurIPS 2020 JAX Ecosystem at DeepMind talk
+[here](https://www.youtube.com/watch?v=iDxJxIyzSiM))
+
+## Citing JAX
+
+To cite this repository:
+
+```
+@software{jax2018github,
+ author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
+ title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
+ url = {http://github.com/google/jax},
+ version = {0.3.13},
+ year = {2018},
+}
+```
+
+In the above bibtex entry, names are in alphabetical order, the version number
+is intended to be that from [jax/version.py](../main/jax/version.py), and
+the year corresponds to the project's open-source release.
+
+A nascent version of JAX, supporting only automatic differentiation and
+compilation to XLA, was described in a [paper that appeared at SysML
+2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on
+covering JAX's ideas and capabilities in a more comprehensive and up-to-date
+paper.
+
+## Reference documentation
+
+For details about the JAX API, see the
+[reference documentation](https://jax.readthedocs.io/).
+
+For getting started as a JAX developer, see the
+[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
+
+
+%package help
+Summary: Development documents and examples for jax
+Provides: python3-jax-doc
+%description help
+<div align="center">
+<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img>
+</div>
+
+# JAX: Autograd and XLA
+
+![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg)
+![PyPI version](https://img.shields.io/pypi/v/jax)
+
+[**Quickstart**](#quickstart-colab-in-the-cloud)
+| [**Transformations**](#transformations)
+| [**Install guide**](#installation)
+| [**Neural net libraries**](#neural-network-libraries)
+| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html)
+| [**Reference docs**](https://jax.readthedocs.io/en/latest/)
+
+
+## What is JAX?
+
+JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla),
+brought together for high-performance machine learning research.
+
+With its updated version of [Autograd](https://github.com/hips/autograd),
+JAX can automatically differentiate native
+Python and NumPy functions. It can differentiate through loops, branches,
+recursion, and closures, and it can take derivatives of derivatives of
+derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
+via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
+and the two can be composed arbitrarily to any order.
+
+What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla)
+to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
+under the hood by default, with library calls getting just-in-time compiled and
+executed. But JAX also lets you just-in-time compile your own Python functions
+into XLA-optimized kernels using a one-function API,
+[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
+composed arbitrarily, so you can express sophisticated algorithms and get
+maximal performance without leaving Python. You can even program multiple GPUs
+or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and
+differentiate through the whole thing.
+
+Dig a little deeper, and you'll see that JAX is really an extensible system for
+[composable function transformations](#transformations). Both
+[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
+are instances of such transformations. Others are
+[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
+[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
+parallel programming of multiple accelerators, with more to come.
+
+This is a research project, not an official Google product. Expect bugs and
+[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
+Please help by trying it out, [reporting
+bugs](https://github.com/google/jax/issues), and letting us know what you
+think!
+
+```python
+import jax.numpy as jnp
+from jax import grad, jit, vmap
+
+def predict(params, inputs):
+ for W, b in params:
+ outputs = jnp.dot(inputs, W) + b
+ inputs = jnp.tanh(outputs) # inputs to the next layer
+ return outputs # no activation on last layer
+
+def loss(params, inputs, targets):
+ preds = predict(params, inputs)
+ return jnp.sum((preds - targets)**2)
+
+grad_loss = jit(grad(loss)) # compiled gradient evaluation function
+perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
+```
+
+### Contents
+* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
+* [Transformations](#transformations)
+* [Current gotchas](#current-gotchas)
+* [Installation](#installation)
+* [Neural net libraries](#neural-network-libraries)
+* [Citing JAX](#citing-jax)
+* [Reference documentation](#reference-documentation)
+
+## Quickstart: Colab in the Cloud
+Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
+Here are some starter notebooks:
+- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
+- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
+
+**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
+Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs).
+
+For a deeper dive into JAX:
+- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
+- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
+- See the [full list of
+notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
+
+You can also take a look at [the mini-libraries in
+`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/example_libraries/README.md),
+like [`stax` for building neural
+networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax)
+and [`optimizers` for first-order stochastic
+optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization),
+or the [examples](https://github.com/google/jax/tree/main/examples).
+
+## Transformations
+
+At its core, JAX is an extensible system for transforming numerical functions.
+Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and
+`pmap`.
+
+### Automatic differentiation with `grad`
+
+JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
+The most popular function is
+[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
+for reverse-mode gradients:
+
+```python
+from jax import grad
+import jax.numpy as jnp
+
+def tanh(x): # Define a function
+ y = jnp.exp(-2.0 * x)
+ return (1.0 - y) / (1.0 + y)
+
+grad_tanh = grad(tanh) # Obtain its gradient function
+print(grad_tanh(1.0)) # Evaluate it at x = 1.0
+# prints 0.4199743
+```
+
+You can differentiate to any order with `grad`.
+
+```python
+print(grad(grad(grad(tanh)))(1.0))
+# prints 0.62162673
+```
+
+For more advanced autodiff, you can use
+[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
+reverse-mode vector-Jacobian products and
+[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
+forward-mode Jacobian-vector products. The two can be composed arbitrarily with
+one another, and with other JAX transformations. Here's one way to compose those
+to make a function that efficiently computes [full Hessian
+matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian):
+
+```python
+from jax import jit, jacfwd, jacrev
+
+def hessian(fun):
+ return jit(jacfwd(jacrev(fun)))
+```
+
+As with [Autograd](https://github.com/hips/autograd), you're free to use
+differentiation with Python control structures:
+
+```python
+def abs_val(x):
+ if x > 0:
+ return x
+ else:
+ return -x
+
+abs_val_grad = grad(abs_val)
+print(abs_val_grad(1.0)) # prints 1.0
+print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
+```
+
+See the [reference docs on automatic
+differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+and the [JAX Autodiff
+Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
+for more.
+
+### Compilation with `jit`
+
+You can use XLA to compile your functions end-to-end with
+[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
+used either as an `@jit` decorator or as a higher-order function.
+
+```python
+import jax.numpy as jnp
+from jax import jit
+
+def slow_f(x):
+ # Element-wise ops see a large benefit from fusion
+ return x * x + x * 2.0
+
+x = jnp.ones((5000, 5000))
+fast_f = jit(slow_f)
+%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
+%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
+```
+
+You can mix `jit` and `grad` and any other JAX transformation however you like.
+
+Using `jit` puts constraints on the kind of Python control flow
+the function can use; see
+the [Gotchas
+Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
+for more.
+
+### Auto-vectorization with `vmap`
+
+[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
+the vectorizing map.
+It has the familiar semantics of mapping a function along array axes, but
+instead of keeping the loop on the outside, it pushes the loop down into a
+function’s primitive operations for better performance.
+
+Using `vmap` can save you from having to carry around batch dimensions in your
+code. For example, consider this simple *unbatched* neural network prediction
+function:
+
+```python
+def predict(params, input_vec):
+ assert input_vec.ndim == 1
+ activations = input_vec
+ for W, b in params:
+ outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side!
+ activations = jnp.tanh(outputs) # inputs to the next layer
+ return outputs # no activation on last layer
+```
+
+We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the
+left side of `activations`, but we’ve written this particular prediction function to
+apply only to single input vectors. If we wanted to apply this function to a
+batch of inputs at once, semantically we could just write
+
+```python
+from functools import partial
+predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
+```
+
+But pushing one example through the network at a time would be slow! It’s better
+to vectorize the computation, so that at every layer we’re doing matrix-matrix
+multiplication rather than matrix-vector multiplication.
+
+The `vmap` function does that transformation for us. That is, if we write
+
+```python
+from jax import vmap
+predictions = vmap(partial(predict, params))(input_batch)
+# or, alternatively
+predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
+```
+
+then the `vmap` function will push the outer loop inside the function, and our
+machine will end up executing matrix-matrix multiplications exactly as if we’d
+done the batching by hand.
+
+It’s easy enough to manually batch a simple neural network without `vmap`, but
+in other cases manual vectorization can be impractical or impossible. Take the
+problem of efficiently computing per-example gradients: that is, for a fixed set
+of parameters, we want to compute the gradient of our loss function evaluated
+separately at each example in a batch. With `vmap`, it’s easy:
+
+```python
+per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
+```
+
+Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other
+JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
+differentiation for fast Jacobian and Hessian matrix calculations in
+`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.
+
+### SPMD programming with `pmap`
+
+For parallel programming of multiple accelerators, like multiple GPUs, use
+[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
+With `pmap` you write single-program multiple-data (SPMD) programs, including
+fast parallel collective communication operations. Applying `pmap` will mean
+that the function you write is compiled by XLA (similarly to `jit`), then
+replicated and executed in parallel across devices.
+
+Here's an example on an 8-GPU machine:
+
+```python
+from jax import random, pmap
+import jax.numpy as jnp
+
+# Create 8 random 5000 x 6000 matrices, one per GPU
+keys = random.split(random.PRNGKey(0), 8)
+mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
+
+# Run a local matmul on each device in parallel (no data transfer)
+result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
+
+# Compute the mean on each device in parallel and print the result
+print(pmap(jnp.mean)(result))
+# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
+```
+
+In addition to expressing pure maps, you can use fast [collective communication
+operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
+between devices:
+
+```python
+from functools import partial
+from jax import lax
+
+@partial(pmap, axis_name='i')
+def normalize(x):
+ return x / lax.psum(x, 'i')
+
+print(normalize(jnp.arange(4.)))
+# prints [0. 0.16666667 0.33333334 0.5 ]
+```
+
+You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
+sophisticated communication patterns.
+
+It all composes, so you're free to differentiate through parallel computations:
+
+```python
+from jax import grad
+
+@pmap
+def f(x):
+ y = jnp.sin(x)
+ @pmap
+ def g(z):
+ return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
+ return grad(lambda w: jnp.sum(g(w)))(x)
+
+print(f(x))
+# [[ 0. , -0.7170853 ],
+# [-3.1085174 , -0.4824318 ],
+# [10.366636 , 13.135289 ],
+# [ 0.22163185, -0.52112055]]
+
+print(grad(lambda x: jnp.sum(f(x)))(x))
+# [[ -3.2369726, -1.6356447],
+# [ 4.7572474, 11.606951 ],
+# [-98.524414 , 42.76499 ],
+# [ -1.6007166, -1.2568436]]
+```
+
+When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
+backward pass of the computation is parallelized just like the forward pass.
+
+See the [SPMD
+Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
+and the [SPMD MNIST classifier from scratch
+example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
+for more.
+
+## Current gotchas
+
+For a more thorough survey of current gotchas, with examples and explanations,
+we highly recommend reading the [Gotchas
+Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
+Some standouts:
+
+1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`.
+1. [In-place mutating updates of
+ arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
+1. [Random numbers are
+ different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md).
+1. If you're looking for [convolution
+ operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
+ they're in the `jax.lax` package.
+1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
+ [to enable
+ double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
+ (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
+ startup (or set the environment variable `JAX_ENABLE_X64=True`).
+ On TPU, JAX uses 32-bit values by default for everything _except_ internal
+ temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`.
+ Those ops have a `precision` parameter which can be used to simulate
+ true 32-bit, with a cost of possibly slower runtime.
+1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
+ and NumPy types aren't preserved, namely `np.add(1, np.array([2],
+ np.float32)).dtype` is `float64` rather than `float32`.
+1. Some transformations, like `jit`, [constrain how you can use Python control
+ flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
+ You'll always get loud errors if something goes wrong. You might have to use
+ [`jit`'s `static_argnums`
+ parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
+ [structured control flow
+ primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)
+ like
+ [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
+ or just use `jit` on smaller subfunctions.
+
+## Installation
+
+JAX is written in pure Python, but it depends on XLA, which needs to be
+installed as the `jaxlib` package. Use the following instructions to install a
+binary package with `pip` or `conda`, or to [build JAX from
+source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and
+macOS (10.12 or later) platforms.
+
+Windows users can use JAX on CPU and GPU via the [Windows Subsystem for
+Linux](https://docs.microsoft.com/en-us/windows/wsl/about). In addition, there
+is some initial community-driven native Windows support, but since it is still
+somewhat immature, there are no official binary releases and it must be [built
+from source for Windows](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows).
+For an unofficial discussion of native Windows builds, see also the [Issue #5795
+thread](https://github.com/google/jax/issues/5795).
+
+### pip installation: CPU
+
+To install a CPU-only version of JAX, which might be useful for doing local
+development on a laptop, you can run
+
+```bash
+pip install --upgrade pip
+pip install --upgrade "jax[cpu]"
+```
+
+On Linux, it is often necessary to first update `pip` to a version that supports
+`manylinux2014` wheels. Also note that for Linux, we currently release wheels for `x86_64` architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to `jaxlib` not being installed alongside `jax`, although `jax` may successfully install (but fail at runtime).
+**These `pip` installations do not work with Windows, and may fail silently; see
+[above](#installation).**
+
+### pip installation: GPU (CUDA, installed via pip, easier)
+
+There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
+installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
+installing CUDA and CUDNN using the pip wheels, since it is much easier!
+
+You must first install the NVIDIA driver. We
+recommend installing the newest driver available from NVIDIA, but the driver
+must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
+
+```bash
+pip install --upgrade pip
+
+# CUDA 12 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# CUDA 11 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+```
+
+### pip installation: GPU (CUDA, installed locally, harder)
+
+If you prefer to use a preinstalled copy of CUDA, you must first
+install [CUDA](https://developer.nvidia.com/cuda-downloads) and
+[CuDNN](https://developer.nvidia.com/CUDNN).
+
+JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
+combinations of operating system and architecture are possible, but require
+[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+You should use an NVIDIA driver version that is at least as new as your
+[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
+If you need to use an newer CUDA toolkit with an older driver, for example
+on a cluster where you cannot update the NVIDIA driver easily, you may be
+able to use the
+[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
+that NVIDIA provides for this purpose.
+
+JAX currently ships three CUDA wheel variants:
+* CUDA 12.0 and CuDNN 8.8.
+* CUDA 11.8 and CuDNN 8.6.
+* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued
+ with jax 0.4.8.
+
+You may use a JAX wheel provided the major version of your CUDA and CuDNN
+installation matches, and the minor version is at least as new as the version
+JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
+CUDA 12.1 and CuDNN 8.9.
+
+Your CUDA installation must also be new enough to support your GPU. If you have
+an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
+you must use CUDA 11.8 or newer.
+
+
+To install, run
+
+```bash
+pip install --upgrade pip
+
+# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated).
+pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+```
+
+**These `pip` installations do not work with Windows, and may fail silently; see
+[above](#installation).**
+
+You can find your CUDA version with the command:
+
+```bash
+nvcc --version
+```
+
+Some GPU functionality expects the CUDA installation to be at
+`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number
+(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either
+create a symlink:
+
+```bash
+sudo ln -s /path/to/cuda /usr/local/cuda-X.X
+```
+
+Please let us know on [the issue tracker](https://github.com/google/jax/issues)
+if you run into any errors or problems with the prebuilt wheels.
+
+### pip installation: Google Cloud TPU
+JAX also provides pre-built wheels for
+[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
+To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run
+the following in your cloud TPU VM:
+```bash
+pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+### pip installation: Colab TPU
+Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ.
+The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
+```python
+import jax.tools.colab_tpu
+jax.tools.colab_tpu.setup_tpu()
+```
+Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer.
+If you need to re-install JAX on a Colab TPU runtime, you can use the following command:
+```
+!pip install jax<=0.3.25 jaxlib<=0.3.25
+```
+
+### Conda installation
+
+There is a community-supported Conda build of `jax`. To install using `conda`,
+simply run
+
+```bash
+conda install jax -c conda-forge
+```
+
+To install on a machine with an NVIDIA GPU, run
+```bash
+conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
+```
+
+Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
+JAX requires. You must therefore either install the `cuda-nvcc` package from
+the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
+is in your path. The channel order above is important (`conda-forge` before
+`nvidia`).
+
+If you would like to override which release of CUDA is used by JAX, or to
+install the CUDA build on a machine without GPUs, follow the instructions in the
+[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)
+section of the `conda-forge` website.
+
+See the `conda-forge`
+[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and
+[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories
+for more details.
+
+### Building JAX from source
+See [Building JAX from
+source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
+
+## Neural network libraries
+
+Multiple Google research groups develop and share libraries for training neural
+networks in JAX. If you want a fully featured library for neural network
+training with examples and how-to guides, try
+[Flax](https://github.com/google/flax).
+
+In addition, DeepMind has open-sourced an [ecosystem of libraries around
+JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research)
+including [Haiku](https://github.com/deepmind/dm-haiku) for neural network
+modules, [Optax](https://github.com/deepmind/optax) for gradient processing and
+optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and
+[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch
+the NeurIPS 2020 JAX Ecosystem at DeepMind talk
+[here](https://www.youtube.com/watch?v=iDxJxIyzSiM))
+
+## Citing JAX
+
+To cite this repository:
+
+```
+@software{jax2018github,
+ author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
+ title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
+ url = {http://github.com/google/jax},
+ version = {0.3.13},
+ year = {2018},
+}
+```
+
+In the above bibtex entry, names are in alphabetical order, the version number
+is intended to be that from [jax/version.py](../main/jax/version.py), and
+the year corresponds to the project's open-source release.
+
+A nascent version of JAX, supporting only automatic differentiation and
+compilation to XLA, was described in a [paper that appeared at SysML
+2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on
+covering JAX's ideas and capabilities in a more comprehensive and up-to-date
+paper.
+
+## Reference documentation
+
+For details about the JAX API, see the
+[reference documentation](https://jax.readthedocs.io/).
+
+For getting started as a JAX developer, see the
+[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
+
+
+%prep
+%autosetup -n jax-0.4.8
+
+%build
+%py3_build
+
+%install
+%py3_install
+install -d -m755 %{buildroot}/%{_pkgdocdir}
+if [ -d doc ]; then cp -arf doc %{buildroot}/%{_pkgdocdir}; fi
+if [ -d docs ]; then cp -arf docs %{buildroot}/%{_pkgdocdir}; fi
+if [ -d example ]; then cp -arf example %{buildroot}/%{_pkgdocdir}; fi
+if [ -d examples ]; then cp -arf examples %{buildroot}/%{_pkgdocdir}; fi
+pushd %{buildroot}
+if [ -d usr/lib ]; then
+ find usr/lib -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+if [ -d usr/lib64 ]; then
+ find usr/lib64 -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+if [ -d usr/bin ]; then
+ find usr/bin -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+if [ -d usr/sbin ]; then
+ find usr/sbin -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+touch doclist.lst
+if [ -d usr/share/man ]; then
+ find usr/share/man -type f -printf "/%h/%f.gz\n" >> doclist.lst
+fi
+popd
+mv %{buildroot}/filelist.lst .
+mv %{buildroot}/doclist.lst .
+
+%files -n python3-jax -f filelist.lst
+%dir %{python3_sitelib}/*
+
+%files help -f doclist.lst
+%{_docdir}/*
+
+%changelog
+* Mon Apr 10 2023 Python_Bot <Python_Bot@openeuler.org> - 0.4.8-1
+- Package Spec generated
diff --git a/sources b/sources
new file mode 100644
index 0000000..3f53767
--- /dev/null
+++ b/sources
@@ -0,0 +1 @@
+06e5f085d658fd782189ccca5fbf21f5 jax-0.4.8.tar.gz