%global _empty_manifest_terminate_build 0 Name: python-jmp Version: 0.0.4 Release: 1 Summary: JMP is a Mixed Precision library for JAX. License: Apache 2.0 URL: https://github.com/deepmind/jmp Source0: https://mirrors.nju.edu.cn/pypi/web/packages/ab/b0/e90fbbffef4b345329c878a69f0336d3edc5a1f9fcba193931aca2132d62/jmp-0.0.4.tar.gz BuildArch: noarch Requires: python3-numpy Requires: python3-dataclasses Requires: python3-jax Requires: python3-jaxlib %description # Mixed precision training in [JAX] ![Test status](https://github.com/deepmind/jmp/workflows/pytest/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jmp) [**Installation**](#installation) | [**Examples**](#examples) | [**Policies**](#policies) | [**Loss scaling**](#loss-scaling) | [**Citing JMP**](#citing-jmp) | [**References**](#references) Mixed precision training [[0]] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model. This library implements support for mixed precision training in [JAX] by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as [Haiku]) can integrate with `jmp` and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules). All code examples below assume the following: ```python import jax import jax.numpy as jnp import jmp half = jnp.float16 # On TPU this should be jnp.bfloat16. full = jnp.float32 ``` ## Installation JMP is written in pure Python, but depends on C++ code via JAX and NumPy. Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in `requirements.txt`. First, follow [these instructions](https://github.com/google/jax#installation) to install JAX with the relevant accelerator support. Then, install JMP using pip: ```bash $ pip install git+https://github.com/deepmind/jmp ``` ## Examples You can find a [fully worked JMP example in Haiku](https://github.com/deepmind/dm-haiku/tree/master/examples/imagenet) which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third. ## Policies A mixed precision policy encapsulates the configuration in a mixed precision experiment. ```python # Our policy specifies that we will store parameters in full precision but will # compute and return output in half precision. my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) ``` The policy object can be used to cast pytrees: ```python def layer(params, x): params, x = my_policy.cast_to_compute((params, x)) w, b = params y = x @ w + b return my_policy.cast_to_output(y) params = {"w": jnp.ones([], dtype=my_policy.param_dtype)} y = layer(params, x) assert y.dtype == half ``` You can replace the output type of a given policy: ```python my_policy = my_policy.with_output_dtype(full) ``` You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment: ```python my_policy = jmp.get_policy("params=float32,compute=float16,output=float32") float16 = jmp.get_policy("float16") # Everything in f16. half = jmp.get_policy("half") # Everything in half (f16 or bf16). ``` ## Loss scaling When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with `float16` and less important for `bfloat16`. See the NVIDIA mixed precision user guide [[1]] for more details. The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by `S` and `1/S` respectively. ```python def my_loss_fn(params, loss_scale: jmp.LossScale, ...): loss = ... # You should apply regularization etc before scaling. loss = loss_scale.scale(loss) return loss def train_step(params, loss_scale: jmp.LossScale, ...): grads = jax.grad(my_loss_fn)(...) grads = loss_scale.unscale(grads) # You should put gradient clipping etc after unscaling. params = apply_optimizer(params, grads) return params loss_scale = jmp.StaticLossScale(2 ** 15) for _ in range(num_steps): params = train_step(params, loss_scale, ...) ``` The appropriate value for `S` depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of `S` that does not introduce overflow during backprop. NVIDIA [[1]] recommend computing statistics about the gradients of your model (in full precision) and picking `S` such that its product with the maximum norm of your gradients is below `65,504`. We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for `S` that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%). ```python def my_loss_fn(params, loss_scale: jmp.LossScale, ...): loss = ... # You should apply regularization etc before scaling. loss = loss_scale.scale(loss) return loss def train_step(params, loss_scale: jmp.LossScale, ...): grads = jax.grad(my_loss_fn)(...) grads = loss_scale.unscale(grads) # You should put gradient clipping etc after unscaling. # You definitely want to skip non-finite updates with the dynamic loss scale, # but you might also want to consider skipping them when using a static loss # scale if you experience NaN's when training. skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale) if skip_nonfinite_updates: grads_finite = jmp.all_finite(grads) # Adjust our loss scale depending on whether gradients were finite. The # loss scale will be periodically increased if gradients remain finite and # will be decreased if not. loss_scale = loss_scale.adjust(grads_finite) # Only apply our optimizer if grads are finite, if any element of any # gradient is non-finite the whole update is discarded. params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params) else: # With static or no loss scaling just apply our optimizer. params = apply_optimizer(params, grads) # Since our loss scale is dynamic we need to return the new value from # each step. All loss scales are `PyTree`s. return params, loss_scale loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15)) for _ in range(num_steps): params, loss_scale = train_step(params, loss_scale, ...) ``` In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue. We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the `jmp.LossScale` API): ```python loss_scale = jmp.NoOpLossScale() assert loss is loss_scale.scale(loss) assert grads is loss_scale.unscale(grads) assert loss_scale is loss_scale.adjust(grads_finite) assert loss_scale.loss_scale == 1 ``` ## Citing JMP This repository is part of the [DeepMind JAX Ecosystem](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research), to cite JMP please use the [DeepMind JAX Ecosystem citation](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt). ## References [[0]] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740. [[1]] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/. [0]: https://arxiv.org/abs/1710.03740 [1]: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html [Haiku]: https://github.com/deepmind/dm-haiku [JAX]: https://github.com/google/jax %package -n python3-jmp Summary: JMP is a Mixed Precision library for JAX. Provides: python-jmp BuildRequires: python3-devel BuildRequires: python3-setuptools BuildRequires: python3-pip %description -n python3-jmp # Mixed precision training in [JAX] ![Test status](https://github.com/deepmind/jmp/workflows/pytest/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jmp) [**Installation**](#installation) | [**Examples**](#examples) | [**Policies**](#policies) | [**Loss scaling**](#loss-scaling) | [**Citing JMP**](#citing-jmp) | [**References**](#references) Mixed precision training [[0]] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model. This library implements support for mixed precision training in [JAX] by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as [Haiku]) can integrate with `jmp` and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules). All code examples below assume the following: ```python import jax import jax.numpy as jnp import jmp half = jnp.float16 # On TPU this should be jnp.bfloat16. full = jnp.float32 ``` ## Installation JMP is written in pure Python, but depends on C++ code via JAX and NumPy. Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in `requirements.txt`. First, follow [these instructions](https://github.com/google/jax#installation) to install JAX with the relevant accelerator support. Then, install JMP using pip: ```bash $ pip install git+https://github.com/deepmind/jmp ``` ## Examples You can find a [fully worked JMP example in Haiku](https://github.com/deepmind/dm-haiku/tree/master/examples/imagenet) which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third. ## Policies A mixed precision policy encapsulates the configuration in a mixed precision experiment. ```python # Our policy specifies that we will store parameters in full precision but will # compute and return output in half precision. my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) ``` The policy object can be used to cast pytrees: ```python def layer(params, x): params, x = my_policy.cast_to_compute((params, x)) w, b = params y = x @ w + b return my_policy.cast_to_output(y) params = {"w": jnp.ones([], dtype=my_policy.param_dtype)} y = layer(params, x) assert y.dtype == half ``` You can replace the output type of a given policy: ```python my_policy = my_policy.with_output_dtype(full) ``` You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment: ```python my_policy = jmp.get_policy("params=float32,compute=float16,output=float32") float16 = jmp.get_policy("float16") # Everything in f16. half = jmp.get_policy("half") # Everything in half (f16 or bf16). ``` ## Loss scaling When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with `float16` and less important for `bfloat16`. See the NVIDIA mixed precision user guide [[1]] for more details. The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by `S` and `1/S` respectively. ```python def my_loss_fn(params, loss_scale: jmp.LossScale, ...): loss = ... # You should apply regularization etc before scaling. loss = loss_scale.scale(loss) return loss def train_step(params, loss_scale: jmp.LossScale, ...): grads = jax.grad(my_loss_fn)(...) grads = loss_scale.unscale(grads) # You should put gradient clipping etc after unscaling. params = apply_optimizer(params, grads) return params loss_scale = jmp.StaticLossScale(2 ** 15) for _ in range(num_steps): params = train_step(params, loss_scale, ...) ``` The appropriate value for `S` depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of `S` that does not introduce overflow during backprop. NVIDIA [[1]] recommend computing statistics about the gradients of your model (in full precision) and picking `S` such that its product with the maximum norm of your gradients is below `65,504`. We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for `S` that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%). ```python def my_loss_fn(params, loss_scale: jmp.LossScale, ...): loss = ... # You should apply regularization etc before scaling. loss = loss_scale.scale(loss) return loss def train_step(params, loss_scale: jmp.LossScale, ...): grads = jax.grad(my_loss_fn)(...) grads = loss_scale.unscale(grads) # You should put gradient clipping etc after unscaling. # You definitely want to skip non-finite updates with the dynamic loss scale, # but you might also want to consider skipping them when using a static loss # scale if you experience NaN's when training. skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale) if skip_nonfinite_updates: grads_finite = jmp.all_finite(grads) # Adjust our loss scale depending on whether gradients were finite. The # loss scale will be periodically increased if gradients remain finite and # will be decreased if not. loss_scale = loss_scale.adjust(grads_finite) # Only apply our optimizer if grads are finite, if any element of any # gradient is non-finite the whole update is discarded. params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params) else: # With static or no loss scaling just apply our optimizer. params = apply_optimizer(params, grads) # Since our loss scale is dynamic we need to return the new value from # each step. All loss scales are `PyTree`s. return params, loss_scale loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15)) for _ in range(num_steps): params, loss_scale = train_step(params, loss_scale, ...) ``` In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue. We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the `jmp.LossScale` API): ```python loss_scale = jmp.NoOpLossScale() assert loss is loss_scale.scale(loss) assert grads is loss_scale.unscale(grads) assert loss_scale is loss_scale.adjust(grads_finite) assert loss_scale.loss_scale == 1 ``` ## Citing JMP This repository is part of the [DeepMind JAX Ecosystem](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research), to cite JMP please use the [DeepMind JAX Ecosystem citation](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt). ## References [[0]] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740. [[1]] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/. [0]: https://arxiv.org/abs/1710.03740 [1]: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html [Haiku]: https://github.com/deepmind/dm-haiku [JAX]: https://github.com/google/jax %package help Summary: Development documents and examples for jmp Provides: python3-jmp-doc %description help # Mixed precision training in [JAX] ![Test status](https://github.com/deepmind/jmp/workflows/pytest/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jmp) [**Installation**](#installation) | [**Examples**](#examples) | [**Policies**](#policies) | [**Loss scaling**](#loss-scaling) | [**Citing JMP**](#citing-jmp) | [**References**](#references) Mixed precision training [[0]] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model. This library implements support for mixed precision training in [JAX] by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as [Haiku]) can integrate with `jmp` and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules). All code examples below assume the following: ```python import jax import jax.numpy as jnp import jmp half = jnp.float16 # On TPU this should be jnp.bfloat16. full = jnp.float32 ``` ## Installation JMP is written in pure Python, but depends on C++ code via JAX and NumPy. Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in `requirements.txt`. First, follow [these instructions](https://github.com/google/jax#installation) to install JAX with the relevant accelerator support. Then, install JMP using pip: ```bash $ pip install git+https://github.com/deepmind/jmp ``` ## Examples You can find a [fully worked JMP example in Haiku](https://github.com/deepmind/dm-haiku/tree/master/examples/imagenet) which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third. ## Policies A mixed precision policy encapsulates the configuration in a mixed precision experiment. ```python # Our policy specifies that we will store parameters in full precision but will # compute and return output in half precision. my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) ``` The policy object can be used to cast pytrees: ```python def layer(params, x): params, x = my_policy.cast_to_compute((params, x)) w, b = params y = x @ w + b return my_policy.cast_to_output(y) params = {"w": jnp.ones([], dtype=my_policy.param_dtype)} y = layer(params, x) assert y.dtype == half ``` You can replace the output type of a given policy: ```python my_policy = my_policy.with_output_dtype(full) ``` You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment: ```python my_policy = jmp.get_policy("params=float32,compute=float16,output=float32") float16 = jmp.get_policy("float16") # Everything in f16. half = jmp.get_policy("half") # Everything in half (f16 or bf16). ``` ## Loss scaling When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with `float16` and less important for `bfloat16`. See the NVIDIA mixed precision user guide [[1]] for more details. The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by `S` and `1/S` respectively. ```python def my_loss_fn(params, loss_scale: jmp.LossScale, ...): loss = ... # You should apply regularization etc before scaling. loss = loss_scale.scale(loss) return loss def train_step(params, loss_scale: jmp.LossScale, ...): grads = jax.grad(my_loss_fn)(...) grads = loss_scale.unscale(grads) # You should put gradient clipping etc after unscaling. params = apply_optimizer(params, grads) return params loss_scale = jmp.StaticLossScale(2 ** 15) for _ in range(num_steps): params = train_step(params, loss_scale, ...) ``` The appropriate value for `S` depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of `S` that does not introduce overflow during backprop. NVIDIA [[1]] recommend computing statistics about the gradients of your model (in full precision) and picking `S` such that its product with the maximum norm of your gradients is below `65,504`. We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for `S` that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%). ```python def my_loss_fn(params, loss_scale: jmp.LossScale, ...): loss = ... # You should apply regularization etc before scaling. loss = loss_scale.scale(loss) return loss def train_step(params, loss_scale: jmp.LossScale, ...): grads = jax.grad(my_loss_fn)(...) grads = loss_scale.unscale(grads) # You should put gradient clipping etc after unscaling. # You definitely want to skip non-finite updates with the dynamic loss scale, # but you might also want to consider skipping them when using a static loss # scale if you experience NaN's when training. skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale) if skip_nonfinite_updates: grads_finite = jmp.all_finite(grads) # Adjust our loss scale depending on whether gradients were finite. The # loss scale will be periodically increased if gradients remain finite and # will be decreased if not. loss_scale = loss_scale.adjust(grads_finite) # Only apply our optimizer if grads are finite, if any element of any # gradient is non-finite the whole update is discarded. params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params) else: # With static or no loss scaling just apply our optimizer. params = apply_optimizer(params, grads) # Since our loss scale is dynamic we need to return the new value from # each step. All loss scales are `PyTree`s. return params, loss_scale loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15)) for _ in range(num_steps): params, loss_scale = train_step(params, loss_scale, ...) ``` In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue. We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the `jmp.LossScale` API): ```python loss_scale = jmp.NoOpLossScale() assert loss is loss_scale.scale(loss) assert grads is loss_scale.unscale(grads) assert loss_scale is loss_scale.adjust(grads_finite) assert loss_scale.loss_scale == 1 ``` ## Citing JMP This repository is part of the [DeepMind JAX Ecosystem](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research), to cite JMP please use the [DeepMind JAX Ecosystem citation](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt). ## References [[0]] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740. [[1]] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/. [0]: https://arxiv.org/abs/1710.03740 [1]: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html [Haiku]: https://github.com/deepmind/dm-haiku [JAX]: https://github.com/google/jax %prep %autosetup -n jmp-0.0.4 %build %py3_build %install %py3_install install -d -m755 %{buildroot}/%{_pkgdocdir} if [ -d doc ]; then cp -arf doc %{buildroot}/%{_pkgdocdir}; fi if [ -d docs ]; then cp -arf docs %{buildroot}/%{_pkgdocdir}; fi if [ -d example ]; then cp -arf example %{buildroot}/%{_pkgdocdir}; fi if [ -d examples ]; then cp -arf examples %{buildroot}/%{_pkgdocdir}; fi pushd %{buildroot} if [ -d usr/lib ]; then find usr/lib -type f -printf "/%h/%f\n" >> filelist.lst fi if [ -d usr/lib64 ]; then find usr/lib64 -type f -printf "/%h/%f\n" >> filelist.lst fi if [ -d usr/bin ]; then find usr/bin -type f -printf "/%h/%f\n" >> filelist.lst fi if [ -d usr/sbin ]; then find usr/sbin -type f -printf "/%h/%f\n" >> filelist.lst fi touch doclist.lst if [ -d usr/share/man ]; then find usr/share/man -type f -printf "/%h/%f.gz\n" >> doclist.lst fi popd mv %{buildroot}/filelist.lst . mv %{buildroot}/doclist.lst . %files -n python3-jmp -f filelist.lst %dir %{python3_sitelib}/* %files help -f doclist.lst %{_docdir}/* %changelog * Tue Apr 11 2023 Python_Bot - 0.0.4-1 - Package Spec generated