diff options
author | CoprDistGit <infra@openeuler.org> | 2023-05-31 04:20:07 +0000 |
---|---|---|
committer | CoprDistGit <infra@openeuler.org> | 2023-05-31 04:20:07 +0000 |
commit | 90858df672dc89ac58b00212b081854853a854ed (patch) | |
tree | ca2970fc88bee2d71ec553ab7d63f95a5dffe8db | |
parent | b2d47ff6674e78fb123be77fd864f3aaa178df8f (diff) |
automatic import of python-elegy
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | python-elegy.spec | 1102 | ||||
-rw-r--r-- | sources | 1 |
3 files changed, 1104 insertions, 0 deletions
@@ -0,0 +1 @@ +/elegy-0.8.6.tar.gz diff --git a/python-elegy.spec b/python-elegy.spec new file mode 100644 index 0000000..733987d --- /dev/null +++ b/python-elegy.spec @@ -0,0 +1,1102 @@ +%global _empty_manifest_terminate_build 0 +Name: python-elegy +Version: 0.8.6 +Release: 1 +Summary: Elegy is a Neural Networks framework based on Jax and Haiku. +License: APACHE +URL: https://poets-ai.github.io/elegy +Source0: https://mirrors.nju.edu.cn/pypi/web/packages/31/26/2efee9bb7bcb8b0d2f4d5d77a630f0cd5da71ee5e89d721c29855297c896/elegy-0.8.6.tar.gz +BuildArch: noarch + +Requires: python3-cloudpickle +Requires: python3-tensorboardx +Requires: python3-wandb +Requires: python3-treex + +%description +# Elegy + +<!-- [](https://pypi.org/project/elegy/) --> +<!-- [](https://pypi.org/project/elegy/) --> +<!-- [](https://poets-ai.github.io/elegy/) --> +<!-- [](https://github.com/psf/black) --> +[](https://codecov.io/gh/poets-ai/elegy) +[](https://github.com/poets-ai/elegy/actions?query=workflow%3A%22GitHub+CI%22) +[](https://github.com/poets-ai/elegy/issues) + +______________________________________________________________________ + +_A High Level API for Deep Learning in JAX_ + +#### Main Features + +- 😀 **Easy-to-use**: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks. +- 💪 **Flexible**: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed. +- 🔌 **Compatible**: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more. +<!-- - 🤷 **Agnostic**: Elegy supports various frameworks, including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API. --> + +Elegy is built on top of [Treex](https://github.com/cgarciae/treex) and [Treeo](https://github.com/cgarciae/treeo) and reexports their APIs for convenience. + + +[Getting Started](https://poets-ai.github.io/elegy/getting-started/high-level-api) | [Examples](/examples) | [Documentation](https://poets-ai.github.io/elegy) + + +## What is included? +* A `Model` class with an Estimator-like API. +* A `callbacks` module with common Keras callbacks. + +**From Treex** + +* A `Module` class. +* A `nn` module for with common layers. +* A `losses` module with common loss functions. +* A `metrics` module with common metrics. + +## Installation + +Install using pip: + +```bash +pip install elegy +``` + +For Windows users, we recommend the Windows subsystem for Linux 2 [WSL2](https://docs.microsoft.com/es-es/windows/wsl/install-win10?redirectedfrom=MSDN) since [jax](https://github.com/google/jax/issues/438) does not support it yet. + +## Quick Start: High-level API + +Elegy's high-level API provides a straightforward interface you can use by implementing the following steps: + +**1.** Define the architecture inside a `Module`: + +```python +import jax +import elegy as eg + +class MLP(eg.Module): + @eg.compact + def __call__(self, x): + x = eg.Linear(300)(x) + x = jax.nn.relu(x) + x = eg.Linear(10)(x) + return x +``` + +**2.** Create a `Model` from this module and specify additional things like losses, metrics, and optimizers: + +```python +import optax optax +import elegy as eg + +model = eg.Model( + module=MLP(), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +**3.** Train the model using the `fit` method: + +```python +model.fit( + inputs=X_train, + labels=y_train, + epochs=100, + steps_per_epoch=200, + batch_size=64, + validation_data=(X_test, y_test), + shuffle=True, + callbacks=[eg.callbacks.TensorBoard("summaries")] +) +``` +#### Using Flax + +<details> +<summary>Show</summary> + +To use Flax with Elegy just create a `flax.linen.Module` and pass it to `Model`. + +```python +import jax +import elegy as eg +import optax optax +import flax.linen as nn + +class MLP(nn.Module): + @nn.compact + def __call__(self, x, training: bool): + x = nn.Dense(300)(x) + x = jax.nn.relu(x) + x = nn.Dense(10)(x) + return x + + +model = eg.Model( + module=MLP(), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +As shown here, Flax Modules can optionally request a `training` argument to `__call__` which will be provided by Elegy / Treex. + +</details> + +#### Using Haiku + +<details> +<summary>Show</summary> + +To use Haiku with Elegy do the following: + +* Create a `forward` function. +* Create a `TransformedWithState` object by feeding `forward` to `hk.transform_with_state`. +* Pass your `TransformedWithState` to `Model`. + +You can also optionally create your own `hk.Module` and use it in `forward` if needed. Putting everything together should look like this: + +```python +import jax +import elegy as eg +import optax optax +import haiku as hk + + +def forward(x, training: bool): + x = hk.Linear(300)(x) + x = jax.nn.relu(x) + x = hk.Linear(10)(x) + return x + + +model = eg.Model( + module=hk.transform_with_state(forward), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +As shown here, `forward` can optionally request a `training` argument which will be provided by Elegy / Treex. + +</details> + +## Quick Start: Low-level API + +Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom `Model` to implement a `LinearClassifier` with pure JAX: + +**1.** Define a custom `init_step` method: + +```python +class LinearClassifier(eg.Model): + # use treex's API to declare parameter nodes + w: jnp.ndarray = eg.Parameter.node() + b: jnp.ndarray = eg.Parameter.node() + + def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray): + self.w = jax.random.uniform( + key=key, + shape=[features_in, 10], + ) + self.b = jnp.zeros([10]) + + self.optimizer = self.optimizer.init(self) + + return self +``` +Here we declared the parameters `w` and `b` using Treex's `Parameter.node()` for pedagogical reasons, however normally you don't have to do this since you typically use a sub-`Module` instead. + +**2.** Define a custom `test_step` method: +```python + def test_step(self, inputs, labels): + # flatten + scale + inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255 + + # forward + logits = jnp.dot(inputs, self.w) + self.b + + # crossentropy loss + target = jax.nn.one_hot(labels["target"], 10) + loss = optax.softmax_cross_entropy(logits, target).mean() + + # metrics + logs = dict( + acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]), + loss=loss, + ) + + return loss, logs, self +``` + +**3.** Instantiate our `LinearClassifier` with an optimizer: + +```python +model = LinearClassifier( + optimizer=optax.rmsprop(1e-3), +) +``` + +**4.** Train the model using the `fit` method: + +```python +model.fit( + inputs=X_train, + labels=y_train, + epochs=100, + steps_per_epoch=200, + batch_size=64, + validation_data=(X_test, y_test), + shuffle=True, + callbacks=[eg.callbacks.TensorBoard("summaries")] +) +``` + +#### Using other JAX Frameworks + +<details> +<summary>Show</summary> + +It is straightforward to integrate other functional JAX libraries with this +low-level API, here is an example with Flax: + +```python +import elegy as eg +import flax.linen as nn + +class LinearClassifier(eg.Model): + params: Mapping[str, Any] = eg.Parameter.node() + batch_stats: Mapping[str, Any] = eg.BatchStat.node() + next_key: eg.KeySeq + + def __init__(self, module: nn.Module, **kwargs): + self.flax_module = module + super().__init__(**kwargs) + + def init_step(self, key, inputs): + self.next_key = eg.KeySeq(key) + + variables = self.flax_module.init( + {"params": self.next_key(), "dropout": self.next_key()}, x + ) + self.params = variables["params"] + self.batch_stats = variables["batch_stats"] + + self.optimizer = self.optimizer.init(self.parameters()) + + def test_step(self, inputs, labels): + # forward + variables = dict( + params=self.params, + batch_stats=self.batch_stats, + ) + logits, variables = self.flax_module.apply( + variables, + inputs, + rngs={"dropout": self.next_key()}, + mutable=True, + ) + self.batch_stats = variables["batch_stats"] + + # loss + target = jax.nn.one_hot(labels["target"], 10) + loss = optax.softmax_cross_entropy(logits, target).mean() + + # logs + logs = dict( + accuracy=accuracy, + loss=loss, + ) + return loss, logs, self +``` + +</details> + +### Examples + +Check out the [/example](/examples) directory for some inspiration. To run an example, first install some requirements: + +```bash +pip install -r examples/requirements.txt +``` + +And the run it normally with python e.g. + +```bash +python examples/flax/mnist_vae.py +``` + +## Contributing + +If your are interested in helping improve Elegy check out the [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing). + +## Sponsors 💚 +* [Quansight](https://www.quansight.com) - paid development time + +## Citing Elegy + + +**BibTeX** + +``` +@software{elegy2020repository, + title = {Elegy: A High Level API for Deep Learning in JAX}, + author = {PoetsAI}, + year = 2021, + url = {https://github.com/poets-ai/elegy}, + version = {0.8.1} +} +``` + + +%package -n python3-elegy +Summary: Elegy is a Neural Networks framework based on Jax and Haiku. +Provides: python-elegy +BuildRequires: python3-devel +BuildRequires: python3-setuptools +BuildRequires: python3-pip +%description -n python3-elegy +# Elegy + +<!-- [](https://pypi.org/project/elegy/) --> +<!-- [](https://pypi.org/project/elegy/) --> +<!-- [](https://poets-ai.github.io/elegy/) --> +<!-- [](https://github.com/psf/black) --> +[](https://codecov.io/gh/poets-ai/elegy) +[](https://github.com/poets-ai/elegy/actions?query=workflow%3A%22GitHub+CI%22) +[](https://github.com/poets-ai/elegy/issues) + +______________________________________________________________________ + +_A High Level API for Deep Learning in JAX_ + +#### Main Features + +- 😀 **Easy-to-use**: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks. +- 💪 **Flexible**: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed. +- 🔌 **Compatible**: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more. +<!-- - 🤷 **Agnostic**: Elegy supports various frameworks, including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API. --> + +Elegy is built on top of [Treex](https://github.com/cgarciae/treex) and [Treeo](https://github.com/cgarciae/treeo) and reexports their APIs for convenience. + + +[Getting Started](https://poets-ai.github.io/elegy/getting-started/high-level-api) | [Examples](/examples) | [Documentation](https://poets-ai.github.io/elegy) + + +## What is included? +* A `Model` class with an Estimator-like API. +* A `callbacks` module with common Keras callbacks. + +**From Treex** + +* A `Module` class. +* A `nn` module for with common layers. +* A `losses` module with common loss functions. +* A `metrics` module with common metrics. + +## Installation + +Install using pip: + +```bash +pip install elegy +``` + +For Windows users, we recommend the Windows subsystem for Linux 2 [WSL2](https://docs.microsoft.com/es-es/windows/wsl/install-win10?redirectedfrom=MSDN) since [jax](https://github.com/google/jax/issues/438) does not support it yet. + +## Quick Start: High-level API + +Elegy's high-level API provides a straightforward interface you can use by implementing the following steps: + +**1.** Define the architecture inside a `Module`: + +```python +import jax +import elegy as eg + +class MLP(eg.Module): + @eg.compact + def __call__(self, x): + x = eg.Linear(300)(x) + x = jax.nn.relu(x) + x = eg.Linear(10)(x) + return x +``` + +**2.** Create a `Model` from this module and specify additional things like losses, metrics, and optimizers: + +```python +import optax optax +import elegy as eg + +model = eg.Model( + module=MLP(), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +**3.** Train the model using the `fit` method: + +```python +model.fit( + inputs=X_train, + labels=y_train, + epochs=100, + steps_per_epoch=200, + batch_size=64, + validation_data=(X_test, y_test), + shuffle=True, + callbacks=[eg.callbacks.TensorBoard("summaries")] +) +``` +#### Using Flax + +<details> +<summary>Show</summary> + +To use Flax with Elegy just create a `flax.linen.Module` and pass it to `Model`. + +```python +import jax +import elegy as eg +import optax optax +import flax.linen as nn + +class MLP(nn.Module): + @nn.compact + def __call__(self, x, training: bool): + x = nn.Dense(300)(x) + x = jax.nn.relu(x) + x = nn.Dense(10)(x) + return x + + +model = eg.Model( + module=MLP(), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +As shown here, Flax Modules can optionally request a `training` argument to `__call__` which will be provided by Elegy / Treex. + +</details> + +#### Using Haiku + +<details> +<summary>Show</summary> + +To use Haiku with Elegy do the following: + +* Create a `forward` function. +* Create a `TransformedWithState` object by feeding `forward` to `hk.transform_with_state`. +* Pass your `TransformedWithState` to `Model`. + +You can also optionally create your own `hk.Module` and use it in `forward` if needed. Putting everything together should look like this: + +```python +import jax +import elegy as eg +import optax optax +import haiku as hk + + +def forward(x, training: bool): + x = hk.Linear(300)(x) + x = jax.nn.relu(x) + x = hk.Linear(10)(x) + return x + + +model = eg.Model( + module=hk.transform_with_state(forward), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +As shown here, `forward` can optionally request a `training` argument which will be provided by Elegy / Treex. + +</details> + +## Quick Start: Low-level API + +Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom `Model` to implement a `LinearClassifier` with pure JAX: + +**1.** Define a custom `init_step` method: + +```python +class LinearClassifier(eg.Model): + # use treex's API to declare parameter nodes + w: jnp.ndarray = eg.Parameter.node() + b: jnp.ndarray = eg.Parameter.node() + + def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray): + self.w = jax.random.uniform( + key=key, + shape=[features_in, 10], + ) + self.b = jnp.zeros([10]) + + self.optimizer = self.optimizer.init(self) + + return self +``` +Here we declared the parameters `w` and `b` using Treex's `Parameter.node()` for pedagogical reasons, however normally you don't have to do this since you typically use a sub-`Module` instead. + +**2.** Define a custom `test_step` method: +```python + def test_step(self, inputs, labels): + # flatten + scale + inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255 + + # forward + logits = jnp.dot(inputs, self.w) + self.b + + # crossentropy loss + target = jax.nn.one_hot(labels["target"], 10) + loss = optax.softmax_cross_entropy(logits, target).mean() + + # metrics + logs = dict( + acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]), + loss=loss, + ) + + return loss, logs, self +``` + +**3.** Instantiate our `LinearClassifier` with an optimizer: + +```python +model = LinearClassifier( + optimizer=optax.rmsprop(1e-3), +) +``` + +**4.** Train the model using the `fit` method: + +```python +model.fit( + inputs=X_train, + labels=y_train, + epochs=100, + steps_per_epoch=200, + batch_size=64, + validation_data=(X_test, y_test), + shuffle=True, + callbacks=[eg.callbacks.TensorBoard("summaries")] +) +``` + +#### Using other JAX Frameworks + +<details> +<summary>Show</summary> + +It is straightforward to integrate other functional JAX libraries with this +low-level API, here is an example with Flax: + +```python +import elegy as eg +import flax.linen as nn + +class LinearClassifier(eg.Model): + params: Mapping[str, Any] = eg.Parameter.node() + batch_stats: Mapping[str, Any] = eg.BatchStat.node() + next_key: eg.KeySeq + + def __init__(self, module: nn.Module, **kwargs): + self.flax_module = module + super().__init__(**kwargs) + + def init_step(self, key, inputs): + self.next_key = eg.KeySeq(key) + + variables = self.flax_module.init( + {"params": self.next_key(), "dropout": self.next_key()}, x + ) + self.params = variables["params"] + self.batch_stats = variables["batch_stats"] + + self.optimizer = self.optimizer.init(self.parameters()) + + def test_step(self, inputs, labels): + # forward + variables = dict( + params=self.params, + batch_stats=self.batch_stats, + ) + logits, variables = self.flax_module.apply( + variables, + inputs, + rngs={"dropout": self.next_key()}, + mutable=True, + ) + self.batch_stats = variables["batch_stats"] + + # loss + target = jax.nn.one_hot(labels["target"], 10) + loss = optax.softmax_cross_entropy(logits, target).mean() + + # logs + logs = dict( + accuracy=accuracy, + loss=loss, + ) + return loss, logs, self +``` + +</details> + +### Examples + +Check out the [/example](/examples) directory for some inspiration. To run an example, first install some requirements: + +```bash +pip install -r examples/requirements.txt +``` + +And the run it normally with python e.g. + +```bash +python examples/flax/mnist_vae.py +``` + +## Contributing + +If your are interested in helping improve Elegy check out the [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing). + +## Sponsors 💚 +* [Quansight](https://www.quansight.com) - paid development time + +## Citing Elegy + + +**BibTeX** + +``` +@software{elegy2020repository, + title = {Elegy: A High Level API for Deep Learning in JAX}, + author = {PoetsAI}, + year = 2021, + url = {https://github.com/poets-ai/elegy}, + version = {0.8.1} +} +``` + + +%package help +Summary: Development documents and examples for elegy +Provides: python3-elegy-doc +%description help +# Elegy + +<!-- [](https://pypi.org/project/elegy/) --> +<!-- [](https://pypi.org/project/elegy/) --> +<!-- [](https://poets-ai.github.io/elegy/) --> +<!-- [](https://github.com/psf/black) --> +[](https://codecov.io/gh/poets-ai/elegy) +[](https://github.com/poets-ai/elegy/actions?query=workflow%3A%22GitHub+CI%22) +[](https://github.com/poets-ai/elegy/issues) + +______________________________________________________________________ + +_A High Level API for Deep Learning in JAX_ + +#### Main Features + +- 😀 **Easy-to-use**: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks. +- 💪 **Flexible**: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed. +- 🔌 **Compatible**: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more. +<!-- - 🤷 **Agnostic**: Elegy supports various frameworks, including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API. --> + +Elegy is built on top of [Treex](https://github.com/cgarciae/treex) and [Treeo](https://github.com/cgarciae/treeo) and reexports their APIs for convenience. + + +[Getting Started](https://poets-ai.github.io/elegy/getting-started/high-level-api) | [Examples](/examples) | [Documentation](https://poets-ai.github.io/elegy) + + +## What is included? +* A `Model` class with an Estimator-like API. +* A `callbacks` module with common Keras callbacks. + +**From Treex** + +* A `Module` class. +* A `nn` module for with common layers. +* A `losses` module with common loss functions. +* A `metrics` module with common metrics. + +## Installation + +Install using pip: + +```bash +pip install elegy +``` + +For Windows users, we recommend the Windows subsystem for Linux 2 [WSL2](https://docs.microsoft.com/es-es/windows/wsl/install-win10?redirectedfrom=MSDN) since [jax](https://github.com/google/jax/issues/438) does not support it yet. + +## Quick Start: High-level API + +Elegy's high-level API provides a straightforward interface you can use by implementing the following steps: + +**1.** Define the architecture inside a `Module`: + +```python +import jax +import elegy as eg + +class MLP(eg.Module): + @eg.compact + def __call__(self, x): + x = eg.Linear(300)(x) + x = jax.nn.relu(x) + x = eg.Linear(10)(x) + return x +``` + +**2.** Create a `Model` from this module and specify additional things like losses, metrics, and optimizers: + +```python +import optax optax +import elegy as eg + +model = eg.Model( + module=MLP(), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +**3.** Train the model using the `fit` method: + +```python +model.fit( + inputs=X_train, + labels=y_train, + epochs=100, + steps_per_epoch=200, + batch_size=64, + validation_data=(X_test, y_test), + shuffle=True, + callbacks=[eg.callbacks.TensorBoard("summaries")] +) +``` +#### Using Flax + +<details> +<summary>Show</summary> + +To use Flax with Elegy just create a `flax.linen.Module` and pass it to `Model`. + +```python +import jax +import elegy as eg +import optax optax +import flax.linen as nn + +class MLP(nn.Module): + @nn.compact + def __call__(self, x, training: bool): + x = nn.Dense(300)(x) + x = jax.nn.relu(x) + x = nn.Dense(10)(x) + return x + + +model = eg.Model( + module=MLP(), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +As shown here, Flax Modules can optionally request a `training` argument to `__call__` which will be provided by Elegy / Treex. + +</details> + +#### Using Haiku + +<details> +<summary>Show</summary> + +To use Haiku with Elegy do the following: + +* Create a `forward` function. +* Create a `TransformedWithState` object by feeding `forward` to `hk.transform_with_state`. +* Pass your `TransformedWithState` to `Model`. + +You can also optionally create your own `hk.Module` and use it in `forward` if needed. Putting everything together should look like this: + +```python +import jax +import elegy as eg +import optax optax +import haiku as hk + + +def forward(x, training: bool): + x = hk.Linear(300)(x) + x = jax.nn.relu(x) + x = hk.Linear(10)(x) + return x + + +model = eg.Model( + module=hk.transform_with_state(forward), + loss=[ + eg.losses.Crossentropy(), + eg.regularizers.L2(l=1e-5), + ], + metrics=eg.metrics.Accuracy(), + optimizer=optax.rmsprop(1e-3), +) +``` + +As shown here, `forward` can optionally request a `training` argument which will be provided by Elegy / Treex. + +</details> + +## Quick Start: Low-level API + +Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom `Model` to implement a `LinearClassifier` with pure JAX: + +**1.** Define a custom `init_step` method: + +```python +class LinearClassifier(eg.Model): + # use treex's API to declare parameter nodes + w: jnp.ndarray = eg.Parameter.node() + b: jnp.ndarray = eg.Parameter.node() + + def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray): + self.w = jax.random.uniform( + key=key, + shape=[features_in, 10], + ) + self.b = jnp.zeros([10]) + + self.optimizer = self.optimizer.init(self) + + return self +``` +Here we declared the parameters `w` and `b` using Treex's `Parameter.node()` for pedagogical reasons, however normally you don't have to do this since you typically use a sub-`Module` instead. + +**2.** Define a custom `test_step` method: +```python + def test_step(self, inputs, labels): + # flatten + scale + inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255 + + # forward + logits = jnp.dot(inputs, self.w) + self.b + + # crossentropy loss + target = jax.nn.one_hot(labels["target"], 10) + loss = optax.softmax_cross_entropy(logits, target).mean() + + # metrics + logs = dict( + acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]), + loss=loss, + ) + + return loss, logs, self +``` + +**3.** Instantiate our `LinearClassifier` with an optimizer: + +```python +model = LinearClassifier( + optimizer=optax.rmsprop(1e-3), +) +``` + +**4.** Train the model using the `fit` method: + +```python +model.fit( + inputs=X_train, + labels=y_train, + epochs=100, + steps_per_epoch=200, + batch_size=64, + validation_data=(X_test, y_test), + shuffle=True, + callbacks=[eg.callbacks.TensorBoard("summaries")] +) +``` + +#### Using other JAX Frameworks + +<details> +<summary>Show</summary> + +It is straightforward to integrate other functional JAX libraries with this +low-level API, here is an example with Flax: + +```python +import elegy as eg +import flax.linen as nn + +class LinearClassifier(eg.Model): + params: Mapping[str, Any] = eg.Parameter.node() + batch_stats: Mapping[str, Any] = eg.BatchStat.node() + next_key: eg.KeySeq + + def __init__(self, module: nn.Module, **kwargs): + self.flax_module = module + super().__init__(**kwargs) + + def init_step(self, key, inputs): + self.next_key = eg.KeySeq(key) + + variables = self.flax_module.init( + {"params": self.next_key(), "dropout": self.next_key()}, x + ) + self.params = variables["params"] + self.batch_stats = variables["batch_stats"] + + self.optimizer = self.optimizer.init(self.parameters()) + + def test_step(self, inputs, labels): + # forward + variables = dict( + params=self.params, + batch_stats=self.batch_stats, + ) + logits, variables = self.flax_module.apply( + variables, + inputs, + rngs={"dropout": self.next_key()}, + mutable=True, + ) + self.batch_stats = variables["batch_stats"] + + # loss + target = jax.nn.one_hot(labels["target"], 10) + loss = optax.softmax_cross_entropy(logits, target).mean() + + # logs + logs = dict( + accuracy=accuracy, + loss=loss, + ) + return loss, logs, self +``` + +</details> + +### Examples + +Check out the [/example](/examples) directory for some inspiration. To run an example, first install some requirements: + +```bash +pip install -r examples/requirements.txt +``` + +And the run it normally with python e.g. + +```bash +python examples/flax/mnist_vae.py +``` + +## Contributing + +If your are interested in helping improve Elegy check out the [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing). + +## Sponsors 💚 +* [Quansight](https://www.quansight.com) - paid development time + +## Citing Elegy + + +**BibTeX** + +``` +@software{elegy2020repository, + title = {Elegy: A High Level API for Deep Learning in JAX}, + author = {PoetsAI}, + year = 2021, + url = {https://github.com/poets-ai/elegy}, + version = {0.8.1} +} +``` + + +%prep +%autosetup -n elegy-0.8.6 + +%build +%py3_build + +%install +%py3_install +install -d -m755 %{buildroot}/%{_pkgdocdir} +if [ -d doc ]; then cp -arf doc %{buildroot}/%{_pkgdocdir}; fi +if [ -d docs ]; then cp -arf docs %{buildroot}/%{_pkgdocdir}; fi +if [ -d example ]; then cp -arf example %{buildroot}/%{_pkgdocdir}; fi +if [ -d examples ]; then cp -arf examples %{buildroot}/%{_pkgdocdir}; fi +pushd %{buildroot} +if [ -d usr/lib ]; then + find usr/lib -type f -printf "/%h/%f\n" >> filelist.lst +fi +if [ -d usr/lib64 ]; then + find usr/lib64 -type f -printf "/%h/%f\n" >> filelist.lst +fi +if [ -d usr/bin ]; then + find usr/bin -type f -printf "/%h/%f\n" >> filelist.lst +fi +if [ -d usr/sbin ]; then + find usr/sbin -type f -printf "/%h/%f\n" >> filelist.lst +fi +touch doclist.lst +if [ -d usr/share/man ]; then + find usr/share/man -type f -printf "/%h/%f.gz\n" >> doclist.lst +fi +popd +mv %{buildroot}/filelist.lst . +mv %{buildroot}/doclist.lst . + +%files -n python3-elegy -f filelist.lst +%dir %{python3_sitelib}/* + +%files help -f doclist.lst +%{_docdir}/* + +%changelog +* Wed May 31 2023 Python_Bot <Python_Bot@openeuler.org> - 0.8.6-1 +- Package Spec generated @@ -0,0 +1 @@ +5d9c75b6b6ef1689624f6138f35e7945 elegy-0.8.6.tar.gz |