summaryrefslogtreecommitdiff
path: root/python-optax.spec
diff options
context:
space:
mode:
Diffstat (limited to 'python-optax.spec')
-rw-r--r--python-optax.spec930
1 files changed, 930 insertions, 0 deletions
diff --git a/python-optax.spec b/python-optax.spec
new file mode 100644
index 0000000..d024a5f
--- /dev/null
+++ b/python-optax.spec
@@ -0,0 +1,930 @@
+%global _empty_manifest_terminate_build 0
+Name: python-optax
+Version: 0.1.4
+Release: 1
+Summary: A gradient processing and optimisation library in JAX.
+License: Apache 2.0
+URL: https://github.com/deepmind/optax
+Source0: https://mirrors.nju.edu.cn/pypi/web/packages/9a/74/66b5c8c59b21017d50a04db5b2eca113418d35c99e17f7a62c76a22c8e88/optax-0.1.4.tar.gz
+BuildArch: noarch
+
+Requires: python3-absl-py
+Requires: python3-chex
+Requires: python3-jax
+Requires: python3-jaxlib
+Requires: python3-numpy
+Requires: python3-typing-extensions
+
+%description
+# Optax
+
+![CI status](https://github.com/deepmind/optax/workflows/tests/badge.svg)
+[![Documentation Status](https://readthedocs.org/projects/optax/badge/?version=latest)](http://optax.readthedocs.io)
+![pypi](https://img.shields.io/pypi/v/optax)
+
+## Introduction
+
+Optax is a gradient processing and optimization library for JAX.
+
+Optax is designed to facilitate research by providing building blocks
+that can be easily recombined in custom ways.
+
+Our goals are to
+
+* Provide simple, well-tested, efficient implementations of core components.
+* Improve research productivity by enabling to easily combine low level
+ ingredients into custom optimisers (or other gradient processing components).
+* Accelerate adoption of new ideas by making it easy for anyone to contribute.
+
+We favour focusing on small composable building blocks that can be effectively
+combined into custom solutions. Others may build upon these basic components
+more complicated abstractions. Whenever reasonable, implementations prioritise
+readability and structuring code to match standard equations, over code reuse.
+
+An initial prototype of this library was made available in JAX's experimental
+folder as `jax.experimental.optix`. Given the wide adoption across DeepMind
+of `optix`, and after a few iterations on the API, `optix` was eventually moved
+out of `experimental` as a standalone open-source library, renamed `optax`.
+
+Documentation on Optax can be found at [optax.readthedocs.io](https://optax.readthedocs.io/).
+
+## Installation
+
+You can install the latest released version of Optax from PyPI via:
+
+```sh
+pip install optax
+```
+
+or you can install the latest development version from GitHub:
+
+```sh
+pip install git+https://github.com/deepmind/optax.git
+```
+
+## Quickstart
+
+Optax contains implementations of [many popular optimizers](https://optax.readthedocs.io/en/latest/api.html#Common-Optimizers) and
+[loss functions](https://optax.readthedocs.io/en/latest/api.html#common-losses).
+For example the following code snippet uses the Adam optimizer from `optax.adam`
+and the mean squared error from `optax.l2_loss`. We initialize the optimizer
+state using the `init` function and `params` of the model.
+
+```python
+optimizer = optax.adam(learning_rate)
+# Obtain the `opt_state` that contains statistics for the optimizer.
+params = {'w': jnp.ones((num_weights,))}
+opt_state = optimizer.init(params)
+```
+
+To write the update loop we need a loss function that can be differentiated by
+Jax (with `jax.grad` in this
+example) to obtain the gradients.
+
+```python
+compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
+grads = jax.grad(compute_loss)(params, xs, ys)
+```
+
+The gradients are then converted via `optimizer.update` to obtain the updates
+that should be applied to the current params to obtain the new ones.
+`optax.apply_updates` is a convinience utility to do this.
+
+```python
+updates, opt_state = optimizer.update(grads, opt_state)
+params = optax.apply_updates(params, updates)
+```
+
+You can continue the quick start in [the Optax quickstart notebook.](https://github.com/deepmind/optax/blob/master/examples/quick_start.ipynb)
+
+
+## Components
+
+We refer to the [docs](https://optax.readthedocs.io/en/latest/index.html)
+for a detailed list of available Optax components. Here, we highlight
+the main categories of buiilding blocks provided by Optax.
+
+### Gradient Transformations ([transform.py](https://github.com/deepmind/optax/blob/master/optax/_src/transform.py))
+
+One of the key building blocks of `optax` is a `GradientTransformation`.
+
+Each transformation is defined two functions:
+
+* `state = init(params)`
+* `grads, state = update(grads, state, params=None)`
+
+The `init` function initializes a (possibly empty) set of statistics (aka state)
+and the `update` function transforms a candidate gradient given some statistics,
+and (optionally) the current value of the parameters.
+
+For example:
+
+```python
+tx = scale_by_rms()
+state = tx.init(params) # init stats
+grads, state = tx.update(grads, state, params) # transform & update stats.
+```
+
+### Composing Gradient Transformations ([combine.py](https://github.com/deepmind/optax/blob/master/optax/_src/combine.py))
+
+The fact that transformations take candidate gradients as input and return
+processed gradients as output (in contrast to returning the updated parameters)
+is critical to allow to combine arbitrary transformations into a custom
+optimiser / gradient processor, and also allows to combine transformations for
+different gradients that operate on a shared set of variables.
+
+For instance, `chain` combines them sequentially, and returns a
+new `GradientTransformation` that applies several transformations in sequence.
+
+For example:
+
+```python
+my_optimiser = chain(
+ clip_by_global_norm(max_norm),
+ scale_by_adam(eps=1e-4),
+ scale(-learning_rate))
+```
+
+### Wrapping Gradient Transformations ([wrappers.py](https://github.com/deepmind/optax/blob/master/optax/_src/wrappers.py))
+
+Optax also provides several wrappers that take a `GradientTransformation` as
+input and return a new `GradientTransformation` that modifies the behaviour
+of the inner transformation in a specific way.
+
+For instance the `flatten` wrapper flattens gradients into a single large vector
+before applying the inner GradientTransformation. The transformed updated are
+then unflattened before being returned to the user. This can be used to reduce
+the overhead of performing many calculations on lots of small variables,
+at the cost of increasing memory usage.
+
+For example:
+```python
+my_optimiser = flatten(adam(learning_rate))
+```
+
+Other examples of wrappers include accumulating gradients over multiple steps,
+or applying the inner transformation only to specific parameters or at
+specific steps.
+
+### Schedules ([schedule.py](https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py))
+
+Many popular transformations use time dependent components, e.g. to anneal
+some hyper-parameter (e.g. the learning rate). Optax provides for this purpose
+`schedules` that can be used to decay scalars as a function of a `step` count.
+
+For example you may use a polynomial schedule (with `power=1`) to decay
+a hyper-parameter linearly over a number of steps:
+
+```python
+schedule_fn = polynomial_schedule(
+ init_value=1., end_value=0., power=1, transition_steps=5)
+
+for step_count in range(6):
+ print(schedule_fn(step_count)) # [1., 0.8, 0.6, 0.4, 0.2, 0.]
+```
+
+Schedules are used by certain gradient transformation, for instance:
+
+```python
+schedule_fn = polynomial_schedule(
+ init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
+optimiser = chain(
+ clip_by_global_norm(max_norm),
+ scale_by_adam(eps=1e-4),
+ scale_by_schedule(schedule_fn))
+```
+
+### Popular optimisers ([alias.py](https://github.com/deepmind/optax/blob/master/optax/_src/alias.py))
+
+In addition to the low level building blocks we also provide aliases for popular
+optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...).
+These are all still instances of a `GradientTransformation`, and can therefore
+be further combined with any of the individual building blocks.
+
+For example:
+
+```python
+def adamw(learning_rate, b1, b2, eps, weight_decay):
+ return chain(
+ scale_by_adam(b1=b1, b2=b2, eps=eps),
+ scale_and_decay(-learning_rate, weight_decay=weight_decay))
+```
+
+### Applying updates ([update.py](https://github.com/deepmind/optax/blob/master/optax/_src/update.py))
+
+After transforming an update using a `GradientTransformation` or any custom
+manipulation of the update, you will typically apply the update to a set
+of parameters. This can be done trivially using `tree_map`.
+
+For convenience, we expose an `apply_updates` function to apply updates to
+parameters. The function just adds the updates and the parameters together,
+i.e. `tree_map(lambda p, u: p + u, params, updates)`.
+
+```python
+updates, state = tx.update(grads, state, params) # transform & update stats.
+new_params = optax.apply_updates(params, updates) # update the parameters.
+```
+
+Note that separating gradient transformations from the parameter update is
+critical to support composing sequence of transformations (e.g. `chain`), as
+well as combine multiple updates to the same parameters (e.g. in multi-task
+settings where different tasks need different sets of gradient transformations).
+
+### Losses ([loss.py](https://github.com/deepmind/optax/blob/master/optax/_src/loss.py))
+
+Optax provides a number of standard losses used in deep learning, such as
+`l2_loss`, `softmax_cross_entropy`, `cosine_distance`, etc.
+
+```python
+loss = huber_loss(predictions, targets)
+```
+
+The losses accept batches as inputs, however they perform no reduction across
+the batch dimension(s). This is trivial to do in JAX, for example:
+
+```python
+avg_loss = jnp.mean(huber_loss(predictions, targets))
+sum_loss = jnp.sum(huber_loss(predictions, targets))
+```
+
+### Second Order ([second_order.py](https://github.com/deepmind/optax/blob/master/optax/_src/second_order.py))
+
+Computing the Hessian or Fisher information matrices for neural networks is
+typically intractable due to the quadratic memory requirements. Solving for the
+diagonals of these matrices is often a better solution. The library offers
+functions for computing these diagonals with sub-quadratic memory requirements.
+
+### Stochastic gradient estimators ([stochastic_gradient_estimators.py](https://github.com/deepmind/optax/blob/master/optax/_src/stochastic_gradient_estimators.py))
+
+Stochastic gradient estimators compute Monte Carlo estimates of gradients of
+the expectation of a function under a distribution with respect to the
+distribution's parameters.
+
+Unbiased estimators, such as the score function estimator (REINFORCE),
+pathwise estimator (reparameterization trick) or measure valued estimator,
+are implemented: `score_function_jacobians`, `pathwise_jacobians` and `
+measure_valued_jacobians`. Their applicability (both in terms of functions and
+distributions) is discussed in their respective documentation.
+
+Stochastic gradient estimators can be combined with common control variates for
+variance reduction via `control_variates_jacobians`. For provided control
+variates see `delta` and `moving_avg_baseline`.
+
+The result of a gradient estimator or `control_variates_jacobians` contains the
+Jacobians of the function with respect to the samples from the input
+distribution. These can then be used to update distributional parameters, or
+to assess gradient variance.
+
+Example of how to use the `pathwise_jacobians` estimator:
+
+```python
+dist_params = [mean, log_scale]
+function = lambda x: jnp.sum(x * weights)
+jacobians = pathwise_jacobians(
+ function, dist_params,
+ utils.multi_normal, rng, num_samples)
+
+mean_grads = jnp.mean(jacobians[0], axis=0)
+log_scale_grads = jnp.mean(jacobians[1], axis=0)
+grads = [mean_grads, log_scale_grads]
+optim_update, optim_state = optim.update(grads, optim_state)
+updated_dist_params = optax.apply_updates(dist_params, optim_update)
+```
+
+where `optim` is an optax optimizer.
+
+## Citing Optax
+
+Optax is part of the [DeepMind JAX Ecosystem], to cite Optax please use
+the [DeepMind JAX Ecosystem citation].
+
+[DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem"
+[DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation"
+
+
+%package -n python3-optax
+Summary: A gradient processing and optimisation library in JAX.
+Provides: python-optax
+BuildRequires: python3-devel
+BuildRequires: python3-setuptools
+BuildRequires: python3-pip
+%description -n python3-optax
+# Optax
+
+![CI status](https://github.com/deepmind/optax/workflows/tests/badge.svg)
+[![Documentation Status](https://readthedocs.org/projects/optax/badge/?version=latest)](http://optax.readthedocs.io)
+![pypi](https://img.shields.io/pypi/v/optax)
+
+## Introduction
+
+Optax is a gradient processing and optimization library for JAX.
+
+Optax is designed to facilitate research by providing building blocks
+that can be easily recombined in custom ways.
+
+Our goals are to
+
+* Provide simple, well-tested, efficient implementations of core components.
+* Improve research productivity by enabling to easily combine low level
+ ingredients into custom optimisers (or other gradient processing components).
+* Accelerate adoption of new ideas by making it easy for anyone to contribute.
+
+We favour focusing on small composable building blocks that can be effectively
+combined into custom solutions. Others may build upon these basic components
+more complicated abstractions. Whenever reasonable, implementations prioritise
+readability and structuring code to match standard equations, over code reuse.
+
+An initial prototype of this library was made available in JAX's experimental
+folder as `jax.experimental.optix`. Given the wide adoption across DeepMind
+of `optix`, and after a few iterations on the API, `optix` was eventually moved
+out of `experimental` as a standalone open-source library, renamed `optax`.
+
+Documentation on Optax can be found at [optax.readthedocs.io](https://optax.readthedocs.io/).
+
+## Installation
+
+You can install the latest released version of Optax from PyPI via:
+
+```sh
+pip install optax
+```
+
+or you can install the latest development version from GitHub:
+
+```sh
+pip install git+https://github.com/deepmind/optax.git
+```
+
+## Quickstart
+
+Optax contains implementations of [many popular optimizers](https://optax.readthedocs.io/en/latest/api.html#Common-Optimizers) and
+[loss functions](https://optax.readthedocs.io/en/latest/api.html#common-losses).
+For example the following code snippet uses the Adam optimizer from `optax.adam`
+and the mean squared error from `optax.l2_loss`. We initialize the optimizer
+state using the `init` function and `params` of the model.
+
+```python
+optimizer = optax.adam(learning_rate)
+# Obtain the `opt_state` that contains statistics for the optimizer.
+params = {'w': jnp.ones((num_weights,))}
+opt_state = optimizer.init(params)
+```
+
+To write the update loop we need a loss function that can be differentiated by
+Jax (with `jax.grad` in this
+example) to obtain the gradients.
+
+```python
+compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
+grads = jax.grad(compute_loss)(params, xs, ys)
+```
+
+The gradients are then converted via `optimizer.update` to obtain the updates
+that should be applied to the current params to obtain the new ones.
+`optax.apply_updates` is a convinience utility to do this.
+
+```python
+updates, opt_state = optimizer.update(grads, opt_state)
+params = optax.apply_updates(params, updates)
+```
+
+You can continue the quick start in [the Optax quickstart notebook.](https://github.com/deepmind/optax/blob/master/examples/quick_start.ipynb)
+
+
+## Components
+
+We refer to the [docs](https://optax.readthedocs.io/en/latest/index.html)
+for a detailed list of available Optax components. Here, we highlight
+the main categories of buiilding blocks provided by Optax.
+
+### Gradient Transformations ([transform.py](https://github.com/deepmind/optax/blob/master/optax/_src/transform.py))
+
+One of the key building blocks of `optax` is a `GradientTransformation`.
+
+Each transformation is defined two functions:
+
+* `state = init(params)`
+* `grads, state = update(grads, state, params=None)`
+
+The `init` function initializes a (possibly empty) set of statistics (aka state)
+and the `update` function transforms a candidate gradient given some statistics,
+and (optionally) the current value of the parameters.
+
+For example:
+
+```python
+tx = scale_by_rms()
+state = tx.init(params) # init stats
+grads, state = tx.update(grads, state, params) # transform & update stats.
+```
+
+### Composing Gradient Transformations ([combine.py](https://github.com/deepmind/optax/blob/master/optax/_src/combine.py))
+
+The fact that transformations take candidate gradients as input and return
+processed gradients as output (in contrast to returning the updated parameters)
+is critical to allow to combine arbitrary transformations into a custom
+optimiser / gradient processor, and also allows to combine transformations for
+different gradients that operate on a shared set of variables.
+
+For instance, `chain` combines them sequentially, and returns a
+new `GradientTransformation` that applies several transformations in sequence.
+
+For example:
+
+```python
+my_optimiser = chain(
+ clip_by_global_norm(max_norm),
+ scale_by_adam(eps=1e-4),
+ scale(-learning_rate))
+```
+
+### Wrapping Gradient Transformations ([wrappers.py](https://github.com/deepmind/optax/blob/master/optax/_src/wrappers.py))
+
+Optax also provides several wrappers that take a `GradientTransformation` as
+input and return a new `GradientTransformation` that modifies the behaviour
+of the inner transformation in a specific way.
+
+For instance the `flatten` wrapper flattens gradients into a single large vector
+before applying the inner GradientTransformation. The transformed updated are
+then unflattened before being returned to the user. This can be used to reduce
+the overhead of performing many calculations on lots of small variables,
+at the cost of increasing memory usage.
+
+For example:
+```python
+my_optimiser = flatten(adam(learning_rate))
+```
+
+Other examples of wrappers include accumulating gradients over multiple steps,
+or applying the inner transformation only to specific parameters or at
+specific steps.
+
+### Schedules ([schedule.py](https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py))
+
+Many popular transformations use time dependent components, e.g. to anneal
+some hyper-parameter (e.g. the learning rate). Optax provides for this purpose
+`schedules` that can be used to decay scalars as a function of a `step` count.
+
+For example you may use a polynomial schedule (with `power=1`) to decay
+a hyper-parameter linearly over a number of steps:
+
+```python
+schedule_fn = polynomial_schedule(
+ init_value=1., end_value=0., power=1, transition_steps=5)
+
+for step_count in range(6):
+ print(schedule_fn(step_count)) # [1., 0.8, 0.6, 0.4, 0.2, 0.]
+```
+
+Schedules are used by certain gradient transformation, for instance:
+
+```python
+schedule_fn = polynomial_schedule(
+ init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
+optimiser = chain(
+ clip_by_global_norm(max_norm),
+ scale_by_adam(eps=1e-4),
+ scale_by_schedule(schedule_fn))
+```
+
+### Popular optimisers ([alias.py](https://github.com/deepmind/optax/blob/master/optax/_src/alias.py))
+
+In addition to the low level building blocks we also provide aliases for popular
+optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...).
+These are all still instances of a `GradientTransformation`, and can therefore
+be further combined with any of the individual building blocks.
+
+For example:
+
+```python
+def adamw(learning_rate, b1, b2, eps, weight_decay):
+ return chain(
+ scale_by_adam(b1=b1, b2=b2, eps=eps),
+ scale_and_decay(-learning_rate, weight_decay=weight_decay))
+```
+
+### Applying updates ([update.py](https://github.com/deepmind/optax/blob/master/optax/_src/update.py))
+
+After transforming an update using a `GradientTransformation` or any custom
+manipulation of the update, you will typically apply the update to a set
+of parameters. This can be done trivially using `tree_map`.
+
+For convenience, we expose an `apply_updates` function to apply updates to
+parameters. The function just adds the updates and the parameters together,
+i.e. `tree_map(lambda p, u: p + u, params, updates)`.
+
+```python
+updates, state = tx.update(grads, state, params) # transform & update stats.
+new_params = optax.apply_updates(params, updates) # update the parameters.
+```
+
+Note that separating gradient transformations from the parameter update is
+critical to support composing sequence of transformations (e.g. `chain`), as
+well as combine multiple updates to the same parameters (e.g. in multi-task
+settings where different tasks need different sets of gradient transformations).
+
+### Losses ([loss.py](https://github.com/deepmind/optax/blob/master/optax/_src/loss.py))
+
+Optax provides a number of standard losses used in deep learning, such as
+`l2_loss`, `softmax_cross_entropy`, `cosine_distance`, etc.
+
+```python
+loss = huber_loss(predictions, targets)
+```
+
+The losses accept batches as inputs, however they perform no reduction across
+the batch dimension(s). This is trivial to do in JAX, for example:
+
+```python
+avg_loss = jnp.mean(huber_loss(predictions, targets))
+sum_loss = jnp.sum(huber_loss(predictions, targets))
+```
+
+### Second Order ([second_order.py](https://github.com/deepmind/optax/blob/master/optax/_src/second_order.py))
+
+Computing the Hessian or Fisher information matrices for neural networks is
+typically intractable due to the quadratic memory requirements. Solving for the
+diagonals of these matrices is often a better solution. The library offers
+functions for computing these diagonals with sub-quadratic memory requirements.
+
+### Stochastic gradient estimators ([stochastic_gradient_estimators.py](https://github.com/deepmind/optax/blob/master/optax/_src/stochastic_gradient_estimators.py))
+
+Stochastic gradient estimators compute Monte Carlo estimates of gradients of
+the expectation of a function under a distribution with respect to the
+distribution's parameters.
+
+Unbiased estimators, such as the score function estimator (REINFORCE),
+pathwise estimator (reparameterization trick) or measure valued estimator,
+are implemented: `score_function_jacobians`, `pathwise_jacobians` and `
+measure_valued_jacobians`. Their applicability (both in terms of functions and
+distributions) is discussed in their respective documentation.
+
+Stochastic gradient estimators can be combined with common control variates for
+variance reduction via `control_variates_jacobians`. For provided control
+variates see `delta` and `moving_avg_baseline`.
+
+The result of a gradient estimator or `control_variates_jacobians` contains the
+Jacobians of the function with respect to the samples from the input
+distribution. These can then be used to update distributional parameters, or
+to assess gradient variance.
+
+Example of how to use the `pathwise_jacobians` estimator:
+
+```python
+dist_params = [mean, log_scale]
+function = lambda x: jnp.sum(x * weights)
+jacobians = pathwise_jacobians(
+ function, dist_params,
+ utils.multi_normal, rng, num_samples)
+
+mean_grads = jnp.mean(jacobians[0], axis=0)
+log_scale_grads = jnp.mean(jacobians[1], axis=0)
+grads = [mean_grads, log_scale_grads]
+optim_update, optim_state = optim.update(grads, optim_state)
+updated_dist_params = optax.apply_updates(dist_params, optim_update)
+```
+
+where `optim` is an optax optimizer.
+
+## Citing Optax
+
+Optax is part of the [DeepMind JAX Ecosystem], to cite Optax please use
+the [DeepMind JAX Ecosystem citation].
+
+[DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem"
+[DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation"
+
+
+%package help
+Summary: Development documents and examples for optax
+Provides: python3-optax-doc
+%description help
+# Optax
+
+![CI status](https://github.com/deepmind/optax/workflows/tests/badge.svg)
+[![Documentation Status](https://readthedocs.org/projects/optax/badge/?version=latest)](http://optax.readthedocs.io)
+![pypi](https://img.shields.io/pypi/v/optax)
+
+## Introduction
+
+Optax is a gradient processing and optimization library for JAX.
+
+Optax is designed to facilitate research by providing building blocks
+that can be easily recombined in custom ways.
+
+Our goals are to
+
+* Provide simple, well-tested, efficient implementations of core components.
+* Improve research productivity by enabling to easily combine low level
+ ingredients into custom optimisers (or other gradient processing components).
+* Accelerate adoption of new ideas by making it easy for anyone to contribute.
+
+We favour focusing on small composable building blocks that can be effectively
+combined into custom solutions. Others may build upon these basic components
+more complicated abstractions. Whenever reasonable, implementations prioritise
+readability and structuring code to match standard equations, over code reuse.
+
+An initial prototype of this library was made available in JAX's experimental
+folder as `jax.experimental.optix`. Given the wide adoption across DeepMind
+of `optix`, and after a few iterations on the API, `optix` was eventually moved
+out of `experimental` as a standalone open-source library, renamed `optax`.
+
+Documentation on Optax can be found at [optax.readthedocs.io](https://optax.readthedocs.io/).
+
+## Installation
+
+You can install the latest released version of Optax from PyPI via:
+
+```sh
+pip install optax
+```
+
+or you can install the latest development version from GitHub:
+
+```sh
+pip install git+https://github.com/deepmind/optax.git
+```
+
+## Quickstart
+
+Optax contains implementations of [many popular optimizers](https://optax.readthedocs.io/en/latest/api.html#Common-Optimizers) and
+[loss functions](https://optax.readthedocs.io/en/latest/api.html#common-losses).
+For example the following code snippet uses the Adam optimizer from `optax.adam`
+and the mean squared error from `optax.l2_loss`. We initialize the optimizer
+state using the `init` function and `params` of the model.
+
+```python
+optimizer = optax.adam(learning_rate)
+# Obtain the `opt_state` that contains statistics for the optimizer.
+params = {'w': jnp.ones((num_weights,))}
+opt_state = optimizer.init(params)
+```
+
+To write the update loop we need a loss function that can be differentiated by
+Jax (with `jax.grad` in this
+example) to obtain the gradients.
+
+```python
+compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
+grads = jax.grad(compute_loss)(params, xs, ys)
+```
+
+The gradients are then converted via `optimizer.update` to obtain the updates
+that should be applied to the current params to obtain the new ones.
+`optax.apply_updates` is a convinience utility to do this.
+
+```python
+updates, opt_state = optimizer.update(grads, opt_state)
+params = optax.apply_updates(params, updates)
+```
+
+You can continue the quick start in [the Optax quickstart notebook.](https://github.com/deepmind/optax/blob/master/examples/quick_start.ipynb)
+
+
+## Components
+
+We refer to the [docs](https://optax.readthedocs.io/en/latest/index.html)
+for a detailed list of available Optax components. Here, we highlight
+the main categories of buiilding blocks provided by Optax.
+
+### Gradient Transformations ([transform.py](https://github.com/deepmind/optax/blob/master/optax/_src/transform.py))
+
+One of the key building blocks of `optax` is a `GradientTransformation`.
+
+Each transformation is defined two functions:
+
+* `state = init(params)`
+* `grads, state = update(grads, state, params=None)`
+
+The `init` function initializes a (possibly empty) set of statistics (aka state)
+and the `update` function transforms a candidate gradient given some statistics,
+and (optionally) the current value of the parameters.
+
+For example:
+
+```python
+tx = scale_by_rms()
+state = tx.init(params) # init stats
+grads, state = tx.update(grads, state, params) # transform & update stats.
+```
+
+### Composing Gradient Transformations ([combine.py](https://github.com/deepmind/optax/blob/master/optax/_src/combine.py))
+
+The fact that transformations take candidate gradients as input and return
+processed gradients as output (in contrast to returning the updated parameters)
+is critical to allow to combine arbitrary transformations into a custom
+optimiser / gradient processor, and also allows to combine transformations for
+different gradients that operate on a shared set of variables.
+
+For instance, `chain` combines them sequentially, and returns a
+new `GradientTransformation` that applies several transformations in sequence.
+
+For example:
+
+```python
+my_optimiser = chain(
+ clip_by_global_norm(max_norm),
+ scale_by_adam(eps=1e-4),
+ scale(-learning_rate))
+```
+
+### Wrapping Gradient Transformations ([wrappers.py](https://github.com/deepmind/optax/blob/master/optax/_src/wrappers.py))
+
+Optax also provides several wrappers that take a `GradientTransformation` as
+input and return a new `GradientTransformation` that modifies the behaviour
+of the inner transformation in a specific way.
+
+For instance the `flatten` wrapper flattens gradients into a single large vector
+before applying the inner GradientTransformation. The transformed updated are
+then unflattened before being returned to the user. This can be used to reduce
+the overhead of performing many calculations on lots of small variables,
+at the cost of increasing memory usage.
+
+For example:
+```python
+my_optimiser = flatten(adam(learning_rate))
+```
+
+Other examples of wrappers include accumulating gradients over multiple steps,
+or applying the inner transformation only to specific parameters or at
+specific steps.
+
+### Schedules ([schedule.py](https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py))
+
+Many popular transformations use time dependent components, e.g. to anneal
+some hyper-parameter (e.g. the learning rate). Optax provides for this purpose
+`schedules` that can be used to decay scalars as a function of a `step` count.
+
+For example you may use a polynomial schedule (with `power=1`) to decay
+a hyper-parameter linearly over a number of steps:
+
+```python
+schedule_fn = polynomial_schedule(
+ init_value=1., end_value=0., power=1, transition_steps=5)
+
+for step_count in range(6):
+ print(schedule_fn(step_count)) # [1., 0.8, 0.6, 0.4, 0.2, 0.]
+```
+
+Schedules are used by certain gradient transformation, for instance:
+
+```python
+schedule_fn = polynomial_schedule(
+ init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
+optimiser = chain(
+ clip_by_global_norm(max_norm),
+ scale_by_adam(eps=1e-4),
+ scale_by_schedule(schedule_fn))
+```
+
+### Popular optimisers ([alias.py](https://github.com/deepmind/optax/blob/master/optax/_src/alias.py))
+
+In addition to the low level building blocks we also provide aliases for popular
+optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...).
+These are all still instances of a `GradientTransformation`, and can therefore
+be further combined with any of the individual building blocks.
+
+For example:
+
+```python
+def adamw(learning_rate, b1, b2, eps, weight_decay):
+ return chain(
+ scale_by_adam(b1=b1, b2=b2, eps=eps),
+ scale_and_decay(-learning_rate, weight_decay=weight_decay))
+```
+
+### Applying updates ([update.py](https://github.com/deepmind/optax/blob/master/optax/_src/update.py))
+
+After transforming an update using a `GradientTransformation` or any custom
+manipulation of the update, you will typically apply the update to a set
+of parameters. This can be done trivially using `tree_map`.
+
+For convenience, we expose an `apply_updates` function to apply updates to
+parameters. The function just adds the updates and the parameters together,
+i.e. `tree_map(lambda p, u: p + u, params, updates)`.
+
+```python
+updates, state = tx.update(grads, state, params) # transform & update stats.
+new_params = optax.apply_updates(params, updates) # update the parameters.
+```
+
+Note that separating gradient transformations from the parameter update is
+critical to support composing sequence of transformations (e.g. `chain`), as
+well as combine multiple updates to the same parameters (e.g. in multi-task
+settings where different tasks need different sets of gradient transformations).
+
+### Losses ([loss.py](https://github.com/deepmind/optax/blob/master/optax/_src/loss.py))
+
+Optax provides a number of standard losses used in deep learning, such as
+`l2_loss`, `softmax_cross_entropy`, `cosine_distance`, etc.
+
+```python
+loss = huber_loss(predictions, targets)
+```
+
+The losses accept batches as inputs, however they perform no reduction across
+the batch dimension(s). This is trivial to do in JAX, for example:
+
+```python
+avg_loss = jnp.mean(huber_loss(predictions, targets))
+sum_loss = jnp.sum(huber_loss(predictions, targets))
+```
+
+### Second Order ([second_order.py](https://github.com/deepmind/optax/blob/master/optax/_src/second_order.py))
+
+Computing the Hessian or Fisher information matrices for neural networks is
+typically intractable due to the quadratic memory requirements. Solving for the
+diagonals of these matrices is often a better solution. The library offers
+functions for computing these diagonals with sub-quadratic memory requirements.
+
+### Stochastic gradient estimators ([stochastic_gradient_estimators.py](https://github.com/deepmind/optax/blob/master/optax/_src/stochastic_gradient_estimators.py))
+
+Stochastic gradient estimators compute Monte Carlo estimates of gradients of
+the expectation of a function under a distribution with respect to the
+distribution's parameters.
+
+Unbiased estimators, such as the score function estimator (REINFORCE),
+pathwise estimator (reparameterization trick) or measure valued estimator,
+are implemented: `score_function_jacobians`, `pathwise_jacobians` and `
+measure_valued_jacobians`. Their applicability (both in terms of functions and
+distributions) is discussed in their respective documentation.
+
+Stochastic gradient estimators can be combined with common control variates for
+variance reduction via `control_variates_jacobians`. For provided control
+variates see `delta` and `moving_avg_baseline`.
+
+The result of a gradient estimator or `control_variates_jacobians` contains the
+Jacobians of the function with respect to the samples from the input
+distribution. These can then be used to update distributional parameters, or
+to assess gradient variance.
+
+Example of how to use the `pathwise_jacobians` estimator:
+
+```python
+dist_params = [mean, log_scale]
+function = lambda x: jnp.sum(x * weights)
+jacobians = pathwise_jacobians(
+ function, dist_params,
+ utils.multi_normal, rng, num_samples)
+
+mean_grads = jnp.mean(jacobians[0], axis=0)
+log_scale_grads = jnp.mean(jacobians[1], axis=0)
+grads = [mean_grads, log_scale_grads]
+optim_update, optim_state = optim.update(grads, optim_state)
+updated_dist_params = optax.apply_updates(dist_params, optim_update)
+```
+
+where `optim` is an optax optimizer.
+
+## Citing Optax
+
+Optax is part of the [DeepMind JAX Ecosystem], to cite Optax please use
+the [DeepMind JAX Ecosystem citation].
+
+[DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem"
+[DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation"
+
+
+%prep
+%autosetup -n optax-0.1.4
+
+%build
+%py3_build
+
+%install
+%py3_install
+install -d -m755 %{buildroot}/%{_pkgdocdir}
+if [ -d doc ]; then cp -arf doc %{buildroot}/%{_pkgdocdir}; fi
+if [ -d docs ]; then cp -arf docs %{buildroot}/%{_pkgdocdir}; fi
+if [ -d example ]; then cp -arf example %{buildroot}/%{_pkgdocdir}; fi
+if [ -d examples ]; then cp -arf examples %{buildroot}/%{_pkgdocdir}; fi
+pushd %{buildroot}
+if [ -d usr/lib ]; then
+ find usr/lib -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+if [ -d usr/lib64 ]; then
+ find usr/lib64 -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+if [ -d usr/bin ]; then
+ find usr/bin -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+if [ -d usr/sbin ]; then
+ find usr/sbin -type f -printf "/%h/%f\n" >> filelist.lst
+fi
+touch doclist.lst
+if [ -d usr/share/man ]; then
+ find usr/share/man -type f -printf "/%h/%f.gz\n" >> doclist.lst
+fi
+popd
+mv %{buildroot}/filelist.lst .
+mv %{buildroot}/doclist.lst .
+
+%files -n python3-optax -f filelist.lst
+%dir %{python3_sitelib}/*
+
+%files help -f doclist.lst
+%{_docdir}/*
+
+%changelog
+* Mon Apr 10 2023 Python_Bot <Python_Bot@openeuler.org> - 0.1.4-1
+- Package Spec generated