diff --git a/flax/experimental/__init__.py b/flax/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flax/experimental/nnx/.gitignore b/flax/experimental/nnx/.gitignore new file mode 100644 index 0000000000..2a90c3eca6 --- /dev/null +++ b/flax/experimental/nnx/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# project specific +.vscode +/tmp \ No newline at end of file diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md new file mode 100644 index 0000000000..0a93624c4e --- /dev/null +++ b/flax/experimental/nnx/README.md @@ -0,0 +1,451 @@ +[![codecov](https://codecov.io/gh/cgarciae/nnx/branch/main/graph/badge.svg?token=VqJjL474Z7)](https://codecov.io/gh/cgarciae/nnx) + +# NNX + +_**N**eural **N**etworks for JA**X**_ + +NNX is a Neural Networks library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of [Flax](https://flax.readthedocs.io/en/latest/) with a simplified, Pythonic API akin to that of [PyTorch](https://pytorch.org/). + +* **Pythonic**: Modules are just regular python classes, they contain their own state, are fully mutable, and allow sharing references between Modules. +* **Compatible**: Easily convert back and forth between Modules and pytrees using the Functional API to integrate with any JAX API. +* **Safe**: NNX incorporates mechanisms to try to prevent tracer leakage, avoid stale RNGs, and ensure proper state propagation in order to help produce correct JAX programs. +* **Semantic**: Partition a Module's state into different semantic collections, allowing for fine-grained control when applying JAX transformations. + +#### Table of Contents +* [Installation](#installation) +* [Getting Started](#getting-started) +* [FAQs](#faqs) +* [Examples](#examples) +* [User Guide](#user-guide) + +## Installation + +To get started with `nnx`, install the package via pip: + +``` +pip install nnx +``` +For the most recent version, install directly from our GitHub repository: + +``` +pip install git+https://github.com/cgarciae/nnx +``` + +## Getting Started + +The following example guides you through creating a basic `Linear` model with NNX and executing a forward pass. It also demonstrate how handle mutable state by showing how to keep track of the number of times the model has been called. + +```python +from flax.experimental import nnx +import jax +import jax.numpy as jnp + +class Count(nnx.Variable): pass + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) # track the number of calls + + def __call__(self, x): + self.count += 1 + return x @ self.w + self.b + +model = Linear(din=12, dout=2, ctx=nnx.context(0)) + +# Forward pass and verify the call count +x = jnp.ones((8, 12)) +y = model(x) +assert model.count == 1 +``` + +In this example `nnx.context(0)` create a `PRNGKey` for `params` with seed `0`, this is used by `make_rng` +inside `__init__` to generate a random key to initialize the parameters. + +### Training with the Functional API + +The [Functional API](#functional-api) converts an NNX Module python semantics into pure pytree object with functional semantics. It is the recommended way to use NNX as it provides tight control over the state, allows you to use regular JAX transformations, and it minimizes overhead. In this example the model will be trained using Stochastic Gradient Descent (SGD). + +```python +(params, counts), moduledef = model.partition(nnx.Param, Count) + +@jax.jit +def train_step(params, counts, x, y): + def loss_fn(params): + y_pred, (updates, _) = moduledef.apply(params, counts)(x) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, updates.filter(Count) + + # compute gradient + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + # SGD update + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + + return params, counts + +# execute the training step +params, counts = train_step(params, counts, x, y) +model = moduledef.merge(params, counts) +assert model.count == 2 +``` + +### Training with Lifted Transforms + +[Lifted Transforms](#lifted-transforms) provide a convenient way interact with NNX Modules. In this example, we use the `nnx.jit` and `nnx.grad` lifted transforms to define the training step. The model is trained using Stochastic Gradient Descent (SGD). Because lifted transforms automatically update the Module's state, `train_step` doesn't require a return statement. + +```python +@nnx.jit +def train_step(model, x, y): + def loss_fn(model): + y_pred = model(x) + return jax.numpy.mean((y_pred - y) ** 2) + + # compute gradient + grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # SGD update + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + +# execute the training step +train_step(model, x, y) +assert model.count == 2 +``` + +**Note**: Using `nnx.jit` introduces some overhead when compared to using `jax.jit` directly. Use `nnx.jit` for simple prototypes, but for production code use `jax.jit` directly. + +## Examples + +* [Using the Functional API](https://github.com/cgarciae/nnx/blob/main/examples/01_functional_api.py): Shows how to train a simple model using the functional API. +* [Using Lifted Transforms](https://github.com/cgarciae/nnx/blob/main/examples/02_lifted_transforms.py): Shows how to train a simple model using lifted transforms. +* [Using TrainState](https://github.com/cgarciae/nnx/blob/main/examples/03_train_state.py): Shows how to train a simple model using the functional API with the help of `TrainState`. +* [Using PureModule](https://github.com/cgarciae/nnx/blob/main/examples/04_pure.py) (experimental): Shows how to train a simple model using the functional API and leveraging `PureModule` to simplify the code. +* [Training a VAE](https://github.com/cgarciae/nnx/blob/main/examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset, uses the functional API, `TrainState`, and shows how to use capture intermediate values to retrieve `kl_loss`. +* [Scan over layers](https://github.com/cgarciae/nnx/blob/main/examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. +* [Creating a Transformer](https://github.com/cgarciae/nnx/blob/main/examples/07_transformer.py): Shows how to create a Transformer with an auto-regressive decoder that uses scan over layers and a kv-cache for fast inference. Credits to @levskaya. + +## FAQs + +### Status +NNX is still in early development so expect bugs and breaking changes. That said, current API is the result of months of experimentation and we don't expect any major changes in the near future. + +### How is it different from Flax? +NNX takes the best features that allow Flax to scale to large projects and integrates them into a much simpler Module system with pythonic semantics. + +One place in which NNX strongly deviates from Flax is that (currently) it avoids shape inference in favor of static initialization. It is not a technical limitation but rather a design choice. This design both simplifies the internal implementation and makes it easier to reason about the code for the user, at the cost of being more verbose at times. On the other hand, Pytorch users will feel right at home. + +### How is it different from Equinox? +While they might look similar at a surface-level, NNX's Module system is more powerful and flexible than Equinox's, it contains the following additional features: + +* Uses regular python classes (no mandatory dataclass behavior). +* Modules are mutable +* Reference sharing between Modules is allowed +* Mutable state lives inside the Module (no need for a separate [State container](https://docs.kidger.site/equinox/examples/stateful/)). +* Supports node metadata and semantic partitioning. + +One major difference between the two frameworks is that, by design, NNX Modules are not Pytrees. This adds a safety layer as it prevents state updates from being lost by accident due to referential transparency. It also removes the need of threading a separate [State container](https://docs.kidger.site/equinox/examples/stateful/) throughout the code in order to propagate state. In NNX state updates are either always preserved or explicitly discarded by the user. + +## User Guide + +### Modules + +NNX Modules are normal python classes, they obey regular python semantics such as mutability and reference sharing, including reference cycles. They can contain 2 types of attributes: node attributes and static attributes. Node attributes include NNX `Variable`s (e.g. `nnx.Param`), Numpy arrays, JAX arrays, submodules Modules, and other NNX types. All other types are treated as static attributes. + +```python +class Foo(nnx.Module): + def __init__(self, ctx: nnx.Context): + # node attributes + self.variable = nnx.Param(jnp.array(1)) + self.np_buffer = np.array(2) + self.jax_buffer = jnp.array(3) + self.node = nnx.Node([4, 5]) + self.submodule = nnx.Linear(2, 4, ctx=ctx) + # static attributes + self.int = 1 + self.float = 2.0 + self.str = "hello" + self.list = [1, 2, 3] + +model = Foo(din=12, dout=2, ctx=nnx.context(0)) +``` +As shown above, python container types such as `list`, `tuple`, and `dict` are treated as static attributes, if similar functionality is needed, NNX provides the `Sequence` and `Dict` Modules. + +### Functional API + +NNX Modules are not pytrees so they cannot be passed to JAX transformations. In order to interact with JAX, a Module must be partitioned into a `State` and `ModuleDef` objects. The `State` object is a flat dictionary-like pytree structure that contains all the deduplicated node attributes, and the `ModuleDef` contains the static attributes and structural information needed to reconstruct the Module. + +```python +state, moduledef = model.partition() +``` +``` +State({ + ('jax_buffer',): Array(3), + ('node',): Node(value=[4, 5]), + ('np_buffer',): array(2), + ('submodule', 'bias'): Param(value=Array(...)), + ('submodule', 'kernel'): Param(value=Array(...)), + ('variable',): Param(value=Array(1)) +}) +``` + +`State` and `ModuleDef` are pytrees so they can be passed to JAX transformations. More over, `ModuleDef` provides 2 very important methods: `merge` and `apply`. The `merge` method can be used to create a new `Module` from a `State` object: + +```python +model = moduledef.merge(state) +``` +This can be use to e.g. recreate a module inside a JAX transformation. The `apply` provides a functional interface to the module, it can be used call any method or submodule and get the output and the updated state: + +```python +# run __call__ +y, (state, moduledef) = moduledef.apply(state)(x) +# run some_method +y, (state, moduledef) = moduledef.apply(state).some_method(x) +# run submodule +y, (state, moduledef) = moduledef.apply(state).submodule(x) +``` + +`apply` can call any nested method or submodule as long as it can be accessed via the `.` or `[]` operators. + +### Partitioning State +In NNX you can filter based on any node type, most commonly you will want to filter based on `nnx.Variable` subclasses such as `nnx.Param` or `nnx.BatchStat`. + +Here are various examples of how you can use the `partition` method to split a module into multiple substates: + +```python +# partition the module into the state with all the nodes and the moduledef +state, moduledef = model.partition() +# verify that the state contains only params, else raise an error +params, moduledef = model.partition(nnx.Param) +# split the state into params and batch_stats, verify no nodes are left +(params, batch_stats), moduledef = model.partition(nnx.Param, nnx.BatchStat) +# if there are any nodes left, use the `...` filter to capture them +(params, batch_stats, rest), moduledef = model.partition(nnx.Param, nnx.BatchStat, ...) +# using `...` as the only filter is equivalent to not passing any filters +model.partition(...) = model.partition() +``` +`partition` will make sure all nodes are match by atleast one filter, else it will raise an error. If you have non-`Variable` nodes like `nnx.Node`, `jax.Array`, or `numpy.ndarray` attributes, you can use the `...` filter which will match any node. For a more general filter you can pass a predicate function of the form: + +```python +(path: Tuple[str, ...], value: Any) -> bool +``` + +To reconstruct the module from a set of substates, you can use `merge` as usual but passing the substates as additional arguments: + +```python +model = moduledef.merge(params, batch_stats, rest) +``` + +The same is true for `apply`. + +```python +y, (state, moduledef) = moduledef.apply(params, batch_stats, rest)(x) +``` + + Note that `apply` will return a single `state` object, if you need to re-partition the state you can use `State`'s own `partition` method: + +```python +params, batch_stats, rest = state.partition(nnx.Param, nnx.BatchStat, ...) +``` + +Alternatively, if you are just interested in a subset of partitions, you can use the `State.filter` method which will not raise an error if some nodes are not matched by any filter: + +```python +# only get params +params = state.filter(nnx.Param) +# get params and batch_stats +params, batch_stats = state.filter(nnx.Param, nnx.BatchStat) +``` + +### Filters + +Filters let you select subsets of nodes based on some criteria. These are use throughout the API in method like `partition`, `filter`, and `pop_state`. There are 4 types of filters: + +* `type`: matches all node instances of the given type. +* `...`: matches all nodes. +* `(path, any) -> bool`: a predicate function that takes a node path and value and returns a boolean. +* `Tuple[Filter, ...]`: a tuple of filters, matches all nodes that match any of the filters. + +NNX also provides the following custom filters: + +* `nnx.Not(filter)`: matches all nodes that do not match the given filter +* `nnx.buffers`: matches all `numpy.ndarray` and `jax.Array` nodes + +Here is an example of how to use `Not` and `buffers`: +```python +rest = module.filter(nnx.Not(nnx.Param)) +buffers = module.filter(nnx.buffers) +``` + + +### Capturing Intermediate Values +In NNX you can easily propagate intemediate values by simply assigning them to an attribute at runtime. For convenience, you should assign them to a `Variable` attribute with a `collection` name by using `nnx.var` so you can easily retrieve them later. + +Here is an example of how to create a `Linear` module that captures its output into a `Variable` attribute with the `intermediates` collection name: + +```python +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b + self.y = nnx.Intermediate(y) + return y + +model = Linear(12, 2, ctx=nnx.context(0)) +``` +Since `y` is only created when the module is called, it is not available upon initialization. However, once you call the module `y` will be created. It is recommended that you use `pop_state` to retrieve temporary collections like `Intermediate`: + +```python +y = model(jnp.ones((8, 12))) +intermediates = model.pop_state(nnx.Intermediate) +``` +`pop_state` will return a `State` object with the nodes that match the given filter and remove them from the module's attributes. + +``` +State({ + ('y',): Intermediate( + value=Array(...) + ) +}) +``` + +If you use the functional API to call the module instead, the `Intermediate` nodes will be present in the output `state`. To retrieve the `intermediates` nodes and optionally separate them from the output `state` you can use `State.partition`: + +```python +state, moduledef = model.partition() +y, (state, moduledef) = moduledef.apply(state)(jnp.ones((8, 12))) +# "pop" the intermediates from the state +intermediates, state = state.partition("intermediates", ...) +``` + +Alternatively, you can use `State.filter` to retrieve the `intermediates` nodes without removing them from the `state`. + + + +### Lifted Transforms + +NNX lifted transforms analogous versions of JAX transforms but they know how to work with Modules. They usually perform the following tasks: + +* Handle the Module's substates and Context's RNG streams according to the transform's semantics. +* Properly propagating state in and out of the transform, including updating the input Module's state with updates that happen inside the transform. + +Here's a diagram illustrating how lifted transformations work: + +![lifted-transforms](https://raw.githubusercontent.com/cgarciae/nnx/main/docs/images/stateful-transforms.png) + +Currently NNX provides the `jit`, `grad`, and `scan` lifted transforms. + +#### Manual Lifting + +In case you want to use JAX transforms directly you can always use the functional API +to manually lift your Modules. + +Here we will create an example of how to implement an MLP that uses "scan over layers" to efficiently process a sequence of inputs assuming that each layer has the same parameters and input/output dimensions. The first thing we need to do is create a `Block` module that represents a single layer, this block with just contain a `Linear` layer, a `Dropout` layer, and a `GELU` activation function: + +```python +class Block(nnx.Module): + def __init__(self, dim: int, *, ctx: nnx.Context): + self.linear = nnx.Linear(dim, dim, ctx=ctx) + self.dropout = nnx.Dropout(0.5) + + def __call__(self, x: jax.Array, *, train: bool, ctx: nnx.Context) -> jax.Array: + x = self.linear(x) + x = self.dropout(x, deterministic=not train, ctx=ctx) + x = jax.nn.gelu(x) + return x +``` + +Now we will define `ScanMLP`. During `__init__`, instead of creating a list of `Block`s, we will use `jax.vmap` to create a single `Block` whose parameters have an addtional `layer` axis. This will allow us to pass the parameters as inputs to scan so it will apply a layer at each step. + +```python +class ScanMLP(nnx.Module): + def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context): + params_key = jax.random.split(ctx.make_rng("params"), n_layers) + self.n_layers = n_layers + self.layers = jax.vmap( + lambda key: Block(dim, ctx=nnx.context(params=key)).partition() + )(params_key).merge() + +``` +Note that we split the `params` key into `n_layers` keys so each layer has different parameters. + +Now we will define `__call__`. Here we need to split the `dropout` key into `n_layers` keys so each layer has a different dropout mask, and `partition` the layers to get their `params`. Both `params` and `dropout_key` will be passed as inputs, `x` will be the carry value. Inside the `scan_fn` we will merge the `params` back into a `Block` module and +apply it to the input `x`, passing the sliced `dropout_key` as part of the `Context`. + + +```python + def __call__(self, x: jax.Array, *, train: bool, ctx: nnx.Context) -> jax.Array: + dropout_key = jax.random.split(ctx.make_rng("dropout"), self.n_layers) + params, moduledef = self.layers.partition(nnx.Param) + + def scan_fn(x: inputs): + params, dropout_key = inputs + module = moduledef.merge(params) + x = module(x, train=train, ctx=nnx.context(dropout=dropout_key)) + return x, module.filter(nnx.Param) + + x, params = jax.lax.scan(scan_fn, x, (params, dropout_key)) + self.layers.update_state(params) + return x +``` +Finally we apply `jax.lax.scan`, update the `layers` state with the new `params`, and return the final `x` value. + +Here is a simple way to test our `ScanMLP`: + +```python +model = ScanMLP(10, n_layers=5, ctx=nnx.context(0)) + +x = jnp.ones((3, 10)) +y = model(x, train=True, ctx=nnx.context(dropout=1)) +``` + +For a more robust implementation with comments take a look at the [Scan over layers](https://github.com/cgarciae/nnx/blob/main/examples/06_scan_over_layers.py) example. + +### Case Studies +#### Shared State + +In NNX, you can create modules that share state between them. This is useful when designing complex neural network architectures, as it allows you to reuse certain layers and reduce the number of learnable parameters. + +Here's an example of creating a module with shared state: + +```python +class Block(nnx.Module): + def __init__(self, linear: nnx.Linear, *, ctx: nnx.Context): + self.linear = linear + self.bn = nnx.BatchNorm(2, ctx=ctx) + + def __call__(self, x, *, ctx: nnx.Context): + x = self.linear(x) + x = self.bn(x, ctx=ctx) + x = nnx.relu(x) + return x + +class Model(nnx.Module): + def __init__(self, *, ctx: nnx.Context): + shared = nnx.Linear(2, 2, ctx=ctx) + self.block1 = Block(shared, ctx=ctx) + self.block2 = Block(shared, ctx=ctx) + + def __call__(self, x, *, ctx: nnx.Context): + x = self.block1(x, ctx=ctx) + x = self.block2(x, ctx=ctx) + return x +``` + +In this example, the `Model` module contains two instances of the `Block` module. Each instance shares the same `nnx.Linear` module. To run the model, you can use the Context `flags` argument to set the `use_running_average` flag for all `BatchNorm` modules. + +Here's an example of computing the loss for a `Model` instance: + +```python +def loss_fn(model: Model, x: jax.Array, y: jax.Array): + ctx = nnx.context(flags=dict(use_running_average=True)) + y_pred = model(x, ctx=ctx) + return jnp.mean((y - y_pred) ** 2) +``` + +It's important to note that the state for the shared `nnx.Linear` module will be kept in sync at all times on both `Block` instances, including during gradient updates. diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py new file mode 100644 index 0000000000..4620dfb879 --- /dev/null +++ b/flax/experimental/nnx/__init__.py @@ -0,0 +1,82 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .nnx.containers import ( + BatchStat, + Cache, + Container, + ContainerMetadata, + Intermediate, + Node, + Param, + Static, + Variable, + with_metadata, +) +from .nnx.contextlib import Context, context +from .nnx.dataclasses import ( + dataclass, + field, + node_field, + param_field, + static_field, + var_field, +) +from .nnx.errors import TraceContextError +from .nnx.helpers import Dict, Sequence, TrainState +from .nnx.module import Module, ModuleDef, Pure, PureModule +from .nnx.nn import initializers +from .nnx.nn.activations import ( + celu, + elu, + gelu, + glu, + hard_sigmoid, + hard_silu, + hard_swish, + hard_tanh, + leaky_relu, + log_sigmoid, + log_softmax, + logsumexp, + normalize, + one_hot, + relu, + relu6, + selu, + sigmoid, + silu, + soft_sign, + softmax, + softplus, + standardize, + swish, + tanh, +) +from .nnx.nn.linear import Conv, Embed, Linear +from .nnx.nn.normalization import BatchNorm, LayerNorm +from .nnx.nn.stochastic import Dropout +from .nnx.nodes import is_node, register_node_type +from .nnx.partitioning import All, Not, buffers +from .nnx.pytreelib import Pytree, TreeNode +from .nnx.spmd import ( + PARTITION_NAME, + get_partition_spec, + logical_axis_rules, + logical_to_mesh, + with_logical_constraint, + with_logical_partitioning, +) +from .nnx.state import State +from .nnx.transforms import Remat, Scan, grad, jit, remat, scan diff --git a/flax/experimental/nnx/docs/blog.md b/flax/experimental/nnx/docs/blog.md new file mode 100644 index 0000000000..5c3b2437e1 --- /dev/null +++ b/flax/experimental/nnx/docs/blog.md @@ -0,0 +1,7 @@ +### Do we need another JAX NN library? + +Hello, today I want to talk to you about a new JAX library that I have been working on, but before I do that, I wanted to discuss the topic: Do we need another JAX NN library? + +### JAX Libraries + +JAX NN libraries come in a wide variety ranging from functional like Flax and Haiku, to Pytree-based like Equinox. \ No newline at end of file diff --git a/flax/experimental/nnx/docs/images/stateful-transforms.png b/flax/experimental/nnx/docs/images/stateful-transforms.png new file mode 100644 index 0000000000..d7002fc163 Binary files /dev/null and b/flax/experimental/nnx/docs/images/stateful-transforms.png differ diff --git a/flax/experimental/nnx/docs/quick_start.ipynb b/flax/experimental/nnx/docs/quick_start.ipynb new file mode 100644 index 0000000000..89a809e991 --- /dev/null +++ b/flax/experimental/nnx/docs/quick_start.ipynb @@ -0,0 +1,561 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NNX\n", + "\n", + "Welcome to NNX!\n", + "\n", + "NNX is an open source Python library for **N**eural **N**etwork in JA**X**. Its main feature is, much like Pytorch, allowing Python object semantics and reference sharing, which brings simplicty and familiarity, and easily crossing over into the functional world with through a set of simple APIs.\n", + "\n", + "This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using NNX and train the network for image classification on the MNIST dataset." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install -q nnx" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the MNIST dataset\n", + "We will use the `datasets` library to load MNIST and convert it to NumPy arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cris/nnx/.venv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Found cached dataset mnist (/home/cris/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", + "100%|██████████| 2/2 [00:00<00:00, 499.95it/s]\n" + ] + } + ], + "source": [ + "import datasets\n", + "import numpy as np\n", + "\n", + "dataset = datasets.load_dataset(\"mnist\")\n", + "X_train = np.array(np.stack(dataset[\"train\"][\"image\"]), dtype=np.uint8)[..., None]\n", + "y_train = np.array(dataset[\"train\"][\"label\"], dtype=np.uint8)\n", + "X_test = np.array(np.stack(dataset[\"test\"][\"image\"]), dtype=np.uint8)[..., None]\n", + "y_test = np.array(dataset[\"test\"][\"label\"], dtype=np.uint8)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets visualize a few examples from the dataset using matplotlib:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAH4CAYAAACbup4ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA58klEQVR4nO3df3zNdf/H8dcxs83P+f0rTWtcfi3UjESGvi3R1apFP7TQD3XxbYmkK2zlilQi+VkpikTzI+RSucxVXNpIiJIRlVVsFiM/Ztvn+0dfuzrn9WFn29n2PmeP++3mj/fT+3zO23q3l4/z2vvjsCzLEgAAUK4qlfcCAAAABRkAACNQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFWUQOHTokDodDXn75ZY9dc+PGjeJwOGTjxo0euyYqBvYjTMJ+LDteW5Dnz58vDodDtm3bVt5LKTXr16+Xnj17Sr169SQ4OFgiIyPl3XffLe9lwYav78fvvvtORowYIV27dpXAwEBxOBxy6NCh8l4WLsLX92NiYqI4HA71KzAwsLyXViKVy3sBsLdq1SqJiYmRa6+9tmDzLV26VOLi4iQzM1NGjBhR3ktEBbJlyxaZPn26tGnTRlq3bi07duwo7yUBMnv2bKlevXrB2M/PrxxXU3IUZEPNmDFDGjduLBs2bJCAgAARERk6dKi0atVK5s+fT0FGmfrrX/8qx48flxo1asjLL79MQYYRYmNjpV69euW9DI/x2n+ydkdOTo6MHz9errnmGqlVq5ZUq1ZNunfvLsnJyRd9zdSpUyUkJESCgoKkR48esnv3bjVn7969EhsbK3Xq1JHAwECJiIiQVatWFbqe06dPy969eyUzM7PQudnZ2VK7du2CYiwiUrlyZalXr54EBQUV+nqYx5v3Y506daRGjRqFzoP38Ob9eIFlWZKdnS2+8tBCny7I2dnZ8uabb0pUVJRMnjxZEhMTJSMjQ6Kjo23/hv/OO+/I9OnTZdiwYfL000/L7t27pVevXnLkyJGCOXv27JEuXbrIt99+K2PGjJEpU6ZItWrVJCYmRlasWHHJ9aSmpkrr1q1lxowZha49KipK9uzZI+PGjZP9+/fLgQMHZMKECbJt2zYZPXp0kb8WKH/evB/he3xhP4aGhkqtWrWkRo0aMnDgQKe1eCXLS7399tuWiFhbt2696Jzc3Fzr3LlzTtlvv/1mNWzY0BoyZEhBdvDgQUtErKCgIOvw4cMFeUpKiiUi1ogRIwqy3r17W+Hh4dbZs2cLsvz8fKtr165WixYtCrLk5GRLRKzk5GSVJSQkFPrnO3XqlNW/f3/L4XBYImKJiFW1alVr5cqVhb4WZc/X9+OfvfTSS5aIWAcPHizS61B2fH0/Tps2zRo+fLi1aNEiKykpyYqPj7cqV65stWjRwjpx4kShrzeVT98h+/n5SZUqVUREJD8/X7KysiQ3N1ciIiJk+/btan5MTIw0bdq0YBwZGSmdO3eWtWvXiohIVlaWbNiwQfr37y8nT56UzMxMyczMlGPHjkl0dLSkpaVJenr6RdcTFRUllmVJYmJioWsPCAiQli1bSmxsrCxevFgWLlwoERERMnDgQPniiy+K+JWACbx5P8L3ePN+jI+Pl9dee03uueceueOOO2TatGmyYMECSUtLk1mzZhXxK2EOny7IIiILFiyQq666SgIDA6Vu3bpSv359+eijj+TEiRNqbosWLVTWsmXLgh/v2L9/v1iWJePGjZP69es7/UpISBARkaNHj3pk3cOHD5fVq1fL+++/L3fddZfce++9sn79emncuLHEx8d75D1Q9rx1P8I3+dJ+vOeee6RRo0ayfv36UnuP0ubTXdYLFy6UQYMGSUxMjDz55JPSoEED8fPzk0mTJsmBAweKfL38/HwRERk1apRER0fbzgkLCyvRmkX+aLaYN2+ejB49WipV+u/fmfz9/aVPnz4yY8YMycnJKfjbLbyDt+5H+CZf3I/NmjWTrKysUn2P0uTTBTkpKUlCQ0Nl+fLl4nA4CvILf1tzlZaWprJ9+/ZJ8+bNReSPBgKRPwrjDTfc4PkF/79jx45Jbm6u5OXlqd87f/685Ofn2/4ezOat+xG+ydf2o2VZcujQIenYsWOZv7en+PQ/WV/4IXHrTy3xKSkpsmXLFtv5K1eudPqMIzU1VVJSUqRPnz4iItKgQQOJioqSuXPnyi+//KJen5GRccn1uNvW36BBAwkODpYVK1ZITk5OQX7q1ClZvXq1tGrVih998kLeuh/hm7x5P9pda/bs2ZKRkSE33XRToa83ldffIb/11luybt06lcfHx0u/fv1k+fLlctttt0nfvn3l4MGDMmfOHGnTpo2cOnVKvSYsLEy6desmjz76qJw7d06mTZsmdevWdfoxo5kzZ0q3bt0kPDxcHnroIQkNDZUjR47Ili1b5PDhw7Jz586LrjU1NVV69uwpCQkJl2xc8PPzk1GjRsnYsWOlS5cuEhcXJ3l5eTJv3jw5fPiwLFy4sGhfJJQZX9yPIiInTpyQ1157TURENm/eLCJ/HF4THBwswcHBMnz4cHe+PChjvrofQ0JCZMCAARIeHi6BgYGyadMmef/996VDhw4ydOhQ979Apimv9u6SutDWf7FfP/30k5Wfn29NnDjRCgkJsQICAqyOHTtaa9asse6//34rJCSk4FoX2vpfeukla8qUKVazZs2sgIAAq3v37tbOnTvVex84cMCKi4uzGjVqZPn7+1tNmza1+vXrZyUlJRXM8cSPmSxatMiKjIy0goODraCgIKtz585O7wFz+Pp+vLAmu19/XjvM4Ov78cEHH7TatGlj1ahRw/L397fCwsKsp556ysrOzi7Jl63cOSzLR444AQDAi/n0Z8gAAHgLCjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAAt0/q+vNZp4Crsv5xdvYjLoX9CJO4ux+5QwYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxQubwXAKBoPv/8c5Vdd911Ktu0aZPKbrvtNpUdO3bMMwsDUCLcIQMAYAAKMgAABqAgAwBgAAoyAAAG8NmmrgYNGqjsySefdBqPGjXKrWtNnz5dZc8884zKTp065ebqgOKzLMutrFu3biqbMWOGyu6++27PLAwoR127dlVZYmKiyiIiIlQWGRmpsv3793tkXUXBHTIAAAagIAMAYAAKMgAABqAgAwBgAJ9t6oqPj1fZyJEjncZ2jTB2/vd//1dloaGhKhswYIDKTp8+7dZ7AGWhSZMm5b0EoEDNmjVV5nrqXHh4uJozePBglTVv3lxlVapUcWsddk3ANHUBAFBBUZABADAABRkAAANQkAEAMIBPNHXZnTT08MMPF/q6pKQklc2dO1dlderUUdmcOXNUtmTJEpXFxsaq7Ny5c4WuDQDKUlhYmNO4bdu2xb5Ww4YNVdanTx+V2Z2Q1ahRo0Kv73A4VGb3ffXDDz9U2eLFi1W2d+/eQt+zLHCHDACAASjIAAAYgIIMAIABKMgAABjAJ5q60tPTVVa3bl2Vbd261Wlsd7KWu6d3tWnTRmUJCQkq+9vf/qayqVOnuvUeAFBSwcHBKvvoo49UdtVVVzmNq1atWuz3tGu6cvd769GjR53GKSkpas7mzZtV9sEHH6js0KFDbr2nKbhDBgDAABRkAAAMQEEGAMAAFGQAAAzgE01dV155ZbFe526TgZ3Jkyer7K677lJZpUr8nQdA+Zk3b57Krr32WpW58/1w0aJFKnP35MHly5er7MSJEyr74YcfnMZ2Tbu+imoBAIABKMgAABiAggwAgAEoyAAAGMAnmrree+89ld10000qu/rqq53GHTp0UHN27Njh1nuePXtWZd98843KittwBgCeYNc4ZXeS1sqVK53Gt99+e2ktCRfBHTIAAAagIAMAYAAKMgAABvCJz5DtfjD9mWeeUVlycrLT+LPPPlNzbr311kJfJyJSs2ZNlf31r39V2caNG1UGlAW7zwnh25o0aaKyqKgoldkdArJhwwansWvPjYjI5ZdfrrJ9+/apzPWJTReTnZ2tspycHLde64u4QwYAwAAUZAAADEBBBgDAABRkAAAM4BNNXXb279+vsri4OKfxqlWr1JwVK1aozK7Rq3nz5irz8/NT2bp16y61TOCSBg0apLLw8HC3XmvXuFORnpzj6xo3bqwyu+83ISEhbl3v1VdfLfGaLrBrKLTbj6mpqSr76KOPnMazZs1Sc7KyskqwOnNxhwwAgAEoyAAAGICCDACAASjIAAAYwGebuuy4nrg1ZMgQNWfJkiUqc30Kit21gJKKiYlR2ezZs1VWpUoVt65ndxLd8OHDi7wulL9WrVqpbOnSpSpr27atW9f75ZdfVHbkyBGn8eLFi91cnTZ48GCV2TV1tWnTRmWRkZFO4379+qk5H3zwgcpee+01lXnbqV/cIQMAYAAKMgAABqAgAwBgAAoyAAAGqFBNXa6WLVumsquuukpldqfJ2DXg2PHVE2XgeZ07d1ZZYGCgyuyaY+zYPXYP5rP7HrR582aVVa1aVWXHjx9X2QMPPKCyLVu2qMy1qaskXnrpJbfm2TWrRUdHO42feuopt64fFhamskcffdStdZiCO2QAAAxAQQYAwAAUZAAADEBBBgDAABW6qSs/P19lu3fvVllSUpLKBg4c6NZ7pKWlFX1hqJDsmrXs9qgdu8eNwnfk5eWp7P3331eZXQPX2bNnS2VNnrB3795Cs3/+859qzptvvqmy+++/X2Vz585V2Y4dO4qwwrLFHTIAAAagIAMAYAAKMgAABqAgAwBggArd1OWuatWqFfu1M2fOVNnjjz/uNOZRjiipqVOnlvcS4CG7du1Smd3pXT/++GNZLKfc7du3T2WTJk1S2Zo1a1Tm+r1WRGTQoEGeWFap4A4ZAAADUJABADAABRkAAANQkAEAMABNXW7o3bu3W/Pmz5+vsjvvvFNlQ4YMcRrT1AXgUipKA5e7srOz3ZpXs2bNUl6JZ3GHDACAASjIAAAYgIIMAIAB+AzZReXK+kvicDhUlpOTo7Lp06erzO6H2hMTE53Gzz//vJpj9xQU+JY6deo4jaOjo916XXp6uso+//xzj6wJ8AadOnUq7yWUCu6QAQAwAAUZAAADUJABADAABRkAAAPQ1OXilltuUVmNGjVUtmfPHpXt2LFDZQcPHlTZTTfd5DSeMmWKmtO3b99LLRM+ICsry2n88ccfqzkdOnRQWdOmTVXWvXt3ldntUXgnu6c9NWjQQGXr168vi+WUu3bt2rk1b+3ataW8Es/iDhkAAANQkAEAMAAFGQAAA1CQAQAwAE1dxeRu88SJEydUtmTJEqfxK6+8ouaEhYWpbP/+/W6uDhXNrbfeqrI5c+aUw0pQGqpXr66yTz75RGW9evVS2caNG0tjSWXmjjvuUNkDDzygMrsnQNl9jUzGHTIAAAagIAMAYAAKMgAABqAgAwBgAJq6XLRs2dKteZ5ssAoICFBZeHh4qb4nzLNr1y6VnT9/XmX+/v4qu/HGG1W2evVqlb333nsqy83NdRp/8MEHl1wnyt7OnTtVtnz5cpW9//77KrNr9Prmm288s7BSMGDAAKfxm2++qebYPf7Wrnntxx9/9Ni6ygJ3yAAAGICCDACAASjIAAAYgIIMAIABaOpysW/fPrfm2T3ububMmSqzO2GnX79+TuPMzEw1x9tP10HR2TXk1K9fX2UTJ05UWdWqVVVm9wjPm2++WWV5eXlO4zp16qg5c+fOVRnKzu+//66y4cOHq8zuEZ5ff/21yuwapV599VWnsacbv3r37q2yNm3aqOzFF190GlepUkXNsft/5d577y3B6szAHTIAAAagIAMAYAAKMgAABqAgAwBgAIdlWZZbEx2O0l6LEYKCglT21Vdfqczu8YgLFy5UWdOmTVXm2txgdzKS62k1pnNzG3lMRdmPdmJiYlT2+OOPq8z1MZ8i7v13stvvKSkpbq3NFBV1PzZp0kRln3/+ucquuOIKlbk2jp05c6bY67D7etSqVUtlfn5+Kjt69KjTePz48WrOu+++q7KzZ88WZYllyt39yB0yAAAGoCADAGAACjIAAAagIAMAYACautxgd9pWQkKCymJjY1VWs2ZNlS1btsxp/Mwzz6g5GRkZRVliuauoTTQwE/vxv9q1a6ey+Ph4lbk+8rVTp07Ffk+7r0dSUpLKdu/erbJ58+Y5jdPT04u9DlPQ1AUAgBehIAMAYAAKMgAABuAzZHgEn9nBJOxHmITPkAEA8CIUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADOD24xcBAEDp4Q4ZAAADUJABADAABRkAAANQkAEAMAAFWUQOHTokDodDXn75ZY9dc+PGjeJwOGTjxo0euyYqBvYjTMJ+LDteW5Dnz58vDodDtm3bVt5LKRWJiYnicDjUr8DAwPJeGmz4+n5csWKFREdHS5MmTSQgIEAuu+wyiY2Nld27d5f30mDD1/ejiMj69eulZ8+eUq9ePQkODpbIyEh59913y3tZJVK5vBeAS5s9e7ZUr169YOzn51eOq0FF9fXXX0vt2rUlPj5e6tWrJ7/++qu89dZbEhkZKVu2bJH27duX9xJRgaxatUpiYmLk2muvLbh5Wbp0qcTFxUlmZqaMGDGivJdYLBRkw8XGxkq9evXKexmo4MaPH6+yBx98UC677DKZPXu2zJkzpxxWhYpqxowZ0rhxY9mwYYMEBASIiMjQoUOlVatWMn/+fK8tyF77T9buyMnJkfHjx8s111wjtWrVkmrVqkn37t0lOTn5oq+ZOnWqhISESFBQkPTo0cP2n+T27t0rsbGxUqdOHQkMDJSIiAhZtWpVoes5ffq07N27VzIzM93+M1iWJdnZ2cL5Ld7PF/bjnzVo0ECqVq0qx48fL9brUb68eT9mZ2dL7dq1C4qxiEjlypWlXr16EhQUVOjrTeXTBTk7O1vefPNNiYqKksmTJ0tiYqJkZGRIdHS07NixQ81/5513ZPr06TJs2DB5+umnZffu3dKrVy85cuRIwZw9e/ZIly5d5Ntvv5UxY8bIlClTpFq1ahITEyMrVqy45HpSU1OldevWMmPGDLf/DKGhoVKrVi2pUaOGDBw40Gkt8C6+sB+PHz8uGRkZ8vXXX8uDDz4o2dnZ0rt3b7dfD3N4836MioqSPXv2yLhx42T//v1y4MABmTBhgmzbtk1Gjx5d5K+FMSwv9fbbb1siYm3duvWic3Jzc61z5845Zb/99pvVsGFDa8iQIQXZwYMHLRGxgoKCrMOHDxfkKSkplohYI0aMKMh69+5thYeHW2fPni3I8vPzra5du1otWrQoyJKTky0RsZKTk1WWkJBQ6J9v2rRp1vDhw61FixZZSUlJVnx8vFW5cmWrRYsW1okTJwp9PcqWr+/HC/7yl79YImKJiFW9enVr7NixVl5entuvR9nw9f146tQpq3///pbD4SjYj1WrVrVWrlxZ6GtN5tN3yH5+flKlShUREcnPz5esrCzJzc2ViIgI2b59u5ofExMjTZs2LRhHRkZK586dZe3atSIikpWVJRs2bJD+/fvLyZMnJTMzUzIzM+XYsWMSHR0taWlpkp6eftH1REVFiWVZkpiYWOja4+Pj5bXXXpN77rlH7rjjDpk2bZosWLBA0tLSZNasWUX8SsAE3rwfL3j77bdl3bp1MmvWLGndurWcOXNG8vLy3H49zOHN+zEgIEBatmwpsbGxsnjxYlm4cKFERETIwIED5YsvvijiV8Ig5fwXgmJz52+AlmVZ8+fPt8LDwy1/f/+Cv0mJiHXFFVcUzLnwN8Dx48er1993331WQECAZVn//RvhpX5t377dsiz7vwF6QqNGjazevXt79JoouYq4H7OysqyGDRtaI0eO9Ng14Rm+vh+HDh1qtW/f3ulfZ3JycqwWLVpYkZGRxbqmCXy6y3rhwoUyaNAgiYmJkSeffFIaNGggfn5+MmnSJDlw4ECRr5efny8iIqNGjZLo6GjbOWFhYSVac2GaNWsmWVlZpfoeKB2+th9r164tvXr1kkWLFnn00AiUDW/djzk5OTJv3jwZPXq0VKr033/k9ff3lz59+siMGTMkJyen4O7fm/h0QU5KSpLQ0FBZvny5OByOgjwhIcF2flpamsr27dsnzZs3F5E/GqxE/vgPf8MNN3h+wYWwLEsOHTokHTt2LPP3Rsn52n4UETlz5oycOHGiXN4bJeOt+/HYsWOSm5tr+1HJ+fPnJT8/32s/RvH5z5BFxOlHhlJSUmTLli2281euXOn0GUdqaqqkpKRInz59ROSPH/OIioqSuXPnyi+//KJen5GRccn1FKWt3+5as2fPloyMDLnpppsKfT3M48378ejRoyo7dOiQ/Otf/5KIiIhCXw/zeOt+bNCggQQHB8uKFSskJyenID916pSsXr1aWrVq5bU/+uT1d8hvvfWWrFu3TuXx8fHSr18/Wb58udx2223St29fOXjwoMyZM0fatGkjp06dUq8JCwuTbt26yaOPPirnzp2TadOmSd26dZ3a6GfOnCndunWT8PBweeihhyQ0NFSOHDkiW7ZskcOHD8vOnTsvutbU1FTp2bOnJCQkFNq4EBISIgMGDJDw8HAJDAyUTZs2yfvvvy8dOnSQoUOHuv8FQpny1f0YHh4uvXv3lg4dOkjt2rUlLS1N5s2bJ+fPn5cXXnjB/S8QypQv7kc/Pz8ZNWqUjB07Vrp06SJxcXGSl5cn8+bNk8OHD8vChQuL9kUySfl+hF18F5oWLvbrp59+svLz862JEydaISEhVkBAgNWxY0drzZo11v3332+FhIQUXOtC08JLL71kTZkyxWrWrJkVEBBgde/e3dq5c6d67wMHDlhxcXFWo0aNLH9/f6tp06ZWv379rKSkpII5JW3rf/DBB602bdpYNWrUsPz9/a2wsDDrqaeesrKzs0vyZUMp8fX9mJCQYEVERFi1a9e2KleubDVp0sS66667rF27dpXky4ZS4uv70bIsa9GiRVZkZKQVHBxsBQUFWZ07d3Z6D2/ksCyOgAIAoLz59GfIAAB4CwoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgALdP6vrzWaeAq7L+cXb2Iy6F/QiTuLsfuUMGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMULm8FwBUBA888IDK3njjDY9d3+FwqOzw4cMqmzhxospmz57tsXUAKD7ukAEAMAAFGQAAA1CQAQAwAAUZAAAD0NQFlIE+ffqozLIsj13f7lpNmjRR2fTp01XWoUMHp/HQoUM9ti6gqPz8/JzGHTt2VHOef/55ldnNu+eee1S2fv36EqyudHGHDACAASjIAAAYgIIMAIABKMgAABiApq4KrG7duk7jVq1aufW6zZs3l8ZyfNqXX36psmuuucZpfPnll5f6OipV0n8Hdz1FbM+ePWqOXTMYUFKuDYUiIgsWLHAaX3XVVcW+frNmzYr92vLAHTIAAAagIAMAYAAKMgAABqAgAwBgAJq6fFCNGjVUNmTIEJWNHDnSady0aVO3ru96kg4KN2nSJJV98MEHTuNGjRp59D1feeUVlbk2konoRze6NvsBnjB48GCVPfPMMyq78sorncb/+c9/1JzTp0+r7OzZsypbuHBhUZZY7rhDBgDAABRkAAAMQEEGAMAAFGQAAAxAU5cXadu2rcoee+wxlUVHR6usuCfW/PDDD8V6HQq3f//+S46Lwu6/77lz54p9PcBddk2e48ePV9kTTzyhssqVdQl6+umnncavv/66mrNx40aVNWzYUGVVqlRR2fnz51VmCu6QAQAwAAUZAAADUJABADAABRkAAAPQ1GUAu8ceDho0SGUPPfSQyoKDg1U2b948lS1btkxlgYGBTuPevXurOX//+99VhrJj15QyYcIEldmdguTuiVu///6703j16tVurg4QeeSRR1Rm19S1detWldmdIPjdd985jZcvX67mhIeHq+yNN95QmeveNh13yAAAGICCDACAASjIAAAYgIIMAIABaOoqZS1btlTZmDFjnMZ33nmnmlO1alWVLVq0SGV2DQ8rV64swgr/68MPPyzW61B6VqxYobKbbrrJo+/h+gi8bdu2efT68B133323yl544QWV/etf/1LZAw88oDK7kwATExOdxv369VNzsrOzVbZkyRKVeRvukAEAMAAFGQAAA1CQAQAwQIX+DNnu6SDPP/+8ys6cOaOyhIQEld1zzz0qc/18TkR/Pvzuu++qOXafDa9fv15l8A59+/ZVmevTu5577jk1x9/fv9jvmZGRobI5c+aobNasWcV+D1QsdocY2X2e+/jjj6vM7vPiW265RWWuPTY5OTlqTmxsrMrsPrf2NtwhAwBgAAoyAAAGoCADAGAACjIAAAaoUE1drk+/WbdunZrTvn17lVmWpbJbb71VZfXr11eZ3SELzz77rNOYZi3v1bRpU5Xdd999Khs3bpzKXJ+25a5Dhw6pbMCAASr78ccfVXb06NFivScgIlKnTh2VJSUlqWz37t0qi46OVtmUKVMKfU+7xthPP/200Nd5I+6QAQAwAAUZAAADUJABADAABRkAAANUqKau0aNHO42vuuqqYl/L4XCo7Pbbb1fZP//5z2K/B8zXsWNHldmd9lZcmzZtUtmLL76oMp7QhLJw8uRJlT388MMqs2v+snuqnZ3Bgwc7jRcvXuzm6rwfd8gAABiAggwAgAEoyAAAGICCDACAAXy2qatPnz4qe+KJJ5zGx48fV3Nq166tsm+++UZld911l8r27NlThBXCF9g9eu73339XWbVq1Yp1/W7duqksLCxMZXZ7Lz09XWWvvvpqofPsHtsIiNh/z3Q9AVFEZODAgW5db9CgQSqrSE1crrhDBgDAABRkAAAMQEEGAMAAFGQAAAzgs01dMTExKqtUyfnvH8uXL1dz3n77bZV99dVXKjtz5kzxFwef8dlnn6nM7sS2ESNGqMz1pLgmTZq49Z6NGjVyK7MTFxensu3btzuN9+3bp+aMHDlSZb/++qtb7wmIiPzjH/9QWUVu4LLDHTIAAAagIAMAYAAKMgAABqAgAwBgAJ9t6tq/f7/KXB+Z2L17dzXnoYceKrU1oWJYv369W1mVKlWcxkOHDlVzJk+erLKAgIASrE67+uqrLzkWEWnXrp3KXnvtNZW99dZbKsvPzy/B6lBemjdvrrIHH3yw2Nf7+uuvVZaTk1Ps6/ki7pABADAABRkAAANQkAEAMAAFGQAAA/hsU9fu3btVlpeX5zRu0aKFmhMREaGybdu2eW5hwP9zbWixa5LaunWryuweLWrH7vSu4jbl2DV1zZ07V2X169dX2aRJk4r1nihbrk2GL7/8sppj9z3z448/Vtn111+vsmuvvVZlS5cuLcoSfR53yAAAGICCDACAASjIAAAYwGFZluXWRJdDNbzRCy+84DR+8skn1ZzffvtNZe3bt1dZenq65xbmA9zcRh5jyn60+zy3R48eKnv99ddV9v3335fKmi7w8/NTWdWqVVV26623Oo2HDx+u5nTq1Mmt93Tt0xARueWWW1Rm97mjJ1XU/VgSTz31lNPY9fuliMj8+fNVZnegzXPPPacyu/9X7L63+iJ39yN3yAAAGICCDACAASjIAAAYgIIMAIABfPZgEDsvvvii0/jOO+9Uc6644gqVVa9evdTWBO9m1zhl1yyYkpKistJu6rJrsDp58qTKFi5c6DS2e1Lahg0bVGb31Cm7r0flyhXq24xXuOGGG1Q2YcIEp3FaWpqaM2bMGJXZPbHp9OnTJVhdxcUdMgAABqAgAwBgAAoyAAAGoCADAGCACtVtkZWV5TT++eef1Ry7pi6gpOxOPXJ9cs6SJUvUnB9++MGt619zzTUqs2tGrFWrlsqefvppp3HTpk3VHLsGLjtfffWVyjZv3uzWa1E6mjdvrjK7k+NcT5MaOXKkmnPkyBGV2T0J7JFHHlHZJ598cqllQrhDBgDACBRkAAAMQEEGAMAAFGQAAAxQoZq6WrVq5TS+7rrr1BxfeIways7WrVtVtnbtWpV1795dZZMmTXIaDx48WM2xOzXLTrdu3VRWs2ZNt15bXHanfrmehicicvz48VJdBy7tvvvuU5ld86rrIxNXr17t1vUffvhhlTVu3Fhl//nPf9y6XkXGHTIAAAagIAMAYAAKMgAABqAgAwBgAIflejzLxSYa0uw0evRolZ09e1Zl7733nspmzJjhNLZ7/KLdCTbx8fEqs3vkWEXm5jbyGFP2o7tiY2NVtmDBAqdxYGBgWS3nkjIyMlT2xRdfqOzVV19VWXJycqmsqajYj/81f/58lfXt21dljRo1chq3bt1azXnjjTdUFhkZqTK777/333+/yvLz81Xmi9zdj9whAwBgAAoyAAAGoCADAGAACjIAAAbwuqaugwcPquzyyy8v1rXsHm0XGhparGtVdDTRFN0tt9ziNF65cmWxrzVt2jSV/fLLL269Ni8vz2k8derUYq/DFOzH/3rooYdUZte8euDAAadxSEiImmPXzDpx4kSVTZ48WWW5ubmXXKcvo6kLAAAvQkEGAMAAFGQAAAxAQQYAwABe9/jF8PBwlQ0ZMkRlEyZMUJlrk8vNN9/suYUBReT6eDs/P79yWgl82eHDh92ad+WVVzqNd+3apebYnba1Y8eOYq0LGnfIAAAYgIIMAIABKMgAABjA6w4GcVerVq1UdvLkSadxenp6WS3H53EQA0zCfoRJOBgEAAAvQkEGAMAAFGQAAAxAQQYAwAA+29SFskUTDUzCfoRJaOoCAMCLUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAAO4fVIXAAAoPdwhAwBgAAoyAAAGoCADAGAACjIAAAagIIvIoUOHxOFwyMsvv+yxa27cuFEcDods3LjRY9dExcB+hEnYj2XHawvy/PnzxeFwyLZt28p7KaUmPT1d+vfvL8HBwVKzZk259dZb5fvvvy/vZcFGRdiP69evl549e0q9evUkODhYIiMj5d133y3vZcGGr+/H5s2bi8PhsP3VokWL8l5esVUu7wXA3qlTp6Rnz55y4sQJ+fvf/y7+/v4ydepU6dGjh+zYsUPq1q1b3ktEBbJq1SqJiYmRa6+9VhITE8XhcMjSpUslLi5OMjMzZcSIEeW9RFQg06ZNk1OnTjllP/zwg4wdO1ZuvPHGclpVyVGQDTVr1ixJS0uT1NRU6dSpk4iI9OnTR9q1aydTpkyRiRMnlvMKUZHMmDFDGjduLBs2bJCAgAARERk6dKi0atVK5s+fT0FGmYqJiVHZP/7xDxERuffee8t4NZ7jtf9k7Y6cnBwZP368XHPNNVKrVi2pVq2adO/eXZKTky/6mqlTp0pISIgEBQVJjx49ZPfu3WrO3r17JTY2VurUqSOBgYESEREhq1atKnQ9p0+flr1790pmZmahc5OSkqRTp04FxVhEpFWrVtK7d29ZunRpoa+Hebx5P2ZnZ0vt2rULirGISOXKlaVevXoSFBRU6OthHm/ej3bee+89ueKKK6Rr167Fer0JfLogZ2dny5tvvilRUVEyefJkSUxMlIyMDImOjpYdO3ao+e+8845Mnz5dhg0bJk8//bTs3r1bevXqJUeOHCmYs2fPHunSpYt8++23MmbMGJkyZYpUq1ZNYmJiZMWKFZdcT2pqqrRu3VpmzJhxyXn5+fmya9cuiYiIUL8XGRkpBw4ckJMnT7r3RYAxvHU/iohERUXJnj17ZNy4cbJ//345cOCATJgwQbZt2yajR48u8tcC5c+b96Orr776Sr799lu55557ivxao1he6u2337ZExNq6detF5+Tm5lrnzp1zyn777TerYcOG1pAhQwqygwcPWiJiBQUFWYcPHy7IU1JSLBGxRowYUZD17t3bCg8Pt86ePVuQ5efnW127drVatGhRkCUnJ1siYiUnJ6ssISHhkn+2jIwMS0Ss5557Tv3ezJkzLRGx9u7de8lroGz58n60LMs6deqU1b9/f8vhcFgiYomIVbVqVWvlypWFvhZlz9f3o6uRI0daImJ98803RX6tSXz6DtnPz0+qVKkiIn/cdWZlZUlubq5ERETI9u3b1fyYmBhp2rRpwTgyMlI6d+4sa9euFRGRrKws2bBhg/Tv319OnjwpmZmZkpmZKceOHZPo6GhJS0uT9PT0i64nKipKLMuSxMTES677zJkzIiJO/zx4QWBgoNMceA9v3Y8if+zFli1bSmxsrCxevFgWLlwoERERMnDgQPniiy+K+JWACbx5P/5Zfn6+vP/++9KxY0dp3bp1kV5rGp9v6lqwYIFMmTJF9u7dK+fPny/Ir7jiCjXXrl2+ZcuWBZ/Z7t+/XyzLknHjxsm4ceNs3+/o0aNOm7Y4Lnwmd+7cOfV7Z8+edZoD7+KN+1FEZPjw4fLFF1/I9u3bpVKlP/4e379/f2nbtq3Ex8dLSkpKid8DZc9b9+Of/fvf/5b09HSfaCz06YK8cOFCGTRokMTExMiTTz4pDRo0ED8/P5k0aZIcOHCgyNfLz88XEZFRo0ZJdHS07ZywsLASrVlEpE6dOhIQECC//PKL+r0LWZMmTUr8Pihb3rofc3JyZN68eTJ69OiCYiwi4u/vL3369JEZM2ZITk5Owd0WvIO37kdXixYtkkqVKsndd9/t8WuXNZ8uyElJSRIaGirLly8Xh8NRkCckJNjOT0tLU9m+ffukefPmIiISGhoqIn98I7rhhhs8v+D/V6lSJQkPD7f9of6UlBQJDQ2VGjVqlNr7o3R46348duyY5ObmSl5envq98+fPS35+vu3vwWzeuh//7Ny5c7Js2TKJioryiZsUn/8MWUTE+tMjn1NSUmTLli2281euXOn0GUdqaqqkpKRInz59RESkQYMGEhUVJXPnzrW9e83IyLjkeorS1h8bGytbt251KsrfffedbNiwQe68885CXw/zeOt+bNCggQQHB8uKFSskJyenID916pSsXr1aWrVqxUcoXshb9+OfrV27Vo4fP+7VP3v8Z15/h/zWW2/JunXrVB4fHy/9+vWT5cuXy2233SZ9+/aVgwcPypw5c6RNmzbqlBeRP/45pVu3bvLoo4/KuXPnZNq0aVK3bl2nH+uYOXOmdOvWTcLDw+Whhx6S0NBQOXLkiGzZskUOHz4sO3fuvOhaU1NTpWfPnpKQkFBo48Lf/vY3eeONN6Rv374yatQo8ff3l1deeUUaNmwoI0eOdP8LhDLli/vRz89PRo0aJWPHjpUuXbpIXFyc5OXlybx58+Tw4cOycOHCon2RUGZ8cT/+2aJFiyQgIEDuuOMOt+Ybr9z6u0voQlv/xX799NNPVn5+vjVx4kQrJCTECggIsDp27GitWbPGuv/++62QkJCCa11o63/ppZesKVOmWM2aNbMCAgKs7t27Wzt37lTvfeDAASsuLs5q1KiR5e/vbzVt2tTq16+flZSUVDDHE239P/30kxUbG2vVrFnTql69utWvXz8rLS2tuF8ylKKKsB8XLVpkRUZGWsHBwVZQUJDVuXNnp/eAOSrCfjxx4oQVGBho3X777cX9MhnHYVl/+vcKAABQLnz6M2QAALwFBRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADCA2yd1/fmsU8BVWf84O/sRl8J+hEnc3Y/cIQMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAaoXN4LAHBxdevWVdl9992nsttvv11l3bt3V1lSUpLKPvvsM6fxa6+9VpQlAvAQ7pABADAABRkAAANQkAEAMAAFGQAAAzgsy7LcmuhwlPZa4MXc3EYeU1H249ixY1X27LPPquzs2bMqW7JkicoGDhyosjNnzjiNQ0JC1Jzjx49fapnGYT/CJO7uR+6QAQAwAAUZAAADUJABADAABRkAAAPQ1FXKwsLCVPbOO++U6nvefPPNKvv9999Vdv78eY+9J000JdeuXTuVffzxxypr1KiRyuwavZ577jmVffPNNyr7y1/+4jSePn26mtO8eXOVPf744yr74YcfVFYeKsJ+vP7661Vm99/u559/VtmyZctUNm/ePM8sDApNXQAAeBEKMgAABqAgAwBgAAoyAAAGoKnLgzp06KCyTz75RGV2j9QrbXanNj322GNO48zMzGJfvyI00ZS2CRMmqOzvf/+7ytLS0lQWHh6uMrumvWHDhqls4sSJTuPq1atfcp0XJCcnq+yGG25w67WlrSLsx/T0dJXZNfy567fffnMaf/DBB2qOXTOY6+tERL788stir8MX0dQFAIAXoSADAGAACjIAAAagIAMAYIDK5b0AX2J3clF5NHDZGTBggMreeOMNp7Fdkw7KTmBgoFvz8vLyVObuqWszZ85UmWtD0uTJk91a2/r16916T5SOTZs2qax9+/Yq8/PzU1loaKjK6tSp4zR++OGH1Ry7zG4/Hjx4UGV2jW9bt25VmWvT4o033qjmnDx5UmVPPvmkynbt2qUyk3GHDACAASjIAAAYgIIMAIABKMgAABjAZ5u6qlWrprL69es7je2asOyaBexOrGnbtq3K+vXr59bazp496zS2ezTi/v37Vfbpp5+qbPTo0SqrUqWKW+uAWTIyMlRm1wjj6VOhFi9e7DS2OzEsKChIZcuXL/foOlA0do2adipX1t/mr776apWNGjXKaWx36lqtWrVUZtc0ZvfYWbt9e+WVV6qsuFzXLyISFxfnseuXBe6QAQAwAAUZAAADUJABADAABRkAAAP4RFNXkyZNVDZ37lyV3XzzzcW6vt0j8Ny1evVqla1Zs8Zp/Oabbxb7+nbNGcX9c8I8do9t8/SjBV955RWncc2aNdWcpKQkle3bt8+j60DpyM3NVVlqaqrK+vfv7zQODg5Wc7p06aKy6OholdmdUGjX1BUbG6uy4jalmnIqYklwhwwAgAEoyAAAGICCDACAAbzuM2S7zxdcn1okInLTTTeVxXKczJ8/X2WPPfaYyuwOAikP8fHxTmOe9uQdatSooTLXJ/WIiGRlZamsXbt2KrM7AMIVh4BUPMePH1fZunXr3Mrcdd9996ksMTHRaTx+/Hi3rjVs2LBir8MU3CEDAGAACjIAAAagIAMAYAAKMgAABjC6qWvEiBEqe/TRR1XmySeG2LH7wfoFCxaozO5pI55s4LJr5gkMDCz29ewOe0D5WbRokcomTZqksqZNm6rs2WefVZndgTOuh9KIiDRq1Mhp/N5776k5y5YtUxlQGq655hqnsd1BOHZP4Dt06FBpLanMcIcMAIABKMgAABiAggwAgAEoyAAAGMCYpq7GjRurLCoqSmWl3cBlx66B6+GHHy7zdTz44IMq69WrV7Gvl56eXpLlwMOOHj2qsk2bNqmsW7duKrvjjjtU1r59e5XZPRnthx9+cBonJCSoOefPn1cZUFJ2J3X9z//8T6Gv+/LLL4v9nnanPV522WUq+/7774v9HsXFHTIAAAagIAMAYAAKMgAABqAgAwBgAGOautq0aaOyfv36lfk67B6haHcCV2m7/PLLVXbXXXeV+TpQduwap1555RWVde/eXWWup22J2DdK5uTkqCwuLs5pXB7NLCi6iRMnqiwiIkJldif82Z1qtWXLFqex3QlZJeFwOFR22223qczf37/Qa/Xs2VNl586dU5lds2PNmjVVVq9ePZU1a9as0HV4GnfIAAAYgIIMAIABKMgAABiAggwAgAEclpuf3Nt9IO9JGRkZKqtTp45H3+Prr792Gvft21fNOX78uMo8+QhFd9k1tH344YfFvt706dNVNmbMGKexXVOEuzzdAFKY0t6PprD7f2D//v0qq1WrlsrsvkY33nijytavX1/M1ZmrIuxHT3/PdP0zlEVTlyff4+eff1ZZcnKyylzrgIjI6tWrVbZ3717PLEzc/3NyhwwAgAEoyAAAGICCDACAASjIAAAYwJiTuuxOSsnPzy/29TZv3qyyu+++22lcXo8fbNu2rdPY7uSll156qdjX//XXX1W2YcMGlZWkiQtl495771WZXQOXu3766aeSLAcGcT1ZS8T90w3tvvcdO3bMaexuI5JdM5XdiWENGjRw63qu7L6Xv/rqqypLSkoq1vVNwh0yAAAGoCADAGAACjIAAAagIAMAYABjmro8rXr16iqrXLl0/7jDhw9X2fXXX6+y0NBQp3HHjh09uo7Bgwer7JNPPvHoe8DzoqOjVfbiiy+qzO4RillZWSqze/ziDTfcoLLvvvvO3SXCIHaPYw0MDHTrtXZ7yO7xn+6waw61e+Tjjh07VNa8eXOVpaamOo3tHrWYm5vr/gK9CHfIAAAYgIIMAIABKMgAABjAmM+Q7Z5g4/pZa1G0b99eZa6fYeTl5RX7+naqVq2qsoCAAI9d/+jRoyqz+6H8f//73x57T5Sehg0bOo1feOEFNadKlSoqe+SRR1T27bffquyzzz5T2cSJE1Xm+qSbH3/8US8Wxjl9+rRbWXmIj49X2RVXXKEyu8NHli5d6jT21c+L7XCHDACAASjIAAAYgIIMAIABKMgAABjAmKYuu0YVTx9mUbNmTY9erzTNmjVLZZ9++qnKVq1aVRbLQSl4/vnnncZ2jYivv/66yt544w23ru9wOFRmd2BDhw4dnMY0daEounXrprInnnjCrdcuX75cZbNnzy7xmrwVd8gAABiAggwAgAEoyAAAGICCDACAAYxp6tq5c6fKXE9sERHp379/WSynVK1du9ZpbNfAtX79epUV92ksKH/33nuvyu644w6ncXp6upozatQot65v95Qfu1OQ7DLAXXaNsW+99ZbKatWqpTK7/f3000+r7OzZs8VcnffjDhkAAANQkAEAMAAFGQAAA1CQAQAwgDFNXZmZmSq77777VGZ3oteaNWtKZU1F9eKLL6rM7nQt18c+0qzl++z2rWvji11zzODBg1XWq1cvlV1++eUlWB3gnmHDhqksLCzMrdfaNXDZPXa3IuMOGQAAA1CQAQAwAAUZAAADUJABADCAw3Lz6B67R7kBF5T1CVDeth8///xzlV133XVO4zNnzqg5didwuevXX39Vmd2j7VybEXNycor9nqZgP5ac3aM/Bw0apDI/Pz+Vbdu2TWXXX3+9ys6dO1e8xXkZd/cjd8gAABiAggwAgAEoyAAAGICCDACAAWjqgkfQRHNp7dq1U1nfvn2dxrfddpua06lTJ5Vt3bpVZV9//bXKxo4dq7IjR45ccp2+gv1YdD169HAa250yaNfA5XryoIj9Xv7oo49KsDrvRlMXAABehIIMAIABKMgAABiAggwAgAFo6oJH0EQDk7Afi65Pnz5OY3cfa5uYmKiyCRMmeGJJPoOmLgAAvAgFGQAAA1CQAQAwQOXyXgAAoPy5PqHp5MmTak5ycrLK3nvvvVJbU0XDHTIAAAagIAMAYAAKMgAABqAgAwBgAA4GgUdwEANMwn6ESTgYBAAAL0JBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAM4PZJXQAAoPRwhwwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAAf4PkLEsNK/INnsAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# plot a 3x3 grid of MNIST digits\n", + "idxs = np.random.randint(0, len(X_train), size=(3, 3))\n", + "fig, axes = plt.subplots(3, 3, figsize=(3*2, 3*2))\n", + "\n", + "for i in range(3):\n", + " for j in range(3):\n", + " axes[i, j].imshow(X_train[idxs[i, j]], cmap=\"gray\")\n", + " axes[i, j].axis(\"off\")\n", + " axes[i, j].set_title(f\"Label: {y_train[idxs[i, j]]}\")\n", + "\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Model\n", + "\n", + "To create a convolutional neural network using NNX define a `nnx.Module` subclass. We define the model by subclassing `nnx.Module` and defining a `forward` method that returns the model output. Like in PyTorch, the `__init__` method instantiates all the modules that will be used in the model. The `__call__` in this case\n", + "will define the forward computation. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "data": { + "text/plain": [ + "(1, 10)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import nnx\n", + "\n", + "class CNN(nnx.Module):\n", + " def __init__(self, *, ctx: nnx.Context):\n", + " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), ctx=ctx)\n", + " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), ctx=ctx)\n", + " self.linear1 = nnx.Linear(7*7*64, 256, ctx=ctx)\n", + " self.linear2 = nnx.Linear(256, 10, ctx=ctx)\n", + " self.num_calls = nnx.var(\"counts\", 0)\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " self.num_calls += 1\n", + " x = self.conv1(x)\n", + " x = nnx.relu(x)\n", + " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", + " x = self.conv2(x)\n", + " x = nnx.relu(x)\n", + " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", + " x = x.reshape((x.shape[0], -1)) # flatten\n", + " x = self.linear1(x)\n", + " x = nnx.relu(x)\n", + " x = self.linear2(x)\n", + " return x\n", + " \n", + "model = CNN(ctx=nnx.context(0))\n", + "\n", + "y = model(X_train[:1])\n", + "y.shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One notable difference with other frameworks is that `__init__`, by convention, accepts a `ctx: nnx.Context` keyword-only argument. This object is passed around to generate PRNG keys as random state is explicit in JAX.\n", + "\n", + "One of the nice things about NNX is that Module contain their own state, are fully inspectable, and you can run them eargerly. For example, we can easily check out the kernel shape of the first `Conv` layer:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 3, 1, 32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "model.conv1.kernel.shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also view the entire `State` of the model using the `.filter()` method. TODO: talk about collections." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'conv1/bias': Variable(\n", + " collection='params',\n", + " value=(32,)\n", + " ),\n", + " 'conv1/kernel': Variable(\n", + " collection='params',\n", + " value=(3, 3, 1, 32)\n", + " ),\n", + " 'conv2/bias': Variable(\n", + " collection='params',\n", + " value=(64,)\n", + " ),\n", + " 'conv2/kernel': Variable(\n", + " collection='params',\n", + " value=(3, 3, 32, 64)\n", + " ),\n", + " 'linear1/bias': Variable(\n", + " collection='params',\n", + " value=(256,)\n", + " ),\n", + " 'linear1/kernel': Variable(\n", + " collection='params',\n", + " value=(3136, 256)\n", + " ),\n", + " 'linear2/bias': Variable(\n", + " collection='params',\n", + " value=(10,)\n", + " ),\n", + " 'linear2/kernel': Variable(\n", + " collection='params',\n", + " value=(256, 10)\n", + " )\n", + "})" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.tree_map(jnp.shape, model.filter(\"params\"))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training in eager mode\n", + "\n", + "For pedagogical purposes, we first train the model in eager mode. This will be uselful to take a look at some of NNX's features, its be more approachable for new users, and great for debugging, but it is not the recommended way to train models in JAX.\n", + "\n", + "Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree_map` operation. Finally, we will update the model's parameters using the `.update_state` method." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0: loss=58.7676\n", + "Step 1: loss=80.0420\n", + "Step 2: loss=108.3005\n", + "Step 3: loss=26.6188\n", + "Step 4: loss=10.7236\n", + "Step 5: loss=4.7499\n", + "Step 6: loss=3.9177\n", + "Step 7: loss=2.9419\n", + "Step 8: loss=2.4733\n", + "Step 9: loss=1.8060\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "for step in range(10):\n", + " idxs = np.random.randint(0, len(X_train), size=32)\n", + " x = jnp.array(X_train[idxs])\n", + " y = jnp.array(y_train[idxs])\n", + "\n", + " def loss_fn(model: CNN):\n", + " logits = model(x)\n", + " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", + " \n", + " loss, grads = nnx.value_and_grad(loss_fn, wrt=\"params\")(model)\n", + " params = model.filter(\"params\")\n", + " params = jax.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n", + "\n", + " model.update_state(params)\n", + " print(f\"Step {step}: loss={loss:.4f}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The loss is going down 🎉.\n", + "\n", + "### Training with the Functional API\n", + "\n", + "Now that we have a working model, lets see how to train it with `jax.jit` using NNX's Functional API. The `Module.partition` method allows you to convert a Module into pytrees with functional semantics, this allows you to integrate with JAX's functional APIs like `jax.jit` and `jax.grad`.\n", + "\n", + "In this next example we will use the `.partition` method to split the model into a `params: State` and `moduledef: ModuleDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `moduledef` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `ModuleDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "params, moduledef = model.partition(\"params\")\n", + "\n", + "@jax.jit\n", + "def train_step(params: nnx.State, x, y):\n", + " def loss_fn(params):\n", + " logits, _updates = moduledef.apply(params)(x)\n", + " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", + " \n", + " loss, grads = jax.value_and_grad(loss_fn)(params)\n", + " params = jax.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n", + "\n", + " return loss, params" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `train_step` we can run a few more iterations and see that the loss is still going down, however, this time execution should be much faster." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0: loss=1.4396\n", + "Step 1: loss=1.4127\n", + "Step 2: loss=1.8718\n", + "Step 3: loss=1.7080\n", + "Step 4: loss=1.7984\n", + "Step 5: loss=1.0350\n", + "Step 6: loss=1.2076\n", + "Step 7: loss=0.9081\n", + "Step 8: loss=0.8217\n", + "Step 9: loss=0.6687\n" + ] + } + ], + "source": [ + "for step in range(10):\n", + " idxs = np.random.randint(0, len(X_train), size=32)\n", + " x = jnp.array(X_train[idxs])\n", + " y = jnp.array(y_train[idxs])\n", + "\n", + " loss, params = train_step(params, x, y)\n", + " print(f\"Step {step}: loss={loss:.4f}\")\n", + "\n", + "model.update_state(params)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Realistic Training using TrainState\n", + "\n", + "For real training scenarios, we recommend using `TrainState` to manage the state of your training loop. `TrainState` manages the `params` of your network along with other types of state, and uses `optax` to update the parameters according to the gradients.\n", + "\n", + "Next, we will define a `train_step` function that accepts a `TrainState` and a batch of data, and returns a new `TrainState` with updated parameters. The `apply_gradients` method will return a new `state` with the updated parameters. Flax users should be familiar with this API. In this case will will also define a `eval_step` function that will be used to evaluate the model on the test set and return some metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "state = nnx.TrainState(\n", + " apply_fn=moduledef.apply,\n", + " params=params,\n", + " tx=optax.adam(0.001),\n", + ")\n", + "\n", + "@jax.jit\n", + "def train_step(state: nnx.TrainState, x, y):\n", + " def loss_fn(params):\n", + " logits, _updates = state.apply_fn(params)(x)\n", + " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", + " \n", + " grads = jax.grad(loss_fn)(state.params)\n", + "\n", + " state = state.apply_gradients(grads=grads)\n", + "\n", + " return state\n", + "\n", + "@jax.jit\n", + "def eval_step(state: nnx.TrainState, x, y):\n", + " logits, _updates = state.apply_fn(state.params)(x)\n", + " metrics = {\n", + " 'accuracy': jnp.mean(jnp.argmax(logits, axis=-1) == y),\n", + " 'loss': optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", + " }\n", + " return metrics" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now lets create a simple training loop that runs for 1000 iterations and prints the metrics every 100 steps. At the end of training we will compute the final metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0: {'accuracy': Array(0.63119996, dtype=float32), 'loss': Array(1.1837534, dtype=float32)}\n", + "Step 100: {'accuracy': Array(0.9492, dtype=float32), 'loss': Array(0.16359854, dtype=float32)}\n", + "Step 200: {'accuracy': Array(0.9564, dtype=float32), 'loss': Array(0.14198248, dtype=float32)}\n", + "Step 300: {'accuracy': Array(0.96279997, dtype=float32), 'loss': Array(0.12757339, dtype=float32)}\n", + "Step 400: {'accuracy': Array(0.97169995, dtype=float32), 'loss': Array(0.09900841, dtype=float32)}\n", + "Step 500: {'accuracy': Array(0.96889997, dtype=float32), 'loss': Array(0.10143881, dtype=float32)}\n", + "Step 600: {'accuracy': Array(0.9745, dtype=float32), 'loss': Array(0.08513925, dtype=float32)}\n", + "Step 700: {'accuracy': Array(0.96379995, dtype=float32), 'loss': Array(0.11632324, dtype=float32)}\n", + "Step 800: {'accuracy': Array(0.97679996, dtype=float32), 'loss': Array(0.07204168, dtype=float32)}\n", + "Step 900: {'accuracy': Array(0.9765, dtype=float32), 'loss': Array(0.08413408, dtype=float32)}\n", + "Final metrics: {'accuracy': Array(0.9819, dtype=float32), 'loss': Array(0.05711861, dtype=float32)}\n" + ] + } + ], + "source": [ + "total_steps = 1000\n", + "eval_every = 100\n", + "\n", + "for step in range(total_steps):\n", + " if step % eval_every == 0:\n", + " metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", + " print(f\"Step {step}: {metrics}\")\n", + "\n", + " idxs = np.random.randint(0, len(X_train), size=32)\n", + " x = jnp.array(X_train[idxs])\n", + " y = jnp.array(y_train[idxs])\n", + "\n", + " state = train_step(state, x, y)\n", + "\n", + "metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", + "print(f\"Final metrics: {metrics}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference\n", + "\n", + "Finally, now that we have a trained model, lets use it to make some predictions. We will update the `model` object with the trained parameters and use it to make predictions on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAH4CAYAAACbup4ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABBzklEQVR4nO3de1hVVf7H8S83lXukqJiGaGrmJW9Zk5cUUUa8JOakZY1iTVTeffLalKWOllrpoJk2hdXgpKaMk6GOlk6ieSnJUrPMMDXGS5PiDS/A+v3hD2qztnA4HDgLeL+exz/Wh7X3XpxWfNnnLNb2UEopAQAAbuXp7gEAAAAKMgAARqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBggEpXkOvXry9Dhw7Nb2/ZskU8PDxky5YtLruGh4eHvPDCCy47Hyou5iNMwnx0rzItyEuXLhUPD4/8f9WqVZPGjRvLiBEj5OTJk2U5lBJLSUkpN5Nq165d8vTTT0vbtm3Fx8dHPDw83D0kIzAfy15ubq4sXbpU+vbtK/Xq1RN/f39p3ry5zJgxQy5fvuzu4bkV89F9FixYIE2bNpWqVavKLbfcIuPGjZOLFy+W+Ti8y/yKIjJt2jSJiIiQy5cvS2pqqixatEhSUlJk37594ufnV6Zj6dy5s2RlZUmVKlWKdVxKSoosXLjQdtJlZWWJt7dbXlpbKSkp8re//U1atmwpDRo0kO+++87dQzIK87HsXLp0SeLi4uSee+6RJ598UmrWrCmfffaZTJ06VT7++GP55JNPKv0vjMzHsjVx4kSZPXu2DBgwQEaPHi0HDhyQhIQE2b9/v2zYsKFsB6PKUGJiohIRtXv3bks+btw4JSJq2bJlNzz2woULLhlDeHi4GjJkSInPM3z4cFXGL5/TTpw4oS5duqSUKl/jLm3Mx7J35coVtW3bNi1/8cUXlYiojRs3umFUZmA+lr2MjAzl7e2tHn30UUuekJCgRET961//KtPxGPEZcmRkpIiIpKeni4jI0KFDJSAgQA4fPiwxMTESGBgogwcPFpHrb3nNmzdPmjVrJtWqVZNatWpJfHy8nDlzxnJOpZTMmDFD6tatK35+ftK1a1fZv3+/du0bfUayc+dOiYmJkZCQEPH395eWLVvK/Pnz88e3cOFCERHLW0x57D4jSUtLk549e0pQUJAEBARIt27dZMeOHZY+eW9Zbdu2TcaNGyehoaHi7+8vsbGxcvr0aUvfzMxMOXjwoGRmZhb5+taqVUt8fX2L7IfrmI/XlcZ8rFKlitx7771aHhsbKyIi33zzTaHHV0bMx+tKYz5+9tlnkp2dLYMGDbLkee3333+/0ONdzYj3DQ4fPiwiItWrV8/PsrOzJTo6Wjp27Chz587Nf6smPj5eli5dKnFxcTJq1ChJT0+XBQsWSFpammzbtk18fHxEROT555+XGTNmSExMjMTExMiePXukR48ecvXq1SLHs3HjRundu7eEhYXJ6NGjpXbt2vLNN9/I2rVrZfTo0RIfHy8ZGRmyceNGee+994o83/79+6VTp04SFBQkEyZMEB8fH1m8eLF06dJF/vOf/8jdd99t6T9y5EgJCQmRqVOnypEjR2TevHkyYsQIWb58eX6f5ORkiYuLk8TERMsiDJQc87Hs5+OJEydERKRGjRrFPraiYz6W3ny8cuWKiIh2w5L3en7xxRdFjt+lyvJ2PO8tmU2bNqnTp0+rY8eOqffff19Vr15d+fr6quPHjyullBoyZIgSETVp0iTL8Vu3blUiopKSkiz5+vXrLfmpU6dUlSpVVK9evVRubm5+vylTpigRsbwls3nzZiUiavPmzUoppbKzs1VERIQKDw9XZ86csVznt+cq7C0ZEVFTp07Nb/fr109VqVJFHT58OD/LyMhQgYGBqnPnztrrExUVZbnW2LFjlZeXlzp79qzWNzEx0XYMN1Je3koqC8xH98/HPFFRUSooKEj7HisT5mPZz8cvvvhCiYiaPn26Jc97zQICAgo93tXc8pZ1VFSUhIaGSr169WTQoEESEBAgycnJcsstt1j6PfXUU5b2ypUrJTg4WLp37y4///xz/r+2bdtKQECAbN68WURENm3aJFevXpWRI0da3ioZM2ZMkWNLS0uT9PR0GTNmjNx0002Wrzmz2CQnJ0f+/e9/S79+/aRBgwb5eVhYmDz88MOSmpoq586dsxzzxBNPWK7VqVMnycnJkR9//DE/Gzp0qCiluDt2Aeaje+fjzJkzZdOmTfLSSy9p32NlxHwsu/nYpk0bufvuu+Xll1+WxMREOXLkiKxbt07i4+PFx8dHsrKyiv09lYRb3rJeuHChNG7cWLy9vaVWrVrSpEkT8fS0/m7g7e0tdevWtWSHDh2SzMxMqVmzpu15T506JSKS/x+mUaNGlq+HhoZKSEhIoWPLe3uoefPmjn9DhTh9+rRcunRJmjRpon2tadOmkpubK8eOHZNmzZrl57feequlX96YC34OBNdgPl7njvm4fPly+fOf/yyPPfaYVmAqK+bjdWU1H1etWiUDBw6UYcOGiYiIl5eXjBs3Tv7zn//It99+69Q5neWWgty+fXtp165doX2qVq2qTcLc3FypWbOmJCUl2R4TGhrqsjG6k5eXl22ulCrjkVQOzMfCldZ83Lhxo/zxj3+UXr16yRtvvFGic1UkzMfCuXo+3nLLLZKamiqHDh2SEydOSKNGjaR27dpSp04dady4cUmGWmxGLOpyVMOGDWXTpk3SoUOHQlcNh4eHi8j13xh/+zbI6dOni/wtqmHDhiIism/fPomKirphP0ffngkNDRU/Pz/b37QOHjwonp6eUq9ePYfOBbMwH523c+dOiY2NlXbt2smKFSuM+rvU8or5WDKNGjXKf9fgwIED8t///rfMPxI04s+eHPXggw9KTk6OTJ8+Xftadna2nD17VkSufwbj4+MjCQkJlt+a5s2bV+Q12rRpIxERETJv3rz88+X57bn8/f1FRLQ+BXl5eUmPHj1kzZo1cuTIkfz85MmTsmzZMunYsaMEBQUVOa6CivNnTygdzMdfFWc+fvPNN9KrVy+pX7++rF27lj/JcxHm469K8vMxNzdXJkyYIH5+fvLkk08W+/iSKFe/lt53330SHx8vs2bNki+//FJ69OghPj4+cujQIVm5cqXMnz9fBgwYIKGhofLMM8/IrFmzpHfv3hITEyNpaWmybt26Iv+swtPTUxYtWiR9+vSRVq1aSVxcnISFhcnBgwctO7e0bdtWRERGjRol0dHR4uXlpf0tW54ZM2bIxo0bpWPHjvL000+Lt7e3LF68WK5cuSKzZ8926rUozp+Z/Pjjj/l/fvD555/nj0nk+m/Ljz76qFNjqOyYj79ydD6eP39eoqOj5cyZMzJ+/Hj56KOPLF9v2LCh/O53v3NqDJUd8/FXxfn5OHr0aLl8+bK0atVKrl27JsuWLZNdu3bJO++8o31eXerKckn3jXaiKWjIkCHK39//hl9fsmSJatu2rfL19VWBgYGqRYsWasKECSojIyO/T05OjnrxxRdVWFiY8vX1VV26dFH79u3TdqIpuKw/T2pqqurevbsKDAxU/v7+qmXLliohISH/69nZ2WrkyJEqNDRUeXh4WJb4S4Fl/UoptWfPHhUdHa0CAgKUn5+f6tq1q9q+fbtDr4/dGIvzZyZ5x9v9u++++4o8vqJiPpb9fExPT7/hXJQCf3JT2TAf3fPzMTExUd15553K399fBQYGqm7duqlPPvmkyONKg4dSrBQCAMDdytVnyAAAVFQUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAO79TlzKO1UHmU9Z+zMx9RGOYjTOLofOQOGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwgMOPXwTgvOjoaC0bP368pR0ZGenSa9o9EnDNmjVatn37dkt73rx5Wp+rV6+6bFwA7HGHDACAASjIAAAYgIIMAIABPJRSyqGONp9HAXkcnEYuY/J8rF+/vpbt379fy6pVq1YGoym+lJQULXv11Ve1bPPmzWUxHKcwH2ESR+cjd8gAABiAggwAgAEoyAAAGICCDACAAVjUBZdgEc2vGjZsqGXfffedG0biOufPn9eyyZMna9nq1au17OTJk6UypsIwH2ESFnUBAFCOUJABADAABRkAAANQkAEAMACLugpo27atlv3lL3/RMn9/fy2bPn26lv373/92zcAMxyKaX/n6+mrZzJkztSwzM9PS3rRpk9bnz3/+s5YlJiZqWUxMjJY1a9ZMy1q3bq1lrmS3y1efPn1K9Zp2mI8wCYu6AAAoRyjIAAAYgIIMAIABKMgAABiARV0FxMXFadmbb77p0LHZ2dla1qNHDy379NNPiz8ww7GIxjx2i8vsFli9/vrrlnZISIjT1/z++++1rF27dlpmt/OXKzEfCzdp0qQis6CgIIfO9dprr2nZ7NmztczRHdsK/sxs06aN1uell15y6FymYFEXAADlCAUZAAADUJABADAABRkAAAOwqKuA2267TcvWr1+vZfXr19cyu9dow4YNWma3q1J5xyKa8qvgohm7HcOCg4OdPv9jjz2mZUuXLnX6fI5gPv5q8ODBWvb3v/9dyy5cuGBpZ2VlaX3sdii0Wzz41VdfadnEiRO1zM/PT8vefffdIq/ZqlUrh65pChZ1AQBQjlCQAQAwAAUZAAADUJABADAAi7ocMHLkSC2z253G7jXKyMjQsnr16rlmYAZhEU3F8eSTT2rZwoULnT7f8uXLtezhhx92+nyOYD7+auvWrVrWoUMHLSu4u5bdbl533HGHlr3yyitaFh0drWVXr17Vsl9++UXLateurWUF/elPf9Kyt956q8jj3IVFXQAAlCMUZAAADEBBBgDAABRkAAAM4O3uAQCo2CIjI909BDjAkd3YDhw4oGWxsbFaZrcw8NVXX9UyRxZwVSbcIQMAYAAKMgAABqAgAwBgAD5DBoAK7NChQ1pmtzHIsGHDLO2kpCStT2pqqpZdvnxZy3bv3l2cIeL/cYcMAIABKMgAABiAggwAgAEoyAAAGIBFXQBK1fz58909hErtpZde0rJevXppWc2aNS3tTz/9VOszbtw4LVu7dq2WlfbTr06ePFmq53cX7pABADAABRkAAANQkAEAMAAFGQAAA7CoC4DF448/7vSxdrs27d+/vyTDQQl99913WtajRw8t+/jjjy3tm2++Wetj98SmqVOnatnnn39enCEWym7HMLuFZBUBd8gAABiAggwAgAEoyAAAGICCDACAAVjUVcoWLFjg7iHACf7+/loWEhKiZfHx8VrWoEGDUhlTYVJSUrQsMzNTy8LDw7Xs3nvvtbRbtGjh9Dh++OEHLfvXv/7l9PlQOvbu3atljz76qKX91ltvaX1q166tZcHBwVrWrVs3p8d2/vx5S/svf/mL0+cqb7hDBgDAABRkAAAMQEEGAMAAFGQAAAzAoq5SdvHiRXcPAUWoWrWqlv3973/Xsr59+5bFcJwyaNAgdw9BROwX/XTo0EHLdu7cqWXZ2dmlMiY4Zt26dZZ2kyZNtD5jxozRsoceekjL/Pz8tOzWW291aBw//fSTpX3w4EGHjqsIuEMGAMAAFGQAAAxAQQYAwAAUZAAADMCiLgd4eHg4lHl66r/f2PWDWbKysrRMKeWGkZR/do/s+/TTT7Vs4sSJWjZ37txSGROcU3DHLBGR6dOnO5TZLeTbunWrQ9e1e1xkZcEdMgAABqAgAwBgAAoyAAAGoCADAGAAFnU5wG6Bj12Wk5OjZRcuXCiVMcFc586d07J//OMfTp3rkUce0TK7R0MCJinJYtbt27e7cCTlC3fIAAAYgIIMAIABKMgAABiAz5BdyO7JTomJiW4YCcrKl19+qWX9+vXTsmPHjjl1/q+++krLWrdurWWPP/64U+e388MPP2hZWlqalkVGRmpZSEiIy8aB8qskG+v079/f0p49e3ZJh1NucIcMAIABKMgAABiAggwAgAEoyAAAGIBFXUAJ3HbbbVq2YsUKl53fbgGXj4+Py84vInL06FFL+/XXX9f6vPbaa1rWp08fLRs3bpxD1zx+/LiDo0NlExQU5O4huA13yAAAGICCDACAASjIAAAYgIIMAIABWNTlQsuXL3f3EOCEb775RssaNWqkZV5eXloWEBCgZe3bt3fNwErBTz/9pGW///3vLe1vv/3WoXN9+OGHDmVAcezevdvdQ3Ab7pABADAABRkAAANQkAEAMAAFGQAAA7CoywGrVq3SspkzZ2qZh4dHWQwHLtasWTMtmzt3rpY99thjWuaOXYWuXbumZWfPntWyf/zjH1q2ePFiLXN0ERdQFvbt2+fuIbgNd8gAABiAggwAgAEoyAAAGICCDACAAVjU5YCMjAwty83N1bK6deuWxXBQBp555hktW7RokZZt2LBByyIiIpy6ZmpqqpbZ7Xz1448/atnKlSuduiZQGkqywLUyL47lDhkAAANQkAEAMAAFGQAAA1CQAQAwgIdSSjnUsRJ/0G4nMzPToX7BwcGlPBIzODiNXIb5iMIwH92rQ4cOWrZ161aHjt21a5elfc8997hkTO7k6HzkDhkAAANQkAEAMAAFGQAAA1CQAQAwADt1Oenzzz/Xsnbt2rlhJABQcbjjkaam4A4ZAAADUJABADAABRkAAAPwGbKTpk+frmUTJ050w0gAoOJwdNOliog7ZAAADEBBBgDAABRkAAAMQEEGAMAAPO0JLsHTdWAS5qN73XbbbVqWlJSkZXfddZeWxcbGWtpr1qxx3cDchKc9AQBQjlCQAQAwAAUZAAADUJABADAAi7rgEiyigUmYjzAJi7oAAChHKMgAABiAggwAgAEoyAAAGMDhRV0AAKD0cIcMAIABKMgAABiAggwAgAEoyAAAGKDSFeT69evL0KFD89tbtmwRDw8P2bJli8uu4eHhIS+88ILLzoeKi/kIkzAf3atMC/LSpUvFw8Mj/1+1atWkcePGMmLECDl58mRZDqXEUlJSytWkys3NlUWLFkmrVq3E19dXqlevLpGRkbJ37153D81tmI/ud+3aNbnjjjvEw8ND5s6d6+7huBXz0X1M+fnoXaZX+3/Tpk2TiIgIuXz5sqSmpsqiRYskJSVF9u3bJ35+fmU6ls6dO0tWVpZUqVKlWMelpKTIwoULbSddVlaWeHu75aW9oWHDhklSUpL88Y9/lBEjRsjFixclLS1NTp065e6huR3z0X0SEhLk6NGj7h6GUZiPZc+Un49ueVV69uwp7dq1ExGRxx9/XKpXry6vvvqqrFmzRh566CHbYy5evCj+/v4uH4unp6dUq1bNped09flKasWKFfLOO+/I6tWrJTY21t3DMQ7z0T1OnTol06ZNk4kTJ8rzzz/v7uEYg/lYtkz6+WjEZ8iRkZEiIpKeni4iIkOHDpWAgAA5fPiwxMTESGBgoAwePFhErr+1MG/ePGnWrJlUq1ZNatWqJfHx8XLmzBnLOZVSMmPGDKlbt674+flJ165dZf/+/dq1b/QZyc6dOyUmJkZCQkLE399fWrZsKfPnz88f38KFC0VELG8x5bH7jCQtLU169uwpQUFBEhAQIN26dZMdO3ZY+uS9ZbVt2zYZN26chIaGir+/v8TGxsrp06ctfTMzM+XgwYOSmZlZ5Ov76quvSvv27SU2NlZyc3Pl4sWLRR5TmTEfryut+Zhn0qRJ0qRJE3nkkUccPqYyYj5eVxl+PhpRkA8fPiwiItWrV8/PsrOzJTo6WmrWrClz586VBx54QERE4uPjZfz48dKhQweZP3++xMXFSVJSkkRHR8u1a9fyj3/++eflueeekzvvvFPmzJkjDRo0kB49ejj0Ym/cuFE6d+4sBw4ckNGjR8srr7wiXbt2lbVr1+aPoXv37iIi8t577+X/u5H9+/dLp06dZO/evTJhwgR57rnnJD09Xbp06SI7d+7U+o8cOVL27t0rU6dOlaeeeko+/PBDGTFihKVPcnKyNG3aVJKTkwv9Xs6dOye7du2Su+66S6ZMmSLBwcESEBAgDRo0kBUrVhT5WlRGzEcrV87HPLt27ZJ33nlH5s2bx6MLi8B8tKrQPx9VGUpMTFQiojZt2qROnz6tjh07pt5//31VvXp15evrq44fP66UUmrIkCFKRNSkSZMsx2/dulWJiEpKSrLk69evt+SnTp1SVapUUb169VK5ubn5/aZMmaJERA0ZMiQ/27x5sxIRtXnzZqWUUtnZ2SoiIkKFh4erM2fOWK7z23MNHz5c3ejlExE1derU/Ha/fv1UlSpV1OHDh/OzjIwMFRgYqDp37qy9PlFRUZZrjR07Vnl5eamzZ89qfRMTE23HkGfPnj1KRFT16tVVrVq11Ouvv66SkpJU+/btlYeHh1q3bl2hx1dkzMeyn495427fvr166KGHlFJKpaenKxFRc+bMKfLYioz5yM9HtxTkgv/Cw8PV+vXr8/vlTbgff/zRcvyoUaNUcHCwOnXqlDp9+rTlX0BAgHr88ceVUkotW7ZMiYjlnEpdn4hFTbjdu3crEVGvvfZaod+LoxMuOztb+fn5qQcffFDrFx8frzw9PVVmZqbl9VmxYoWl3+rVq5WIqL179xY6Jjuffvpp/uu8Y8eO/Pz8+fOqRo0aqkOHDsU+Z0XBfLQqi/molFJvv/228vX1VUePHlVKUZDzMB+tKuPPR7cs6lq4cKE0btxYvL29pVatWtKkSRPx9LS+e+7t7S1169a1ZIcOHZLMzEypWbOm7XnzVsT9+OOPIiLSqFEjy9dDQ0MlJCSk0LHlvT3UvHlzx7+hQpw+fVouXbokTZo00b7WtGlTyc3NlWPHjkmzZs3y81tvvdXSL2/MBT8HcoSvr6+IiERERMjdd9+dnwcEBEifPn3k73//u2RnZxu36rEsMR+vK4v5eO7cOZk8ebKMHz9e6tWrV+zjKwPm43WV8eejW34Kt2/fPn8V4Y1UrVpVm4S5ublSs2ZNSUpKsj0mNDTUZWN0Jy8vL9tcOfFgrjp16oiISK1atbSv1axZU65duyYXL16U4ODgYp+7omA+Fs6V83Hu3Lly9epVGThwoBw5ckRERI4fPy4i13+gHjlyROrUqVPsP7OpSJiPhavIPx/L1W1Rw4YNZdOmTdKhQ4f832zshIeHi8j13xgbNGiQn58+fbrI36IaNmwoIiL79u2TqKioG/ZzdCFKaGio+Pn5ybfffqt97eDBg+Lp6Vmqdwp16tSR2rVry08//aR9LSMjQ6pVqyaBgYGldv2KjPlYfEePHpUzZ85Y7njyzJw5U2bOnClpaWnSqlWrUhtDRcV8LD7Tfj4ascraUQ8++KDk5OTI9OnTta9lZ2fL2bNnRUQkKipKfHx8JCEhwfJb07x584q8Rps2bSQiIkLmzZuXf748vz1X3t/8FexTkJeXl/To0UPWrFmTf0cgInLy5ElZtmyZdOzYUYKCgoocV0HFWdY/cOBAOXbsmGzcuDE/+/nnn2XNmjUSGRmp/aYNxzAff+XofBw1apQkJydb/i1evFhErv+5THJyskRERBT7+mA+/lZ5/flYru6Q77vvPomPj5dZs2bJl19+KT169BAfHx85dOiQrFy5UubPny8DBgyQ0NBQeeaZZ2TWrFnSu3dviYmJkbS0NFm3bp3UqFGj0Gt4enrKokWLpE+fPtKqVSuJi4uTsLAwOXjwoOzfv182bNggIiJt27YVkes/YKKjo8XLy0sGDRpke84ZM2bIxo0bpWPHjvL000+Lt7e3LF68WK5cuSKzZ8926rVITk6WuLg4SUxMtOw9a2fy5MmyYsUKeeCBB2TcuHESHBwsb7zxhly7dk1mzpzp1PXBfPwtR+djmzZtpE2bNpYs7wdxs2bNpF+/fk5dH8zH3yq3Px/LcgVZ3iq53bt3F9pvyJAhyt/f/4ZfX7JkiWrbtq3y9fVVgYGBqkWLFmrChAkqIyMjv09OTo568cUXVVhYmPL19VVdunRR+/btU+Hh4YWuIsyTmpqqunfvrgIDA5W/v79q2bKlSkhIyP96dna2GjlypAoNDVUeHh6WFYVSYFm/UteX10dHR6uAgADl5+enunbtqrZv3+7Q62M3xuL8mYlSSh0+fFjFxsaqoKAg5evrqyIjI9WuXbscOraiYj66bz7+Fqusr2M+8vPRQyknPgkHAAAuxYeHAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYwOGduniIOApT1n/OznxEYZiPMImj85E7ZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAAA4/frEiCgoK0rKvvvpKy0aOHKllH374YamMCQDcoeAjAvfu3av1mThxopZt2LCh1MZU2XCHDACAASjIAAAYgIIMAIABPFTBDw5u1NHDo7THUuqaNWtmaSckJGh9unbtqmXvvvuulg0ZMsR1A6sAHJxGLlMR5iNKD/Ox+FauXGlp9+/fX+tz4MABLWvfvr2WZWVluW5gFYCj85E7ZAAADEBBBgDAABRkAAAMQEEGAMAAlWpjkObNm1vadgu47Jw6dao0hgNUCnPmzNGyvn37almTJk3KYji4gczMzCL7nD9/XstycnJKYziVEnfIAAAYgIIMAIABKMgAABiAggwAgAEq1aIuR/zvf//TMrsnnADQjR07VsuefvppLfPx8dGyNm3aWNp79uxx3cDgEidOnNCyq1evumEkFRN3yAAAGICCDACAASjIAAAYgIIMAIABKuyiLi8vLy17+OGHizxu586dWpabm+uSMRVHSEiIlp07d07L2CWnYqtWrZqWFXyMqIjIF198UarjsPv/adq0aVo2efJkLbN79NyyZcu0zO7Rfig7hw4dcvcQKj3ukAEAMAAFGQAAA1CQAQAwAAUZAAADVNhFXYMHD9aygo98++WXX7Q+Tz75ZKmN6UZuv/12Ldu8ebOW2S04mzFjhpZ9/vnnrhkY3O7ZZ5/VsuzsbC1z5aIuuwVcL7zwgpZNmjTJofPZLeD605/+pGWXL1926HwoHR07drS0PTw8tD7bt28vq+FUStwhAwBgAAoyAAAGoCADAGAACjIAAAaosIu6srKyiuxjtxtWYGBgaQzHomXLlpb2ggULtD61a9fWsvvvv1/LYmJitKzg4gwRkV27dhVniHCDdu3aadn48eO1bNasWaU6jsjISC2bMmWKQ8euXLlSy4YNG6Zl165dK/7A4DJ2O8D16tXL0v7LX/6i9dm0aVOpjQncIQMAYAQKMgAABqAgAwBgAAoyAAAGqBCLum655RYti4uLK/K4o0ePapnd7l2O8vTUf78ZMGCAlg0fPtzS7tSpk9PXtNtVyW4nJ5hv9OjRWlalSpVSv25sbKylvWrVKoeOs9sRbuDAgS4ZE1zHbqFqfHx8kceNGDFCyxYuXOiSMcEed8gAABiAggwAgAEoyAAAGKBCfIb82GOPaVnPnj2LPK7gZ7kiIidOnHB6HM2bN9ey5cuXO30+R+Tm5mrZnj17SvWacI0aNWpY2nYbgyiltOznn392+pp33HGHlr399ttFXvPQoUNa1qFDB6fHgbJz6dIlLXvjjTe0rODmL+vWrdP6lOTnI4rGHTIAAAagIAMAYAAKMgAABqAgAwBggAqxqMtu8w07q1evtrRTUlKcvqafn5+WJSUlOXUuu0U6O3bs0LIWLVpoWXh4uJY9+OCDWrZixQqnxgbXqFWrlpZ99NFHlnaTJk20Plu3btWyJUuWOHTNoKAgLVu8eHGR/Y4cOaL16devn5bxxKbyoW7dulr20EMPadm+ffssbbuNX0zRuHFjh/p99913pTwS1+IOGQAAA1CQAQAwAAUZAAADUJABADBAuVvUZff0G7vdh+wU3LHGbkciOz4+Plq2dOlSLbPbqctOwd1uRo4cqfX54IMPHMrq1aunZT/88IND40DpsHsC1zvvvKNlrVu3trSvXr2q9Zk+fbqWObqY6oEHHtAyu921Cv5/cODAAa1P9+7dtcxul7jytoimMnj66ae17P7779eyw4cPW9qXL18utTEVJjo62tKePXu21sdugauHh4eWJSYmatmwYcNKMLrSxR0yAAAGoCADAGAACjIAAAagIAMAYIByt6jLbhcqu0U0X3/9tZY988wzRZ7fbtHYsmXLtMxuwYyjRo8ebWnbLday8/3332uZ3UIGu13EUHbee+89LbNbFFXQm2++qWV2/83tFuS0adNGy+Lj44u8pp2YmBiHsm3btmlZ586dnbomSs/Bgwe1LDIyUsvWrl1raS9atKjUxpTHbrfEggsPAwICtD6OLsjt1auXcwNzE+6QAQAwAAUZAAADUJABADAABRkAAAOUu0VdGzdu1DK7HYPsHj138uTJIs9/6623allJFnC9//77WuboIq6C7rrrLi2zW9TVv39/Lfv000+duiaKr127dlpm99+poBEjRmjZ8OHDXTKm4ozDUdu3b3fZuVB6du7cqWV2i2MffvhhS9vVi7rsdtyy2zmu4CKuPXv2aH08PfV7Sbuf3f/85z+LMUL34w4ZAAADUJABADAABRkAAANQkAEAMEC5W9Rlt+OR3Qf86enpRZ6rTp06WpacnOzcwEQkLS1Ny15++WUtK7gIzW4nmilTpmiZ3S5Ido/i27BhQ6HjhOvUqFFDy0qys5CrjruRs2fPalnBnZzsdrmzW4i4ZcsWVw0LpSgwMFDL7B716ayqVatq2aOPPqpldjslZmdna9nChQst7QsXLmh9/vCHP2iZ3Y5kBXdFNB13yAAAGICCDACAASjIAAAYgIIMAIAByt2iri5dujjUb8eOHUX22bRpk5Y1bdrUofOnpqZq2ciRI7UsLCxMy/r27WtpjxkzRusTEhLi0DhefPFFLVu3bp1Dx6Lkfv75Zy373e9+p2U+Pj5Fnstu1yK7RyjefffdWma3WGvOnDla9u6772pZRkZGkWND+WW3c9yVK1e07LvvvnPq/Hbz0W6xlp2pU6dq2axZsyxtu0WGDRo00LITJ05o2eXLlx0ahym4QwYAwAAUZAAADEBBBgDAAB7KwZ0HXPmUmJJYvXq1lsXGxmrZpUuXtOz8+fOWdq1atZweR8Fz3YjdH+U7y27jkXvuuUfLXPlH/45y9QYWRTFlPrpSlSpVtMzu87PbbrtNy+bOnatlEydOdM3AyiHmY/H17t3b0l67dq3Wx+6JTXabHz3xxBNadu+992rZV199pWU9e/a0tI8cOaL1sXuCld04TOHofOQOGQAAA1CQAQAwAAUZAAADUJABADBAudsYZNWqVVpmt6jLz8/PocxZrlysZcduIcOzzz6rZe5YwIXSMX36dC2zW8C1bds2LSu4mQKQJzw8XMvsFljt37+/yHN9//33Wma3wNXbWy8t9erV07KCT74TETlz5oylbbfJiN2GPBUBd8gAABiAggwAgAEoyAAAGICCDACAAcrdTl12Pv74Yy2LjIx0w0gcc+jQIUv7r3/9q9ZnwYIFZTUcl2BnpOL7wx/+YGmvWLFC62P3uto9AerNN9903cAqAOZj4aKiorSs4C5cdk/Da9SokZb997//1TK7OXrs2DEtCw0N1bK4uDhL+5dfftH62I3NZOzUBQBAOUJBBgDAABRkAAAMQEEGAMAAFWJRV0hIiJbZLSoYO3aspV2zZk2nr2m3Q1ZmZqaWHThwQMsGDBhgaVeEXWdYRFM4f39/LSv4aMX69etrfVauXKlldnP77NmzTo+tImI+Fl/BXbMcfQ3tfhb+8MMPWmb3Gi1btkzLCi4uqwi7EbKoCwCAcoSCDACAASjIAAAYgIIMAIABKsSiLkfdfffdlrbdY+y8vLwcOpfd4+7sdtw6ceKEg6Mr31hEU7hXXnlFy8aMGWNpX7p0SevTs2dPLUtNTXXZuCoq5mPxFdzt7YEHHtD6BAcHa9mHH36oZf369XPZuCoCFnUBAFCOUJABADAABRkAAANQkAEAMEClWtSF0sMiml917txZy5YuXapl4eHhlvbw4cO1Pm+88YbLxlWZMB9hEhZ1AQBQjlCQAQAwAAUZAAADUJABADAAi7rgEiyi+dWGDRu0LCoqSsvefvttS/tPf/pTqY2psmE+wiQs6gIAoByhIAMAYAAKMgAABuAzZLgEn9nBJMxHmITPkAEAKEcoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGMDhxy8CAIDSwx0yAAAGoCADAGAACjIAAAagIAMAYIBKV5Dr168vQ4cOzW9v2bJFPDw8ZMuWLS67hoeHh7zwwgsuOx8qLuYjTMJ8dK8yLchLly4VDw+P/H/VqlWTxo0by4gRI+TkyZNlOZQSS0lJKZeT6tq1a3LHHXeIh4eHzJ07193DcSvmo/usWLFC7rnnHrnpppukevXqct9998lHH33k7mG5FfPRPd5880257777pFatWlK1alWJiIiQuLg4OXLkSJmPxbvMrygi06ZNk4iICLl8+bKkpqbKokWLJCUlRfbt2yd+fn5lOpbOnTtLVlaWVKlSpVjHpaSkyMKFC20nXVZWlnh7u+WlLVJCQoIcPXrU3cMwCvOxbCUkJMioUaOkV69e8tJLL8nly5dl6dKl0rt3b1m1apX079/f3UN0K+Zj2UpLS5OIiAjp27evhISESHp6urz55puydu1a2bt3r9SpU6fsBqPKUGJiohIRtXv3bks+btw4JSJq2bJlNzz2woULLhlDeHi4GjJkSInPM3z4cFXGL1+JnTx5UgUHB6tp06YpEVFz5sxx95DcivnoHo0aNVJ33XWXys3Nzc8yMzNVQECA6tu3rxtH5l7MR3N8/vnnSkTUrFmzyvS6RnyGHBkZKSIi6enpIiIydOhQCQgIkMOHD0tMTIwEBgbK4MGDRUQkNzdX5s2bJ82aNZNq1apJrVq1JD4+Xs6cOWM5p1JKZsyYIXXr1hU/Pz/p2rWr7N+/X7v2jT4j2blzp8TExEhISIj4+/tLy5YtZf78+fnjW7hwoYiI5S2mPHafkaSlpUnPnj0lKChIAgICpFu3brJjxw5Ln7y3rLZt2ybjxo2T0NBQ8ff3l9jYWDl9+rSlb2Zmphw8eFAyMzMdeYlFRGTSpEnSpEkTeeSRRxw+pjJiPl5XWvPx3LlzUrNmTcsY88bh6+tb5PGVDfPxutL++fhb9evXFxGRs2fPOnW8s4x43+Dw4cMiIlK9evX8LDs7W6Kjo6Vjx44yd+7c/Ldq4uPjZenSpRIXFyejRo2S9PR0WbBggaSlpcm2bdvEx8dHRESef/55mTFjhsTExEhMTIzs2bNHevToIVevXi1yPBs3bpTevXtLWFiYjB49WmrXri3ffPONrF27VkaPHi3x8fGSkZEhGzdulPfee6/I8+3fv186deokQUFBMmHCBPHx8ZHFixdLly5d5D//+Y/cfffdlv4jR46UkJAQmTp1qhw5ckTmzZsnI0aMkOXLl+f3SU5Olri4OElMTLQswriRXbt2yTvvvCOpqamW/zmgYz6W7nzs0qWLfPDBB5KQkCB9+vSRy5cvS0JCgmRmZsro0aOLHH9lw3ws/Z+PIiL/+9//JCcnR44ePSrTpk0TEZFu3bo5dKzLlOXteN5bMps2bVKnT59Wx44dU++//76qXr268vX1VcePH1dKKTVkyBAlImrSpEmW47du3apERCUlJVny9evXW/JTp06pKlWqqF69elneFpsyZYoSEctbMps3b1YiojZv3qyUUio7O1tFRESo8PBwdebMGct1fnuuwt6SERE1derU/Ha/fv1UlSpV1OHDh/OzjIwMFRgYqDp37qy9PlFRUZZrjR07Vnl5eamzZ89qfRMTE23HUHDc7du3Vw899JBSSqn09HTeslbMR3fNx5MnT6pu3bopEcn/V6NGDbV9+/Yij63ImI/umY95qlatmj8fq1evrv761786fKyruOUt66ioKAkNDZV69erJoEGDJCAgQJKTk+WWW26x9Hvqqacs7ZUrV0pwcLB0795dfv755/x/bdu2lYCAANm8ebOIiGzatEmuXr0qI0eOtNwNjhkzpsixpaWlSXp6uowZM0Zuuukmy9ecubPMycmRf//739KvXz9p0KBBfh4WFiYPP/ywpKamyrlz5yzHPPHEE5ZrderUSXJycuTHH3/Mz4YOHSpKKYd++1u6dKl8/fXX8vLLLxd7/JUB87Fs56Ofn580adJEhgwZIitXrpS3335bwsLCpH///vL9998X+3uqaJiPZTsf86xbt05SUlLklVdekVtvvVUuXrxY7O+npNzylvXChQulcePG4u3tLbVq1ZImTZqIp6f1dwNvb2+pW7euJTt06JBkZmZKzZo1bc976tQpEZH8/zCNGjWyfD00NFRCQkIKHVve20PNmzd3/BsqxOnTp+XSpUvSpEkT7WtNmzaV3NxcOXbsmDRr1iw/v/XWWy398sZc8HMgR5w7d04mT54s48ePl3r16hX7+MqA+XhdWcxHEZE//OEP4u3tLR9++GF+dv/990ujRo3k2Weftbz1WBkxH68rq/mYp2vXriIi0rNnT7n//vulefPmEhAQICNGjCjReYvDLQW5ffv20q5du0L7VK1aVZuEubm5UrNmTUlKSrI9JjQ01GVjdCcvLy/bXDnxYK65c+fK1atXZeDAgfl/V3f8+HERuT6Bjxw5InXq1Cn2nzVUJMzHwrlyPv7www+yfv16WbJkiSW/+eabpWPHjrJt2zanxliRMB8L58r5eCMNGzaU1q1bS1JSUsUvyM5q2LChbNq0STp06FDoaszw8HARuf4b42/fBjl9+nSRv0U1bNhQRET27dsnUVFRN+zn6NszoaGh4ufnJ99++632tYMHD4qnp2ep3rkePXpUzpw5Y/kNM8/MmTNl5syZkpaWJq1atSq1MVRUzMfiy9vgIicnR/vatWvXJDs7u9SuXdExH10rKytLrly5UqbXNOLPnhz14IMPSk5OjkyfPl37WnZ2dv4S9aioKPHx8ZGEhATLb03z5s0r8hpt2rSRiIgImTdvnrbk/bfn8vf3F5Gil8V7eXlJjx49ZM2aNZadX06ePCnLli2Tjh07SlBQUJHjKsjRZf2jRo2S5ORky7/FixeLyPXPWZKTkyUiIqLY1wfz8bccnY+33XabeHp6yvLlyy3jP378uGzdulVat25d7GvjOubjrxydj9nZ2ba/hOzatUu+/vrrIt+pcLVydYd83333SXx8vMyaNUu+/PJL6dGjh/j4+MihQ4dk5cqVMn/+fBkwYICEhobKM888I7NmzZLevXtLTEyMpKWlybp166RGjRqFXsPT01MWLVokffr0kVatWklcXJyEhYXJwYMHZf/+/bJhwwYREWnbtq2IXC940dHR4uXlJYMGDbI954wZM2Tjxo3SsWNHefrpp8Xb21sWL14sV65ckdmzZzv1Wji6rL9NmzbSpk0bS5Y38Zs1ayb9+vVz6vpgPv6Wo/MxNDRUhg0bJn/729+kW7du0r9/fzl//ry8/vrrkpWVJZMnT3bq+mA+/paj8/HChQtSr149GThwoDRr1kz8/f3l66+/lsTERAkODpbnnnvOqes7rSyXdN9oJ5qChgwZovz9/W/49SVLlqi2bdsqX19fFRgYqFq0aKEmTJigMjIy8vvk5OSoF198UYWFhSlfX1/VpUsXtW/fPm0nmoLL+vOkpqaq7t27q8DAQOXv769atmypEhIS8r+enZ2tRo4cqUJDQ5WHh4dlib8UWNavlFJ79uxR0dHRKiAgQPn5+amuXbtqf+Zxo9fHbozOLOvPw589Xcd8dM98vHbtmkpISFCtWrVSAQEBKiAgQHXt2lV98sknRR5bkTEfy34+XrlyRY0ePVq1bNlSBQUFKR8fHxUeHq4ee+wxlZ6eXuixpcFDKRd+Eg4AAJxSrj5DBgCgoqIgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABnB4py4eao/ClPWfszMfURjmI0zi6HzkDhkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADCAt7sHAJRn3t6O/S90//33a9mdd95paa9evVrr07x5cy37/PPPtezgwYMOjQOAubhDBgDAABRkAAAMQEEGAMAAFGQAAAzgoZRSDnX08CjtsaAI9evX17LZs2drWWRkpJY1atRIy86cOeOScYmIODiNXMYd8zEwMFDLPvroIy3z9/fXsoiICC07evSopd2iRQuHxrF7924tW7p0qZZt3rxZy7799luHrlHeVYb5aIoOHTo41C8qKkrLJk6cqGWbNm2ytJOTk7U+dnP7yJEjDo3DHRydj9whAwBgAAoyAAAGoCADAGAACjIAAAZgUZehfH19tSwpKUnL7HaAevnll7VsypQprhnYDVS0RTQPPPCAltktQGnbtm2pjsNR2dnZWrZu3Tot++Mf/6hl586dK5UxuVNFm48l4ePjo2XBwcFadvnyZUt73LhxWp9BgwZp2e23365lrnz97V5bu59xkydPdtk1XY1FXQAAlCMUZAAADEBBBgDAABRkAAAMwOMXnVS1alUts9uNyW53mvfee8/SzszM1PosWbJEy+wWcH3xxRdaNmfOHC1D8XzwwQdalpub64aROObUqVNa9sknn2iZ3S5fBefosWPHXDcwuN2iRYu0LC4uTssK7hx36623ltqY8mzdulXLOnXqVOrXNRV3yAAAGICCDACAASjIAAAYgIIMAIABWNTlgAYNGmjZSy+9pGV2uzvZadeunaWdkJCg9RkwYIBD53r22We1zJWPVayszp49q2VBQUFlPxAH2T3y0c4vv/yiZdWrVy+yT2hoqJYdP35cy+x2DEPZee2117Rs2LBhWma3c1TBRVyHDh3S+tjtVmfXz24BZHx8vJa1adNGywrau3evlq1fv77I48oj7pABADAABRkAAANQkAEAMAAFGQAAA7Coq4D69etrmd0CLrtFV+fPn9eyzZs3a1nBx4SlpqZqfex2AktLS9OyjRs3ahlKzm6hnd0COjt2u2YV3AVJRGTZsmWW9lNPPaX1OXHihJa99dZbWvbVV19pmd0OcHY7NBVc1PW73/1O67NgwQIts9uZzu77RNk5efKk08cW3DVr8ODBWp+ffvrJoXO98MILWvbII49o2c0336xl3333naX9+9//XutTku/TZNwhAwBgAAoyAAAGoCADAGAACjIAAAbwUHZbtth19PAo7bGUuTp16mhZSkqKlrVo0ULLLly4oGWvv/66lj3//PNaNmLECEv7lVde0fpcvHhRy+wWRaxZs0bL3MHBaeQypT0fW7VqpWV2j7q0s2fPHoeOLbhQyu6xh3YLs+zYLUbs3r27lr3xxhtaVnChTmBgoNbHbpcykxd1VbT5WBJ2r4VdtnDhQkvb19dX69O4cWMts3tcot35v/nmGy1buXKlltktCCvvHJ2P3CEDAGAACjIAAAagIAMAYIBK9Rny7bffbmnbPTGkXr16WnbgwAEtGzNmjJZ9/PHHWmb32d7hw4ct7atXr2p9hg4dqmXLly/XMlNUtM/sPD3131XtPu/q16+fy645ZcoULbPb+KVu3bpatmTJEi2ze0KTI+zOtWvXLi2zm4+XLl1y6pquVtHmY0l07NhRy5KTk7XMbpMOR8yZM0fLPvjgAy07ePCgltmtxamI+AwZAIByhIIMAIABKMgAABiAggwAgAEq1aKugk8zuffeex06bvr06Vrm6B+vF/xjexGRJ5980tK2W0Rj9+Qfk1WGRTQtW7bUsnXr1mlZ7dq1y2I4LrN9+3ZLOyYmRutj9yQzk1WG+VgSdgta7TaSccTf/vY3LfvnP/+pZXb/r1QWLOoCAKAcoSADAGAACjIAAAagIAMAYIAKu6hr4sSJWvbSSy9Z2ufOndP6tG7dWst++OEHh645bdo0Lfvzn/+sZdu2bbO07Z6WUt5U1kU0Y8eO1bK5c+e6YSTOmz9/vqU9btw4N43EdSrrfHSU3ZOc+vbta2n3799f62O361dYWJiW5eTkaNnevXu1rODPZBGRjz76yNLOysrS+pQ3LOoCAKAcoSADAGAACjIAAAagIAMAYIAKsajLboHCjh07tKxFixaW9rVr17Q+//vf/xy6pt3rUaNGDS2ze4zf5cuXLe3OnTtrfb744guHxmGKyrqIxsfHR8tmzZqlZXaLv0wxfPhwSzs1NVXrs2/fvrIajktU1vlY2mrVqqVldjsevvXWW1oWHBzs0DVWrVplab/77rtan7Vr1zp0LlOwqAsAgHKEggwAgAEoyAAAGICCDACAASrsoq7XX39dyx588EFLu1q1ak5f0+71sHsp7RaJJSYmWtp2u9WcOXPG6bG5A4toflWlShUtW7FihZb16dOnLIZTbHZz9vDhw1r217/+Vcu+/vprLXPHgjDmo3mioqK0bNGiRVrWsGFDS9vutZ0yZYqW2S2mNAWLugAAKEcoyAAAGICCDACAASjIAAAYoEIs6nJU8+bNLW27x4bZiYyM1DK7xzvavZQFd0ESEXnjjTccum55wiKawj3zzDNa9vLLLzt1rp9++knL7BbHzJgxQ8u+/fZbLWvSpIlT47Czfv16LYuNjdWyq1evuuyadpiP5UNoaKiWPfLII5b2c889p/UJCAjQsmeffVbLXnnlFS3Lzc0tzhBdgkVdAACUIxRkAAAMQEEGAMAAFGQAAAxQqRZ1OWvDhg1a1r17dy3bu3evlrVu3bpUxmQaFtEUzm7xyttvv21px8TEOH3+5ORkLbN7VN6kSZO0rOBcvvPOO7U+ffv21bIDBw5o2R133KFldgvJtm7dqmVPPfWUpV2SxTfMx4qjQ4cOWvbpp586dGzNmjW1zNFH7LoSi7oAAChHKMgAABiAggwAgAH4DLmA3//+91pm9/lc1apVtcxuExC7DRsqIj6zK76CT4VatWqV1qcknytfunRJyzIyMrTsrbfesrTt1kIsX75cy7Kzs7UsJCSkOEO0aN++vaX9xRdfOH2u8jwfvb29tczuyXQXLlxw2TVN5uPjo2V2TxC77bbbtGzs2LFaZveUstLGZ8gAAJQjFGQAAAxAQQYAwAAUZAAADKCvHqhE7BYL2D0xpODiGxGRzz77TMvefPNN1wwMlULBJx5t27ZN61OSRV1+fn5aZrfwZdasWZa23cIvX19fLbNbfFQSU6dOtbTtNiOpDAYPHqxldk8Lmzlzppb94x//KJUxudO1a9e0LCcnx6Fj7Rbfmow7ZAAADEBBBgDAABRkAAAMQEEGAMAAlXqnLrtdXObOnevQsQMHDtSyDz74oMRjKq/K885IprBbZFi9enUtGz16tJZNmDChVMbkTl5eXk4fW57n40033aRldgv+mjZtqmVr167VstmzZ2tZamqqc4NzA7uFiLt379ay4OBgLYuLi9Oyd955xzUDKwZ26gIAoByhIAMAYAAKMgAABqAgAwBggEq9qGv79u1ads8992jZe++9p2VDhgwplTGVV+V5EU15Y7dD1owZM7Rs/PjxZTEcl9m5c6elfe+99zp9roo2HwMCArTM7pGY3bp107JffvlFy+wWtK5bt87SPnfuXHGG6DIFHzX5r3/9S+tj930W3PlOROSWW27RMrvXo7SxqAsAgHKEggwAgAEoyAAAGICCDACAASrV4xcfeughS7tFixZaH7uFAVu2bCmtIQHFlp2drWV2jw21W2gUGxurZatWrbK0H3vsMa2P3Y5hrjZt2rRSv0Z5deHCBS3r1auXlnXs2FHL7BalLlu2TMtOnDhhaT/++ONan4ILv4ojLCxMy7p3765lBXeia926tdbHbpHU4sWLtcwdC7hKgjtkAAAMQEEGAMAAFGQAAAxAQQYAwACVaqeugo/satOmjdbnrbfe0rInnnii1MZUUVS0nZEqAk9P/fdtu6zgIrE777xT69O/f38ts3uEn92iIjt2jwS8dOmSpV2SOcV8/JXdLl92/52WLFliadeuXVvrk5SUpGWnTp3SMrsFs3Y7rwUGBmpZQV999ZWWTZo0Scs2b96sZXaLdN2BnboAAChHKMgAABiAggwAgAEoyAAAGKBSLer6/vvvLW27XVxeffVVLXv//fdLbUwVBYtoYBLmY/HdfPPNlvbtt9/u0HF2C6zsdhGzY7dIbPXq1Zb2Z599pvU5efKkQ+c3BYu6AAAoRyjIAAAYgIIMAIABKtVnyCg9fGYHkzAfYRI+QwYAoByhIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYACHH78IAABKD3fIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABjg/wCJ6yDm5w+D/QAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.update_state(state.params)\n", + "\n", + "# plot a 3x3 grid of MNIST digits\n", + "idxs = np.random.randint(0, len(X_test), size=(3, 3))\n", + "fig, axes = plt.subplots(3, 3, figsize=(3*2, 3*2))\n", + "\n", + "for i in range(3):\n", + " for j in range(3):\n", + " logits = model(jnp.array([X_test[idxs[i, j]]]))\n", + " axes[i, j].imshow(X_test[idxs[i, j]], cmap=\"gray\")\n", + " axes[i, j].axis(\"off\")\n", + " axes[i, j].set_title(f\"Prediction: {jnp.argmax(logits)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Awesome! We hope you've enjoyed this tutorial and learned the basics of NNX." + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/experimental/nnx/docs/tiny_nnx.ipynb b/flax/experimental/nnx/docs/tiny_nnx.ipynb new file mode 100644 index 0000000000..549d63cfc4 --- /dev/null +++ b/flax/experimental/nnx/docs/tiny_nnx.ipynb @@ -0,0 +1,459 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tiny NNX\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cgarciae/nnx/blob/main/docs/tiny_nnx.ipynb)\n", + "\n", + "A pedagogical implementation of NNX's core APIs.\n", + "\n", + "## Core APIs" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "import hashlib\n", + "import typing as tp\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import random\n", + "import dataclasses\n", + "\n", + "A = tp.TypeVar(\"A\")\n", + "M = tp.TypeVar(\"M\", bound=\"Module\")\n", + "Sharding = tp.Tuple[tp.Optional[str], ...]\n", + "KeyArray = random.KeyArray\n", + "\n", + "\n", + "class Variable(tp.Generic[A]):\n", + "\n", + " def __init__(\n", + " self,\n", + " value: A,\n", + " *,\n", + " sharding: tp.Optional[Sharding] = None,\n", + " ):\n", + " self.value = value\n", + " self.sharding = sharding\n", + "\n", + " def __repr__(self) -> str:\n", + " return f\"{type(self).__name__}(value={self.value}, sharding={self.sharding})\"\n", + "\n", + " def __init_subclass__(cls):\n", + " super().__init_subclass__()\n", + " jax.tree_util.register_pytree_node(\n", + " cls,\n", + " lambda x: ((x.value,), (x.sharding,)),\n", + " lambda metadata, value: Variable(value[0], sharding=metadata[0]),\n", + " )\n", + "\n", + "\n", + "class State(dict[str, Variable[tp.Any]]):\n", + "\n", + " def filter(self, variable_type: tp.Type[Variable]) -> \"State\":\n", + " return State(\n", + " {\n", + " path: variable\n", + " for path, variable in self.items()\n", + " if isinstance(variable, variable_type)\n", + " }\n", + " )\n", + "\n", + " def __repr__(self) -> str:\n", + " elems = \",\\n \".join(\n", + " f\"'{path}': {variable}\".replace(\"\\n\", \"\\n \")\n", + " for path, variable in self.items()\n", + " )\n", + " return f\"State({{\\n {elems}\\n}})\"\n", + "\n", + "\n", + "jax.tree_util.register_pytree_node(\n", + " State,\n", + " # in reality, values and paths should be sorted by path\n", + " lambda x: (tuple(x.values()), tuple(x.keys())),\n", + " lambda paths, values: State(dict(zip(paths, values))),\n", + ")\n", + "\n", + "\n", + "@dataclasses.dataclass\n", + "class ModuleDef(tp.Generic[M]):\n", + " type: tp.Type[M]\n", + " index: int\n", + " submodules: tp.Dict[str, tp.Union[\"ModuleDef[Module]\", int]]\n", + " static_fields: tp.Dict[str, tp.Any]\n", + "\n", + " def merge(self, state: State) -> M:\n", + " module = ModuleDef._build_module_recursive(self, {})\n", + " module.update_state(state)\n", + " return module\n", + "\n", + " @staticmethod\n", + " def _build_module_recursive(\n", + " moduledef: tp.Union[\"ModuleDef[M]\", int],\n", + " index_to_module: tp.Dict[int, \"Module\"],\n", + " ) -> M:\n", + " if isinstance(moduledef, int):\n", + " return index_to_module[moduledef] # type: ignore\n", + "\n", + " assert moduledef.index not in index_to_module\n", + "\n", + " # add a dummy module to the index to avoid infinite recursion\n", + " module = object.__new__(moduledef.type)\n", + " index_to_module[moduledef.index] = module\n", + "\n", + " submodules = {\n", + " name: ModuleDef._build_module_recursive(submodule, index_to_module)\n", + " for name, submodule in moduledef.submodules.items()\n", + " }\n", + " vars(module).update(moduledef.static_fields)\n", + " vars(module).update(submodules)\n", + " return module\n", + "\n", + " def apply(\n", + " self, state: State\n", + " ) -> tp.Callable[..., tuple[tp.Any, tuple[State, \"ModuleDef[M]\"]]]:\n", + " def _apply(*args, **kwargs):\n", + " module = self.merge(state)\n", + " out = module(*args, **kwargs) # type: ignore\n", + " return out, module.partition()\n", + "\n", + " return _apply\n", + "\n", + "\n", + "class Module:\n", + "\n", + " def partition(self: M) -> tp.Tuple[State, ModuleDef[M]]:\n", + " state = State()\n", + " moduledef = Module._partition_recursive(\n", + " module=self, module_id_to_index={}, path_parts=(), state=state\n", + " )\n", + " assert isinstance(moduledef, ModuleDef)\n", + " return state, moduledef\n", + "\n", + " @staticmethod\n", + " def _partition_recursive(\n", + " module: M,\n", + " module_id_to_index: tp.Dict[int, int],\n", + " path_parts: tp.Tuple[str, ...],\n", + " state: State,\n", + " ) -> tp.Union[ModuleDef[M], int]:\n", + " if id(module) in module_id_to_index:\n", + " return module_id_to_index[id(module)]\n", + "\n", + " index = len(module_id_to_index)\n", + " module_id_to_index[id(module)] = index\n", + "\n", + " submodules = {}\n", + " static_fields = {}\n", + "\n", + " # iterate fields sorted by name to ensure deterministic order\n", + " for name, value in sorted(vars(module).items(), key=lambda x: x[0]):\n", + " value_path = (*path_parts, name)\n", + " # if value is a Module, recurse\n", + " if isinstance(value, Module):\n", + " submoduledef = Module._partition_recursive(\n", + " value, module_id_to_index, value_path, state\n", + " )\n", + " submodules[name] = submoduledef\n", + " # if value is a Variable, add to state\n", + " elif isinstance(value, Variable):\n", + " state[\"/\".join(value_path)] = value\n", + " else: # otherwise, add to static fields\n", + " static_fields[name] = value\n", + "\n", + " return ModuleDef(\n", + " type=type(module),\n", + " index=index,\n", + " submodules=submodules,\n", + " static_fields=static_fields,\n", + " )\n", + "\n", + " def update_state(self, state: State) -> None:\n", + " for path, value in state.items():\n", + " path_parts = path.split(\"/\")\n", + " Module._set_value_at_path(self, path_parts, value)\n", + "\n", + " @staticmethod\n", + " def _set_value_at_path(\n", + " module: \"Module\", path_parts: tp.Sequence[str], value: Variable[tp.Any]\n", + " ) -> None:\n", + " if len(path_parts) == 1:\n", + " setattr(module, path_parts[0], value)\n", + " else:\n", + " Module._set_value_at_path(getattr(module, path_parts[0]), path_parts[1:], value)\n", + "\n", + "\n", + "@dataclasses.dataclass\n", + "class Context:\n", + " key: KeyArray\n", + " count: int = 0\n", + " count_path: tuple[int, ...] = ()\n", + "\n", + " def fork(self) -> \"Context\":\n", + " \"\"\"Forks the context, guaranteeing that all the random numbers generated\n", + " will be different from the ones generated in the original context. Fork is\n", + " used to create a new Context that can be passed to a JAX transform\"\"\"\n", + " count_path = self.count_path + (self.count,)\n", + " self.count += 1\n", + " return Context(self.key, count_path=count_path)\n", + "\n", + " def make_rng(self) -> jax.Array:\n", + " fold_data = self._stable_hash(self.count_path + (self.count,))\n", + " self.count += 1\n", + " return random.fold_in(self.key, fold_data) # type: ignore\n", + "\n", + " @staticmethod\n", + " def _stable_hash(data: tuple[int, ...]) -> int:\n", + " hash_str = \" \".join(str(x) for x in data)\n", + " _hash = hashlib.blake2s(hash_str.encode())\n", + " hash_bytes = _hash.digest()\n", + " # uint32 is represented as 4 bytes in big endian\n", + " return int.from_bytes(hash_bytes[:4], byteorder=\"big\")\n", + "\n", + "\n", + "# in the real NNX Context is not a pytree, instead\n", + "# it has a partition/merge API similar to Module\n", + "# but for simplicity we use a pytree here\n", + "jax.tree_util.register_pytree_node(\n", + " Context,\n", + " lambda x: ((x.key,), (x.count, x.count_path)),\n", + " lambda metadata, value: Context(value[0], *metadata),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "class Param(Variable[A]):\n", + " pass\n", + "\n", + "\n", + "class BatchStat(Variable[A]):\n", + " pass\n", + "\n", + "\n", + "class Linear(Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, ctx: Context):\n", + " self.din = din\n", + " self.dout = dout\n", + " key = ctx.make_rng()\n", + " self.w = Param(random.uniform(key, (din, dout)))\n", + " self.b = Param(jnp.zeros((dout,)))\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " return x @ self.w.value + self.b.value\n", + "\n", + "\n", + "class BatchNorm(Module):\n", + "\n", + " def __init__(self, din: int, mu: float = 0.95):\n", + " self.mu = mu\n", + " self.scale = Param(jax.numpy.ones((din,)))\n", + " self.bias = Param(jax.numpy.zeros((din,)))\n", + " self.mean = BatchStat(jax.numpy.zeros((din,)))\n", + " self.var = BatchStat(jax.numpy.ones((din,)))\n", + "\n", + " def __call__(self, x, train: bool) -> jax.Array:\n", + " if train:\n", + " axis = tuple(range(x.ndim - 1))\n", + " mean = jax.numpy.mean(x, axis=axis)\n", + " var = jax.numpy.var(x, axis=axis)\n", + " # ema update\n", + " self.mean.value = self.mu * self.mean.value + (1 - self.mu) * mean\n", + " self.var.value = self.mu * self.var.value + (1 - self.mu) * var\n", + " else:\n", + " mean, var = self.mean.value, self.var.value\n", + "\n", + " scale, bias = self.scale.value, self.bias.value\n", + " x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias\n", + " return x\n", + "\n", + "\n", + "class Dropout(Module):\n", + "\n", + " def __init__(self, rate: float):\n", + " self.rate = rate\n", + "\n", + " def __call__(self, x: jax.Array, *, train: bool, ctx: Context) -> jax.Array:\n", + " if train:\n", + " mask = random.bernoulli(ctx.make_rng(), (1 - self.rate), x.shape)\n", + " x = x * mask / (1 - self.rate)\n", + " return x" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scan Over Layers Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Block(Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, ctx: Context):\n", + " self.linear = Linear(din, dout, ctx=ctx)\n", + " self.bn = BatchNorm(dout)\n", + " self.dropout = Dropout(0.1)\n", + "\n", + " def __call__(self, x: jax.Array, *, train: bool, ctx: Context) -> jax.Array:\n", + " x = self.linear(x)\n", + " x = self.bn(x, train=train)\n", + " x = jax.nn.gelu(x)\n", + " x = self.dropout(x, train=train, ctx=ctx)\n", + " return x\n", + "\n", + "\n", + "class ScanMLP(Module):\n", + "\n", + " def __init__(self, hidden_size: int, n_layers: int, *, ctx: Context):\n", + " self.n_layers = n_layers\n", + "\n", + " # lift init\n", + " key = random.split(ctx.make_rng(), n_layers - 1)\n", + " moduledef: ModuleDef[Block] = None # type: ignore\n", + "\n", + " def init_fn(key):\n", + " nonlocal moduledef\n", + " state, moduledef = Block(hidden_size, hidden_size, ctx=Context(key)).partition()\n", + " return state\n", + "\n", + " state = jax.vmap(init_fn)(key)\n", + " self.layers = moduledef.merge(state)\n", + " self.linear = Linear(hidden_size, hidden_size, ctx=ctx)\n", + "\n", + " def __call__(self, x: jax.Array, *, train: bool, ctx: Context) -> jax.Array:\n", + " # lift call\n", + " key: jax.Array = random.split(ctx.make_rng(), self.n_layers - 1) # type: ignore\n", + " state, moduledef = self.layers.partition()\n", + "\n", + " def scan_fn(x, inputs: tuple[jax.Array, State]):\n", + " key, state = inputs\n", + " x, (state, _) = moduledef.apply(state)(x, train=train, ctx=Context(key))\n", + " return x, state\n", + "\n", + " x, state = jax.lax.scan(scan_fn, x, (key, state))\n", + " self.layers.update_state(state)\n", + " x = self.linear(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state = State({\n", + " 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", + " 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", + " 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n", + " 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n", + " 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n", + "})\n", + "moduledef = ModuleDef(type=, index=0, submodules={'layers': ModuleDef(type=, index=1, submodules={'bn': ModuleDef(type=, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': ModuleDef(type=, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': ModuleDef(type=, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': ModuleDef(type=, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})\n" + ] + } + ], + "source": [ + "module = ScanMLP(hidden_size=10, n_layers=5, ctx=Context(random.PRNGKey(0)))\n", + "x = jax.random.normal(random.PRNGKey(0), (2, 10))\n", + "y = module(x, train=True, ctx=Context(random.PRNGKey(1)))\n", + "\n", + "state, moduledef = module.partition()\n", + "print(\"state =\", jax.tree_map(jnp.shape, state))\n", + "print(\"moduledef =\", moduledef)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filtering State" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params = State({\n", + " 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n", + " 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n", + " 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n", + "})\n", + "batch_stats = State({\n", + " 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", + " 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None)\n", + "})\n" + ] + } + ], + "source": [ + "# split\n", + "params = state.filter(Param)\n", + "batch_stats = state.filter(BatchStat)\n", + "# merge\n", + "state = State({**params, **batch_stats})\n", + "\n", + "print(\"params =\", jax.tree_map(jnp.shape, params))\n", + "print(\"batch_stats =\", jax.tree_map(jnp.shape, batch_stats))" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/experimental/nnx/examples/00_demo.ipynb b/flax/experimental/nnx/examples/00_demo.ipynb new file mode 100644 index 0000000000..ada0c4bc85 --- /dev/null +++ b/flax/experimental/nnx/examples/00_demo.ipynb @@ -0,0 +1,298 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " din=2,\n", + " dout=2\n", + ")\n", + "[[0.63114893 1.2928092 ]\n", + " [0.63114893 1.2928092 ]]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from flax.experimental import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "class Linear(nnx.Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, ctx: nnx.Context):\n", + " # static attributes\n", + " self.din = din\n", + " self.dout = dout\n", + " # variables\n", + " self.w = nnx.Param(jax.random.uniform(ctx.make_rng(\"params\"), (din, dout)))\n", + " self.b = nnx.Param(jnp.zeros((dout,)))\n", + " # other state\n", + " self.jax_array = jnp.array(1)\n", + " self.numpy_array = np.array(1)\n", + "\n", + " def __call__(self, x):\n", + " return x @ self.w + self.b\n", + "\n", + "\n", + "linear = Linear(2, 2, ctx=nnx.context(0))\n", + "\n", + "y = linear(jnp.ones((2, 2)))\n", + "\n", + "print(linear)\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State({\n", + " 'b': Param(\n", + " sharding=None,\n", + " value=Array([0., 0.], dtype=float32)\n", + " ),\n", + " 'jax_array': Array(1, dtype=int32, weak_type=True),\n", + " 'numpy_array': array(1),\n", + " 'w': Param(\n", + " sharding=None,\n", + " value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " )\n", + "})\n", + "ModuleDef(\n", + " type=Linear,\n", + " index=0,\n", + " submodules=(),\n", + " static_fields=(('din', 2), ('dout', 2))\n", + ")\n" + ] + } + ], + "source": [ + "state, moduledef = linear.partition()\n", + "\n", + "print(state)\n", + "print(moduledef)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " din=2,\n", + " dout=2,\n", + " submodule=Linear(...)\n", + ")\n", + "[[0.63114893 1.2928092 ]\n", + " [0.63114893 1.2928092 ]]\n" + ] + } + ], + "source": [ + "class Linear(nnx.Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, ctx: nnx.Context):\n", + " self.din = din\n", + " self.dout = dout\n", + " self.w = nnx.Param(jax.random.uniform(ctx.make_rng(\"params\"), (din, dout)))\n", + " self.b = nnx.Param(jnp.zeros((dout,)))\n", + " # introduce a self-reference\n", + " self.submodule = self\n", + "\n", + " def __call__(self, x):\n", + " return x @ self.submodule.w + self.submodule.b\n", + "\n", + "\n", + "linear = Linear(2, 2, ctx=nnx.context(0))\n", + "\n", + "y = linear(jnp.ones((2, 2)))\n", + "\n", + "print(linear)\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State({\n", + " 'b': Param(\n", + " sharding=None,\n", + " value=Array([0., 0.], dtype=float32)\n", + " ),\n", + " 'w': Param(\n", + " sharding=None,\n", + " value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " )\n", + "})\n", + "ModuleDef(\n", + " type=Linear,\n", + " index=0,\n", + " submodules=(\n", + " ('submodule', 0)\n", + " ),\n", + " static_fields=(('din', 2), ('dout', 2))\n", + ")\n" + ] + } + ], + "source": [ + "state, moduledef = linear.partition()\n", + "\n", + "print(state)\n", + "print(moduledef)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "linear2 = moduledef.merge(state)\n", + "\n", + "linear2.submodule is linear2" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " din=2,\n", + " dout=2\n", + ")\n", + "[[0.63114893 1.2928092 ]\n", + " [0.63114893 1.2928092 ]]\n" + ] + } + ], + "source": [ + "class Linear(nnx.Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, ctx: nnx.Context):\n", + " # static attributes\n", + " self.din = din\n", + " self.dout = dout\n", + " # variables\n", + " self.w = nnx.Param(jax.random.uniform(ctx.make_rng(\"params\"), (din, dout)))\n", + " self.b = nnx.Param(jnp.zeros((dout,)))\n", + "\n", + " def __call__(self, x):\n", + " y = x @ self.w + self.b\n", + " self.y = nnx.Intermediate(y)\n", + " return y\n", + "\n", + "\n", + "linear = Linear(2, 2, ctx=nnx.context(0))\n", + "\n", + "y = linear(jnp.ones((2, 2)))\n", + "\n", + "print(linear)\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State({\n", + " 'y': Intermediate(\n", + " sharding=None,\n", + " value=Array([[0.63114893, 1.2928092 ],\n", + " [0.63114893, 1.2928092 ]], dtype=float32)\n", + " )\n", + "})\n", + "State({\n", + " 'b': Param(\n", + " sharding=None,\n", + " value=Array([0., 0.], dtype=float32)\n", + " ),\n", + " 'w': Param(\n", + " sharding=None,\n", + " value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " )\n", + "})\n" + ] + } + ], + "source": [ + "intermediates = linear.pop_state(nnx.Intermediate)\n", + "state, moduledef = linear.partition()\n", + "\n", + "print(intermediates)\n", + "print(state)" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/experimental/nnx/examples/01_functional_api.py b/flax/experimental/nnx/examples/01_functional_api.py new file mode 100644 index 0000000000..62d1179cfe --- /dev/null +++ b/flax/experimental/nnx/examples/01_functional_api.py @@ -0,0 +1,106 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +from flax.experimental import nnx + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class MLP(nnx.Module): + + def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): + self.count = jnp.array(0) + self.linear1 = Linear(din, dhidden, ctx=ctx) + self.linear2 = Linear(dhidden, dout, ctx=ctx) + + def __call__(self, x): + self.count += 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +(params, buffers), modeldef = MLP( + din=1, dhidden=32, dout=1, ctx=nnx.context(0) +).partition(nnx.Param, ...) + + +@jax.jit +def train_step(params, buffers, batch): + x, y = batch + + def loss_fn(params): + y_pred, (updates, _) = modeldef.apply(params, buffers)(x) + _state = updates.filter(nnx.buffers) + loss = jnp.mean((y - y_pred) ** 2) + return loss, _state + + grad, buffers = jax.grad(loss_fn, has_aux=True)(params) + # |-------- sgd ---------| + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad) + + return params, buffers + + +@jax.jit +def test_step(params: nnx.State, buffers: nnx.State, batch): + x, y = batch + y_pred, _ = modeldef.apply(params, buffers)(x) + loss = jnp.mean((y - y_pred) ** 2) + return {"loss": loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + params, buffers = train_step(params, buffers, batch) + + if step % 1000 == 0: + logs = test_step(params, buffers, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +model = modeldef.merge(params, buffers) +print("times called:", model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color="blue") +plt.plot(X, y_pred, color="black") +plt.show() diff --git a/flax/experimental/nnx/examples/02_lifted_transforms.py b/flax/experimental/nnx/examples/02_lifted_transforms.py new file mode 100644 index 0000000000..6f00be7622 --- /dev/null +++ b/flax/experimental/nnx/examples/02_lifted_transforms.py @@ -0,0 +1,108 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +from flax.experimental import nnx + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Count(nnx.Variable): + pass + + +class MLP(nnx.Module): + + def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): + self.count = Count(jnp.array(0)) + self.linear1 = Linear(din, dhidden, ctx=ctx) + self.linear2 = Linear(dhidden, dout, ctx=ctx) + + def __call__(self, x): + self.count += 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +model = MLP(din=1, dhidden=32, dout=1, ctx=nnx.context(0)) + + +@nnx.jit +def train_step(model: MLP, batch): + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + # |--default--| + grad: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # sdg update + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grad) + ) + + # no return!!! + + +@nnx.jit +def test_step(model: MLP, batch): + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {"loss": loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + train_step(model, batch) + + if step % 1000 == 0: + logs = test_step(model, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +print("times called:", model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color="blue") +plt.plot(X, y_pred, color="black") +plt.show() diff --git a/flax/experimental/nnx/examples/03_train_state.py b/flax/experimental/nnx/examples/03_train_state.py new file mode 100644 index 0000000000..2eaa38c708 --- /dev/null +++ b/flax/experimental/nnx/examples/03_train_state.py @@ -0,0 +1,116 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +from flax.training import train_state + +from flax.experimental import nnx + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class MLP(nnx.Module): + + def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): + self.count = jnp.array(0) + self.linear1 = Linear(din, dhidden, ctx=ctx) + self.linear2 = Linear(dhidden, dout, ctx=ctx) + + def __call__(self, x): + self.count += 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +(params, buffers), moduledef = MLP( + din=1, dhidden=32, dout=1, ctx=nnx.context(0) +).partition(nnx.Param, ...) + +state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(0.1), + buffers=buffers, +) +del params, buffers + + +@jax.jit +def train_step(state: nnx.TrainState, batch): + x, y = batch + + def loss_fn(params): + y_pred, (updates, _) = state.apply(params, "buffers")(x) + buffers = updates.filter(nnx.buffers) + loss = jnp.mean((y - y_pred) ** 2) + return loss, buffers + + grads, buffers = jax.grad(loss_fn, has_aux=True)(state.params) + # sdg update + state = state.apply_gradients(grads=grads, buffers=buffers) + + return state + + +@jax.jit +def test_step(state: nnx.TrainState, batch): + x, y = batch + y_pred, _ = state.apply("params", "buffers")(x) + loss = jnp.mean((y - y_pred) ** 2) + return {"loss": loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + state = train_step(state, batch) + + if step % 1000 == 0: + logs = test_step(state, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +model = moduledef.merge(state.params, state.buffers) +print("times called:", model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color="blue") +plt.plot(X, y_pred, color="black") +plt.show() diff --git a/flax/experimental/nnx/examples/04_pure.py b/flax/experimental/nnx/examples/04_pure.py new file mode 100644 index 0000000000..f140c4ed43 --- /dev/null +++ b/flax/experimental/nnx/examples/04_pure.py @@ -0,0 +1,109 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +from flax.experimental import nnx + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Count(nnx.Variable): + pass + + +class MLP(nnx.Module): + + def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): + self.count = Count(jnp.array(0)) + self.linear1 = Linear(din, dhidden, ctx=ctx) + self.linear2 = Linear(dhidden, dout, ctx=ctx) + + def __call__(self, x) -> jax.Array: + self.count += 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +pure_model = MLP(din=1, dhidden=32, dout=1, ctx=nnx.context(0)).partition() + + +@jax.jit +def train_step(pure_model: nnx.PureModule[MLP], batch): + x, y = batch + model = pure_model.merge() + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grad: nnx.State = nnx.grad(loss_fn)(model) + # sdg update + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grad) + ) + + return model.partition() + + +@jax.jit +def test_step(pure_model: nnx.PureModule[MLP], batch): + x, y = batch + y_pred = pure_model.call(x) + loss = jnp.mean((y - y_pred) ** 2) + return {"loss": loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + pure_model = train_step(pure_model, batch) + + if step % 1000 == 0: + logs = test_step(pure_model, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +model = pure_model.merge() +print("times called:", model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color="blue") +plt.plot(X, y_pred, color="black") +plt.show() diff --git a/flax/experimental/nnx/examples/05_vae.py b/flax/experimental/nnx/examples/05_vae.py new file mode 100644 index 0000000000..20e9c0a070 --- /dev/null +++ b/flax/experimental/nnx/examples/05_vae.py @@ -0,0 +1,222 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import typing as tp +from functools import partial + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +from datasets import load_dataset + +from flax.experimental import nnx + +np.random.seed(42) +latent_size = 32 +image_shape: tp.Sequence[int] = (28, 28) +steps_per_epoch: int = 200 +batch_size: int = 64 +epochs: int = 20 + + +dataset = load_dataset("mnist") +X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8) +X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8) +# Now binarize data +X_train = (X_train > 0).astype(jnp.float32) +X_test = (X_test > 0).astype(jnp.float32) + +print("X_train:", X_train.shape, X_train.dtype) +print("X_test:", X_test.shape, X_test.dtype) + + +class Loss(nnx.Variable): + pass + + +# %% +class Encoder(nnx.Module): + + def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): + self.linear1 = nnx.Linear(din, dmid, ctx=ctx) + self.linear_mean = nnx.Linear(dmid, dout, ctx=ctx) + self.linear_std = nnx.Linear(dmid, dout, ctx=ctx) + + def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: + x = x.reshape((x.shape[0], -1)) # flatten + x = self.linear1(x) + x = jax.nn.relu(x) + + mean = self.linear_mean(x) + std = jnp.exp(self.linear_std(x)) + + self.kl_loss = Loss( + jnp.mean( + 0.5 + * jnp.mean(-jnp.log(std**2) - 1.0 + std**2 + mean**2, axis=-1) + ) + ) + key = ctx.make_rng("noise") + z = mean + std * jax.random.normal(key, mean.shape) + return z + + +class Decoder(nnx.Module): + + def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): + self.linear1 = nnx.Linear(din, dmid, ctx=ctx) + self.linear2 = nnx.Linear(dmid, dout, ctx=ctx) + + def __call__(self, z: jax.Array) -> jax.Array: + z = self.linear1(z) + z = jax.nn.relu(z) + logits = self.linear2(z) + return logits + + +class VAE(nnx.Module): + + def __init__( + self, + din: int, + hidden_size: int, + latent_size: int, + output_shape: tp.Sequence[int], + *, + ctx: nnx.Context, + ): + self.output_shape = output_shape + self.encoder = Encoder(din, hidden_size, latent_size, ctx=ctx) + self.decoder = Decoder( + latent_size, hidden_size, int(np.prod(output_shape)), ctx=ctx + ) + + def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: + z = self.encoder(x, ctx=ctx) + logits = self.decoder(z) + logits = jnp.reshape(logits, (-1, *self.output_shape)) + return logits + + def generate(self, z): + logits = self.decoder(z) + logits = jnp.reshape(logits, (-1, *self.output_shape)) + return nnx.sigmoid(logits) + + +params, moduledef = VAE( + din=int(np.prod(image_shape)), + hidden_size=256, + latent_size=latent_size, + output_shape=image_shape, + ctx=nnx.context(0), +).partition(nnx.Param) + +state = nnx.TrainState( + moduledef, + params=params, + tx=optax.adam(1e-3), +) + + +# %% +@jax.jit +def train_step(state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array): + def loss_fn(params: nnx.State): + ctx = nnx.context(noise=jax.random.fold_in(key, state.step)) + logits, (updates, _) = state.apply(params)(x, ctx=ctx) + + losses = updates.filter(Loss) + kl_loss = sum(jax.tree_util.tree_leaves(losses), 0.0) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + + loss = reconstruction_loss + 0.1 * kl_loss + return loss, loss + + grad_fn = jax.grad(loss_fn, has_aux=True) + grads, loss = grad_fn(state.params) + state.apply_gradients(grads=grads) + + return state, loss + + +@partial(jax.jit, donate_argnums=(0,)) +def forward( + state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array +) -> jax.Array: + ctx = nnx.context(noise=key) + y_pred = state.apply("params")(x, ctx=ctx)[0] + return jax.nn.sigmoid(y_pred) + + +@jax.jit +def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: + return state.apply("params").generate(z)[0] + + +# %% +key = jax.random.PRNGKey(0) + +for epoch in range(epochs): + losses = [] + for step in range(steps_per_epoch): + idxs = np.random.randint(0, len(X_train), size=(batch_size,)) + x_batch = X_train[idxs] + + state, loss = train_step(state, x_batch, key) + losses.append(np.asarray(loss)) + + print(f"Epoch {epoch} loss: {np.mean(losses)}") + +exit() +# %% +# get random samples +idxs = np.random.randint(0, len(X_test), size=(5,)) +x_sample = X_test[idxs] + +# get predictions +y_pred = forward(state, x_sample, key) + +# plot reconstruction +figure = plt.figure(figsize=(3 * 5, 3 * 2)) +plt.title("Reconstruction Samples") +for i in range(5): + plt.subplot(2, 5, i + 1) + plt.imshow(x_sample[i], cmap="gray") + plt.subplot(2, 5, 5 + i + 1) + plt.imshow(y_pred[i], cmap="gray") + # # tbwriter.add_figure("VAE Example", figure, epochs) + +plt.show() + +# %% +# plot generative samples +z_samples = np.random.normal(scale=1.5, size=(12, latent_size)) +samples = sample(state, z_samples) + +figure = plt.figure(figsize=(3 * 5, 3 * 2)) +plt.title("Generative Samples") +for i in range(5): + plt.subplot(2, 5, 2 * i + 1) + plt.imshow(samples[i], cmap="gray") + plt.subplot(2, 5, 2 * i + 2) + plt.imshow(samples[i + 1], cmap="gray") + +plt.show() + +# %% diff --git a/flax/experimental/nnx/examples/06_scan_over_layers.py b/flax/experimental/nnx/examples/06_scan_over_layers.py new file mode 100644 index 0000000000..891a0e5c84 --- /dev/null +++ b/flax/experimental/nnx/examples/06_scan_over_layers.py @@ -0,0 +1,92 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import jax +import jax.numpy as jnp + +from flax.experimental import nnx + + +class Block(nnx.Module): + + def __init__(self, dim: int, *, ctx: nnx.Context): + self.linear = nnx.Linear(dim, dim, ctx=ctx) + self.dropout = nnx.Dropout(0.5) + + def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: + x = self.linear(x) + x = self.dropout(x, ctx=ctx) + x = jax.nn.gelu(x) + return x + + +class ScanMLP(nnx.Module): + """ + An MLP that uses `vmap` during `__init__` to create a Block instance + with an additional `layer` axis, and `scan` during `__call__` to apply + the sequence of layers iteratively over the input / output `x`. + """ + + def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context): + self.n_layers = n_layers + # partition Context and split the `params` key + keys, ctxdef = ctx.partition() + params_key = jax.random.split(keys["params"], n_layers) + + def create_block(params_key): + # merge back Context using the sliced `params` key + ctx = ctxdef.merge({"params": params_key}) + # create Block instance and return its partition + return Block(dim, ctx=ctx).partition() + + # call vmap over create_block, passing the split `params` key + # and immediately merge to get a Block instance + self.layers = jax.vmap(create_block)(params_key).merge() + + def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: + # partition Context and split the `dropout` key + keys, ctxdef = ctx.partition() + dropout_key = jax.random.split(keys["dropout"], self.n_layers) + # partition Module to get params + params, moduledef = self.layers.partition(nnx.Param) + + def scan_fn( + x: jax.Array, inputs: Tuple[nnx.State, jax.Array] + ) -> Tuple[jax.Array, nnx.State]: + params, dropout_key = inputs + # merge back Module and Context + ctx = ctxdef.merge({"dropout": dropout_key}) + module = moduledef.merge(params) + # forward pass + x = module(x, ctx=ctx) + # partition state and return + params, _ = module.partition(nnx.Param) + return x, params + + # call scan passing x as the carry, and params + dropout_key as the input + x, params = jax.lax.scan(scan_fn, x, (params, dropout_key)) + # update layers state and return + self.layers.update_state(params) + return x + + +model = ScanMLP(10, n_layers=5, ctx=nnx.context(0)) + +x = jnp.ones((3, 10)) +y = model(x, ctx=nnx.context(dropout=1, flags=dict(deterministic=False))) + +print(jax.tree_map(jnp.shape, model.get_state())) +print(y.shape) diff --git a/flax/experimental/nnx/examples/07_transformer.py b/flax/experimental/nnx/examples/07_transformer.py new file mode 100644 index 0000000000..efe8305d47 --- /dev/null +++ b/flax/experimental/nnx/examples/07_transformer.py @@ -0,0 +1,419 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import PartitionSpec as P + +from flax.experimental import nnx + +ShardSpec = tp.Union[str, tp.Tuple[str, ...], None] + + +# Sharding +@dataclasses.dataclass +class Sharding: + batch: ShardSpec = "data" + sequence: ShardSpec = None + layers: ShardSpec = None + vocab: ShardSpec = "model" + embed: ShardSpec = None + heads: ShardSpec = "model" + depth: ShardSpec = None + hidden: ShardSpec = "model" + + +# Config +@dataclasses.dataclass +class Config: + # mode + decode: bool = False + # shapes + batch: int = 16 + layers: int = 2 + vocab: int = 1024 + embed: int = 64 + heads: int = 12 + depth: int = 64 + hidden: int = 256 + max_length: int = 256 + # dtypes + param_dtype: tp.Any = jnp.float32 + dtype: tp.Any = jnp.float32 + # sharding + sharding: Sharding = Sharding() + scanned: bool = False + # layer params + epsilon: float = 1e-6 + dropout_rate: float = 0.0 + rp_num_buckets: int = 32 + rp_max_distance: int = 128 + + +cfg = Config() + + +def nd_dense_init(scale, mode, distribution): + """Initializer with in_axis, out_axis set at call time.""" + + def init_fn(key, shape, dtype, in_axis, out_axis) -> jax.Array: + fn = jax.nn.initializers.variance_scaling( + scale, mode, distribution, in_axis, out_axis + ) + return fn(key, shape, dtype) + + return init_fn + + +dense_init = nd_dense_init(1.0, "fan_in", "truncated_normal") +embed_init = nd_dense_init(1.0, "fan_in", "normal") + + +def make_attention_mask( + query_input: tp.Any, + key_input: tp.Any, + pairwise_fn: tp.Callable = jnp.multiply, + dtype: tp.Any = jnp.float32, +): + mask = pairwise_fn( + jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) + ) + return jnp.expand_dims(mask, axis=-3).astype(dtype) + + +def make_causal_mask(x, dtype=jnp.float32): + idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) + return make_attention_mask(idxs, idxs, jnp.greater_equal, dtype=dtype) + + +# padding mask +# make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype) +# packing mask +# make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype) + + +def sine_table(features, length, min_timescale=1.0, max_timescale=10000.0): + fraction = jnp.arange(0, features, 2, dtype=jnp.float32) / features + timescale = min_timescale * (max_timescale / min_timescale) ** fraction + rotational_frequency = 1.0 / timescale + # Must use high precision einsum here, bfloat16 rounding is catastrophic. + sinusoid_inp = jnp.einsum( + "i,j->ij", + jnp.arange(length), + rotational_frequency, + precision=jax.lax.Precision.HIGHEST, + ) + sinusoid_inp = jnp.concatenate([sinusoid_inp, sinusoid_inp], axis=-1) + return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) + + +def rotate_half(x): + x1, x2 = jnp.split(x, 2, axis=-1) + x = jnp.concatenate([-x2, x1], axis=-1) + return x + + +def apply_rotary_embedding(q, k, cos, sin, index=None): + """Helper function to apply Rotary Embeddings.""" + batch, qlen, qheads, d = q.shape + kbatch, klen, kheads, kd = k.shape + if index is not None: + qcos = jax.lax.broadcast_in_dim( + cos[index, :], (batch, qlen, qheads, d), (3,) + ) + qsin = jax.lax.broadcast_in_dim( + sin[index, :], (batch, qlen, qheads, d), (3,) + ) + else: + qcos = jax.lax.broadcast_in_dim( + cos[:qlen, :], (batch, qlen, qheads, d), (1, 3) + ) + qsin = jax.lax.broadcast_in_dim( + sin[:qlen, :], (batch, qlen, qheads, d), (1, 3) + ) + kcos = jax.lax.broadcast_in_dim( + cos[:klen, :], (batch, klen, kheads, d), (1, 3) + ) + ksin = jax.lax.broadcast_in_dim( + sin[:klen, :], (batch, klen, kheads, d), (1, 3) + ) + out_q = (q * qcos) + (rotate_half(q) * qsin) + out_k = (k * kcos) + (rotate_half(k) * ksin) + return out_q, out_k + + +def rms_norm(cfg, scale, x): + x = jnp.asarray(x, jnp.float32) + mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * jax.lax.rsqrt(mean2 + cfg.epsilon), cfg.dtype) + return y * jnp.asarray(scale, cfg.dtype) + + +def dropout(cfg: Config, x, broadcast_dims=(-2,), *, ctx: nnx.Context): + if cfg.dropout_rate == 0.0: + return x + broadcast_shape = list(x.shape) + for dim in broadcast_dims: + broadcast_shape[dim] = 1 + keep_rate = 1.0 - cfg.dropout_rate + key = ctx.make_rng("dropout") + mask = jax.random.bernoulli(key, p=keep_rate, shape=broadcast_shape) + return jax.lax.select( + jnp.broadcast_to(mask, x.shape), x / keep_rate, jnp.zeros_like(x) + ) + + +class Attention(nnx.Module): + + def __init__(self, cfg: Config, *, ctx: nnx.Context): + sharding = cfg.sharding + + key = ctx.make_rng("params") + self.WQ = nnx.Param( + dense_init( + key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) + ), + P(sharding.embed, sharding.heads, sharding.depth), + ) + key = ctx.make_rng("params") + self.WK = nnx.Param( + dense_init( + key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) + ), + P(sharding.embed, sharding.heads, sharding.depth), + ) + key = ctx.make_rng("params") + self.WV = nnx.Param( + dense_init( + key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) + ), + P(sharding.embed, sharding.heads, sharding.depth), + ) + key = ctx.make_rng("params") + self.WO = nnx.Param( + dense_init( + key, (cfg.heads, cfg.depth, cfg.embed), cfg.param_dtype, (0, 1), 2 + ), + P(sharding.heads, sharding.depth, sharding.embed), + ) + # cache + self.index = nnx.variable("cache", jnp.array(0, dtype=jnp.int32), P()) + self.key = nnx.variable( + "cache", + jnp.zeros( + (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), + jnp.bfloat16, + ), + P(sharding.batch, sharding.heads, sharding.depth, None), + ) + self = nnx.variable( + "cache", + jnp.zeros( + (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), + jnp.bfloat16, + ), + P(sharding.batch, sharding.heads, sharding.depth, None), + ) + + # We combine the cache and params into "vs", but it would be no harder at all + # to thread through a separate "cache" argument storing cache entries. + def __call__(self, cfg: Config, x_q, x_kv, mask=None, *, ctx: nnx.Context): + q = jnp.einsum("bse,enh->bsnh", x_q, self.WQ.astype(cfg.dtype)).astype( + jnp.float32 + ) + k = jnp.einsum("bte,enh->btnh", x_kv, self.WK.astype(cfg.dtype)).astype( + jnp.float32 + ) + v = jnp.einsum("bte,enh->btnh", x_kv, self.WV.astype(cfg.dtype)) + + index = None + if cfg.decode: + index = self.index + one_hot_indices = jax.nn.one_hot( + self.index, cfg.max_length, dtype=cfg.dtype + ) + self.key = self.key + jnp.moveaxis(k, -3, -1) * one_hot_indices + self = self + jnp.moveaxis(v, -3, -1) * one_hot_indices + k = jnp.moveaxis(self.key, -1, -3) + v = jnp.moveaxis(self, -1, -3) + cache_mask = jnp.broadcast_to( + jnp.arange(cfg.max_length) <= self.index, + (cfg.batch, 1, 1, cfg.max_length), + ) + mask = jnp.logical_and( + cache_mask if mask is None else mask, cache_mask + ).astype(cfg.dtype) + self.index = self.index + 1 + + attention_bias = 0.0 + if mask is None: # Hack in lieu of general mask routing. + mask = make_causal_mask(x, jnp.float32) + if mask is not None: + attention_bias = jax.lax.select( + mask > 0, + jnp.full(mask.shape, 0.0, cfg.dtype), + jnp.full(mask.shape, -1e10, cfg.dtype), + ) + + sin, cos = sine_table(q.shape[-1], max(q.shape[1], k.shape[1])) + q, k = apply_rotary_embedding(q, k, cos, sin, index=index) + + l = ( + jnp.einsum("bsnh,btnh->bnst", q, k) / np.sqrt(cfg.depth) + + attention_bias + ) + s = jax.nn.softmax(l).astype(cfg.dtype) + s = dropout(cfg, s, ctx=ctx) + a = jnp.einsum("bnst,btnh->bsnh", s, v) + o = jnp.einsum("bsnh,nhe->bse", a, self.WO.astype(cfg.dtype)) + + return o + + +class MLP(nnx.Module): + + def __init__(self, cfg: Config, *, ctx: nnx.Context): + sharding = cfg.sharding + self.Win1 = nnx.Param( + dense_init( + ctx.make_rng("params"), + (cfg.embed, cfg.hidden), + cfg.param_dtype, + 0, + 1, + ), + P(sharding.embed, sharding.hidden), + ) + self.Win2 = nnx.Param( + dense_init( + ctx.make_rng("params"), + (cfg.embed, cfg.hidden), + cfg.param_dtype, + 0, + 1, + ), + P(sharding.embed, sharding.hidden), + ) + self.Wout = nnx.Param( + dense_init( + ctx.make_rng("params"), + (cfg.hidden, cfg.embed), + cfg.param_dtype, + 0, + 1, + ), + P(sharding.hidden, sharding.embed), + ) + + def __call__(self, cfg: Config, x, *, ctx: nnx.Context): + h1 = jnp.einsum("bse,eh->bsh", x, self.Win1.astype(cfg.dtype)) + h2 = jnp.einsum("bse,eh->bsh", x, self.Win2.astype(cfg.dtype)) + h = jax.nn.gelu(h1) * h2 + h = dropout(cfg, h, ctx=ctx) + o = jnp.einsum("bsh,he->bse", h, self.Wout.astype(cfg.dtype)) + return o + + +class DecoderBlock(nnx.Module): + + def __init__(self, cfg: Config, *, ctx: nnx.Context): + sharding = cfg.sharding + self.attn = Attention(cfg, ctx=ctx) + self.mlp = MLP(cfg, ctx=ctx) + self.scale1 = nnx.Param( + jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) + ) + self.scale2 = nnx.Param( + jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) + ) + + def __call__(self, cfg: Config, input, *, ctx: nnx.Context): + x = rms_norm(cfg, self.scale1, input) + x = self.attn(cfg, x, x, mask=None, ctx=ctx) + x = dropout(cfg, x, ctx=ctx) + x = x + input + y = rms_norm(cfg, self.scale2, x) + y = self.mlp(cfg, y, ctx=ctx) + y = dropout(cfg, y, ctx=ctx) + return y + x + + +class Decoder(nnx.Module): + + def __init__(self, cfg: Config, *, ctx: nnx.Context): + sharding = cfg.sharding + self.embed = nnx.Param( + embed_init( + ctx.make_rng("params"), + (cfg.vocab, cfg.embed), + cfg.param_dtype, + 1, + 0, + ), + P(sharding.vocab, sharding.embed), + ) + self.unembed = nnx.Param( + dense_init( + ctx.make_rng("params"), (cfg.embed, cfg.vocab), jnp.float32, 0, 1 + ), + P(sharding.embed, sharding.vocab), + ) + self.scale1 = nnx.Param( + jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) + ) + + if cfg.scanned: + self.layers = jax.vmap( + lambda key: DecoderBlock(cfg, ctx=nnx.context(key)).partition() + )(jax.random.split(ctx.make_rng("params"), cfg.layers)).merge() + else: + self.layers = nnx.Sequence( + DecoderBlock(cfg, ctx=ctx) for _ in range(cfg.layers) + ) + + def __call__(self, cfg: Config, x, *, ctx: nnx.Context): + # TODO: handle right-shifting for training: here or in train loop. + # TODO: handle general mask routing. + x = self.embed.astype(cfg.dtype)[x] + + if cfg.scanned: + assert isinstance(self.layers, DecoderBlock) + + state, moduledef = self.layers.partition() + rngs, ctxdef = ctx.partition() + dropout_key = jax.random.split(rngs["dropout"], cfg.layers) + + def scan_fn(x, s: tp.Tuple[jax.random.KeyArray, nnx.State]): + dropout_key, state = s + ctx = ctxdef.merge({"dropout": dropout_key}) + y, (state, _) = moduledef.apply(state)(cfg, x, ctx=ctx) + return y, state + + x, state = jax.lax.scan( + scan_fn, + x, + (dropout_key, state), + ) + self.layers.update_state(state) + else: + assert isinstance(self.layers, nnx.Sequence) + for decoder_block in self.layers: + x = decoder_block(cfg, x, ctx=ctx) + + x = jnp.einsum("bse,ev->bsv", x, self.unembed) + return x diff --git a/flax/experimental/nnx/examples/08_save_load_checkpoints.py b/flax/experimental/nnx/examples/08_save_load_checkpoints.py new file mode 100644 index 0000000000..ac22eb4fce --- /dev/null +++ b/flax/experimental/nnx/examples/08_save_load_checkpoints.py @@ -0,0 +1,68 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tempfile import TemporaryDirectory + +import jax +import jax.numpy as jnp +import orbax.checkpoint as orbax + +from flax.experimental import nnx + + +class MLP(nnx.Module): + + def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): + self.dense1 = nnx.Linear(din, dmid, ctx=ctx) + self.dense2 = nnx.Linear(dmid, dout, ctx=ctx) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.dense1(x) + x = jax.nn.relu(x) + x = self.dense2(x) + return x + + +def create_model(seed: int): + return MLP(10, 20, 30, ctx=nnx.context(seed)) + + +def create_and_save(seed: int, path: str): + model = create_model(seed) + state = model.get_state() + # Save the parameters + checkpointer = orbax.PyTreeCheckpointer() + checkpointer.save(f"{path}/state", state) + + +def load_model(path: str) -> MLP: + # create that model with abstract shapes + state, moduledef = jax.eval_shape(lambda: create_model(0).partition()) + # Load the parameters + checkpointer = orbax.PyTreeCheckpointer() + state = checkpointer.restore(f"{path}/state", item=state) + # Merge the parameters into the model + model = moduledef.merge(state) + return model + + +with TemporaryDirectory() as tmpdir: + # create a checkpoint + create_and_save(42, tmpdir) + # load model from checkpoint + model = load_model(tmpdir) + # run the model + y = model(jnp.ones((1, 10))) + print(model) + print(y) diff --git a/flax/experimental/nnx/examples/09_parameter_surgery.py b/flax/experimental/nnx/examples/09_parameter_surgery.py new file mode 100644 index 0000000000..f5cc43cf1d --- /dev/null +++ b/flax/experimental/nnx/examples/09_parameter_surgery.py @@ -0,0 +1,62 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import jax + +from flax.experimental import nnx + + +# lets pretend this function loads a pretrained model from a checkpoint +def load_backbone(): + return nnx.Linear(784, 128, ctx=nnx.context(0)) + + +# create a simple linear classifier using a pretrained backbone +class Classifier(nnx.Module): + + def __init__( + self, backbone: Callable[[jax.Array], jax.Array], *, ctx: nnx.Context + ): + self.backbone = backbone + self.head = nnx.Linear(128, 10, ctx=ctx) + + def __call__(self, x): + x = self.backbone(x) + x = nnx.relu(x) + x = self.head(x) + return x + + +backbone = load_backbone() + +# create the classifier using the pretrained backbone, here we are technically +# doing "parameter surgery", however, compared to Haiku/Flax where you must manually +# construct the parameter structure, in NNX this is done automatically +model = Classifier(backbone, ctx=nnx.context(42)) + +# create a filter to select all the parameters that are not part of the +# backbone, i.e. the classifier parameters +is_trainable = nnx.All( + nnx.Param, lambda path, node: path.startswith("backbone") +) + +# partition the parameters into trainable and non-trainable parameters +(trainable_params, non_trainable), moduledef = model.partition( + is_trainable, ... +) + +print("trainable_params =", jax.tree_map(jax.numpy.shape, trainable_params)) +print("non_trainable = ", jax.tree_map(jax.numpy.shape, non_trainable)) diff --git a/flax/experimental/nnx/ideas/nnx_example.py b/flax/experimental/nnx/ideas/nnx_example.py new file mode 100644 index 0000000000..2706acb58c --- /dev/null +++ b/flax/experimental/nnx/ideas/nnx_example.py @@ -0,0 +1,163 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp + +from flax.experimental import nnx + + +class Linear(nnx.Module): + kernel: jax.Array = nnx.Param() + bias: jax.Array = nnx.Param() + + def __init__(self, din: int, dout: int): + self.kernel = jax.random.uniform(nnx.make_rng("params"), (din, dout)) + self.bias = jax.numpy.zeros((dout,)) + + def __call__(self, x): + return x @ self.kernel + self.bias + + +class BatchNorm(nnx.Module): + scale: jax.Array = nnx.Param() + bias: jax.Array = nnx.Param() + mean: jax.Array = nnx.variable("batch_stats") + var: jax.Array = nnx.variable("batch_stats") + mu: float = nnx.static_field() + + def __init__(self, din: int, mu: float = 0.95): + self.scale = jax.random.uniform(nnx.make_rng("params"), (din,)) + self.bias = jax.numpy.zeros((din,)) + self.mean = jax.numpy.zeros((din,)) + self.var = jax.numpy.ones((din,)) + self.mu = mu + + def __call__(self, x, *, use_running_averages: bool) -> jax.Array: + scale, bias = self.scale, self.bias + if use_running_averages: + mean, var = self.mean, self.var + else: + axis = tuple(range(0, x.ndim - 1)) + mean = jax.numpy.mean(x, axis=axis) + var = jax.numpy.var(x, axis=axis) + # ema update + self.mean = self.mu * self.mean + (1 - self.mu) * mean + self.var = self.mu * self.var + (1 - self.mu) * var + + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias + + return x + + +@nnx.dataclasses +class Dropout(nnx.Module): + rate: float + + def __call__(self, inputs, *, deterministic: bool): + if (self.rate == 0.0) or deterministic: + return inputs + rng = nnx.make_rng("dropout") + keep_prob = 1.0 - self.rate + mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape) + return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + +class MLP(nnx.Module): + + def __init__(self, din: int, dmid: int, dout: int): + self.linear1 = Linear(din, dmid) + self.bn1 = BatchNorm(dmid) + self.dropout = Dropout(0.5) + self.linear2 = Linear(dmid, dout) + + def __call__(self, x: jax.Array, *, train: bool) -> jax.Array: + x = self.linear1(x) + x = self.bn1(x, use_running_averages=not train) + x = self.dropout(x, deterministic=not train) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +rngs = nnx.Context(jax.random.PRNGKey(0)) +model = MLP.init(rngs)(10, 20, 30) + + +@nnx.jit +def train_step(model: MLP, key, batch): + x, y = batch + + def loss(model: MLP): + rngs = nnx.Context(dropout=key) + y_pred = model.apply(rngs=rngs)(x, train=True) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss + + grads = nnx.grad(loss, wrt=nnx.Param)(model) + model[:] = jax.tree_map(lambda w, g: w - 0.1 * g, model["params"], grads) + + +# ---------------------------------------- +# scan over layers + shared batchnorm +# ---------------------------------------- + +n_layers = 10 +params_keys = jax.random.PRNGKey(0) +params_keys = jax.random.split(params_keys, n_layers) + + +@partial(jax.vmap, in_axes=0, out_axes=(0, None, None)) +def create_state(params_key: jax.random.KeyArray): + rngs = nnx.Context(params=params_key) + model = MLP.init(rngs)(10, 20, 10) + (params, batch_stats), modeldef = model.partition(nnx.Param, "batch_stats") + return params, batch_stats, modeldef + + +params, batch_stats, modeldef = create_state(params_keys) +x = jax.numpy.zeros((32, 10)) +dropout_key = jax.random.PRNGKey(1) +dropout_stream = nnx.RngStream(jax.random.split(dropout_key, n_layers)) + + +def scan_fn( + carry: Tuple[jax.Array, nnx.State], + inputs: Tuple[nnx.State, nnx.RngStream], +): + # extract args + x, batch_stats = carry + params, dropout_stream = inputs + + # create state and rngs + model = modeldef.merge([params, batch_stats]) + rngs = nnx.Context(dropout=dropout_stream) + + # forward pass + x = model.apply(rngs=rngs)(x, train=True) + + # partition state + params, batch_stats = model.partition(nnx.Param, "batch_stats")[0] + + return (x, batch_stats), params + + +(y, batch_stats), params = jax.lax.scan( + scan_fn, (x, batch_stats), (params, dropout_stream) +) +model = modeldef.merge([params, batch_stats]) diff --git a/flax/experimental/nnx/ideas/pure/__init__.py b/flax/experimental/nnx/ideas/pure/__init__.py new file mode 100644 index 0000000000..32467d6f22 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module import Initializer, Module +from .partitioning import ( + NOTHING, + Partition, + get_partition, + merge_partitions, + tree_partition, +) +from .rngs import Rngs, RngStream +from .state import State, Variable, merge diff --git a/flax/experimental/nnx/ideas/pure/full/partitioning_full.py b/flax/experimental/nnx/ideas/pure/full/partitioning_full.py new file mode 100644 index 0000000000..62f2055f3f --- /dev/null +++ b/flax/experimental/nnx/ideas/pure/full/partitioning_full.py @@ -0,0 +1,269 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.tree_util as jtu + +A = tp.TypeVar("A") +CollectionPredicate = tp.Callable[[str], bool] +Leaf = tp.Any +Leaves = tp.List[Leaf] +KeyPath = tp.Tuple[tp.Hashable, ...] +LeafPredicate = tp.Callable[[tp.Any], bool] + + +class Variable: + __slots__ = ("_value", "_collection") + + def __init__(self, value: tp.Any, collection: str = "params"): + self._value = value + self._collection = collection + + @property + def value(self) -> tp.Any: + return self._value + + @property + def collection(self) -> str: + return self._collection + + @classmethod + def from_value(cls, value: tp.Any) -> "Variable": + return value if isinstance(value, Variable) else Variable(value) + + def copy(self) -> "Variable": + return Variable(self.value, self.collection) + + def update(self, value: tp.Any) -> "Variable": + if isinstance(value, Variable): + if value.collection != self.collection: + raise ValueError( + "Cannot update variable with value from a different collection. " + f"Expected collection {self.collection}, got {value.collection}" + ) + value = value.value + return Variable(value, self.collection) + + def __repr__(self) -> str: + return f"Variable({self.value}, collection={self.collection})" + + +def _flatten_variable_with_keys(variable: Variable): + node = (jtu.GetAttrKey("value"), variable.value) + return (node,), variable.collection + + +def _flatten_variable(variable: Variable): + return (variable.value,), variable.collection + + +def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): + return Variable(nodes[0], collection) + + +jax.tree_util.register_pytree_with_keys( + Variable, + _flatten_variable_with_keys, + _unflatten_variable, + flatten_func=_flatten_variable, +) + + +class Nothing: + + def __repr__(self) -> str: + return "Nothing" # pragma: no cover + + +def _nothing_flatten(x): + return (), None + + +def _nothing_unflatten(aux_data, children): + return NOTHING + + +NOTHING = Nothing() + +jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) + + +class StrPath(tp.Tuple[str, ...]): + pass + + +class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): + + def __setitem__(self, key, value): + raise TypeError("Partition is immutable") + + +def _partition_flatten_with_keys( + x: Partition, +) -> tp.Tuple[ + tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] +]: + children = tuple((StrPath(key), value) for key, value in x.items()) + return children, tuple(x.keys()) + + +def _partition_unflatten( + keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...] +): + return Partition(zip(keys, leaves)) + + +jax.tree_util.register_pytree_with_keys( + Partition, _partition_flatten_with_keys, _partition_unflatten +) + + +def _key_path_to_str_gen(key_path: KeyPath) -> tp.Generator[str, None, None]: + for key_entry in key_path: + if isinstance(key_entry, StrPath): + yield from key_entry + elif isinstance(key_entry, jtu.SequenceKey): + yield str(key_entry.idx) + elif isinstance(key_entry, jtu.DictKey): # "['a']" + yield str(key_entry.key) + elif isinstance(key_entry, jtu.GetAttrKey): + yield str(key_entry.name) + elif isinstance(key_entry, jtu.FlattenedIndexKey): + yield str(key_entry.key) + elif hasattr(key_entry, "__dict__") and len(key_entry.__dict__) == 1: + yield str(next(iter(key_entry.__dict__.values()))) + else: + yield str(key_entry) + + +def _key_path_to_str_path(key_path: KeyPath) -> StrPath: + return StrPath(_key_path_to_str_gen(key_path)) + + +class StateDef(tp.Generic[A]): + __slots__ = ("treedef",) + + def __init__(self, treedef: jtu.PyTreeDef): + self.treedef = treedef + + def merge(self, *partitions: Partition) -> A: + raise NotImplementedError + + +def statedef_flatten(x: StateDef): + return (), x.treedef + + +def statedef_unflatten(treedef, children): + return StateDef(treedef) + + +jtu.register_pytree_node(StateDef, statedef_flatten, statedef_unflatten) + + +def tree_partition( + pytree: A, + *predicates: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef[A]]: + paths_leaves: tp.List[tp.Tuple[KeyPath, Leaf]] + paths_leaves, treedef = jax.tree_util.tree_flatten_with_path( + pytree, + is_leaf=lambda x: (isinstance(x, Variable) or x is NOTHING) + or (False if is_leaf is None else is_leaf(x)), + ) + + leaves: tp.Tuple[Leaf, ...] + paths, leaves = zip(*paths_leaves) + paths = tuple(map(_key_path_to_str_path, paths)) + + # we have n + 1 partitions, where n is the number of predicates + # the last partition is for values that don't match any predicate + partition_leaves: tp.Tuple[Leaves, ...] = tuple( + [NOTHING] * len(leaves) for _ in range(len(predicates) + 1) + ) + for j, leaf in enumerate(leaves): + for i, predicate in enumerate(predicates): + if isinstance(leaf, Variable) and predicate(leaf.collection): + partition_leaves[i][j] = leaf + break + else: + # if we didn't break, set leaf to last partition + partition_leaves[-1][j] = leaf + + partitions = tuple( + Partition(zip(paths, partition)) for partition in partition_leaves + ) + return partitions, StateDef(treedef) + + +def get_partition( + pytree, + predicate: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> Partition: + (partition, _rest), _treedef = tree_partition( + pytree, predicate, is_leaf=is_leaf + ) + return partition + + +def _get_non_nothing( + paths: tp.Tuple[StrPath, ...], + leaves: tp.Tuple[tp.Union[Leaf, Nothing], ...], + position: int, +): + # check that all paths are the same + paths_set = set(paths) + if len(paths_set) != 1: + raise ValueError( + "All partitions must have the same paths, " + f" at position [{position}] got " + "".join(f"\n- {path}" for path in paths_set) + ) + non_null = [option for option in leaves if option is not NOTHING] + if len(non_null) == 0: + raise ValueError( + f"Expected at least one non-null value for position [{position}]" + ) + elif len(non_null) > 1: + raise ValueError( + f"Expected at most one non-null value for position [{position}]" + ) + return non_null[0] + + +def merge_partitions( + partitions: tp.Sequence[Partition], treedef: jax.tree_util.PyTreeDef +): + lenghts = [len(partition) for partition in partitions] + if not all(length == lenghts[0] for length in lenghts): + raise ValueError( + "All partitions must have the same length, got " + f"{', '.join(str(length) for length in lenghts)}" + ) + + partition_paths = (list(partition.keys()) for partition in partitions) + partition_leaves = (list(partition.values()) for partition in partitions) + + merged_leaves = [ + _get_non_nothing(paths, leaves, i) + for i, (paths, leaves) in enumerate( + zip(zip(*partition_paths), zip(*partition_leaves)) + ) + ] + + return jax.tree_util.tree_unflatten(treedef, merged_leaves) diff --git a/flax/experimental/nnx/ideas/pure/full/state_full.py b/flax/experimental/nnx/ideas/pure/full/state_full.py new file mode 100644 index 0000000000..40373eaf53 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure/full/state_full.py @@ -0,0 +1,159 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp +from types import MappingProxyType + +import jax +import jax.tree_util as jtu +from pure.partitioning import Partition, StateDef, Variable + +Node = tp.Union[Variable, "State"] +S = tp.TypeVar("S", bound="State") + + +class State(tp.Mapping[str, Node]): + __slots__ = ("_variables",) + + def __init__(self, *args, **kwargs: tp.Union[Node, jax.Array]): + self._variables = { + k: self._create_node_field(v) for k, v in dict(*args, **kwargs).items() + } + + @staticmethod + def _create_node_field(value: tp.Any) -> Node: + if isinstance(value, State): + return value + else: + return Variable.from_value(value) + + @staticmethod + def _update_node_field(node: Node, value: tp.Any) -> Node: + if isinstance(node, State) and isinstance(value, State): + return value + elif isinstance(node, Variable) and isinstance(value, Variable): + return node.update(value) + else: + raise ValueError( + f"Cannot update node of type {type(node).__name__} with " + f"value of type {type(value).__name__}" + ) + + def __getitem__(self, name: str) -> tp.Any: + return self._variables[name].value + + __getattr__ = __getitem__ + + def __iter__(self) -> tp.Iterator[str]: + return iter(self._variables) + + def __len__(self) -> int: + return len(self._variables) + + def keys(self) -> tp.KeysView[str]: + return self._variables.keys() + + def values(self) -> tp.ValuesView[Node]: + return self._variables.values() + + def __repr__(self) -> str: + return f"State({self._variables})" + + def update(self, *args, **kwargs: tp.Union[Node, tp.Any]) -> "State": + raise NotImplementedError + + @tp.overload + def partition(self) -> tp.Tuple[tp.Dict[str, Partition], StateDef["State"]]: + ... + + @tp.overload + def partition( + self, collection: str + ) -> tp.Tuple[Partition, StateDef["State"]]: + ... + + @tp.overload + def partition( + self, collection: str, *collections: str + ) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef["State"]]: + ... + + def partition( + self, *collections: str + ) -> tp.Tuple[ + tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition], + StateDef["State"], + ]: + raise NotImplementedError + + @tp.overload + def get_partition(self, collection: str) -> Partition: + ... + + @tp.overload + def get_partition( + self, collection: str, *collections: str + ) -> tp.Tuple[Partition, ...]: + ... + + def get_partition( + self, *collections: str + ) -> tp.Union[Partition, tp.Tuple[Partition, ...]]: + raise NotImplementedError + + def update_partition( + self, partition: Partition, *partitions: Partition + ) -> "State": + raise NotImplementedError + + @tp.overload + def pop(self, name: str) -> Node: + ... + + @tp.overload + def pop(self, name: str, *names: str) -> tp.Tuple[Node, ...]: + ... + + def pop(self, *names: str) -> tp.Union[Node, tp.Tuple[Node, ...]]: + if len(names) == 0: + raise ValueError("pop expected at least 1 argument, got 0") + elif len(names) == 1: + name = names[0] + return self._variables.pop(name) + else: + return tuple(self._variables.pop(name) for name in names) + + +def _state_flatten_with_keys(state: State): + nodes = tuple( + (jtu.GetAttrKey(name), variable) for name, variable in state.items() + ) + names = tuple(state) + return nodes, names + + +def _state_unflatten(names: tp.Tuple[str, ...], nodes: tp.Tuple[Variable, ...]): + return State(zip(names, nodes)) + + +def _state_flatten(state: State): + return tuple(state.values()), tuple(state) + + +jtu.register_pytree_with_keys( + State, + _state_flatten_with_keys, + _state_unflatten, + flatten_func=_state_flatten, +) diff --git a/flax/experimental/nnx/ideas/pure/module.py b/flax/experimental/nnx/ideas/pure/module.py new file mode 100644 index 0000000000..7eb399e0d0 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure/module.py @@ -0,0 +1,84 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp +from dataclasses import dataclass +from typing import Any + +import jax +import jax.tree_util as jtu +from pure.partitioning import Partition +from pure.rngs import KeyArray, Rngs +from pure.state import State, Variable + +A = tp.TypeVar("A", contravariant=True) + + +class InitFn(tp.Protocol, tp.Generic[A]): + + @tp.overload + def __call__(self, __key_or_stream: A) -> tp.Any: + ... + + def __call__(self, __key_or_stream: A, *args: tp.Any) -> tp.Any: + ... + + +class Initializer: + + @tp.overload + def __init__( + self, + initializer: InitFn[KeyArray], + *args, + collection: str = "params", + ): + ... + + @tp.overload + def __init__( + self, + initializer: InitFn[Rngs], + *args, + stream: None, + collection: str = "params", + ): + ... + + def __init__( + self, + initializer: tp.Union[InitFn[KeyArray], InitFn[Rngs]], + *args, + stream: tp.Optional[str] = "params", + collection: str = "params", + ): + ... + + def create_variable(self, rngs: Rngs) -> Variable: + ... + + +class Module: + + def create_state(self, rngs: Rngs) -> State: + return State( + ( + name, + v.create_state(rngs) + if isinstance(v, Module) + else v.create_variable(rngs), + ) + for name, v in vars(self).items() + if isinstance(v, (Initializer, Module)) + ) diff --git a/flax/experimental/nnx/ideas/pure/partitioning.py b/flax/experimental/nnx/ideas/pure/partitioning.py new file mode 100644 index 0000000000..385850503b --- /dev/null +++ b/flax/experimental/nnx/ideas/pure/partitioning.py @@ -0,0 +1,168 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.tree_util as jtu + +A = tp.TypeVar("A") +CollectionPredicate = tp.Callable[[str], bool] +Leaf = tp.Any +Leaves = tp.List[Leaf] +KeyPath = tp.Tuple[tp.Hashable, ...] +LeafPredicate = tp.Callable[[tp.Any], bool] + + +class Variable: + __slots__ = ("_value", "_collection") + + def __init__(self, value: tp.Any, collection: str = "params"): + ... + + @property + def value(self) -> tp.Any: + ... + + @property + def collection(self) -> str: + ... + + @classmethod + def from_value(cls, value: tp.Any) -> "Variable": + ... + + def copy(self) -> "Variable": + ... + + def update(self, value: tp.Any) -> "Variable": + ... + + def __repr__(self) -> str: + return f"Variable({self.value}, collection={self.collection})" + + +def _flatten_variable_with_keys(variable: Variable): + ... + + +def _flatten_variable(variable: Variable): + ... + + +def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): + ... + + +jax.tree_util.register_pytree_with_keys( + Variable, + _flatten_variable_with_keys, + _unflatten_variable, + flatten_func=_flatten_variable, +) + + +class Nothing: + + def __repr__(self) -> str: + ... + + +def _nothing_flatten(x): + ... + + +def _nothing_unflatten(aux_data, children): + ... + + +NOTHING = Nothing() + +jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) + + +class StrPath(tp.Tuple[str, ...]): + pass + + +class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): + + def __setitem__(self, key, value): + raise TypeError("Partition is immutable") + + +def _partition_flatten_with_keys( + x: Partition, +) -> tp.Tuple[ + tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] +]: + ... + + +def _partition_unflatten( + keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...] +): + ... + + +jax.tree_util.register_pytree_with_keys( + Partition, _partition_flatten_with_keys, _partition_unflatten +) + + +class StateDef(tp.Generic[A]): + __slots__ = ("treedef",) + + def __init__(self, treedef: jtu.PyTreeDef): + ... + + @property + def treedef(self) -> jtu.PyTreeDef: + ... + + def merge(self, *partitions: Partition) -> A: + ... + + +def statedef_flatten(x: StateDef): + ... + + +def statedef_unflatten(treedef, children): + ... + + +jtu.register_pytree_node(StateDef, statedef_flatten, statedef_unflatten) + + +def tree_partition( + pytree: A, + *predicates: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef[A]]: + ... + + +def get_partition( + pytree, + predicate: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> Partition: + ... + + +def merge_partitions( + partitions: tp.Sequence[Partition], treedef: jax.tree_util.PyTreeDef +): + ... diff --git a/flax/experimental/nnx/ideas/pure/rngs.py b/flax/experimental/nnx/ideas/pure/rngs.py new file mode 100644 index 0000000000..6780a06761 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure/rngs.py @@ -0,0 +1,67 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax + +KeyArray = tp.Union[jax.Array, jax.random.KeyArray] + + +class RngStream: + + def __init__( + self, key: KeyArray, count: int = 0, count_path: tp.Tuple[int, ...] = () + ): + ... + + @property + def key(self) -> jax.random.KeyArray: + ... + + @property + def count(self) -> int: + ... + + @property + def count_path(self) -> tp.Tuple[int, ...]: + ... + + def next(self) -> jax.random.KeyArray: + ... + + def fork(self) -> "RngStream": + ... + + +class Rngs: + + def __init__(self, **streams: tp.Union[KeyArray, RngStream]): + ... + + def make_rng(self, stream: str) -> jax.Array: + ... + + @tp.overload + def fork(self, stream: str) -> RngStream: + ... + + @tp.overload + def fork(self, stream: str, *streams: str) -> tp.Tuple[RngStream, ...]: + ... + + def fork( + self, *streams: str + ) -> tp.Union[RngStream, tp.Tuple[RngStream, ...]]: + ... diff --git a/flax/experimental/nnx/ideas/pure/state.py b/flax/experimental/nnx/ideas/pure/state.py new file mode 100644 index 0000000000..416e16e15e --- /dev/null +++ b/flax/experimental/nnx/ideas/pure/state.py @@ -0,0 +1,110 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp +from types import MappingProxyType + +import jax +import jax.tree_util as jtu +from pure.partitioning import Partition, StateDef, Variable + +Node = tp.Union[Variable, "State"] +S = tp.TypeVar("S", bound="State") + + +class State(tp.Mapping[str, Node]): + __slots__ = ("_variables",) + + def __init__(self, *args, **kwargs: tp.Union[Node, jax.Array]): + ... + + def __getitem__(self, name: str) -> tp.Any: + ... + + __getattr__ = __getitem__ + + def __iter__(self) -> tp.Iterator[str]: + ... + + def __len__(self) -> int: + ... + + def keys(self) -> tp.KeysView[str]: + ... + + def values(self) -> tp.ValuesView[Node]: + ... + + def __repr__(self) -> str: + ... + + def update(self, *args, **kwargs: tp.Union[Node, tp.Any]) -> "State": + ... + + @tp.overload + def partition(self) -> tp.Dict[str, Partition]: + ... + + @tp.overload + def partition(self, collection: str) -> Partition: + ... + + @tp.overload + def partition( + self, collection: str, *collections: str + ) -> tp.Tuple[Partition, ...]: + ... + + def partition( + self, *collections: str + ) -> tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition]: + ... + + def merge(self, partition: Partition, *partitions: Partition) -> "State": + ... + + @tp.overload + def pop(self, name: str) -> Node: + ... + + @tp.overload + def pop(self, name: str, *names: str) -> tp.Tuple[Node, ...]: + ... + + def pop(self, *names: str) -> tp.Union[Node, tp.Tuple[Node, ...]]: + ... + + +def _state_flatten_with_keys(state: State): + ... + + +def _state_unflatten(names: tp.Tuple[str, ...], nodes: tp.Tuple[Variable, ...]): + ... + + +def _state_flatten(state: State): + ... + + +jtu.register_pytree_with_keys( + State, + _state_flatten_with_keys, + _state_unflatten, + flatten_func=_state_flatten, +) + + +def merge(partition: Partition, other: Partition, *rest: Partition) -> State: + ... diff --git a/flax/experimental/nnx/ideas/pure_example.py b/flax/experimental/nnx/ideas/pure_example.py new file mode 100644 index 0000000000..601209f8bf --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_example.py @@ -0,0 +1,175 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from functools import partial +from typing import Tuple + +import jax +import pure +from pure.rngs import Rngs +from pure.state import State + + +@dataclass +class Linear: + din: int + dout: int + + def create_state(self, rngs: Rngs) -> State: + key = rngs.make_rng("params") + return State( + kernel=jax.random.uniform(key, (self.din, self.dout)), + bias=jax.numpy.zeros((self.dout,)), + ) + + def __call__(self, state: pure.State, x): + return x @ state.kernel + state.bias + + +class BatchNorm(pure.Module): + + def __init__(self, din: int, mu: float = 0.95): + self.scale = pure.Initializer(jax.random.uniform, (din,)) + self.bias = pure.Initializer(lambda _: jax.numpy.zeros((din,))) + self.mean = pure.Initializer( + lambda _: jax.numpy.zeros((din,)), collection="batch_stats" + ) + self.var = pure.Initializer( + lambda _: jax.numpy.ones((din,)), collection="batch_stats" + ) + self.mu = mu + + def __call__( + self, state: pure.State, x, use_running_averages: bool + ) -> Tuple[jax.Array, pure.State]: + scale, bias = state.scale, state.bias + if use_running_averages: + mean, var = state.mean, state.var + else: + axis = tuple(range(0, x.ndim - 1)) + mean = jax.numpy.mean(x, axis=axis) + var = jax.numpy.var(x, axis=axis) + # ema update + state = state.update( + mean=self.mu * state.mean + (1 - self.mu) * mean, + var=self.mu * state.var + (1 - self.mu) * var, + ) + + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias + + return x, state + + +class Dropout(pure.Module): + + def __init__(self, rate: float): + raise NotImplementedError + + def __call__(self, state, rngs: Rngs, x, *, deterministic: bool) -> jax.Array: + key = rngs.make_rng("dropout") + raise NotImplementedError + + +class MLP(pure.Module): + + def __init__(self, din: int, dmid: int, dout: int): + self.linear1 = Linear(din, dmid) + self.bn1 = BatchNorm(dmid) + self.dropout = Dropout(0.5) + self.linear2 = Linear(dmid, dout) + + def __call__( + self, state: pure.State, rngs: pure.Rngs, x: jax.Array, *, train: bool + ) -> Tuple[jax.Array, pure.State]: + x = self.linear1(state.linear1, x) + x, bn1 = self.bn1(state.bn1, x, use_running_averages=not train) + x = self.dropout(state.dropout, rngs, x, deterministic=not train) + x = jax.nn.relu(x) + x = self.linear2(state.linear2, x) + return x, state.update(bn1=bn1) + + +model = MLP(10, 20, 30) +rngs = pure.Rngs(params=jax.random.PRNGKey(0)) +state = model.create_state(rngs) + + +@jax.jit +def train_step(state: pure.State, key, batch): + x, y = batch + params = state.partition(nnx.Param) + rngs = pure.Rngs(dropout=key) + + def loss(params): + _state = state.merge(params) + y_pred, _state = model(_state, rngs, x, train=True) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, _state + + grads, state = jax.grad(loss, has_aux=True)(params) + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + state = state.merge(params) + + return state + + +# ---------------------------------------- +# scan over layers + shared batch_stats +# ---------------------------------------- + +model = MLP(10, 20, 10) +n_layers = 10 +params_keys = jax.random.PRNGKey(0) +params_keys = jax.random.split(params_keys, n_layers) + + +@partial(jax.vmap, in_axes=0, out_axes=(0, None)) +def create_state(params_key: jax.random.KeyArray): + state = model.create_state(pure.Rngs(params=params_key)) + params, batch_stats = state.partition(nnx.Param, "batch_stats") + return params, batch_stats + + +params, batch_stats = create_state(params_keys) +x = jax.numpy.zeros((32, 10)) +dropout_key = jax.random.PRNGKey(1) +dropout_stream = pure.RngStream(jax.random.split(dropout_key, n_layers)) + + +def scan_fn( + carry: Tuple[jax.Array, pure.Partition], + inputs: Tuple[pure.Partition, pure.RngStream], +): + # extract args + x, batch_stats = carry + params, dropout_stream = inputs + + # create state and rngs + state = pure.merge(params, batch_stats) + rngs = pure.Rngs(dropout=dropout_stream) + + # forward pass + x, state = model(state, rngs, x, train=True) + + # partition state + params, batch_stats = state.partition(nnx.Param, "batch_stats") + + return (x, batch_stats), params + + +(y, batch_stats), params = jax.lax.scan( + scan_fn, (x, batch_stats), (params, dropout_stream) +) +state = pure.merge(params, batch_stats) diff --git a/flax/experimental/nnx/ideas/pure_nnx_example.py b/flax/experimental/nnx/ideas/pure_nnx_example.py new file mode 100644 index 0000000000..9ecef21ccd --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_nnx_example.py @@ -0,0 +1,161 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp + +from flax.experimental import nnx + + +class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.kernel = nnx.Param( + jax.random.uniform(ctx.make_rng("params"), (din, dout)) + ) + self.bias = nnx.Param(jax.numpy.zeros((dout,))) + + def __call__(self, x): + return x @ self.kernel + self.bias + + +class BatchNorm(nnx.Module): + + def __init__(self, din: int, mu: float = 0.95, *, ctx: nnx.Context): + self.scale = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din,))) + self.bias = nnx.Param(jax.numpy.zeros((din,))) + self.mean = nnx.BatchStat(jax.numpy.zeros((din,))) + self.var = nnx.BatchStat(jax.numpy.ones((din,))) + self.mu = mu + + def __call__(self, x, *, use_running_averages: bool) -> jax.Array: + scale, bias = self.scale, self.bias + if use_running_averages: + mean, var = self.mean, self.var + else: + axis = tuple(range(0, x.ndim - 1)) + mean = jax.numpy.mean(x, axis=axis) + var = jax.numpy.var(x, axis=axis) + # ema update + self.mean = self.mu * self.mean + (1 - self.mu) * mean + self.var = self.mu * self.var + (1 - self.mu) * var + + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias + + return x + + +@dataclasses.dataclass +class Dropout(nnx.Module): + rate: float + + def __call__(self, inputs, *, deterministic: bool, ctx: nnx.Context): + if (self.rate == 0.0) or deterministic: + return inputs + rng = ctx.make_rng("dropout") + keep_prob = 1.0 - self.rate + mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape) + return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + +class MLP(nnx.Module): + + def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): + self.linear1 = Linear(din, dmid, ctx=ctx) + self.bn1 = BatchNorm(dmid, ctx=ctx) + self.dropout = Dropout(0.5) + self.linear2 = Linear(dmid, dout, ctx=ctx) + + def __call__( + self, x: jax.Array, *, train: bool, ctx: nnx.Context + ) -> jax.Array: + x = self.linear1(x) + x = self.bn1(x, use_running_averages=not train) + x = self.dropout(x, deterministic=not train, ctx=ctx) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +ctx = nnx.Context(jax.random.PRNGKey(0)) +model = MLP(10, 20, 30, ctx=ctx) + + +@nnx.jit +def train_step(model: MLP, key, batch): + x, y = batch + + def loss(model: MLP): + ctx = nnx.Context(rngs=dict(dropout=key)) + y_pred = model(x, train=True, ctx=ctx) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss + + grads = nnx.grad(loss, wrt=nnx.Param)(model) + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + + +# ---------------------------------------- +# scan over layers + shared batchnorm +# ---------------------------------------- + +n_layers = 10 +params_keys = jax.random.PRNGKey(0) +params_keys = jax.random.split(params_keys, n_layers) + + +@partial(jax.vmap, in_axes=0, out_axes=(0, None, None)) +def create_state(params_key: jax.random.KeyArray): + ctx = nnx.Context(rngs=dict(params=params_key)) + model = MLP(10, 20, 10, ctx=ctx) + (params, batch_stats), modeldef = model.partition(nnx.Param, "batch_stats") + return params, batch_stats, modeldef + + +params, batch_stats, modeldef = create_state(params_keys) +x = jax.numpy.zeros((32, 10)) +dropout_key = jax.random.split(jax.random.PRNGKey(1), n_layers) + + +def scan_fn( + carry: Tuple[jax.Array, nnx.State], + inputs: Tuple[nnx.State, jax.random.KeyArray], +): + # extract args + x, batch_stats = carry + params, dropout_key = inputs + + # create state and ctx + model = modeldef.merge(params, batch_stats) + ctx = nnx.Context(dropout=dropout_key) + + # forward pass + x = model(x, train=True, ctx=ctx) + + # partition state + (params, batch_stats), _ = model.partition(nnx.Param, "batch_stats") + + return (x, batch_stats), params + + +(y, batch_stats), params = jax.lax.scan( + scan_fn, (x, batch_stats), (params, dropout_key) +) +model = modeldef.merge(params, batch_stats) diff --git a/flax/experimental/nnx/ideas/pure_pytree/__init__.py b/flax/experimental/nnx/ideas/pure_pytree/__init__.py new file mode 100644 index 0000000000..77f9a57aec --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dataclass import VariableField, dataclass, field, param, static_field, variable +from .module import Initializer, Module +from .partitioning import NOTHING, Partition, get_partition +from .partitioning import merge_partitions as merge +from .partitioning import tree_partition as partition +from .rngs import Rngs, RngStream diff --git a/flax/experimental/nnx/ideas/pure_pytree/dataclass.py b/flax/experimental/nnx/ideas/pure_pytree/dataclass.py new file mode 100644 index 0000000000..4910de7615 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree/dataclass.py @@ -0,0 +1,132 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import typing as tp +from dataclasses import field + +import typing_extensions as tpe +from simple_pytree import static_field + +A = tp.TypeVar("A") +K = tp.TypeVar("K", bound=tp.Hashable) + + +class VariableField(dataclasses.Field, tp.Generic[A]): + + def __init__( + self, + *, + collection: tp.Hashable = None, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[tp.Any, tp.Any]] = None, + ): + ... + + def __set_name__(self, cls, name): + ... + + def __get__(self, obj, objtype=None) -> A: + ... + + def __set__(self, obj, value: A): + ... + + +# ---------------------------------------- +# fields +# ---------------------------------------- + + +def variable( + collection: str, + default: tp.Any = dataclasses.MISSING, + *, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +) -> tp.Any: + return VariableField( + collection=collection, + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def param( + default: tp.Any = dataclasses.MISSING, + *, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +) -> tp.Any: + return variable( + "params", + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +@tp.overload +def dataclass(cls: tp.Type[A]) -> tp.Type[A]: + ... + + +@tp.overload +def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Callable[[tp.Type[A]], tp.Type[A]]: + ... + + +@tpe.dataclass_transform( + field_specifiers=(variable, param, field, static_field) +) +def dataclass( + cls: tp.Optional[tp.Type[A]] = None, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Union[tp.Type[A], tp.Callable[[tp.Type[A]], tp.Type[A]]]: + ... diff --git a/flax/experimental/nnx/ideas/pure_pytree/full/partitioning_full.py b/flax/experimental/nnx/ideas/pure_pytree/full/partitioning_full.py new file mode 100644 index 0000000000..62f2055f3f --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree/full/partitioning_full.py @@ -0,0 +1,269 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.tree_util as jtu + +A = tp.TypeVar("A") +CollectionPredicate = tp.Callable[[str], bool] +Leaf = tp.Any +Leaves = tp.List[Leaf] +KeyPath = tp.Tuple[tp.Hashable, ...] +LeafPredicate = tp.Callable[[tp.Any], bool] + + +class Variable: + __slots__ = ("_value", "_collection") + + def __init__(self, value: tp.Any, collection: str = "params"): + self._value = value + self._collection = collection + + @property + def value(self) -> tp.Any: + return self._value + + @property + def collection(self) -> str: + return self._collection + + @classmethod + def from_value(cls, value: tp.Any) -> "Variable": + return value if isinstance(value, Variable) else Variable(value) + + def copy(self) -> "Variable": + return Variable(self.value, self.collection) + + def update(self, value: tp.Any) -> "Variable": + if isinstance(value, Variable): + if value.collection != self.collection: + raise ValueError( + "Cannot update variable with value from a different collection. " + f"Expected collection {self.collection}, got {value.collection}" + ) + value = value.value + return Variable(value, self.collection) + + def __repr__(self) -> str: + return f"Variable({self.value}, collection={self.collection})" + + +def _flatten_variable_with_keys(variable: Variable): + node = (jtu.GetAttrKey("value"), variable.value) + return (node,), variable.collection + + +def _flatten_variable(variable: Variable): + return (variable.value,), variable.collection + + +def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): + return Variable(nodes[0], collection) + + +jax.tree_util.register_pytree_with_keys( + Variable, + _flatten_variable_with_keys, + _unflatten_variable, + flatten_func=_flatten_variable, +) + + +class Nothing: + + def __repr__(self) -> str: + return "Nothing" # pragma: no cover + + +def _nothing_flatten(x): + return (), None + + +def _nothing_unflatten(aux_data, children): + return NOTHING + + +NOTHING = Nothing() + +jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) + + +class StrPath(tp.Tuple[str, ...]): + pass + + +class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): + + def __setitem__(self, key, value): + raise TypeError("Partition is immutable") + + +def _partition_flatten_with_keys( + x: Partition, +) -> tp.Tuple[ + tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] +]: + children = tuple((StrPath(key), value) for key, value in x.items()) + return children, tuple(x.keys()) + + +def _partition_unflatten( + keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...] +): + return Partition(zip(keys, leaves)) + + +jax.tree_util.register_pytree_with_keys( + Partition, _partition_flatten_with_keys, _partition_unflatten +) + + +def _key_path_to_str_gen(key_path: KeyPath) -> tp.Generator[str, None, None]: + for key_entry in key_path: + if isinstance(key_entry, StrPath): + yield from key_entry + elif isinstance(key_entry, jtu.SequenceKey): + yield str(key_entry.idx) + elif isinstance(key_entry, jtu.DictKey): # "['a']" + yield str(key_entry.key) + elif isinstance(key_entry, jtu.GetAttrKey): + yield str(key_entry.name) + elif isinstance(key_entry, jtu.FlattenedIndexKey): + yield str(key_entry.key) + elif hasattr(key_entry, "__dict__") and len(key_entry.__dict__) == 1: + yield str(next(iter(key_entry.__dict__.values()))) + else: + yield str(key_entry) + + +def _key_path_to_str_path(key_path: KeyPath) -> StrPath: + return StrPath(_key_path_to_str_gen(key_path)) + + +class StateDef(tp.Generic[A]): + __slots__ = ("treedef",) + + def __init__(self, treedef: jtu.PyTreeDef): + self.treedef = treedef + + def merge(self, *partitions: Partition) -> A: + raise NotImplementedError + + +def statedef_flatten(x: StateDef): + return (), x.treedef + + +def statedef_unflatten(treedef, children): + return StateDef(treedef) + + +jtu.register_pytree_node(StateDef, statedef_flatten, statedef_unflatten) + + +def tree_partition( + pytree: A, + *predicates: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef[A]]: + paths_leaves: tp.List[tp.Tuple[KeyPath, Leaf]] + paths_leaves, treedef = jax.tree_util.tree_flatten_with_path( + pytree, + is_leaf=lambda x: (isinstance(x, Variable) or x is NOTHING) + or (False if is_leaf is None else is_leaf(x)), + ) + + leaves: tp.Tuple[Leaf, ...] + paths, leaves = zip(*paths_leaves) + paths = tuple(map(_key_path_to_str_path, paths)) + + # we have n + 1 partitions, where n is the number of predicates + # the last partition is for values that don't match any predicate + partition_leaves: tp.Tuple[Leaves, ...] = tuple( + [NOTHING] * len(leaves) for _ in range(len(predicates) + 1) + ) + for j, leaf in enumerate(leaves): + for i, predicate in enumerate(predicates): + if isinstance(leaf, Variable) and predicate(leaf.collection): + partition_leaves[i][j] = leaf + break + else: + # if we didn't break, set leaf to last partition + partition_leaves[-1][j] = leaf + + partitions = tuple( + Partition(zip(paths, partition)) for partition in partition_leaves + ) + return partitions, StateDef(treedef) + + +def get_partition( + pytree, + predicate: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> Partition: + (partition, _rest), _treedef = tree_partition( + pytree, predicate, is_leaf=is_leaf + ) + return partition + + +def _get_non_nothing( + paths: tp.Tuple[StrPath, ...], + leaves: tp.Tuple[tp.Union[Leaf, Nothing], ...], + position: int, +): + # check that all paths are the same + paths_set = set(paths) + if len(paths_set) != 1: + raise ValueError( + "All partitions must have the same paths, " + f" at position [{position}] got " + "".join(f"\n- {path}" for path in paths_set) + ) + non_null = [option for option in leaves if option is not NOTHING] + if len(non_null) == 0: + raise ValueError( + f"Expected at least one non-null value for position [{position}]" + ) + elif len(non_null) > 1: + raise ValueError( + f"Expected at most one non-null value for position [{position}]" + ) + return non_null[0] + + +def merge_partitions( + partitions: tp.Sequence[Partition], treedef: jax.tree_util.PyTreeDef +): + lenghts = [len(partition) for partition in partitions] + if not all(length == lenghts[0] for length in lenghts): + raise ValueError( + "All partitions must have the same length, got " + f"{', '.join(str(length) for length in lenghts)}" + ) + + partition_paths = (list(partition.keys()) for partition in partitions) + partition_leaves = (list(partition.values()) for partition in partitions) + + merged_leaves = [ + _get_non_nothing(paths, leaves, i) + for i, (paths, leaves) in enumerate( + zip(zip(*partition_paths), zip(*partition_leaves)) + ) + ] + + return jax.tree_util.tree_unflatten(treedef, merged_leaves) diff --git a/flax/experimental/nnx/ideas/pure_pytree/full/state_full.py b/flax/experimental/nnx/ideas/pure_pytree/full/state_full.py new file mode 100644 index 0000000000..40373eaf53 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree/full/state_full.py @@ -0,0 +1,159 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp +from types import MappingProxyType + +import jax +import jax.tree_util as jtu +from pure.partitioning import Partition, StateDef, Variable + +Node = tp.Union[Variable, "State"] +S = tp.TypeVar("S", bound="State") + + +class State(tp.Mapping[str, Node]): + __slots__ = ("_variables",) + + def __init__(self, *args, **kwargs: tp.Union[Node, jax.Array]): + self._variables = { + k: self._create_node_field(v) for k, v in dict(*args, **kwargs).items() + } + + @staticmethod + def _create_node_field(value: tp.Any) -> Node: + if isinstance(value, State): + return value + else: + return Variable.from_value(value) + + @staticmethod + def _update_node_field(node: Node, value: tp.Any) -> Node: + if isinstance(node, State) and isinstance(value, State): + return value + elif isinstance(node, Variable) and isinstance(value, Variable): + return node.update(value) + else: + raise ValueError( + f"Cannot update node of type {type(node).__name__} with " + f"value of type {type(value).__name__}" + ) + + def __getitem__(self, name: str) -> tp.Any: + return self._variables[name].value + + __getattr__ = __getitem__ + + def __iter__(self) -> tp.Iterator[str]: + return iter(self._variables) + + def __len__(self) -> int: + return len(self._variables) + + def keys(self) -> tp.KeysView[str]: + return self._variables.keys() + + def values(self) -> tp.ValuesView[Node]: + return self._variables.values() + + def __repr__(self) -> str: + return f"State({self._variables})" + + def update(self, *args, **kwargs: tp.Union[Node, tp.Any]) -> "State": + raise NotImplementedError + + @tp.overload + def partition(self) -> tp.Tuple[tp.Dict[str, Partition], StateDef["State"]]: + ... + + @tp.overload + def partition( + self, collection: str + ) -> tp.Tuple[Partition, StateDef["State"]]: + ... + + @tp.overload + def partition( + self, collection: str, *collections: str + ) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef["State"]]: + ... + + def partition( + self, *collections: str + ) -> tp.Tuple[ + tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition], + StateDef["State"], + ]: + raise NotImplementedError + + @tp.overload + def get_partition(self, collection: str) -> Partition: + ... + + @tp.overload + def get_partition( + self, collection: str, *collections: str + ) -> tp.Tuple[Partition, ...]: + ... + + def get_partition( + self, *collections: str + ) -> tp.Union[Partition, tp.Tuple[Partition, ...]]: + raise NotImplementedError + + def update_partition( + self, partition: Partition, *partitions: Partition + ) -> "State": + raise NotImplementedError + + @tp.overload + def pop(self, name: str) -> Node: + ... + + @tp.overload + def pop(self, name: str, *names: str) -> tp.Tuple[Node, ...]: + ... + + def pop(self, *names: str) -> tp.Union[Node, tp.Tuple[Node, ...]]: + if len(names) == 0: + raise ValueError("pop expected at least 1 argument, got 0") + elif len(names) == 1: + name = names[0] + return self._variables.pop(name) + else: + return tuple(self._variables.pop(name) for name in names) + + +def _state_flatten_with_keys(state: State): + nodes = tuple( + (jtu.GetAttrKey(name), variable) for name, variable in state.items() + ) + names = tuple(state) + return nodes, names + + +def _state_unflatten(names: tp.Tuple[str, ...], nodes: tp.Tuple[Variable, ...]): + return State(zip(names, nodes)) + + +def _state_flatten(state: State): + return tuple(state.values()), tuple(state) + + +jtu.register_pytree_with_keys( + State, + _state_flatten_with_keys, + _state_unflatten, + flatten_func=_state_flatten, +) diff --git a/flax/experimental/nnx/ideas/pure_pytree/module.py b/flax/experimental/nnx/ideas/pure_pytree/module.py new file mode 100644 index 0000000000..d17bfad3db --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree/module.py @@ -0,0 +1,72 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +from pure_pytree.partitioning import Partition, PartitionDef, Variable + +A = tp.TypeVar("A", contravariant=True) +M = tp.TypeVar("M", bound="Module") + + +class Pytree: + ... + + +class Module(Pytree): + + def replace(self: M, **kwargs: tp.Any) -> M: + ... + + @tp.overload + def partition(self: M) -> tp.Tuple[tp.Dict[str, Partition], PartitionDef[M]]: + ... + + @tp.overload + def partition( + self: M, collection: str + ) -> tp.Tuple[Partition, PartitionDef[M]]: + ... + + @tp.overload + def partition( + self: M, collection: str, *collections: str + ) -> tp.Tuple[tp.Tuple[Partition, ...], PartitionDef[M]]: + ... + + def partition( + self: M, *collections: str + ) -> tp.Tuple[ + tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition], + PartitionDef[M], + ]: + ... + + @tp.overload + def get_partition(self, collection: str) -> Partition: + ... + + @tp.overload + def get_partition( + self, collection: str, *collections: str + ) -> tp.Tuple[Partition, ...]: + ... + + def get_partition( + self, *collections: str + ) -> tp.Union[Partition, tp.Tuple[Partition, ...]]: + ... + + def merge(self: M, partition: Partition, *partitions: Partition) -> M: + ... diff --git a/flax/experimental/nnx/ideas/pure_pytree/partitioning.py b/flax/experimental/nnx/ideas/pure_pytree/partitioning.py new file mode 100644 index 0000000000..27adebc682 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree/partitioning.py @@ -0,0 +1,168 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.tree_util as jtu + +A = tp.TypeVar("A") +CollectionPredicate = tp.Callable[[str], bool] +Leaf = tp.Any +Leaves = tp.List[Leaf] +KeyPath = tp.Tuple[tp.Hashable, ...] +LeafPredicate = tp.Callable[[tp.Any], bool] + + +class Variable: + __slots__ = ("_value", "_collection") + + def __init__(self, value: tp.Any, collection: str = "params"): + ... + + @property + def value(self) -> tp.Any: + ... + + @property + def collection(self) -> str: + ... + + @classmethod + def from_value(cls, value: tp.Any) -> "Variable": + ... + + def copy(self) -> "Variable": + ... + + def update(self, value: tp.Any) -> "Variable": + ... + + def __repr__(self) -> str: + return f"Variable({self.value}, collection={self.collection})" + + +def _flatten_variable_with_keys(variable: Variable): + ... + + +def _flatten_variable(variable: Variable): + ... + + +def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): + ... + + +jax.tree_util.register_pytree_with_keys( + Variable, + _flatten_variable_with_keys, + _unflatten_variable, + flatten_func=_flatten_variable, +) + + +class Nothing: + + def __repr__(self) -> str: + ... + + +def _nothing_flatten(x): + ... + + +def _nothing_unflatten(aux_data, children): + ... + + +NOTHING = Nothing() + +jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) + + +class StrPath(tp.Tuple[str, ...]): + pass + + +class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): + + def __setitem__(self, key, value): + raise TypeError("Partition is immutable") + + +def _partition_flatten_with_keys( + x: Partition, +) -> tp.Tuple[ + tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] +]: + ... + + +def _partition_unflatten( + keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...] +): + ... + + +jax.tree_util.register_pytree_with_keys( + Partition, _partition_flatten_with_keys, _partition_unflatten +) + + +class PartitionDef(tp.Generic[A]): + __slots__ = ("treedef",) + + def __init__(self, treedef: jtu.PyTreeDef): + ... + + @property + def treedef(self) -> jtu.PyTreeDef: + ... + + def merge(self, *partitions: Partition) -> A: + ... + + +def partitiondef_flatten(x: PartitionDef): + ... + + +def statedef_unflatten(treedef, children): + ... + + +jtu.register_pytree_node(PartitionDef, partitiondef_flatten, statedef_unflatten) + + +def tree_partition( + pytree: A, + *predicates: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> tp.Tuple[tp.Tuple[Partition, ...], PartitionDef[A]]: + ... + + +def get_partition( + pytree, + predicate: CollectionPredicate, + is_leaf: tp.Optional[LeafPredicate] = None, +) -> Partition: + ... + + +def merge_partitions( + partitions: tp.Sequence[Partition], partitiondef: PartitionDef[A] +) -> A: + ... diff --git a/flax/experimental/nnx/ideas/pure_pytree/rngs.py b/flax/experimental/nnx/ideas/pure_pytree/rngs.py new file mode 100644 index 0000000000..6780a06761 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree/rngs.py @@ -0,0 +1,67 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax + +KeyArray = tp.Union[jax.Array, jax.random.KeyArray] + + +class RngStream: + + def __init__( + self, key: KeyArray, count: int = 0, count_path: tp.Tuple[int, ...] = () + ): + ... + + @property + def key(self) -> jax.random.KeyArray: + ... + + @property + def count(self) -> int: + ... + + @property + def count_path(self) -> tp.Tuple[int, ...]: + ... + + def next(self) -> jax.random.KeyArray: + ... + + def fork(self) -> "RngStream": + ... + + +class Rngs: + + def __init__(self, **streams: tp.Union[KeyArray, RngStream]): + ... + + def make_rng(self, stream: str) -> jax.Array: + ... + + @tp.overload + def fork(self, stream: str) -> RngStream: + ... + + @tp.overload + def fork(self, stream: str, *streams: str) -> tp.Tuple[RngStream, ...]: + ... + + def fork( + self, *streams: str + ) -> tp.Union[RngStream, tp.Tuple[RngStream, ...]]: + ... diff --git a/flax/experimental/nnx/ideas/pure_pytree_example.py b/flax/experimental/nnx/ideas/pure_pytree_example.py new file mode 100644 index 0000000000..8482db1ba4 --- /dev/null +++ b/flax/experimental/nnx/ideas/pure_pytree_example.py @@ -0,0 +1,170 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from functools import partial +from typing import Tuple + +import jax +import pure_pytree as pure + + +class Linear(pure.Module): + kernel: jax.Array = pure.param() + bias: jax.Array = pure.param() + + def __init__(self, din: int, dout: int, *, rngs: pure.Rngs): + self.kernel = jax.random.uniform(rngs.make_rng("params"), (din, dout)) + self.bias = jax.numpy.zeros((dout,)) + + def __call__(self, x): + return x @ self.kernel + self.bias + + +class BatchNorm(pure.Module): + scale: jax.Array = pure.param() + bias: jax.Array = pure.param() + mean: jax.Array = pure.variable("batch_stats") + var: jax.Array = pure.variable("batch_stats") + mu: float = pure.static_field() + + def __init__(self, din: int, mu: float = 0.95, *, rngs: pure.Rngs): + self.scale = jax.random.uniform(rngs.make_rng("params"), (din,)) + self.bias = jax.numpy.zeros((din,)) + self.mean = jax.numpy.zeros((din,)) + self.var = jax.numpy.ones((din,)) + self.mu = mu + + def __call__( + self, x, use_running_averages: bool + ) -> Tuple[jax.Array, "BatchNorm"]: + scale, bias = self.scale, self.bias + if use_running_averages: + mean, var = self.mean, self.var + else: + axis = tuple(range(0, x.ndim - 1)) + mean = jax.numpy.mean(x, axis=axis) + var = jax.numpy.var(x, axis=axis) + # ema update + self = self.replace( + mean=self.mu * self.mean + (1 - self.mu) * mean, + var=self.mu * self.var + (1 - self.mu) * var, + ) + + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias + + return x, self + + +class Dropout(pure.Module): + + def __init__(self, rate: float): + raise NotImplementedError + + def __call__(self, x, *, deterministic: bool, rngs: pure.Rngs) -> jax.Array: + key = rngs.make_rng("dropout") + ... + raise NotImplementedError + + +class MLP(pure.Module): + + def __init__(self, din: int, dmid: int, dout: int, *, rngs: pure.Rngs): + self.linear1 = Linear(din, dmid, rngs=rngs) + self.bn1 = BatchNorm(dmid, rngs=rngs) + self.dropout = Dropout(0.5) + self.linear2 = Linear(dmid, dout, rngs=rngs) + + def __call__( + self, x: jax.Array, *, train: bool, rngs: pure.Rngs + ) -> Tuple[jax.Array, "MLP"]: + x = self.linear1(x) + x, bn1 = self.bn1(x, use_running_averages=not train) + x = self.dropout(x, deterministic=not train, rngs=rngs) + x = jax.nn.relu(x) + x = self.linear2(x) + return x, self.replace(bn1=bn1) + + +rngs = pure.Rngs(params=jax.random.PRNGKey(0)) +model = MLP(10, 20, 30, rngs=rngs) + + +@jax.jit +def train_step(model: MLP, key, batch): + x, y = batch + params = model.get_partition("params") + rngs = pure.Rngs(dropout=key) + + def loss(params): + _model = model.merge(params) + y_pred, _model = model(x, train=True, rngs=rngs) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, _model + + grads, model = jax.grad(loss, has_aux=True)(params) + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + model = model.merge(params) + + return model + + +# ---------------------------------------- +# scan over layers + shared batchnorm +# ---------------------------------------- + +n_layers = 10 +params_keys = jax.random.PRNGKey(0) +params_keys = jax.random.split(params_keys, n_layers) + + +@partial(jax.vmap, in_axes=0, out_axes=(0, None, None)) +def create_state(params_key: jax.random.KeyArray): + rngs = pure.Rngs(params=params_key) + model = MLP(10, 20, 10, rngs=rngs) + (params, batch_stats), modeldef = model.partition(nnx.Param, "batch_stats") + return params, batch_stats, modeldef + + +params, batch_stats, modeldef = create_state(params_keys) +x = jax.numpy.zeros((32, 10)) +dropout_key = jax.random.PRNGKey(1) +dropout_stream = pure.RngStream(jax.random.split(dropout_key, n_layers)) + + +def scan_fn( + carry: Tuple[jax.Array, pure.Partition], + inputs: Tuple[pure.Partition, pure.RngStream], +): + # extract args + x, batch_stats = carry + params, dropout_stream = inputs + + # create state and rngs + model = pure.merge([params, batch_stats], modeldef) + rngs = pure.Rngs(dropout=dropout_stream) + + # forward pass + x, model = model(x, train=True, rngs=rngs) + + # partition state + params, batch_stats = model.get_partition("params", "batch_stats") + + return (x, batch_stats), params + + +(y, batch_stats), params = jax.lax.scan( + scan_fn, (x, batch_stats), (params, dropout_stream) +) +model = pure.merge([params, batch_stats], modeldef) diff --git a/flax/experimental/nnx/ideas/shape_inference.py b/flax/experimental/nnx/ideas/shape_inference.py new file mode 100644 index 0000000000..45660e7746 --- /dev/null +++ b/flax/experimental/nnx/ideas/shape_inference.py @@ -0,0 +1,216 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +from jax import random + +from flax.experimental import nnx + + +class Linear(nnx.Module): + + @tp.overload + def __init__(self, *, din: int, dout: int, ctx: nnx.Context): + ... + + @tp.overload + def __init__(self, *, dout: int): + ... + + @tp.overload + def __init__( + self, + *, + din: tp.Optional[int] = None, + dout: int, + ctx: tp.Optional[nnx.Context] = None, + ): + ... + + def __init__( + self, + *, + din: tp.Optional[int] = None, + dout: int, + ctx: tp.Optional[nnx.Context] = None, + ): + self.dout = dout + if din is not None: + if ctx is None: + raise ValueError("ctx must be provided if din is provided") + self.init_variables(din, ctx) + + def init_variables(self, din: int, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(random.uniform(key, (din, self.dout))) + self.b = nnx.Param(jnp.zeros((self.dout,))) + + def __call__( + self, x: jax.Array, *, ctx: tp.Optional[nnx.Context] = None + ) -> jax.Array: + if self.is_initializing and not hasattr(self, "w"): + if ctx is None: + raise ValueError("ctx must be provided to initialize module") + self.init_variables(x.shape[-1], ctx) + + return x @ self.w + self.b + + +class BatchNorm(nnx.Module): + + @tp.overload + def __init__(self, *, mu: float = 0.95): + ... + + @tp.overload + def __init__(self, *, din: int, mu: float = 0.95, ctx: nnx.Context): + ... + + @tp.overload + def __init__( + self, + *, + din: tp.Optional[int] = None, + mu: float = 0.95, + ctx: tp.Optional[nnx.Context] = None, + ): + ... + + def __init__( + self, + *, + din: tp.Optional[int] = None, + mu: float = 0.95, + ctx: tp.Optional[nnx.Context] = None, + ): + self.mu = mu + + if din is not None: + if ctx is None: + raise ValueError("ctx must be provided if din is provided") + self.init_variables(din, ctx) + + def init_variables(self, din: int, ctx: nnx.Context): + self.scale = nnx.Param(jax.numpy.ones((din,))) + self.bias = nnx.Param(jax.numpy.zeros((din,))) + self.mean = nnx.BatchStat(jax.numpy.zeros((din,))) + self.var = nnx.BatchStat(jax.numpy.ones((din,))) + + def __call__( + self, x, *, train: bool, ctx: tp.Optional[nnx.Context] = None + ) -> jax.Array: + if self.is_initializing and not hasattr(self, "scale"): + if ctx is None: + raise ValueError("ctx must be provided to initialize module") + self.init_variables(x.shape[-1], ctx) + + if train: + axis = tuple(range(x.ndim - 1)) + mean = jax.numpy.mean(x, axis=axis) + var = jax.numpy.var(x, axis=axis) + # ema update + self.mean = self.mu * self.mean + (1 - self.mu) * mean + self.var = self.mu * self.var + (1 - self.mu) * var + else: + mean, var = self.mean, self.var + + scale, bias = self.scale, self.bias + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias + return x + + +class Dropout(nnx.Module): + + def __init__(self, rate: float): + self.rate = rate + + def __call__( + self, x: jax.Array, *, train: bool, ctx: nnx.Context + ) -> jax.Array: + if train: + mask = random.bernoulli(ctx.make_rng("dropout"), (1 - self.rate), x.shape) + x = x * mask / (1 - self.rate) + return x + + +# ---------------------------- +# test Linear +# ---------------------------- +print("test Linear") + +# eager +m1 = Linear(din=32, dout=10, ctx=nnx.context(params=0)) +y = m1(x=jnp.ones((1, 32))) +print(jax.tree_map(jnp.shape, m1.get_state())) + +# lazy +m2 = Linear(dout=10) +y = m2.init(x=jnp.ones((1, 32)), ctx=nnx.context(params=0)) +print(jax.tree_map(jnp.shape, m2.get_state())) + +# usage +y1 = m1(x=jnp.ones((1, 32))) +y2 = m2(x=jnp.ones((1, 32))) + +# ---------------------------- +# Test scan +# ---------------------------- +print("\ntest scan") + + +class Block(nnx.Module): + + def __init__( + self, + din: tp.Optional[int] = None, + dout: int = 10, + ctx: tp.Optional[nnx.Context] = None, + ): + self.linear = Linear(din=din, dout=dout, ctx=ctx) + self.bn = BatchNorm(din=dout if din is not None else None, ctx=ctx) + self.dropout = Dropout(0.5) + + def __call__(self, x: jax.Array, _, *, train: bool, ctx: nnx.Context): + x = self.linear(x, ctx=ctx) + x = self.bn(x, train=train, ctx=ctx) + x = self.dropout(x, train=train, ctx=ctx) + x = jax.nn.gelu(x) + return x, None + + +MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + variable_carry=nnx.BatchStat, + split_rngs={"params": True, "dropout": True}, + length=5, +) + + +# eager +mlp = MLP(din=10, dout=10, ctx=nnx.context(params=0)) +y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, ctx=nnx.context(dropout=1)) +print(f"{y.shape=}") +print("state =", jax.tree_map(jnp.shape, mlp.get_state())) +print() + +# lazy +mlp = MLP(dout=10) +mlp.init(jnp.ones((1, 10)), None, train=False, ctx=nnx.context(params=0)) +y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, ctx=nnx.context(dropout=1)) +print(f"{y.shape=}") +print("state =", jax.tree_map(jnp.shape, mlp.get_state())) diff --git a/flax/experimental/nnx/nnx/__init__.py b/flax/experimental/nnx/nnx/__init__.py new file mode 100644 index 0000000000..e80ba0b35f --- /dev/null +++ b/flax/experimental/nnx/nnx/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flax/experimental/nnx/nnx/containers.py b/flax/experimental/nnx/nnx/containers.py new file mode 100644 index 0000000000..0bc69b9c82 --- /dev/null +++ b/flax/experimental/nnx/nnx/containers.py @@ -0,0 +1,223 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import functools +import typing as tp +from abc import ABCMeta +from functools import partial +from typing import Any + +import jax.tree_util as jtu + +from flax.experimental.nnx.nnx import nodes, reprlib + +A = tp.TypeVar("A") +B = tp.TypeVar("B") +F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any]) +Sharding = tp.Tuple[tp.Optional[str], ...] + + +@dataclasses.dataclass +class ContainerMetadata(tp.Generic[A]): + value: A + metadata: tp.Mapping[str, tp.Any] + + +class ContainerMetaclass(ABCMeta): + + def __call__(self, value: A, **metadata: tp.Any) -> A: + if isinstance(value, Container): + container = value + value = container.value + else: + container = None + + obj = super().__call__(value, **metadata) + + if container is not None and not container.is_equivalent(obj): + raise ValueError( + f"input value of type '{type(container).__name__}' is not compatible " + f"with return type '{type(obj).__name__}'" + ) + + return obj + + +class Container( + tp.Generic[A], reprlib.Representable, metaclass=ContainerMetaclass +): + value: A + + def __init__( + self, value: tp.Union[A, ContainerMetadata[A]], **metadata: tp.Any + ): + if isinstance(value, ContainerMetadata): + metadata.update(value.metadata) + value = tp.cast(A, value.value) + + vars(self).update(metadata, value=value) + + if tp.TYPE_CHECKING: + + def __getattr__(self, name: str) -> tp.Any: + ... + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Container): + return False + return type(self) is type(other) and vars(other) == vars(self) + + @tp.overload + def replace(self, *, value: B, **kwargs) -> "Container[B]": + ... + + @tp.overload + def replace(self, **kwargs) -> "Container[A]": + ... + + def replace(self, **kwargs) -> "Container[tp.Any]": + if "value" in kwargs: + value = kwargs["value"] + if isinstance(value, Container): + if not self.is_equivalent(value): + raise ValueError( + "Cannot replace value from incompatible container, " + f"expected {self}, got {value}" + ) + kwargs["value"] = value.value + + attributes = vars(self).copy() + # validate keys + for key in kwargs: + if key not in attributes: + raise ValueError(f"Unknown metadata key {key!r}") + attributes.update(**kwargs) + node_type = type(self) + return node_type(**attributes) + + def is_equivalent(self, other: tp.Any) -> bool: + def metadata_fields(container: Container[tp.Any]) -> tp.Dict[str, tp.Any]: + return {k: v for k, v in vars(container).items() if k != "value"} + + return type(self) is type(other) and metadata_fields( + self + ) == metadata_fields(other) + + def copy(self: "Container[A]") -> "Container[A]": + return type(self)(**vars(self)) + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + for name, value in vars(self).items(): + yield reprlib.Attr(name, repr(value)) + + +class NodeBase(Container[A]): + + def __init_subclass__(cls): + super().__init_subclass__() + + def _node_flatten( + x: NodeBase[tp.Any], + *, + with_keys: bool, + ): + attributes = vars(x).copy() + value = attributes.pop("value") + if with_keys: + node = (jtu.GetAttrKey("value"), value) + else: + node = value + + return (node,), attributes + + def _node_unflatten( + metadata: tp.Mapping[str, tp.Any], children: tp.Tuple[A] + ) -> NodeBase[A]: + return cls(children[0], **metadata) + + jtu.register_pytree_with_keys( + cls, + partial(_node_flatten, with_keys=True), + _node_unflatten, + flatten_func=partial(_node_flatten, with_keys=False), + ) + + +class Node(NodeBase[A]): + pass + + +class Variable(Node[A]): + sharding: tp.Optional[Sharding] + + def __init__( + self, + value: tp.Union[A, ContainerMetadata[A]], + sharding: tp.Optional[Sharding] = None, + **metadata: Any, + ): + super().__init__(value, sharding=sharding, **metadata) + + +class Param(Variable[A]): + pass + + +class BatchStat(Variable[A]): + pass + + +class Cache(Variable[A]): + pass + + +class Intermediate(Variable[A]): + pass + + +class Static(Container[A], reprlib.Representable): + + def __init__(self, value: A): + super().__init__(value) + + def __hash__(self) -> int: + return hash(self.value) + + +def _static_flatten(x: Static[tp.Any]): + return (), x.value + + +def _static_unflatten(metadata: A, _) -> Static[A]: + return Static(metadata) + + +jtu.register_pytree_node(Static, _static_flatten, _static_unflatten) + + +def with_metadata( + initializer: F, + **metadata: tp.Any, +) -> F: + @functools.wraps(initializer) + def wrapper(*args): + return ContainerMetadata(initializer(*args), metadata=metadata) + + return wrapper # type: ignore + + +# register nodes +nodes.register_node_type(Node) diff --git a/flax/experimental/nnx/nnx/contextlib.py b/flax/experimental/nnx/nnx/contextlib.py new file mode 100644 index 0000000000..5fe5d93c3d --- /dev/null +++ b/flax/experimental/nnx/nnx/contextlib.py @@ -0,0 +1,214 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import dataclasses +import hashlib +import typing as tp + +import jax +import jax.tree_util as jtu + +from flax.experimental.nnx.nnx import errors, tracers + +KeyArray = jax.Array +Counts = tp.Tuple[int, ...] + + +def _stable_hash(data: Counts) -> int: + hash_str = " ".join(str(x) for x in data) + _hash = hashlib.blake2s(hash_str.encode()) + hash_bytes = _hash.digest() + # uint32 is represented as 4 bytes in big endian + return int.from_bytes(hash_bytes[:4], byteorder="big") + + +class ContextDef: + __slots__ = ("_rng_counts", "_flags") + + def __init__( + self, + rng_counts: tp.Tuple[tp.Tuple[str, Counts], ...], + flags: tp.Tuple[tp.Tuple[str, bool], ...], + ): + self._rng_counts = rng_counts + self._flags = flags + + def merge(self, keys: tp.Mapping[str, KeyArray]) -> "Context": + rngs = { + name: RngStream(keys[name], count=0, count_path=count_path) + for name, count_path in self._rng_counts + } + return Context(rngs=rngs, flags=dict(self._flags)) + + +class PureContext(tp.Tuple[tp.Dict[str, KeyArray], ContextDef]): + + @classmethod + def new(cls, keys: tp.Dict[str, KeyArray], contextdef: ContextDef): + return cls((keys, contextdef)) + + @property + def keys(self) -> tp.Dict[str, KeyArray]: + return self[0] + + @property + def contextdef(self) -> ContextDef: + return self[1] + + def merge(self): + return self.contextdef.merge(self.keys) + + +def _pure_context_flatten(pure_context: PureContext): + return tuple(pure_context), None + + +def _pure_context_unflatten( + aux_data: None, + children: tp.Tuple[tp.Dict[str, KeyArray], ContextDef], +) -> PureContext: + return PureContext(children) + + +jtu.register_pytree_node( + PureContext, _pure_context_flatten, _pure_context_unflatten +) + + +@dataclasses.dataclass +class RngStream: + key: KeyArray + count: int = 0 + count_path: Counts = () + + +class Context: + __slots__ = ("_rngs", "_flags", "_trace_state") + + def __init__( + self, + rngs: tp.Mapping[str, RngStream], + flags: tp.Mapping[str, bool], + ): + self._rngs = rngs + self._flags = flags + self._trace_state = tracers.TraceState() + + def has_rng(self, name: str) -> bool: + return name in self._rngs + + def make_rng(self, name: str) -> KeyArray: + if name not in self._rngs: + raise ValueError(f"Unknown Rng Stream: {name}") + elif not self._trace_state.is_valid(): + raise errors.TraceContextError( + "Cannot use Context from a different trace level" + ) + + stream = self._rngs[name] + fold_data = _stable_hash(stream.count_path + (stream.count,)) + stream.count += 1 + return jax.random.fold_in(stream.key, fold_data) + + def copy(self) -> "Context": + return Context(rngs=self._rngs, flags=self._flags) + + def has_flag(self, name: str) -> bool: + return name in self._flags + + def get_flag(self, name: str) -> tp.Optional[bool]: + return self._flags.get(name, None) + + def partition(self) -> PureContext: + if not self._trace_state.is_valid(): + raise errors.TraceContextError( + "Cannot use Context from a different trace level" + ) + + def fork(stream) -> "RngStream": + count_path = stream.count_path + (stream.count,) + stream.count += 1 + return RngStream(stream.key, count_path=count_path) + + rngs = {name: fork(stream) for name, stream in self._rngs.items()} + keys = {name: stream.key for name, stream in rngs.items()} + rng_counts = tuple( + (name, stream.count_path) for name, stream in rngs.items() + ) + return PureContext.new( + keys, ContextDef(rng_counts, tuple(self._flags.items())) + ) + + +def context( + params: tp.Union[int, KeyArray, RngStream, None] = None, + *, + flags: tp.Optional[tp.Mapping[str, bool]] = None, + **rngs: tp.Union[int, KeyArray, RngStream], +) -> Context: + _flags = flags or {} + + if params is not None: + rngs["params"] = params + + _rngs = { + name: ( + RngStream(jax.random.PRNGKey(value)) + if isinstance(value, int) + else RngStream(value) + if isinstance(value, jax.Array) + else value + ) + for name, value in rngs.items() + } + + return Context(rngs=_rngs, flags=_flags) + + +if tp.TYPE_CHECKING: + ellipsis = builtins.ellipsis +else: + ellipsis = tp.Any + +RngPredicate = tp.Callable[[str], bool] +RngFilterLiteral = tp.Union[str, RngPredicate, ellipsis, None] +RngFilter = tp.Union[ + RngFilterLiteral, + tp.Sequence[RngFilterLiteral], + tp.Mapping[RngFilterLiteral, bool], +] + + +def to_rng_predicate(filter: RngFilter) -> RngPredicate: + if filter is None: + return lambda _: False + elif filter is ...: + return lambda _: True + elif callable(filter): + return filter + elif isinstance(filter, str): + return lambda name: name == filter + elif isinstance(filter, tp.Mapping): + predicates = tuple( + to_rng_predicate(filter) + for filter, include in filter.items() + if include + ) + return lambda name: any(predicate(name) for predicate in predicates) + elif isinstance(filter, tp.Sequence): + predicates = tuple(map(to_rng_predicate, filter)) + return lambda name: any(predicate(name) for predicate in predicates) + else: + raise TypeError(f"Invalid rng filter: {filter}") diff --git a/flax/experimental/nnx/nnx/dataclasses.py b/flax/experimental/nnx/nnx/dataclasses.py new file mode 100644 index 0000000000..264f4fa238 --- /dev/null +++ b/flax/experimental/nnx/nnx/dataclasses.py @@ -0,0 +1,207 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import typing as tp + +import typing_extensions as tpe + +from flax.experimental.nnx.nnx import containers + +A = tp.TypeVar("A") + + +def field( + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +): + return dataclasses.field( # type: ignore + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def node_field( + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +): + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + if "nnx_container_fn" in metadata: + raise ValueError("'nnx_container_fn' found in metadata") + + metadata["nnx_container_fn"] = lambda value: containers.Node(value) + + return field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def static_field( + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +): + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + if "nnx_container_fn" in metadata: + raise ValueError("'nnx_container_fn' found in metadata") + + metadata["nnx_container_fn"] = lambda value: containers.Static(value) + + return field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def var_field( + variable_type: tp.Type[containers.Variable[tp.Any]], + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, + sharding: tp.Optional[containers.Sharding] = None, +) -> tp.Any: + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + if "nnx_container_fn" in metadata: + raise ValueError("'nnx_container_fn' found in metadata") + + metadata["nnx_container_fn"] = lambda value: variable_type( + value, sharding=sharding + ) + + return field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def param_field( + default: tp.Any = dataclasses.MISSING, + *, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +) -> tp.Any: + return var_field( + containers.Param, + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +@tp.overload +def dataclass(cls: tp.Type[A]) -> tp.Type[A]: + ... + + +@tp.overload +def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Callable[[tp.Type[A]], tp.Type[A]]: + ... + + +@tpe.dataclass_transform( + field_specifiers=(field, node_field, static_field, var_field, param_field) +) +def dataclass( + cls: tp.Optional[tp.Type[A]] = None, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Union[tp.Type[A], tp.Callable[[tp.Type[A]], tp.Type[A]]]: + decorator = dataclasses.dataclass( + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + ) + + if cls is None: + return decorator + + return decorator(cls) diff --git a/flax/experimental/nnx/nnx/errors.py b/flax/experimental/nnx/nnx/errors.py new file mode 100644 index 0000000000..c72305e62d --- /dev/null +++ b/flax/experimental/nnx/nnx/errors.py @@ -0,0 +1,17 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TraceContextError(Exception): + pass diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py new file mode 100644 index 0000000000..85558e6833 --- /dev/null +++ b/flax/experimental/nnx/nnx/helpers.py @@ -0,0 +1,171 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import typing as tp + +import jax.numpy as jnp +import optax + +from flax.experimental.nnx.nnx import pytreelib +from flax.experimental.nnx.nnx.contextlib import Context +from flax.experimental.nnx.nnx.module import ApplyCaller, Module, ModuleDef, Pure +from flax.experimental.nnx.nnx.state import State + +A = tp.TypeVar("A") +M = tp.TypeVar("M", bound=Module) + + +class Dict(Module, tp.Mapping[str, A]): + + @tp.overload + def __init__(self, __iterable: tp.Iterable[tp.Tuple[str, A]]): + ... + + @tp.overload + def __init__( + self, __mapping: tp.Optional[tp.Mapping[str, A]] = None, **kwargs: A + ): + ... + + def __init__(self, *args, **kwargs): + for name, value in dict(*args, **kwargs).items(): + setattr(self, name, value) + + def __getitem__(self, key) -> A: + return getattr(self, key) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getattr__(self, key) -> A: + return super().__getattribute__(key) + + def __setattr__(self, key, value): + super().__setattr__(key, value) + + def __iter__(self) -> tp.Iterator[str]: + return iter(vars(self)) + + def __len__(self) -> int: + return len(vars(self)) + + +class Sequence(Module, tp.Generic[A]): + + def __init__(self, iterable: tp.Iterable[A]): + i = 0 + for i, value in enumerate(iterable): + setattr(self, str(i), value) + self._length = i + 1 + + def __getitem__(self, key: int) -> A: + if key >= len(self): + raise IndexError(f"index {key} out of range for {self}") + return getattr(self, str(key)) + + def __iter__(self) -> tp.Iterator[A]: + for i in range(len(self)): + yield getattr(self, str(i)) + + def __len__(self) -> int: + return self._length + + def __call__( + self, *args, ctx: tp.Optional[Context] = None, **kwargs + ) -> tp.Any: + output: tp.Any = None + + for i, f in enumerate(self): + if not callable(f): + raise TypeError(f"Sequence[{i}] is not callable: {f}") + if i > 0: + if isinstance(output, tp.Tuple): + args = output + kwargs = {} + elif isinstance(output, tp.Dict): + args = () + kwargs = output + else: + args = (output,) + kwargs = {} + if ctx is not None and has_keyword_arg(f, "ctx"): + kwargs["ctx"] = ctx + + output = f(*args, **kwargs) + + return output + + +class ModuleDefApply(tp.Protocol, tp.Generic[M]): + + def __call__(self, state: State, *states: State) -> ApplyCaller["Pure[M]"]: + ... + + +class TrainState(pytreelib.Pytree, tp.Generic[M]): + + def __init__( + self, + moduledef: ModuleDef[M], + *, + params: State, + tx: optax.GradientTransformation, + step: int = 0, + **kwargs, + ): + self.moduledef = moduledef + self.params: State = pytreelib.TreeNode(params) + self.tx = tx + self.opt_state = pytreelib.TreeNode(tx.init(self.params)) + self.step = pytreelib.TreeNode(jnp.asarray(step)) + for name, value in kwargs.items(): + setattr(self, name, value) + + if tp.TYPE_CHECKING: + + def __getattr__(self, key: str) -> tp.Any: + ... + + def apply( + self, state: tp.Union[State, str], *states: tp.Union[State, str] + ) -> ApplyCaller[Pure[State, M]]: + states = (state, *states) + + _states = ( + getattr(self, state) if isinstance(state, str) else state + for state in states + ) + + return self.moduledef.apply(*_states) + + def apply_gradients(self, grads: State, **kwargs) -> "TrainState[M]": + updates, opt_state = self.tx.update(grads, self.opt_state, self.params) + params = optax.apply_updates(self.params, updates) # type: ignore + step = self.step + 1 + return self.replace( + params=params, + opt_state=opt_state, + step=step, + **kwargs, + ) + + +def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool: + """Return True if func has keyword-only arguments with the given name.""" + return any( + param.name == name + and param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD) + for param in inspect.signature(func).parameters.values() + ) diff --git a/flax/experimental/nnx/nnx/ids.py b/flax/experimental/nnx/nnx/ids.py new file mode 100644 index 0000000000..5ca7e99843 --- /dev/null +++ b/flax/experimental/nnx/nnx/ids.py @@ -0,0 +1,79 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UUIDs for Flax internals.""" + +import threading + + +class UUIDManager: + """Globally unique counter-based id manager. + + We need globally unique key ids for Module and Variable object instances + to preserve and recreate sharing-by-reference relationship when lifting + transforms and adopting outside Modules. + - Use of id() is unacceptable because these identifiers are literally + pointers which can be recycled, so we rely on a globally unique counter id + instead. + - We need to handle copy/deepcopy uniqueness via a wrapped type. + """ + + def __init__(self): + self._lock = threading.Lock() + self._id = 0 + + def __call__(self): + with self._lock: + self._id += 1 + return UUID(self._id) + + +uuid = UUIDManager() + + +class UUID: + """Hashable wrapper for ids that handles uniqueness of copies.""" + + def __init__(self, rawid): + self.id = rawid + + def __eq__(self, other): + return isinstance(other, UUID) and other.id == self.id + + def __hash__(self): + return hash(self.id) + + def __repr__(self): + return f"UUID({self.id})" + + def __deepcopy__(self, memo): + del memo + return uuid() + + def __copy__(self): + return uuid() diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py new file mode 100644 index 0000000000..7ca5ba7c22 --- /dev/null +++ b/flax/experimental/nnx/nnx/module.py @@ -0,0 +1,993 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import typing as tp +from abc import ABCMeta +from functools import partial +from typing import Any + +import jax.tree_util as jtu + +from flax.experimental.nnx.nnx import containers, errors, ids, nodes, partitioning, reprlib, tracers +from flax.experimental.nnx.nnx.containers import Container, Node, Sharding +from flax.experimental.nnx.nnx.state import State + +A = tp.TypeVar("A") +M = tp.TypeVar("M", bound="Module") +S = tp.TypeVar("S", bound=tp.Union[State, tp.Tuple[State, ...]]) +V = tp.TypeVar("V", bound=containers.Variable[tp.Any]) + +Path = str +PathParts = tp.Tuple[str, ...] +StateDict = tp.Dict[Path, tp.Any] +StateMapping = tp.Mapping[Path, tp.Any] + + +class _ProxyContext(tp.Protocol): + + def __call__(self, __fn: tp.Callable[..., tp.Any], *args, **kwargs) -> tp.Any: + ... + + +@dataclasses.dataclass +class CallableProxy: + _proxy_context: _ProxyContext + _proxy_callable: tp.Callable[..., tp.Any] + + def __call__(self, *args, **kwargs): + return self._proxy_context(self._proxy_callable, *args, **kwargs) + + def __getattr__(self, name) -> "CallableProxy": + return CallableProxy( + self._proxy_context, getattr(self._proxy_callable, name) + ) + + def __getitem__(self, key) -> "CallableProxy": + return CallableProxy(self._proxy_context, self._proxy_callable[key]) + + +def _identity(x): + return x + + +@dataclasses.dataclass +class DelayedAccessor: + accessor: tp.Callable[[tp.Any], tp.Any] = _identity + + def __call__(self, x): + return self.accessor(x) + + def __getattr__(self, name): + return DelayedAccessor(lambda x: getattr(x, name)) + + def __getitem__(self, key): + return DelayedAccessor(lambda x: x[key]) + + +class ApplyCaller(tp.Protocol, tp.Generic[A]): + + def __getattr__(self, __name) -> "ApplyCaller[A]": + ... + + def __getitem__(self, __name) -> "ApplyCaller[A]": + ... + + def __call__(self, *args, **kwargs) -> tp.Tuple[tp.Any, A]: + ... + + +@dataclasses.dataclass(repr=False) +class _SubmodulesRepr(reprlib.Representable): + submodules: tp.Tuple[tp.Tuple[str, tp.Union["ModuleDef[Module]", int]], ...] + + def __nnx_repr__(self): + yield reprlib.Object(type="", value_sep=", ") + + for name, submodule in self.submodules: + yield reprlib.Attr(repr(name), submodule, start="(", end=")") + + +class ModuleDef(tp.Generic[M], reprlib.Representable): + __slots__ = ( + "_type", + "_index", + "_submodules", + "_static_fields", + "_module_state", + ) + + def __init__( + self, + type: tp.Type[M], + index: int, + submodules: tp.Tuple[ + tp.Tuple[str, tp.Union["ModuleDef[Module]", int]], ... + ], + static_fields: tp.Tuple[tp.Tuple[str, tp.Any], ...], + module_state: "ModuleStateTuple", + ): + self._type = type + self._index = index + self._submodules = submodules + self._static_fields = static_fields + self._module_state = module_state + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + + yield reprlib.Attr("type", self._type.__name__) + yield reprlib.Attr("index", self._index) + yield reprlib.Attr("submodules", _SubmodulesRepr(self._submodules)) + yield reprlib.Attr("static_fields", self._static_fields) + + def __hash__(self) -> int: + return hash((self._type, self._submodules, self._static_fields)) + + def __eq__(self, other: tp.Any) -> bool: + if not isinstance(other, ModuleDef): + return False + return ( + self._type == other._type + and self._submodules == other._submodules + and self._static_fields == other._static_fields + ) + + @property + def type(self) -> tp.Type[M]: + return self._type + + @property + def index(self) -> int: + return self._index + + @property + def submodules( + self, + ) -> tp.Tuple[tp.Tuple[str, tp.Union["ModuleDef[Module]", int]], ...]: + return self._submodules + + @property + def static_fields(self) -> tp.Tuple[tp.Tuple[str, tp.Any], ...]: + return self._static_fields + + @property + def module_state(self) -> "ModuleStateTuple": + return self._module_state + + def merge(self, state: State, *states: State) -> M: + states = (state, *states) + module = _build_module(self) + current_state = State({}) + + _update_module(module, current_state, states) + + return module + + def apply( + self, state: State, *states: State + ) -> ApplyCaller["Pure[State, M]"]: + accessesor = DelayedAccessor() + + def _context( + accessesor, *args, **kwargs + ) -> tp.Tuple[tp.Any, Pure[State, M]]: + module = self.merge(state, *states) + fn = accessesor(module) + out = fn(*args, **kwargs) + return out, module.partition() + + return CallableProxy(_context, accessesor) # type: ignore + + +def _moddef_flatten(moduledef: ModuleDef[M]): + return (), ( + moduledef._type, + moduledef._index, + moduledef._submodules, + moduledef._static_fields, + moduledef._module_state, + ) + + +def _moddef_unflatten( + metadata: tp.Tuple[ + tp.Type[M], + int, + tp.Tuple[tp.Tuple[str, tp.Union["ModuleDef[Module]", int]], ...], + tp.Tuple[tp.Tuple[str, tp.Any], ...], + "ModuleStateTuple", + ], + _, +) -> ModuleDef[M]: + return ModuleDef(*metadata) + + +jtu.register_pytree_node(ModuleDef, _moddef_flatten, _moddef_unflatten) + + +class Pure(tp.Tuple[S, ModuleDef[M]]): + + @classmethod + def new(cls, states: S, moduledef: ModuleDef[M]) -> "Pure[S, M]": + return cls((states, moduledef)) + + @property + def states(self) -> S: + return self[0] + + @property + def moduledef(self) -> ModuleDef[M]: + return self[1] + + def merge(self) -> M: + if isinstance(self.states, tuple): + return self.moduledef.merge(*self.states) + else: + return self.moduledef.merge(self.states) + + @property + def apply(self) -> ApplyCaller["Pure[State, M]"]: + if isinstance(self.states, tuple): + return self.moduledef.apply(*self.states) + else: + return self.moduledef.apply(self.states) + + @property + def call(self) -> M: + accessesor = DelayedAccessor() + + def _context(accessesor, *args, **kwargs): + module = self.merge() + fn = accessesor(module) + return fn(*args, **kwargs) + + return CallableProxy(_context, accessesor) # type: ignore + + def get_state(self) -> State: + if isinstance(self.states, tuple): + return State.merge(*self.states) + return self.states + + @tp.overload + def filter( + self, + filter: partitioning.Filter, + /, + ) -> State: + ... + + @tp.overload + def filter( + self, + filter: partitioning.Filter, + filter2: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Tuple[State, ...]: + ... + + def filter( + self, filter: partitioning.Filter, *filters: partitioning.Filter + ) -> tp.Union[State, tp.Tuple[State, ...]]: + filters = (filter, *filters) + state = self.get_state() + + if len(filters) == 1: + return state.filter(filters[0]) + else: + return state.filter(filters[0], filters[1], *filters[2:]) + + @tp.overload + def partition(self) -> "Pure[State, M]": + ... + + @tp.overload + def partition(self, first: partitioning.Filter, /) -> "Pure[State, M]": + ... + + @tp.overload + def partition( + self, + first: partitioning.Filter, + second: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> "Pure[tp.Tuple[State, ...], M]": + ... + + def partition( + self, *filters: partitioning.Filter + ) -> tp.Union["Pure[State, M]", "Pure[tp.Tuple[State, ...], M]"]: + state = self.get_state() + + if len(filters) == 0: + states = state + elif len(filters) == 1: + states = state.partition(filters[0]) + else: + states = state.partition(filters[0], filters[1], *filters[2:]) + + if isinstance(states, State): + return Pure.new(states, self.moduledef) + else: + return Pure.new(states, self.moduledef) + + @tp.overload + def pop_state( + self, filter: partitioning.Filter, / + ) -> tp.Tuple[State, "Pure[State, M]"]: + ... + + @tp.overload + def pop_state( + self, + filter: partitioning.Filter, + filter2: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Tuple[tp.Tuple[State, ...], "Pure[State, M]"]: + ... + + def pop_state( + self, filter: partitioning.Filter, *filters: partitioning.Filter + ) -> tp.Tuple[tp.Union[State, tp.Tuple[State, ...]], "Pure[State, M]"]: + filters = (filter, *filters) + + state = self.get_state() + *states, rest = state.partition(*filters, ...) + + if len(states) == 1: + states = states[0] + else: + states = tuple(states) + + return states, Pure.new(rest, self.moduledef) + + def update_state( + self, + updates: tp.Union[M, "Pure[S, M]", State, tp.Tuple[State, ...]], + ) -> "Pure[State, M]": + if isinstance(updates, Module): + states = (updates.get_state(),) + elif isinstance(updates, Pure): + if isinstance(updates.states, tuple): + states = updates.states + elif isinstance(updates.states, State): + states = (updates.states,) + else: + raise TypeError( + "Expected Module, PureModule, State or tuple of State, " + f"got {type(updates.states).__name__}" + ) + elif isinstance(updates, State): + states = (updates,) + elif isinstance(updates, tuple): + states = updates + else: + raise TypeError( + "Expected Module, PureModule, State or tuple of State, " + f"got {type(updates).__name__}" + ) + + if isinstance(self.states, tuple): + states += self.states + else: + states += (self.states,) + + state = State.merge(*states) + return Pure.new(state, self.moduledef) + + +def _pure_module_flatten(bounded: Pure[S, M]): + return tuple(bounded), None + + +def _pure_module_unflatten(_, values: tp.Tuple[S, ModuleDef[M]]): + return Pure(values) + + +jtu.register_pytree_node(Pure, _pure_module_flatten, _pure_module_unflatten) + +PureModule = Pure[tp.Union[State, tp.Tuple[State, ...]], M] + + +SEEN_MODULES_REPR: tp.Optional[tp.Set[ids.UUID]] = None + +ModuleStateTuple = tuple[bool] + + +class ModuleState(reprlib.Representable): + __slots__ = ("_trace_state", "_id", "is_initializing") + + def __init__(self, is_initializing: bool = False): + self._trace_state = tracers.TraceState() + self._id = ids.uuid() + self.is_initializing = is_initializing + + @property + def trace_state(self) -> tracers.TraceState: + return self._trace_state + + @property + def id(self) -> ids.UUID: + return self._id + + def to_tuple(self) -> ModuleStateTuple: + return (self.is_initializing,) + + @classmethod + def from_tuple(cls, tup: ModuleStateTuple) -> "ModuleState": + return cls(*tup) + + def __nnx_repr__(self): + yield reprlib.Object(type(self)) + yield reprlib.Attr("trace_state", self._trace_state) + + +class ModuleMeta(ABCMeta): + if not tp.TYPE_CHECKING: + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self._meta_call(*args, **kwargs) + + def _meta_call(self: tp.Type[M], *args, **kwargs) -> M: + module = self.__new__(self, *args, **kwargs) + vars(module)["_module__state"] = ModuleState() + module.__init__(*args, **kwargs) + + if dataclasses.is_dataclass(module): + assert isinstance(module, Module) + for field in dataclasses.fields(module): + if "nnx_container_fn" not in field.metadata: + continue + + container_fn = field.metadata["nnx_container_fn"] + value = vars(module)[field.name] + value = container_fn(value) + vars(module)[field.name] = value + + return module + + +@tp.runtime_checkable +class HasUnboxFn(tp.Protocol): + unbox_fn: tp.Callable[["Container[tp.Any]"], tp.Any] + + +class Module(reprlib.Representable, metaclass=ModuleMeta): + if tp.TYPE_CHECKING: + _module__state: ModuleState + + if not tp.TYPE_CHECKING: + + def __getattribute__(self, name: str) -> Any: + value = object.__getattribute__(self, name) + if isinstance(value, Container): + if isinstance(value, HasUnboxFn): + return value.unbox_fn(value) + return value.value + return value + + def __setattr__(self, name: str, value: Any) -> None: + self._setattr(name, value) + + def _setattr(self, name: str, value: Any) -> None: + if not self._module__state.trace_state.is_valid(): + raise errors.TraceContextError( + "Cannot mutate Module from different trace level" + ) + + vars_dict = vars(self) + if name in vars_dict and isinstance(vars_dict[name], Container): + vars_dict[name] = vars_dict[name].replace(value=value) + else: + if isinstance(value, Container): + value = value.copy() + vars_dict[name] = value + + def __hash__(self) -> int: + return hash(self._module__state.id) + + def __nnx_repr__(self): + global SEEN_MODULES_REPR + + if SEEN_MODULES_REPR is None: + SEEN_MODULES_REPR = set() + clear_seen = True + else: + clear_seen = False + + if self._module__state.id in SEEN_MODULES_REPR: + yield reprlib.Object(type=type(self), empty_repr="...") + return + + yield reprlib.Object(type=type(self)) + SEEN_MODULES_REPR.add(self._module__state.id) + + try: + for name, value in vars(self).items(): + if isinstance(value, Module) or ( + not nodes.is_node(value) and not name.startswith("_") + ): + yield reprlib.Attr(name, value) + finally: + if clear_seen: + SEEN_MODULES_REPR = None + + @property + def init(self: M) -> M: + accessor = DelayedAccessor() + + def _init(accessor, *args, **kwargs): + def _set_is_initializing(module: Module, value: bool): + module._module__state.is_initializing = value + + self.for_each(Module, partial(_set_is_initializing, value=True)) + + try: + return accessor(self)(*args, **kwargs) + finally: + self.for_each(Module, partial(_set_is_initializing, value=False)) + + return CallableProxy(_init, accessor) # type: ignore + + @property + def is_initializing(self) -> bool: + return self._module__state.is_initializing + + def clone(self: M) -> M: + return self.partition().merge() + + @tp.overload + def partition(self: M) -> Pure[State, M]: + ... + + @tp.overload + def partition(self: M, first: partitioning.Filter, /) -> Pure[State, M]: + ... + + @tp.overload + def partition( + self: M, + first: partitioning.Filter, + second: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> Pure[tp.Tuple[State, ...], M]: + ... + + def partition( + self: M, *filters: partitioning.Filter + ) -> tp.Union[Pure[State, M], Pure[tp.Tuple[State, ...], M]]: + moduledef = _get_module_def(self) + state = _get_module_state(self) + + if len(filters) == 0: + states = state + elif len(filters) == 1: + states = state.partition(filters[0]) + else: + states = state.partition(filters[0], filters[1], *filters[2:]) + + if isinstance(states, tuple): + return Pure.new(states, moduledef) + else: + return Pure.new(states, moduledef) + + def get_state(self) -> State: + return _get_module_state(self) + + def get_module_def(self: M) -> ModuleDef[M]: + return _get_module_def(self) + + @tp.overload + def filter(self, first: partitioning.Filter, /) -> State: + ... + + @tp.overload + def filter( + self, + first: partitioning.Filter, + second: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Tuple[State, ...]: + ... + + def filter( + self, + first: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Union[State, tp.Tuple[State, ...]]: + state = _get_module_state(self) + + if len(filters) == 0: + states = state.filter(first) + else: + states = state.filter(first, filters[0], *filters[1:]) + + return states + + @tp.overload + def pop_state( + self, + filter: partitioning.Filter, + /, + ) -> State: + ... + + @tp.overload + def pop_state( + self, + filter: partitioning.Filter, + filter2: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Tuple[State, ...]: + ... + + def pop_state( + self, *filters: partitioning.Filter + ) -> tp.Union[State, tp.Tuple[State, ...]]: + if len(filters) == 0: + raise ValueError("Expected at least one filter") + + states = _pop(self, filters) + + if len(states) == 1: + return states[0] + else: + return states + + @property + def apply(self: M) -> ApplyCaller[M]: + accessesor = DelayedAccessor() + + def _context(accessesor, *args, **kwargs) -> tp.Tuple[tp.Any, M]: + module = self.clone() + fn = accessesor(module) + out = fn(*args, **kwargs) + return out, module + + return CallableProxy(_context, accessesor) # type: ignore + + def update_state( + self: M, + updates: tp.Union[M, Pure[S, M], State, tp.Tuple[State, ...]], + ) -> None: + current_state = _get_module_state(self) + + if isinstance(updates, Pure): + updates = updates.states + elif isinstance(updates, Module): + assert type(self) == type(updates) + updates = _get_module_state(updates) + + _update_module(self, current_state, updates) + + def mutable_state_dict(self) -> tp.Dict[Path, "MutableLeaf"]: + return {path: MutableLeaf(self, path) for path, _ in _iter_state(self)} + + def sow( + self, + variable_type: tp.Type[containers.Variable], + name: str, + value: tp.Any, + ) -> None: + if hasattr(self, name): + variable = vars(self)[name] + if not isinstance(variable, containers.Variable): + raise ValueError( + f"Expected '{name}' to be a Variable, got {type(variable).__name__}" + ) + elif type(variable) != variable_type: + raise ValueError( + f"Expected '{name}' to be of type '{variable_type.__name__}', " + f"got '{type(variable).__name__}'" + ) + current_value = variable.value + if not isinstance(current_value, tuple): + raise ValueError( + f"Expected '{name}' to be a tuple, " + f"got {type(current_value).__name__}" + ) + value = current_value + (value,) + setattr(self, name, value) + else: + setattr(self, name, variable_type((value,))) + + def for_each( + self, module_type: tp.Type[M], fn: tp.Callable[[M], None] + ) -> None: + visited: tp.Set[ids.UUID] = set() + self._on_all(module_type, fn, visited) + + def _on_all( + self, + module_type: tp.Type[M], + fn: tp.Callable[[M], None], + visited: tp.Set[ids.UUID], + ) -> None: + if self._module__state.id in visited: + return + + visited.add(self._module__state.id) + + if isinstance(self, module_type): + fn(self) + + for value in vars(self).values(): + if isinstance(value, Module): + value._on_all(module_type, fn, visited) + + # Pytree Definition + # def __init_subclass__(cls) -> None: + # super().__init_subclass__() + + # def _flatten(module: Module, *, with_keys: bool): + # state, moduledef = module.partition() + # paths = tuple(state.keys()) + + # if with_keys: + # nodes = tuple( + # (jtu.DictKey(path), value) for path, value in state.items() + # ) + # else: + # nodes = tuple(state.values()) + + # return nodes, (paths, moduledef) + + # def _unflatten( + # paths_moddef: tp.Tuple[tp.Tuple[Path, ...], ModuleDef[M]], + # nodes: tp.Tuple[tp.Any, ...], + # ) -> M: + # paths, moduledef = paths_moddef + # return moduledef.merge(State(zip(paths, nodes))) + + # jtu.register_pytree_with_keys( + # cls, + # partial(_flatten, with_keys=True), + # _unflatten, + # flatten_func=partial(_flatten, with_keys=False), + # ) + + +class MutableLeaf(reprlib.Representable): + __slots__ = ("_module", "_name") + + def __init__(self, root: Module, path: Path): + path_parts = path.split("/") + *module_path, name = path_parts + module = _get_value_path(root, module_path) + if not isinstance(module, Module): + raise ValueError( + f"Expected a module at path {path_parts[:-1]}, ", + f" got {type(module).__name__}", + ) + self._module = module + self._name = name + + @property + def value(self) -> tp.Any: + return getattr(self._module, self._name) + + @value.setter + def value(self, value: tp.Any) -> None: + setattr(self._module, self._name, value) + + @property + def sharding(self) -> tp.Optional[Sharding]: + attr = vars(self._module)[self._name] + if not isinstance(attr, Node): + return None + + return attr.metadata.get("sharding", None) + + +def _get_module_state(module: Module) -> State: + return State(_iter_state(module)) + + +def _get_module_def(module: M) -> ModuleDef[M]: + module_index: tp.Dict[ids.UUID, int] = {} + path: PathParts = () + + moduledef = _make_module_def_recursive(module, module_index, path) + assert isinstance(moduledef, ModuleDef) + + return moduledef + + +def _make_module_def_recursive( + module: M, + module_index: tp.Dict[ids.UUID, int], + path: PathParts, +) -> tp.Union[ModuleDef[M], int]: + if module._module__state.id in module_index: + return module_index[module._module__state.id] + + index = len(module_index) + module_index[module._module__state.id] = index + + submodules = [] + static_fields = [] + + for name, value in sorted(vars(module).items(), key=lambda x: x[0]): + value_path = (*path, name) + if isinstance(value, Module): + submodule_def = _make_module_def_recursive( + value, module_index, value_path + ) + submodules.append((name, submodule_def)) + elif not nodes.is_node(value) and not name.startswith("_module__"): + static_fields.append((name, value)) + + module_def = ModuleDef( + type=type(module), + index=index, + submodules=tuple(submodules), + static_fields=tuple(static_fields), + module_state=module._module__state.to_tuple(), + ) + return module_def + + +def _iter_state(module: Module) -> tp.Iterator[tp.Tuple[Path, tp.Any]]: + seen_modules: tp.Set[ids.UUID] = set() + path_parts: PathParts = () + + yield from _iter_state_recursive(module, seen_modules, path_parts) + + +def _iter_state_recursive( + module: Module, seen_modules: tp.Set[ids.UUID], path_parts: PathParts +) -> tp.Iterator[tp.Tuple[Path, tp.Any]]: + if module._module__state.id in seen_modules: + return + + seen_modules.add(module._module__state.id) + + for name, value in sorted(vars(module).items(), key=lambda x: x[0]): + new_path_parts = (*path_parts, name) + if isinstance(value, Module): + yield from _iter_state_recursive(value, seen_modules, new_path_parts) + elif nodes.is_node(value): + path = "/".join(new_path_parts) + yield path, value + + +def _set_value_at_path( + module: tp.Any, path_parts: tp.Union[PathParts, tp.List[str]], value: tp.Any +): + if len(path_parts) == 1: + setattr(module, path_parts[0], value) + else: + _set_value_at_path(vars(module)[path_parts[0]], path_parts[1:], value) + + +def _get_value_path(module: tp.Any, path: tp.Sequence[str]) -> tp.Any: + if len(path) == 0: + return module + else: + return _get_value_path(vars(module)[path[0]], path[1:]) + + +def _build_module(moduledef: ModuleDef[M]) -> M: + index_module: tp.Dict[int, Module] = {} + module = _build_module_recursive(moduledef, index_module) + return module + + +def _build_module_recursive( + moduledef: tp.Union[ModuleDef[M], int], + index_module: tp.Dict[int, Module], +) -> M: + if isinstance(moduledef, int): + return index_module[moduledef] # type: ignore + + assert moduledef.index not in index_module + + # add a dummy module to the index to avoid infinite recursion + module = object.__new__(moduledef.type) + index_module[moduledef.index] = module + + submodules = { + name: _build_module_recursive(submodule, index_module) + for name, submodule in moduledef.submodules + } + + vars(module).update(moduledef.static_fields) + vars(module).update(submodules) + vars(module)["_module__state"] = ModuleState.from_tuple( + moduledef.module_state + ) + + return module + + +def _pop( + module: Module, + filters: tp.Tuple[partitioning.Filter, ...], +) -> tp.Tuple[State, ...]: + module_index: tp.Dict[ids.UUID, int] = {} + path_parts: PathParts = () + predicates = tuple(partitioning.to_predicate(filter) for filter in filters) + states = tuple({} for _ in predicates) + _pop_recursive(module, module_index, path_parts, states, predicates) + + return tuple(State(x) for x in states) + + +def _pop_recursive( + module: Module, + module_index: tp.Dict[ids.UUID, int], + path_parts: PathParts, + states: tp.Tuple[tp.Dict[Path, tp.Any]], + predicates: tp.Tuple[partitioning.Predicate, ...], +) -> None: + if module._module__state.id in module_index: + return + + for name, value in list(vars(module).items()): + if isinstance(value, Module): + _pop_recursive( + value, module_index, (*path_parts, name), states, predicates + ) + continue + elif not nodes.is_node(value): + continue + + path = "/".join((*path_parts, name)) + for state, predicate in zip(states, predicates): + if predicate(path, value): + state[path] = value + delattr(module, name) + break + + module_index[module._module__state.id] = len(module_index) + + +def _update_module( + module: Module, + current_state: State, + updates: tp.Union[State, tp.Tuple[State, ...]], +) -> None: + new_states = [current_state] + + if isinstance(updates, State): + new_states.append(updates) + else: + new_states.extend(updates) + + state: StateDict = {} + for new_state in new_states: + state.update(new_state) + + for path, value in state.items(): + path_parts = path.split("/") + _set_value_at_path(module, path_parts, value) + + +def first_from(*args: tp.Optional[A]) -> A: + """Return the first non-None argument.""" + for arg in args: + if arg is not None: + return arg + raise ValueError("No non-None arguments found.") + + +# register nodes +nodes.register_node_type(Pure) diff --git a/flax/experimental/nnx/nnx/nn/__init__.py b/flax/experimental/nnx/nnx/nn/__init__.py new file mode 100644 index 0000000000..e80ba0b35f --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flax/experimental/nnx/nnx/nn/activations.py b/flax/experimental/nnx/nnx/nn/activations.py new file mode 100644 index 0000000000..6dd783963e --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/activations.py @@ -0,0 +1,69 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.nn import ( + celu, + elu, + gelu, + glu, + hard_sigmoid, + hard_silu, + hard_swish, + hard_tanh, + leaky_relu, + log_sigmoid, + log_softmax, + logsumexp, + normalize, + one_hot, + relu, + relu6, + selu, + sigmoid, + silu, + soft_sign, + softmax, + softplus, + standardize, + swish, +) +from jax.numpy import tanh + +__all__ = [ + "celu", + "elu", + "gelu", + "glu", + "hard_sigmoid", + "hard_silu", + "hard_swish", + "hard_tanh", + "leaky_relu", + "log_sigmoid", + "log_softmax", + "logsumexp", + "normalize", + "one_hot", + "relu", + "relu6", + "selu", + "sigmoid", + "silu", + "soft_sign", + "softmax", + "softplus", + "standardize", + "swish", + "tanh", +] diff --git a/flax/experimental/nnx/nnx/nn/dtypes.py b/flax/experimental/nnx/nnx/nn/dtypes.py new file mode 100644 index 0000000000..32ee84a465 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/dtypes.py @@ -0,0 +1,81 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional + +import jax +from jax import numpy as jnp + +Dtype = Any +Array = Any + + +def canonicalize_dtype( + *args, dtype: Optional[Dtype] = None, inexact: bool = True +) -> Dtype: + """Canonicalize an optional dtype to the definitive dtype. + + If the ``dtype`` is None this function will infer the dtype. If it is not + None it will be returned unmodified or an exceptions is raised if the dtype + is invalid. + from the input arguments using ``jnp.result_type``. + + Args: + *args: JAX array compatible values. None values + are ignored. + dtype: Optional dtype override. If specified the arguments are cast to + the specified dtype instead and dtype inference is disabled. + inexact: When True, the output dtype must be a subdtype + of `jnp.inexact`. Inexact dtypes are real or complex floating points. This + is useful when you want to apply operations that don't work directly on + integers like taking a mean for example. + Returns: + The dtype that *args should be cast to. + """ + if dtype is None: + args_filtered = [jnp.asarray(x) for x in args if x is not None] + dtype = jnp.result_type(*args_filtered) + if inexact and not jnp.issubdtype(dtype, jnp.inexact): + dtype = jnp.promote_types(jnp.float32, dtype) + if inexact and not jnp.issubdtype(dtype, jnp.inexact): + raise ValueError(f"Dtype must be inexact: {dtype}") + return dtype + + +def promote_dtype(*args, dtype=None, inexact=True) -> List[Array]: + """ "Promotes input arguments to a specified or inferred dtype. + + All args are cast to the same dtype. See ``canonicalize_dtype`` for how + this dtype is determined. + + The behavior of promote_dtype is mostly a convinience wrapper around + ``jax.numpy.promote_types``. The differences being that it automatically casts + all input to the inferred dtypes, allows inference to be overridden by a + forced dtype, and has an optional check to garantuee the resulting dtype is + inexact. + + Args: + *args: JAX array compatible values. None values + are returned as is. + dtype: Optional dtype override. If specified the arguments are cast to + the specified dtype instead and dtype inference is disabled. + inexact: When True, the output dtype must be a subdtype + of `jnp.inexact`. Inexact dtypes are real or complex floating points. This + is useful when you want to apply operations that don't work directly on + integers like taking a mean for example. + Returns: + The arguments cast to arrays of the same dtype. + """ + dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact) + return [jnp.asarray(x, dtype) if x is not None else None for x in args] diff --git a/flax/experimental/nnx/nnx/nn/initializers.py b/flax/experimental/nnx/nnx/nn/initializers.py new file mode 100644 index 0000000000..2d2e8587e0 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/initializers.py @@ -0,0 +1,75 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +from jax.nn.initializers import constant as constant +from jax.nn.initializers import delta_orthogonal as delta_orthogonal +from jax.nn.initializers import glorot_normal as glorot_normal +from jax.nn.initializers import glorot_uniform as glorot_uniform +from jax.nn.initializers import he_normal as he_normal +from jax.nn.initializers import he_uniform as he_uniform +from jax.nn.initializers import kaiming_normal as kaiming_normal +from jax.nn.initializers import kaiming_uniform as kaiming_uniform +from jax.nn.initializers import lecun_normal as lecun_normal +from jax.nn.initializers import lecun_uniform as lecun_uniform +from jax.nn.initializers import normal as normal +from jax.nn.initializers import orthogonal as orthogonal +from jax.nn.initializers import uniform as uniform +from jax.nn.initializers import variance_scaling as variance_scaling +from jax.nn.initializers import xavier_normal as xavier_normal +from jax.nn.initializers import xavier_uniform as xavier_uniform + +Shape = tp.Sequence[int] +DTypeLikeInexact = tp.Any +KeyArray = jax.random.KeyArray +Array = jax.Array + + +class Initializer(tp.Protocol): + + @staticmethod + def __call__( + key: KeyArray, shape: Shape, dtype: DTypeLikeInexact = jnp.float_ + ) -> Array: + ... + + +def zeros() -> Initializer: + """Builds an initializer that returns a constant array full of zeros. + + >>> import jax, jax.numpy as jnp + >>> from flax.linen.initializers import zeros_init + >>> zeros_initializer = zeros_init() + >>> zeros_initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) + """ + return jax.nn.initializers.zeros + + +def ones() -> Initializer: + """Builds an initializer that returns a constant array full of ones. + + >>> import jax, jax.numpy as jnp + >>> from flax.linen.initializers import ones_init + >>> ones_initializer = ones_init() + >>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32) + Array([[1., 1.], + [1., 1.], + [1., 1.]], dtype=float32) + """ + return jax.nn.initializers.ones diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py new file mode 100644 index 0000000000..c182a6a955 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -0,0 +1,447 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np +from jax import lax + +from flax.experimental import nnx +from flax.experimental.nnx.nnx import contextlib +from flax.experimental.nnx.nnx.module import Module +from flax.experimental.nnx.nnx.nn import dtypes, initializers + +Array = jax.Array +PRNGKey = tp.Any +Shape = tp.Tuple[int, ...] +Dtype = tp.Any # this could be a real type? +PrecisionLike = tp.Union[ + None, + str, + lax.Precision, + tp.Tuple[str, str], + tp.Tuple[lax.Precision, lax.Precision], +] +ConvGeneralDilatedT = tp.Callable[..., Array] +PaddingLike = tp.Union[str, int, tp.Sequence[tp.Union[int, tp.Tuple[int, int]]]] +LaxPadding = tp.Union[str, tp.Sequence[tp.Tuple[int, int]]] +DotGeneralT = tp.Callable[..., Array] + + +default_kernel_init = initializers.lecun_normal() + + +def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: + """ "Canonicalizes conv padding to a jax.lax supported format.""" + if isinstance(padding, str): + return padding + if isinstance(padding, int): + return [(padding, padding)] * rank + if isinstance(padding, tp.Sequence) and len(padding) == rank: + new_pad = [] + for p in padding: + if isinstance(p, int): + new_pad.append((p, p)) + elif isinstance(p, tuple) and len(p) == 2: + new_pad.append(p) + else: + break + if len(new_pad) == rank: + return new_pad + raise ValueError( + f"Invalid padding format: {padding}, should be str, int," + f" or a sequence of len {rank} where each element is an" + " int or pair of ints." + ) + + +def _conv_dimension_numbers(input_shape): + """Computes the dimension numbers based on the input shape.""" + ndim = len(input_shape) + lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) + rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) + out_spec = lhs_spec + return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) + + +class Linear(Module): + """A linear transformation applied over the last dimension of the input. + + Attributes: + features: the number of output features. + use_bias: whether to add a bias to the output (default: True). + dtype: the dtype of the computation (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer function for the weight matrix. + bias_init: initializer function for the bias. + """ + + def __init__( + self, + in_features: int, + out_features: int, + *, + use_bias: bool = True, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + precision: PrecisionLike = None, + kernel_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = default_kernel_init, + bias_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = initializers.zeros(), + dot_general: DotGeneralT = lax.dot_general, + ctx: contextlib.Context, + ): + kernel_key = ctx.make_rng("params") + self.kernel = nnx.Param( + kernel_init(kernel_key, (in_features, out_features), param_dtype) + ) + if use_bias: + bias_key = ctx.make_rng("params") + self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype)) + else: + self.bias = nnx.Param(None) + + self.in_features = in_features + self.out_features = out_features + self.use_bias = use_bias + self.dtype = dtype + self.param_dtype = param_dtype + self.precision = precision + self.kernel_init = kernel_init + self.bias_init = bias_init + self.dot_general = dot_general + + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along the last dimension. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + kernel = self.kernel + bias = self.bias + + inputs, kernel, bias = dtypes.promote_dtype( + inputs, kernel, bias, dtype=self.dtype + ) + y = self.dot_general( + inputs, + kernel, + (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) + if bias is not None: + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + return y + + +class Conv(Module): + """Convolution Module wrapping `lax.conv_general_dilated[_local]`. + + Attributes: + features: number of convolution filters. + kernel_size: shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: an integer or a sequence of `n` integers, representing the + inter-window strides (default: 1). + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. `'CAUSAL'` padding for a 1D convolution will + left-pad the convolution axis, resulting in same-sized output. + input_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + kernel_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + feature_group_count: integer, default 1. If specified divides the input + features into groups. + use_bias: whether to add a bias to the output (default: True). + mask: Optional mask for the weights during masked convolution. The mask must + be the same shape as the convolution weight matrix. + dtype: the dtype of the computation (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. + """ + + def __init__( + self, + in_features: int, + out_features: int, + kernel_size: tp.Sequence[int], + strides: tp.Union[None, int, tp.Sequence[int]] = 1, + *, + padding: PaddingLike = "SAME", + input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1, + kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1, + feature_group_count: int = 1, + use_bias: bool = True, + mask_fn: tp.Optional[tp.Callable[[Array], Array]] = None, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + precision: PrecisionLike = None, + kernel_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = default_kernel_init, + bias_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = initializers.zeros(), + conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated, + ctx: contextlib.Context, + ): + if isinstance(kernel_size, int): + raise TypeError( + "Expected Conv kernel_size to be a" + " tuple/list of integers (eg.: [3, 3]) but got" + f" {kernel_size}." + ) + else: + kernel_size = tuple(kernel_size) + + kernel_shape = kernel_size + ( + in_features // feature_group_count, + out_features, + ) + kernel_key = ctx.make_rng("params") + self.kernel = nnx.Param(kernel_init(kernel_key, kernel_shape, param_dtype)) + + if use_bias: + bias_shape = (out_features,) + bias_key = ctx.make_rng("params") + self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype)) + else: + self.bias = nnx.Param(None) + + self.in_features = in_features + self.out_features = out_features + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + self.input_dilation = input_dilation + self.kernel_dilation = kernel_dilation + self.feature_group_count = feature_group_count + self.use_bias = use_bias + self.mask_fn = mask_fn + self.dtype = dtype + self.param_dtype = param_dtype + self.precision = precision + self.kernel_init = kernel_init + self.bias_init = bias_init + self.conv_general_dilated = conv_general_dilated + + def __call__(self, inputs: Array) -> Array: + """Applies a (potentially unshared) convolution to the inputs. + + Args: + inputs: input data with dimensions (*batch_dims, spatial_dims..., + features). This is the channels-last convention, i.e. NHWC for a 2d + convolution and NDHWC for a 3D convolution. Note: this is different from + the input convention used by `lax.conv_general_dilated`, which puts the + spatial dimensions last. + Note: If the input has more than 1 batch dimension, all batch dimensions + are flattened into a single dimension for the convolution and restored + before returning. In some cases directly vmap'ing the layer may yield + better performance than this default flattening approach. If the input + lacks a batch dimension it will be added for the convolution and removed + n return, an allowance made to enable writing single-example code. + + Returns: + The convolved data. + """ + + assert isinstance(self.kernel_size, tuple) + kernel_size = self.kernel_size + + def maybe_broadcast( + x: tp.Optional[tp.Union[int, tp.Sequence[int]]] + ) -> tp.Tuple[int, ...]: + if x is None: + # backward compatibility with using None as sentinel for + # broadcast 1 + x = 1 + if isinstance(x, int): + return (x,) * len(kernel_size) + return tuple(x) + + # Combine all input batch dimensions into a single leading batch axis. + num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) + if num_batch_dimensions != 1: + input_batch_shape = inputs.shape[:num_batch_dimensions] + total_batch_size = int(np.prod(input_batch_shape)) + flat_input_shape = (total_batch_size,) + inputs.shape[ + num_batch_dimensions: + ] + inputs = jnp.reshape(inputs, flat_input_shape) + + # self.strides or (1,) * (inputs.ndim - 2) + strides = maybe_broadcast(self.strides) + input_dilation = maybe_broadcast(self.input_dilation) + kernel_dilation = maybe_broadcast(self.kernel_dilation) + + padding_lax = canonicalize_padding(self.padding, len(kernel_size)) + if padding_lax == "CIRCULAR": + kernel_size_dilated = [ + (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) + ] + zero_pad: tp.List[tp.Tuple[int, int]] = [(0, 0)] + pads = ( + zero_pad + + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + + [(0, 0)] + ) + inputs = jnp.pad(inputs, pads, mode="wrap") + padding_lax = "VALID" + elif padding_lax == "CAUSAL": + if len(kernel_size) != 1: + raise ValueError( + "Causal padding is only implemented for 1D convolutions." + ) + left_pad = kernel_dilation[0] * (kernel_size[0] - 1) + pads = [(0, 0), (left_pad, 0), (0, 0)] + inputs = jnp.pad(inputs, pads) + padding_lax = "VALID" + + dimension_numbers = _conv_dimension_numbers(inputs.shape) + + # One shared convolutional kernel for all pixels in the output. + assert self.in_features % self.feature_group_count == 0 + + kernel = self.kernel + + if self.mask_fn is not None: + kernel = self.mask_fn(kernel) + + bias = self.bias + + inputs, kernel, bias = dtypes.promote_dtype( + inputs, kernel, bias, dtype=self.dtype + ) + + y = self.conv_general_dilated( + inputs, + kernel, + strides, + padding_lax, + lhs_dilation=input_dilation, + rhs_dilation=kernel_dilation, + dimension_numbers=dimension_numbers, + feature_group_count=self.feature_group_count, + precision=self.precision, + ) + + if self.use_bias: + bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) + y += bias + + if num_batch_dimensions != 1: + output_shape = input_batch_shape + y.shape[1:] + y = jnp.reshape(y, output_shape) + return y + + +default_embed_init = initializers.variance_scaling( + 1.0, "fan_in", "normal", out_axis=0 +) + + +class Embed(Module): + """Embedding Module. + + A parameterized function from integers [0, n) to d-dimensional vectors. + + Attributes: + num_embeddings: number of embeddings. + features: number of feature dimensions for each embedding. + dtype: the dtype of the embedding vectors (default: same as embedding). + param_dtype: the dtype passed to parameter initializers (default: float32). + embedding_init: embedding initializer. + """ + + def __init__( + self, + num_embeddings: int, + features: int, + *, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + embedding_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = default_embed_init, + ctx: contextlib.Context, + ): + self.embedding = nnx.Param( + embedding_init( + ctx.make_rng("params"), (num_embeddings, features), param_dtype + ) + ) + + self.num_embeddings = num_embeddings + self.features = features + self.dtype = dtype or self.embedding.dtype + self.param_dtype = param_dtype + self.embedding_init = embedding_init + + def __call__(self, inputs: Array) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `features` dimension appended. + """ + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError("Input type must be an integer or unsigned integer.") + # Use take because fancy indexing numpy arrays with JAX indices does not + # work correctly. + (embedding,) = dtypes.promote_dtype( + self.embedding, dtype=self.dtype, inexact=False + ) + return jnp.take(embedding, inputs, axis=0) + + def attend(self, query: Array) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `features` of the + embedding. + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + query, embedding = dtypes.promote_dtype( + query, self.embedding, dtype=self.dtype + ) + return jnp.dot(query, embedding.T) diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/experimental/nnx/nnx/nn/normalization.py new file mode 100644 index 0000000000..d3712c94b4 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/normalization.py @@ -0,0 +1,403 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +from jax import lax + +from flax.experimental import nnx +from flax.experimental.nnx.nnx import contextlib +from flax.experimental.nnx.nnx.module import Module, first_from +from flax.experimental.nnx.nnx.nn import dtypes, initializers + +PRNGKey = jax.Array +Array = jax.Array +Shape = tp.Tuple[int, ...] +Dtype = tp.Any # this could be a real type? + +Axes = tp.Union[int, tp.Any] + + +def _canonicalize_axes(rank: int, axes: Axes) -> tp.Tuple[int, ...]: + """Returns a tuple of deduplicated, sorted, and positive axes.""" + if not isinstance(axes, tp.Iterable): + axes = (axes,) + return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) + + +def _abs_sq(x): + """Computes the elementwise square of the absolute value |x|^2.""" + if jnp.iscomplexobj(x): + return lax.square(lax.real(x)) + lax.square(lax.imag(x)) + else: + return lax.square(x) + + +def _compute_stats( + x: Array, + axes: tp.Optional[Axes], + dtype: tp.Optional[Dtype], + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + use_mean: bool = True, +): + """Computes mean and variance statistics. + + This implementation takes care of a few important details: + - Computes in float32 precision for stability in half precision training. + - mean and variance are computable in a single XLA fusion, + by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]). + - Clips negative variances to zero which can happen due to + roundoff errors. This avoids downstream NaNs. + - Supports averaging across a parallel axis and subgroups of a parallel axis + with a single `lax.pmean` call to avoid latency. + + Arguments: + x: Input array. + axes: The axes in ``x`` to compute mean and variance statistics for. + dtype: tp.Optional dtype specifying the minimal precision. Statistics + are always at least float32 for stability (default: dtype of x). + axis_name: tp.Optional name for the pmapped axis to compute mean over. + axis_index_groups: tp.Optional axis indices. + use_mean: If true, calculate the mean from the input and use it when + computing the variance. If false, set the mean to zero and compute + the variance without subtracting the mean. + + Returns: + A pair ``(mean, var)``. + """ + if dtype is None: + dtype = jnp.result_type(x) + # promote x to at least float32, this avoids half precision computation + # but preserves double or complex floating points + dtype = jnp.promote_types(dtype, jnp.float32) + x = jnp.asarray(x, dtype) + + mean2 = jnp.mean(_abs_sq(x), axes) + if use_mean: + mean = jnp.mean(x, axes) + else: + mean = jnp.zeros(mean2.shape, dtype=dtype) + + if axis_name is not None: + concatenated_mean = jnp.concatenate([mean, mean2]) + mean, mean2 = jnp.split( + lax.pmean( + concatenated_mean, + axis_name=axis_name, + axis_index_groups=axis_index_groups, + ), + 2, + ) + # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due + # to floating point round-off errors. + var = jnp.maximum(0.0, mean2 - _abs_sq(mean)) + return mean, var + + +def _normalize( + x: Array, + mean: Array, + var: Array, + scale: tp.Optional[Array], + bias: tp.Optional[Array], + reduction_axes: Axes, + feature_axes: Axes, + dtype: Dtype, + epsilon: float, +): + """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. + + Arguments: + x: The input. + mean: Mean to use for normalization. + var: Variance to use for normalization. + reduction_axes: The axes in ``x`` to reduce. + feature_axes: Axes containing features. A separate bias and scale is learned + for each specified feature. + dtype: The dtype of the result (default: infer from input and params). + epsilon: Normalization epsilon. + + Returns: + The normalized input. + """ + reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) + feature_axes = _canonicalize_axes(x.ndim, feature_axes) + stats_shape = list(x.shape) + for axis in reduction_axes: + stats_shape[axis] = 1 + mean = mean.reshape(stats_shape) + var = var.reshape(stats_shape) + feature_shape = [1] * x.ndim + reduced_feature_shape = [] + for ax in feature_axes: + feature_shape[ax] = x.shape[ax] + reduced_feature_shape.append(x.shape[ax]) + y = x - mean + mul = lax.rsqrt(var + epsilon) + args = [x] + if scale is not None: + scale = scale.reshape(feature_shape) + mul *= scale + args.append(scale) + y *= mul + if bias is not None: + bias = bias.reshape(feature_shape) + y += bias + args.append(bias) + dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) + return jnp.asarray(y, dtype) + + +class BatchNorm(Module): + """BatchNorm Module. + + Attributes: + use_running_average: if True, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of + the batch statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + + def __init__( + self, + num_features: int, + *, + use_running_average: tp.Optional[bool] = None, + axis: int = -1, + momentum: float = 0.99, + epsilon: float = 1e-5, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = initializers.zeros(), + scale_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = initializers.ones(), + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + ctx: contextlib.Context, + ): + feature_shape = (num_features,) + self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32)) + self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32)) + + if use_scale: + key = ctx.make_rng("params") + self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) + else: + self.scale = nnx.Param(None) + + if use_bias: + key = ctx.make_rng("params") + self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) + else: + self.bias = nnx.Param(None) + + self.num_features = num_features + self.use_running_average = use_running_average + self.axis = axis + self.momentum = momentum + self.epsilon = epsilon + self.dtype = dtype + self.param_dtype = param_dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + self.axis_name = axis_name + self.axis_index_groups = axis_index_groups + + def __call__( + self, + x, + use_running_average: tp.Optional[bool] = None, + *, + ctx: tp.Optional[contextlib.Context] = None, + ): + """Normalizes the input using batch statistics. + + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + + use_running_average = first_from( + use_running_average, + self.use_running_average, + ctx and ctx.get_flag("use_running_average"), + ) + feature_axes = _canonicalize_axes(x.ndim, self.axis) + reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) + + if use_running_average: + mean, var = self.mean, self.var + else: + mean, var = _compute_stats( + x, + reduction_axes, + dtype=self.dtype, + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups, + ) + + self.mean = self.momentum * self.mean + (1 - self.momentum) * mean + self.var = self.momentum * self.var + (1 - self.momentum) * var + + return _normalize( + x, + mean, + var, + self.scale, + self.bias, + reduction_axes, + feature_axes, + self.dtype, + self.epsilon, + ) + + +class LayerNorm(Module): + """Layer normalization (https://arxiv.org/abs/1607.06450). + + LayerNorm normalizes the activations of the layer for each given example in a + batch independently, rather than across a batch like Batch Normalization. + i.e. applies a transformation that maintains the mean activation within + each example close to 0 and the activation standard deviation close to 1. + + Attributes: + epsilon: A small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: If True, bias (beta) is added. + use_scale: If True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + reduction_axes: Axes for computing normalization statistics. + feature_axes: Feature axes for learned bias and scaling. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + This is only needed if the model is subdivided across devices, i.e. the + array being normalized is sharded across devices within a pmap. + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + + def __init__( + self, + num_features: int, + *, + epsilon: float = 1e-6, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = initializers.zeros(), + scale_init: tp.Callable[ + [PRNGKey, Shape, Dtype], Array + ] = initializers.ones(), + reduction_axes: Axes = -1, + feature_axes: Axes = -1, + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + ctx: contextlib.Context, + ): + feature_shape = (num_features,) + + if use_scale: + key = ctx.make_rng("params") + self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) + else: + self.scale = nnx.Param(None) + + if use_bias: + key = ctx.make_rng("params") + self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) + else: + self.bias = nnx.Param(None) + + self.num_features = num_features + self.epsilon = epsilon + self.dtype = dtype + self.param_dtype = param_dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + self.reduction_axes = reduction_axes + self.feature_axes = feature_axes + self.axis_name = axis_name + self.axis_index_groups = axis_index_groups + + def __call__(self, x): + """Applies layer normalization on the input. + + Args: + x: the inputs + + Returns: + Normalized inputs (the same shape as inputs). + """ + mean, var = _compute_stats( + x, + self.reduction_axes, + self.dtype, + self.axis_name, + self.axis_index_groups, + ) + + return _normalize( + x, + mean, + var, + self.scale, + self.bias, + self.reduction_axes, + self.feature_axes, + self.dtype, + self.epsilon, + ) diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py new file mode 100644 index 0000000000..d0b76bea63 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -0,0 +1,86 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from typing import Optional, Sequence + +import jax.numpy as jnp +from jax import lax, random + +from flax.experimental.nnx.nnx import contextlib +from flax.experimental.nnx.nnx.module import Module, first_from + + +@dataclasses.dataclass +class Dropout(Module): + """Create a dropout layer. + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by `1 / (1 - rate)` and + masked, whereas if true, no mask is applied and the inputs are returned + as is. + rng_collection: the rng collection name to use when requesting an rng key. + """ + + rate: float + broadcast_dims: Sequence[int] = () + deterministic: Optional[bool] = None + rng_collection: str = "dropout" + + def __call__( + self, + inputs, + *, + deterministic: Optional[bool] = None, + ctx: Optional[contextlib.Context] = None, + ): + """Applies a random dropout mask to the input. + + Args: + inputs: the inputs that should be randomly masked. + deterministic: if false the inputs are scaled by `1 / (1 - rate)` and + masked, whereas if true, no mask is applied and the inputs are returned + as is. + + Returns: + The masked inputs reweighted to preserve mean. + """ + deterministic = first_from( + deterministic, + self.deterministic, + ctx and ctx.get_flag("deterministic"), + ) + + if (self.rate == 0.0) or deterministic: + return inputs + + # Prevent gradient NaNs in 1.0 edge-case. + if self.rate == 1.0: + return jnp.zeros_like(inputs) + + if ctx is None: + raise ValueError( + "Dropout needs to generate a random mask but no 'ctx' were provided." + ) + + keep_prob = 1.0 - self.rate + rng = ctx.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) diff --git a/flax/experimental/nnx/nnx/nodes.py b/flax/experimental/nnx/nnx/nodes.py new file mode 100644 index 0000000000..5d62aa6051 --- /dev/null +++ b/flax/experimental/nnx/nnx/nodes.py @@ -0,0 +1,34 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import numpy as np + +node_types: tp.Tuple[type, ...] = () + + +def register_node_type(node_type: type) -> None: + global node_types + node_types += (node_type,) + + +def is_node(obj: object) -> bool: + return isinstance(obj, node_types) + + +# register nodes +register_node_type(jax.Array) +register_node_type(np.ndarray) diff --git a/flax/experimental/nnx/nnx/partitioning.py b/flax/experimental/nnx/nnx/partitioning.py new file mode 100644 index 0000000000..d875845ebd --- /dev/null +++ b/flax/experimental/nnx/nnx/partitioning.py @@ -0,0 +1,103 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import dataclasses +import typing as tp + +import jax +import numpy as np + +from flax.experimental import nnx + +if tp.TYPE_CHECKING: + ellipsis = builtins.ellipsis +else: + ellipsis = tp.Any + +Path = str +Predicate = tp.Callable[[Path, tp.Any], bool] +FilterLiteral = tp.Union[type, Predicate, ellipsis, None] +Filter = tp.Union[FilterLiteral, tp.Tuple[FilterLiteral, ...]] + + +def to_predicate(filter: Filter) -> Predicate: + if isinstance(filter, str): + raise TypeError(f"Invalid filter of type '{type(filter).__name__}'") + elif isinstance(filter, type): + return OfType(filter) + elif filter is Ellipsis: + return Everything() + elif filter is None: + return Nothing() + elif callable(filter): + return filter + elif isinstance(filter, tp.Tuple): + return Any(*filter) + else: + raise TypeError(f"Invalid collection filter: {filter:!r}. ") + + +@dataclasses.dataclass +class OfType: + type: type + + def __call__(self, path: Path, x: tp.Any): + return isinstance(x, self.type) + + +class Any: + + def __init__(self, *filters: Filter): + self.predicates = tuple( + to_predicate(collection_filter) for collection_filter in filters + ) + + def __call__(self, path: Path, x: tp.Any): + return any(predicate(path, x) for predicate in self.predicates) + + +class All: + + def __init__(self, *filters: Filter): + self.predicates = tuple( + to_predicate(collection_filter) for collection_filter in filters + ) + + def __call__(self, path: Path, x: tp.Any): + return all(predicate(path, x) for predicate in self.predicates) + + +class Not: + + def __init__(self, collection_filter: Filter): + self.predicate = to_predicate(collection_filter) + + def __call__(self, path: Path, x: tp.Any): + return not self.predicate(path, x) + + +class Everything: + + def __call__(self, path: Path, x: tp.Any): + return True + + +class Nothing: + + def __call__(self, path: Path, x: tp.Any): + return False + + +buffers = (jax.Array, np.ndarray) diff --git a/flax/experimental/nnx/nnx/pytreelib.py b/flax/experimental/nnx/nnx/pytreelib.py new file mode 100644 index 0000000000..ad1ac688eb --- /dev/null +++ b/flax/experimental/nnx/nnx/pytreelib.py @@ -0,0 +1,288 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import importlib.util +import inspect +import typing as tp +from abc import ABCMeta +from copy import copy +from functools import partial +from types import MappingProxyType + +import jax + +from flax.experimental.nnx.nnx import containers, nodes, reprlib + +A = tp.TypeVar("A") +P = tp.TypeVar("P", bound="Pytree") + + +class TreeNode(containers.NodeBase[A]): + pass + + +@contextlib.contextmanager +def _mutable(obj: P) -> tp.Iterator[None]: + vars(obj)["_pytree__is_mutable"] = True + try: + yield + finally: + del vars(obj)["_pytree__is_mutable"] + + +@contextlib.contextmanager +def _initializing(obj: P) -> tp.Iterator[None]: + vars(obj)["_pytree__initializing"] = True + try: + yield + finally: + del vars(obj)["_pytree__initializing"] + + +class PytreeMeta(ABCMeta): + if not tp.TYPE_CHECKING: + + def __call__(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: + return cls.call(*args, **kwargs) + + def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: + obj: P = cls.__new__(cls, *args, **kwargs) + vars(obj)["_pytree__sorted_fields"] = ["_pytree__sorted_fields"] + + with _mutable(obj), _initializing(obj): + obj.__init__(*args, **kwargs) + + if dataclasses.is_dataclass(obj): + assert isinstance(obj, Pytree) + for field in dataclasses.fields(obj): + if "nnx_container_fn" not in field.metadata: + continue + + container_fn = field.metadata["nnx_container_fn"] + value = vars(obj)[field.name] + value = container_fn(value) + vars(obj)[field.name] = value + + vars(obj)["_pytree__sorted_fields"] = sorted(vars(obj)) + + return obj + + +class Pytree(reprlib.Representable, metaclass=PytreeMeta): + _pytree__is_mutable: bool + _pytree__class_is_mutable: bool + _pytree__sorted_fields: tp.Tuple[str, ...] + + if not tp.TYPE_CHECKING: + + def __getattribute__(self, name: str) -> tp.Any: + value = object.__getattribute__(self, name) + if isinstance(value, containers.Container): + return value.value + return value + + def __setattr__(self, name: str, value: tp.Any) -> None: + self._setattr(name, value) + + def _setattr(self: P, name: str, value: tp.Any): + vars_dict = vars(self) + if "_pytree__initializing" in vars_dict: + pass + elif name not in vars_dict: + raise AttributeError(r"Cannot add new fields to an initialized Pytree") + elif ( + "_pytree__is_mutable" not in vars_dict + and not self._pytree__class_is_mutable + ): + raise AttributeError( + f"{type(self)} is immutable, trying to update field {name}" + ) + + if name in vars_dict and isinstance(vars_dict[name], containers.Container): + vars_dict[name] = vars_dict[name].replace(value=value) + else: + if isinstance(value, containers.Container): + value = value.copy() + vars_dict[name] = value + + def __init_subclass__(cls, mutable: bool = False): + super().__init_subclass__() + # init class variables + cls._pytree__is_mutable = False + cls._pytree__class_is_mutable = mutable + + # TODO: clean up this in the future once minimal supported version is 0.4.7 + if hasattr(jax.tree_util, "register_pytree_with_keys"): + if ( + "flatten_func" + in inspect.signature( + jax.tree_util.register_pytree_with_keys + ).parameters + ): + jax.tree_util.register_pytree_with_keys( + cls, + partial( + cls._pytree__flatten, + with_key_paths=True, + ), + cls._pytree__unflatten, + flatten_func=partial( + cls._pytree__flatten, + with_key_paths=False, + ), + ) + else: + jax.tree_util.register_pytree_with_keys( + cls, + partial( + cls._pytree__flatten, + with_key_paths=True, + ), + cls._pytree__unflatten, + ) + else: + jax.tree_util.register_pytree_node( + cls, + partial( + cls._pytree__flatten, + with_key_paths=False, + ), + cls._pytree__unflatten, + ) + + # flax serialization support + if importlib.util.find_spec("flax") is not None: + from flax import serialization + + serialization.register_serialization_state( + cls, cls._to_flax_state_dict, cls._from_flax_state_dict + ) + + @classmethod + def _pytree__flatten( + cls, + pytree: "Pytree", + *, + with_key_paths: bool, + ): + all_vars = vars(pytree) + static = {} + node_values = [] + node_names = [] + + for field in pytree._pytree__sorted_fields: + value = all_vars[field] + + if nodes.is_node(value): + node_names.append(field) + if with_key_paths: + node_values.append((jax.tree_util.GetAttrKey(field), value)) + else: + node_values.append(value) + else: + static[field] = value + + return node_values, (tuple(node_names), MappingProxyType(static)) + + @classmethod + def _pytree__unflatten( + cls: tp.Type[P], + metadata: tp.Tuple[tp.Tuple[str, ...], tp.Mapping[str, tp.Any]], + node_values: tp.Tuple[tp.Any, ...], + ) -> P: + node_names, static_fields = metadata + pytree = object.__new__(cls) + pytree.__dict__.update(zip(node_names, node_values)) + pytree.__dict__.update(static_fields) + return pytree + + @classmethod + def _to_flax_state_dict(cls, pytree: "Pytree") -> tp.Dict[str, tp.Any]: + from flax import serialization + + state_dict = { + name: serialization.to_state_dict(getattr(pytree, name)) + for name, value in vars(pytree).items() + if nodes.is_node(value) + } + return state_dict + + @classmethod + def _from_flax_state_dict( + cls, + pytree: P, + state: tp.Dict[str, tp.Any], + ) -> P: + """Restore the state of a data class.""" + from flax import serialization + + state = state.copy() # copy the state so we can pop the restored fields. + updates = {} + for name, value in vars(pytree).items(): + if not nodes.is_node(value): + continue + if name not in state: + raise ValueError( + f"Missing field {name} in state dict while restoring" + f" an instance of {type(pytree).__name__}," + f" at path {serialization.current_path()}" + ) + value_state = state.pop(name) + updates[name] = serialization.from_state_dict( + value, value_state, name=name + ) + if state: + names = ",".join(state.keys()) + raise ValueError( + f'Unknown field(s) "{names}" in state dict while' + f" restoring an instance of {type(pytree).__name__}" + f" at path {serialization.current_path()}" + ) + return pytree.replace(**updates) + + def replace(self: P, **kwargs: tp.Any) -> P: + """ + Replace the values of the fields of the object with the values of the + keyword arguments. If the object is a dataclass, `dataclasses.replace` + will be used. Otherwise, a new object will be created with the same + type as the original object. + """ + if dataclasses.is_dataclass(self): + return dataclasses.replace(self, **kwargs) + + unknown_keys = set(kwargs) - set(vars(self)) + if unknown_keys and not self._pytree__class_is_mutable: + raise ValueError( + f"Trying to replace unknown fields {unknown_keys} " + f"for '{type(self).__name__}'" + ) + + pytree = copy(self) + with _mutable(pytree): + for key, value in kwargs.items(): + setattr(pytree, key, value) + + return pytree + + def __nnx_repr__(self): + yield reprlib.Object(type(self)) + for name, value in vars(self).items(): + yield reprlib.Attr(name, repr(value)) + + +# register node types +nodes.register_node_type(Pytree) +nodes.register_node_type(TreeNode) diff --git a/flax/experimental/nnx/nnx/reprlib.py b/flax/experimental/nnx/nnx/reprlib.py new file mode 100644 index 0000000000..f52b3552d3 --- /dev/null +++ b/flax/experimental/nnx/nnx/reprlib.py @@ -0,0 +1,106 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import threading +import typing as tp +from abc import abstractmethod + + +@dataclasses.dataclass +class ReprContext(threading.local): + indent_stack: tp.List[str] = dataclasses.field(default_factory=lambda: [""]) + + +REPR_CONTEXT = ReprContext() + + +@dataclasses.dataclass +class Object: + type: tp.Union[str, type] + start: str = "(" + end: str = ")" + value_sep: str = "=" + elem_indent: str = " " + empty_repr: str = "" + + +@dataclasses.dataclass +class Attr: + key: str + value: tp.Union[str, tp.Any] + start: str = "" + end: str = "" + + +class Representable: + __slots__ = () + + @abstractmethod + def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: + raise NotImplementedError + + def __repr__(self) -> str: + return get_repr(self) + + +@contextlib.contextmanager +def add_indent(indent: str) -> tp.Iterator[None]: + REPR_CONTEXT.indent_stack.append(REPR_CONTEXT.indent_stack[-1] + indent) + + try: + yield + finally: + REPR_CONTEXT.indent_stack.pop() + + +def get_indent() -> str: + return REPR_CONTEXT.indent_stack[-1] + + +def get_repr(obj: Representable) -> str: + if not isinstance(obj, Representable): + raise TypeError(f"Object {obj!r} is not representable") + + iterator = obj.__nnx_repr__() + config = next(iterator) + if not isinstance(config, Object): + raise TypeError(f"First item must be Config, got {type(config).__name__}") + + def _repr_elem(elem: tp.Any) -> str: + if not isinstance(elem, Attr): + raise TypeError(f"Item must be Elem, got {type(elem).__name__}") + + value = elem.value if isinstance(elem.value, str) else repr(elem.value) + + if "\n" in value and not isinstance(elem.value, Representable): + value = value.replace("\n", "\n" + get_indent()) + + return f"{get_indent()}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}" + + with add_indent(config.elem_indent): + elems = list(map(_repr_elem, iterator)) + elems = ",\n".join(elems) + + if elems: + elems = "\n" + elems + "\n" + get_indent() + else: + elems = config.empty_repr + + type_repr = ( + config.type if isinstance(config.type, str) else config.type.__name__ + ) + + return f"{type_repr}{config.start}{elems}{config.end}" diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/experimental/nnx/nnx/spmd.py new file mode 100644 index 0000000000..8e64d5fb2f --- /dev/null +++ b/flax/experimental/nnx/nnx/spmd.py @@ -0,0 +1,440 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import contextlib +import dataclasses +import enum +import functools +import threading +import typing as tp + +import jax +from jax.experimental import maps +from jax.sharding import Mesh, PartitionSpec + +from flax.experimental.nnx.nnx import containers +from flax.experimental.nnx.nnx.nn import initializers +from flax.experimental.nnx.nnx.state import State + +# Real types and dummy aliases for documentation +LogicalRules = tp.Sequence[tuple[str, tp.Union[str, tuple[str, ...], None]]] +Array = tp.Any # pylint: disable=invalid-name +ArrayPytree = tp.Any # pylint: disable=invalid-name +LogicalPartitionSpec = tp.Any # pylint: disable=invalid-name +LogicalPartitionSpecPytree = tp.Any # pylint: disable=invalid-name +PartitionSpecPytree = tp.Any # pylint: disable=invalid-name + +A = tp.TypeVar("A") +F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any]) +PARTITION_NAME = "partition_name" +Sharding = tuple[tp.Optional[str], ...] + + +@tp.runtime_checkable +class HasSharding(tp.Protocol): + sharding: tp.Optional[Sharding] + + +def add_axis( + state: State, index: int, params: tp.Mapping[tp.Any, tp.Any] +) -> State: + axis_name = _get_partition_name(params) + + def _add_axis(x: tp.Any): + if ( + isinstance(x, containers.Node) + and isinstance(x, HasSharding) + and x.sharding is not None + ): + sharding = list(x.sharding) + while len(sharding) < index: + sharding.append(None) + sharding.insert(index, axis_name) + return x.replace(sharding=tuple(sharding)) + return x + + return jax.tree_map( + _add_axis, state, is_leaf=lambda x: isinstance(x, containers.Node) + ) + + +def remove_axis( + state: State, index: int, params: tp.Mapping[tp.Any, tp.Any] +) -> State: + axis_name = _get_partition_name(params) + + def _remove_axis(x: tp.Any): + if ( + isinstance(x, containers.Node) + and isinstance(x, HasSharding) + and x.sharding is not None + ): + sharding = list(x.sharding) + assert sharding.pop(index) == axis_name + return x.replace(sharding=tuple(sharding)) + return x + + return jax.tree_map( + _remove_axis, state, is_leaf=lambda x: isinstance(x, containers.Node) + ) + + +def _get_partition_name(params: tp.Mapping[tp.Any, tp.Any]) -> str: + if PARTITION_NAME not in params: + raise ValueError( + 'Trying to transform a Partitioned variable but "partition_name"' + f" is not specified in metadata_params: {params}" + ) + return params[PARTITION_NAME] + + +def get_partition_spec(tree: A) -> A: + """Extracts a PartitionSpec tree from a PyTree containing ``Node`` values.""" + + def f(x): + if isinstance(x, containers.Node): + if isinstance(x, HasSharding) and x.sharding: + return PartitionSpec(*x.sharding) + else: + x = x.value + + # Unboxed arrays, which should be replicated across all devices + if hasattr(x, "shape"): + return PartitionSpec() + else: + return None + + return jax.tree_map(f, tree, is_leaf=lambda x: isinstance(x, containers.Node)) + + +# Dynamic Axis Mapping Context +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass +class _AxisRules(threading.local): + """Dynamic logical axis to mesh axis binding context.""" + + rules: LogicalRules = () + + +# Global axis binding context. +_axis_rules = _AxisRules() + + +def set_logical_axis_rules(rules: LogicalRules): + """Sets the global logical axis to mesh axis binding.""" + _axis_rules.rules = rules + + +def get_logical_axis_rules() -> LogicalRules: + """Returns the global logical axis to mesh axis binding.""" + return _axis_rules.rules + + +@contextlib.contextmanager +def logical_axis_rules(rules: LogicalRules): + """Context manager for setting the logical to mesh axis bindings.""" + old_rules = _axis_rules.rules + try: + _axis_rules.rules = rules + yield + finally: + _axis_rules.rules = old_rules + + +class _UnassignedAxis: + """Sentinel class for unassigned logical axis name.""" + + def __repr__(self): + return "UnassignedAxis" + + def __bool__(self): + return False + + +_unassigned_axis = _UnassignedAxis() + + +def _mesh_assignment_free(new_assignment, existing_assignments): + """Determines if a given mesh axis has already been assigned.""" + new = set(jax.tree_util.tree_leaves(new_assignment)) + existing = set(jax.tree_util.tree_leaves(existing_assignments)) + if existing.intersection(new): + return False + return True + + +def _logical_to_mesh_axes( + array_dim_names: tp.Optional[tp.Sequence[tp.Optional[str]]], + rules: tp.Optional[LogicalRules] = None, +) -> tp.Optional[list[tp.Union[_UnassignedAxis, None, str, tuple[str]]]]: + """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" + if array_dim_names is None: + return None + if rules is None: + rules = _axis_rules.rules + axis_name_counts = collections.Counter(array_dim_names) + dups = tuple( + k for k, v in axis_name_counts.items() if v > 1 and k is not None + ) + if dups: + raise ValueError( + f"Unsupported: Dimensions {dups} occur more than once in array names." + ) + if not isinstance(rules, (tuple, list)): + raise ValueError("Unknown axis rule specification type.") + # We assign mesh axes using a priority based ruleset over logical axis names. + result: list[tp.Union[_UnassignedAxis, None, str, tuple[str]]] + result = [_unassigned_axis] * len(array_dim_names) + for rule_model_name, rule_mesh_names in rules: + if rule_model_name in array_dim_names: + pos = array_dim_names.index(rule_model_name) + if ( + _mesh_assignment_free(rule_mesh_names, result) + and result[pos] == _unassigned_axis + ): + result[pos] = rule_mesh_names + return result + + +def logical_to_mesh_axes( + array_dim_names: tp.Optional[tp.Sequence[tp.Optional[str]]], + rules: tp.Optional[LogicalRules] = None, +) -> tp.Optional[jax.sharding.PartitionSpec]: + """Compute layout for an array. + + The rules are in order of precedence, and consist of pairs: + (ArrayDimensionName, MeshDimensionName), meaning that the given array + dimension (if present and unused) should be sharded across the given + mesh dimension (if present and unused). + + A Layout of an Array is expressed as a tuple with one element for each + dimension in the Array. The element is either None, or is the name of a + mesh-dimension, meaning that this dimension of the array is sharded across + this dimension of the mesh. + + For example, given an array with + array_dim_names = ('batch', 'length', 'heads', 'features') + and the layout rules are: + rules = (('batch', 'X'), + ('features', 'X'), + ('heads', 'Y'), + ('batch', 'Z')) + + then this function will return + + PartitionSpec('X', None, 'Y', None) + + Args: + array_dim_names: tp.Tuple of array dimension names or None. + rules: tp.Optional logical to mesh rules override. Defaults to using the + rules defined in the dynamic context set from the `axis_rules` function. + + Returns: + PartitionSpec for the parameter. + """ + result = _logical_to_mesh_axes(array_dim_names, rules) + if result is None: + return None + # We default to None - ie unsharded along the dimension. + result = [None if x is _unassigned_axis else x for x in result] + return jax.sharding.PartitionSpec(*result) + + +def logical_to_mesh( + tree: tp.Any, rules: tp.Optional[LogicalRules] = None +) -> tp.Any: + """Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.""" + return jax.tree_map( + lambda x: logical_to_mesh_axes(x, rules), + tree, + is_leaf=lambda x: isinstance(x, jax.sharding.PartitionSpec), + ) + + +def logical_to_mesh_sharding( + tree: tp.Any, + mesh: jax.sharding.Mesh, + rules: tp.Optional[LogicalRules] = None, +) -> tp.Any: + """Convert pytrees of logical PartitionSpecs to shardings.""" + return jax.tree_map( + lambda x: jax.sharding.NamedSharding(mesh, x), + logical_to_mesh(tree, rules), + is_leaf=lambda x: isinstance(x, jax.sharding.PartitionSpec), + ) + + +def _global_mesh_defined() -> bool: + """Checks if global xmap/pjit mesh resource environment is defined.""" + maps_env = maps.thread_resources.env + return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison + + +class RulesFallback(enum.Enum): + """How a sharding constraint should behave when no matching rule is found.""" + + AXIS_IS_UNSHARDED = "axis_is_unsharded" + RAISE_ERROR = "raise_error" + NO_CONSTRAINT = "no_constraint" + + +def _with_sharding_constraint( + x: Array, + axis_resources: tp.Optional[jax.sharding.PartitionSpec], + mesh: tp.Optional[jax.sharding.Mesh] = None, +): + """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit.""" + # if jax.devices()[0].platform == "cpu" or ( + if not _global_mesh_defined() and mesh is None: + return x + else: + if mesh is not None and axis_resources is not None: + sharding = jax.sharding.NamedSharding(mesh, axis_resources) + return jax.lax.with_sharding_constraint(x, sharding) + return jax.lax.with_sharding_constraint(x, axis_resources) + + +def _with_sharding_constraint_one_fallback( + axis_resources: LogicalPartitionSpec, + x: Array, + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, + rules: tp.Optional[LogicalRules] = None, + mesh: tp.Optional[jax.sharding.Mesh] = None, +): + """Either imposes a sharding constraint or applies fallback.""" + mesh_axes = _logical_to_mesh_axes(axis_resources, rules) + if mesh_axes is None: + return _with_sharding_constraint(x, None, mesh=mesh) + + if fallback == RulesFallback.AXIS_IS_UNSHARDED: + mesh_axes = [None if x is _unassigned_axis else x for x in mesh_axes] + else: + if any(x is _unassigned_axis for x in mesh_axes): + if fallback == RulesFallback.RAISE_ERROR: + raise ValueError(f"Axis names {axis_resources} did not match a rule") + else: + return x + return _with_sharding_constraint( + x, jax.sharding.PartitionSpec(*mesh_axes), mesh=mesh + ) + + +def _is_logical_spec(x): + return x is None or ( + isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x) + ) + + +def with_logical_constraint( + x: ArrayPytree, + logical_axis_resources: LogicalPartitionSpecPytree, + rules: tp.Optional[LogicalRules] = None, + mesh: tp.Optional[jax.sharding.Mesh] = None, + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, +): + """Version of pjit's with_sharding_constraint that uses logical axis names.""" + # If no axis binding is set, this is a no-op. + if rules is None: + rules = _axis_rules.rules + if not rules or logical_axis_resources is None: + return x + # Translate logical names to mesh assignments. + return jax.tree_util.tree_map( + functools.partial( + _with_sharding_constraint_one_fallback, + fallback=fallback, + rules=rules, + mesh=mesh, + ), + logical_axis_resources, + x, + is_leaf=_is_logical_spec, + ) + + +# Logical Partitioning Axis Metadata +# ------------------------------------------------------------------------------ + + +@tp.runtime_checkable +class LogicallyPartitioned(tp.Protocol): + unbox_fn: tp.Callable[[containers.Container[tp.Any]], tp.Any] + sharding: Sharding + mesh: tp.Optional[Mesh] + rules: tp.Optional[LogicalRules] + + +def with_logical_partitioning( + initializer: F, + sharding: Sharding, + mesh: tp.Optional[jax.sharding.Mesh] = None, + rules: tp.Optional[LogicalRules] = None, + **metadata: tp.Any, +) -> F: + """Wraps a function's return value with LogicallyPartitioned. + + Example:: + + kernel_init = with_logical_partitioning( + nn.initializers.lecun_normal, (None, "data")) + partitioned_dense = nn.Dense(features, kernel_init=kernel_init) + + Args: + fn: The function to be wrapped. Typically this is an initializer. + names: The logical axis passed to ``LogicallyPartitioned``. + mesh: The mesh to use for the partitioning. If None, the global mesh + resource is used if available. + rules: tp.Optional logical to mesh rules use. If None, the global rules + are used if available. + Returns: + A function wrapping ``fn`` that will return an instance of + ``LogicallyPartitioned``. + """ + + def unbox_fn(node: containers.Node[tp.Any]) -> tp.Any: + """Returns the wrapped value with the partitioning constraint applied.""" + if _global_mesh_defined() or ( + isinstance(node, LogicallyPartitioned) and node.mesh is not None + ): + return with_logical_constraint( + node.value, + get_partition_spec(node), + rules=node.rules, + mesh=node.mesh, + ) + return node.value + + @functools.wraps(initializer) + def wrapper(*args, **kwargs): + y = initializer(*args, **kwargs) + if _global_mesh_defined() or (mesh is not None): + return with_logical_constraint( + y, + sharding, + rules=rules, + mesh=mesh, + ) + return y + + return containers.with_metadata( + tp.cast(F, wrapper), + unbox_fn=unbox_fn, + sharding=sharding, + mesh=mesh, + rules=rules, + **metadata, + ) diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py new file mode 100644 index 0000000000..03460652c2 --- /dev/null +++ b/flax/experimental/nnx/nnx/state.py @@ -0,0 +1,211 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.tree_util as jtu + +from flax.experimental.nnx.nnx import nodes, partitioning, reprlib +from flax.experimental.nnx.nnx.containers import Node + +A = tp.TypeVar("A") + +Leaf = tp.Any +Path = str +StateDict = tp.Dict[Path, tp.Any] +StateMapping = tp.Mapping[Path, tp.Any] + + +class State(tp.Mapping[Path, Leaf], reprlib.Representable): + __slots__ = ("_mapping",) + + def __init__( + self, + __input: tp.Union[ + tp.Mapping[Path, Leaf], + tp.Iterator[tp.Tuple[Path, Leaf]], + ], + /, + ): + if isinstance(__input, tp.Mapping): + self._mapping = dict(sorted(__input.items(), key=lambda x: x[0])) + else: + self._mapping = dict(sorted(__input, key=lambda x: x[0])) + + def __getitem__(self, __key: Path) -> Leaf: + return self._mapping[__key] + + def __iter__(self) -> tp.Iterator[Path]: + return iter(self._mapping) + + def __len__(self) -> int: + return len(self._mapping) + + def __nnx_repr__(self): + yield reprlib.Object(type(self), value_sep=": ", start="({", end="})") + + for k, v in self._mapping.items(): + yield reprlib.Attr(repr(k), v) + + @tp.overload + def partition(self, first: partitioning.Filter, /) -> "State": + ... + + @tp.overload + def partition( + self, + first: partitioning.Filter, + second: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Tuple["State", ...]: + ... + + def partition( + self, first: partitioning.Filter, /, *filters: partitioning.Filter + ) -> tp.Union["State", tp.Tuple["State", ...]]: + filters = (first, *filters) + *states, rest = _split_state(self, *filters) + + if rest: + raise ValueError( + "Non-exhaustive filters, got a non-empty remainder: " + f"{list(rest.keys())}.\nUse `...` to match all remaining elements." + ) + + if len(states) == 1: + states = State(states[0]) + else: + states = tuple(State(state) for state in states) + return states + + @tp.overload + def filter( + self, + first: partitioning.Filter, + /, + ) -> "State": + ... + + @tp.overload + def filter( + self, + first: partitioning.Filter, + second: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Tuple["State", ...]: + ... + + def filter( + self, + first: partitioning.Filter, + /, + *filters: partitioning.Filter, + ) -> tp.Union["State", tp.Tuple["State", ...]]: + *states, _rest = _split_state(self, first, *filters) + + assert len(states) == len(filters) + 1 + + if len(states) == 1: + states = State(states[0]) + else: + states = tuple(State(state) for state in states) + + return states + + @staticmethod + def merge(state: "State", /, *states: "State") -> "State": + states = (state, *states) + + if len(states) == 1: + return states[0] + + new_state: StateDict = {} + + for state in states: + new_state.update(state) + + return State(new_state) + + def __or__(self, other: "State") -> "State": + if not other: + return self + return State.merge(self, other) + + def __sub__(self, other: "State") -> "State": + if not other: + return self + + # create new State via __new__ to avoid __init__ sorting + _mapping = {k: v for k, v in self.items() if k not in other} + state = object.__new__(State) + state._mapping = _mapping + return state + + +def _state_flatten_with_keys( + x: State, +): + children = tuple((jtu.DictKey(key), value) for key, value in x.items()) + return children, tuple(x.keys()) + + +def _state_unflatten( + keys: tp.Tuple[Path, ...], + leaves: tp.Tuple[Leaf, ...], +): + state = object.__new__(State) + state._mapping = dict(zip(keys, leaves)) + return state + + +jax.tree_util.register_pytree_with_keys( + State, _state_flatten_with_keys, _state_unflatten +) + + +def _split_state( + state: StateMapping, + *filters: partitioning.Filter, +) -> tp.Tuple[StateDict, ...]: + for i, filter_ in enumerate(filters): + if filter_ is ... and i != len(filters) - 1: + raise ValueError( + "Ellipsis `...` can only be used as the last filter, " + f"got it at index {i}." + ) + predicates = tuple(map(partitioning.to_predicate, filters)) + + # we have n + 1 states, where n is the number of predicates + # the last state is for values that don't match any predicate + states: tp.Tuple[StateDict, ...] = tuple( + {} for _ in range(len(predicates) + 1) + ) + + for path, value in state.items(): + for i, predicate in enumerate(predicates): + if predicate(path, value): + states[i][path] = value + break + else: + # if we didn't break, set leaf to last state + states[-1][path] = value + + return states + + +# register nodes +nodes.register_node_type(State) diff --git a/flax/experimental/nnx/nnx/tracers.py b/flax/experimental/nnx/nnx/tracers.py new file mode 100644 index 0000000000..6f6472b2bc --- /dev/null +++ b/flax/experimental/nnx/nnx/tracers.py @@ -0,0 +1,113 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Taken from flax/core/tracer.py 🏴‍☠️ + +import contextlib +import dataclasses +import threading +import typing as tp + +import jax +import jax.core +from jax.core import MainTrace + +from flax.experimental.nnx.nnx import reprlib + + +@tp.runtime_checkable +class Tracer(tp.Protocol): + _trace: jax.core.Trace + + +def get_top_trace(pytree: tp.Union[tp.Any, Tracer]) -> MainTrace: + """Returns the main top trace of a sequence of tracers.""" + if isinstance(pytree, Tracer): + return pytree._trace.main + + return jax.core.find_top_trace(jax.tree_util.tree_leaves(pytree)).main + + +def current_jax_trace() -> MainTrace: + """Returns the innermost Jax tracer.""" + return get_top_trace(()) + + +def get_all_traces(pytree: tp.Union[tp.Any, Tracer]) -> tp.Set[MainTrace]: + """Returns True if all tracers have the same main trace.""" + if isinstance(pytree, Tracer): + return {pytree._trace.main} + else: + return { + trace._trace.main + for trace in jax.tree_util.tree_leaves(pytree) + if isinstance(trace, Tracer) + } + + +def trace_level(main): + """Returns the level of the trace of -infinity if it is None.""" + if main: + return main.level + return float("-inf") + + +@dataclasses.dataclass +class TraceContext(threading.local): + nnx_trace_stack: tp.List[MainTrace] = dataclasses.field( + default_factory=lambda: [current_jax_trace()] + ) + + +TRACE_CONTEXT = TraceContext() + + +@contextlib.contextmanager +def nnx_trace(trace: MainTrace): + TRACE_CONTEXT.nnx_trace_stack.append(trace) + try: + yield + finally: + TRACE_CONTEXT.nnx_trace_stack.pop() + + +def current_nnx_trace() -> MainTrace: + return TRACE_CONTEXT.nnx_trace_stack[-1] + + +class TraceState(reprlib.Representable): + __slots__ = ["_jax_trace", "_nnx_trace"] + + def __init__(self): + self._jax_trace = current_jax_trace() + self._nnx_trace = current_nnx_trace() + + @property + def jax_trace(self): + return self._jax_trace + + @property + def nnx_trace(self): + return self._nnx_trace + + def is_valid(self) -> bool: + return ( + self._jax_trace is current_jax_trace() + and self._nnx_trace is current_nnx_trace() + ) + + def __nnx_repr__(self): + yield reprlib.Object(f"{type(self).__name__}") + yield reprlib.Attr("jax_trace", self._jax_trace) + yield reprlib.Attr("nnx_trace", self._nnx_trace) diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py new file mode 100644 index 0000000000..545d1e8602 --- /dev/null +++ b/flax/experimental/nnx/nnx/transforms.py @@ -0,0 +1,966 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import functools +import typing as tp +from types import MappingProxyType +from typing import Any + +import jax +import jax.numpy as jnp +import jax.stages +import jax.tree_util as jtu + +from flax.experimental.nnx.nnx import containers, contextlib, partitioning, spmd, tracers +from flax.experimental.nnx.nnx.module import ( + CallableProxy, + DelayedAccessor, + Module, + ModuleDef, + ModuleMeta, + PureModule, +) +from flax.experimental.nnx.nnx.state import State + +A = tp.TypeVar("A") +C = tp.TypeVar("C") +B = tp.TypeVar("B") +F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any]) +G = tp.TypeVar("G", bound=tp.Callable[..., tp.Any]) +M = tp.TypeVar("M", bound=Module) + +AxisName = tp.Hashable +Leaf = tp.Any +Leaves = tp.List[Leaf] + + +class JitTransform(jax.stages.Wrapped): + + def __init__( + self, + fun: tp.Callable[..., tp.Any], + stateful: bool, + **jit_kwargs, + ): + @functools.partial(jax.jit, **jit_kwargs) + def jitted_fn(pure_module: PureModule[Module], *args, **kwargs): + if "ctx" in kwargs and isinstance(kwargs["ctx"], contextlib.PureContext): + kwargs["ctx"] = kwargs["ctx"].merge() + + nnx_trace = tracers.get_top_trace((args, kwargs)) + with tracers.nnx_trace(nnx_trace): + module = pure_module.merge() + out = fun(module, *args, **kwargs) + + if self.stateful: + updates = module.get_state() + out = (updates, out) + + return out + + self.jitted_fn = jitted_fn + self.stateful = stateful + + def __call__(self, module: tp.Any, *args, **kwargs): + if not isinstance(module, Module): + raise TypeError(f"Expected Module, got {type(module).__name__}") + if "ctx" in kwargs and isinstance(kwargs["ctx"], contextlib.Context): + kwargs["ctx"] = kwargs["ctx"].partition() + + pure_module = module.partition() + out = self.jitted_fn(pure_module, *args, **kwargs) + if self.stateful: + updates: State + updates, out = out + module.update_state(updates) + return out + + def __repr__(self): + return f"JitTransform({self.jitted_fn})" + + def lower(self, *args, **kwargs): + return self.jitted_fn.lower(*args, **kwargs) + + +UNSPECIFIED = object() + + +def jit( + fun: tp.Callable[..., tp.Any], + *, + stateful: bool = True, + in_shardings: tp.Any = UNSPECIFIED, + out_shardings: tp.Any = UNSPECIFIED, + static_argnums: tp.Union[int, tp.Sequence[int], None] = None, + static_argnames: tp.Union[str, tp.Iterable[str], None] = None, + donate_argnums: tp.Union[int, tp.Sequence[int]] = (), + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, +) -> jax.stages.Wrapped: + if static_argnames is None: + static_argnames = [] + elif isinstance(static_argnames, str): + static_argnames = [static_argnames] + else: + static_argnames = list(static_argnames) + + static_argnames.append("_nnx__dagdef") + + jit_kwargs = dict( + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + ) + + if in_shardings is not UNSPECIFIED: + jit_kwargs["in_shardings"] = in_shardings + if out_shardings is not UNSPECIFIED: + jit_kwargs["out_shardings"] = out_shardings + + ref_jit = JitTransform( + fun, + stateful, + **jit_kwargs, + ) + ref_jit = functools.wraps(fun)(ref_jit) + # _update_decorator_fields(ref_jit, fun) + return ref_jit + + +class GradTransform: + + def __init__( + self, + fun: tp.Callable[..., tp.Any], + stateful: bool, + predicate: partitioning.Predicate, + has_aux: bool, + holomorphic: bool, + allow_int: bool, + reduce_axes: tp.Sequence[AxisName], + ): + @functools.partial( + jax.grad, + argnums=0, # we'll handle this ourselves + has_aux=has_aux or stateful, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + def grad_fn( + diff: State, + non_diff: State, + moduledef: ModuleDef[Module], + *args: tp.Any, + ): + with tracers.nnx_trace(tracers.get_top_trace(diff)): + module = moduledef.merge(diff, non_diff) + out = fun(module, *args) + + if self.stateful: + updates = module.get_state() + if self.has_aux: + loss, aux = out + out = (loss, (updates, aux)) + else: + out = (out, updates) + + return out + + self.grad_fn = grad_fn + self.predicate = predicate + self.has_aux = has_aux + self.stateful = stateful + + def __call__(self, module: Module, *args: tp.Any): + if not isinstance(module, Module): + raise TypeError(f"Expected a Module, got {type(module).__name__}") + + (diff, nondiff), moduledef = module.partition(self.predicate, ...) + + grads = self.grad_fn(diff, nondiff, moduledef, *args) + + if self.stateful: + updates: State + if self.has_aux: + grads, (updates, aux) = grads + out = grads, aux + else: + out, updates = grads + module.update_state(updates) + else: + out = grads + + return out + + def __repr__(self): + return f"GradTransform({self.grad_fn})" + + +@tp.overload +def grad( + fun: tp.Callable[..., tp.Any], + wrt: partitioning.Filter = "params", + *, + stateful: bool = True, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., State]: + ... + + +@tp.overload +def grad( + fun: tp.Callable[..., tp.Any], + wrt: partitioning.Filter = "params", + *, + stateful: bool = True, + has_aux: tp.Literal[True], + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tp.Tuple[State, tp.Any]]: + ... + + +def grad( + fun: tp.Callable[..., tp.Any], + wrt: partitioning.Filter = containers.Param, + *, + stateful: bool = True, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tp.Union[tp.Tuple[State, tp.Any], State]]: + predicate = partitioning.to_predicate(wrt) + ref_grad = GradTransform( + fun, + stateful=stateful, + predicate=predicate, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + ref_grad = functools.wraps(fun)(ref_grad) + # _update_decorator_fields(ref_grad, fun) + return ref_grad + + +@dataclasses.dataclass +class ScanOptions: + variable_axes: tp.Mapping[partitioning.Filter, int] + variable_broadcast: partitioning.Filter + variable_carry: partitioning.Filter + split_rngs: contextlib.RngFilter + in_axes: tp.Any + out_axes: tp.Any + length: tp.Optional[int] + reverse: bool + unroll: int + data_transform: tp.Optional[tp.Callable[..., tp.Any]] + metadata_params: tp.Mapping[tp.Any, tp.Any] + + +class ScanMeta(ModuleMeta): + + def __call__( + self, + module_constructor: tp.Callable[..., M], + *, + variable_axes: tp.Mapping[partitioning.Filter, int] = MappingProxyType( + {} + ), + variable_broadcast: partitioning.Filter = None, + variable_carry: partitioning.Filter = ..., + split_rngs: contextlib.RngFilter = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + length: tp.Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + data_transform: tp.Optional[tp.Callable[..., tp.Any]] = None, + metadata_params: tp.Mapping[tp.Any, tp.Any] = {}, + ) -> tp.Callable[..., "Scan[M]"]: + super_call = super().__call__ + + def _create_scan(*args, **kwargs) -> Scan[M]: + return super_call( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + variable_axes=variable_axes, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + split_rngs=split_rngs, + in_axes=in_axes, + out_axes=out_axes, + length=length, + reverse=reverse, + unroll=unroll, + data_transform=data_transform, + metadata_params=metadata_params, + ) + + return _create_scan + + +class Scan(Module, tp.Generic[M], metaclass=ScanMeta): + + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + variable_axes: tp.Mapping[partitioning.Filter, int] = MappingProxyType( + {} + ), + variable_broadcast: partitioning.Filter = None, + variable_carry: partitioning.Filter = ..., + split_rngs: contextlib.RngFilter = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + length: tp.Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + data_transform: tp.Optional[tp.Callable[..., tp.Any]] = None, + metadata_params: tp.Mapping[tp.Any, tp.Any] = MappingProxyType({}), + module_init_args: tp.Tuple[tp.Any, ...], + module_init_kwargs: tp.Dict[str, tp.Any], + ): + self.module_constructor = module_constructor + self.options = ScanOptions( + variable_axes=variable_axes, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + split_rngs=split_rngs, + in_axes=in_axes, + out_axes=out_axes, + length=length, + reverse=reverse, + unroll=unroll, + data_transform=data_transform, + metadata_params=metadata_params, + ) + self.scan_module = scan_init( + self.options, module_constructor, module_init_args, module_init_kwargs + ) + + def __call__( + self, + carry_arg: C, + axes_arg, + *broadcast_args, + ctx: tp.Optional[contextlib.Context] = None, + **broadcast_kwargs, + ) -> tp.Tuple[C, tp.Any]: + return self.call( # type: ignore + carry_arg, axes_arg, *broadcast_args, ctx=ctx, **broadcast_kwargs + ) + + @property + def call(self) -> M: + accessesor = DelayedAccessor() + + def _context( + accessesor, + carry_arg: C, + axes_arg, + *broadcast_args, + **broadcast_kwargs, + ) -> tp.Tuple[C, tp.Any]: + def _apply(module, *args, **kwargs): + return accessesor(module)(*args, **kwargs) + + return scan_apply( + self.options, + _apply, + self.scan_module, + carry_arg, + axes_arg, + broadcast_args, + broadcast_kwargs, + ) + + return CallableProxy(_context, accessesor) # type: ignore + + +class ScanCall(tp.Protocol, tp.Generic[C, A, B]): + + def __call__( + self, + module: Module, + carry_arg: C, + axes_arg: A, + *broadcast_args: tp.Any, + **broadcast_kwargs: tp.Any, + ) -> tp.Tuple[C, B]: + ... + + +def scan_init( + options: ScanOptions, + module_constructor: tp.Callable[..., M], + module_init_args: tp.Tuple[tp.Any, ...], + module_init_kwargs: tp.Dict[str, tp.Any], +) -> M: + if options.variable_axes and options.length is None: + raise ValueError("Cannot use variable_axes without specifying a length") + + ctx = module_init_kwargs.pop("ctx", None) + + if ctx is not None and not isinstance(ctx, contextlib.Context): + raise TypeError(f"Expected a Context, got {type(ctx).__name__}") + + key_values = [] + + if ctx is not None: + if not isinstance(ctx, contextlib.Context): + raise TypeError(f"Expected a Context, got {type(ctx).__name__}") + + keys, ctxdef = ctx.partition() + split_predicate = contextlib.to_rng_predicate(options.split_rngs) + + key_axes = [] + key_names = tuple(keys.keys()) + + for name, key in keys.items(): + if split_predicate(name): + if options.length is None: + raise ValueError("Cannot split RNGs without specifying a length") + key = jax.random.split(key, options.length) + key_axes.append(0) + else: + key_axes.append(None) + key_values.append(key) + else: + key_names = None + ctxdef = None + key_axes = None + + moduledef: tp.Optional[ModuleDef[M]] = None + + def _init_state(*key_values): + nonlocal moduledef + + if ctxdef is not None: + assert key_names is not None + keys = dict(zip(key_names, key_values)) + ctx = ctxdef.merge(keys) + module_init_kwargs["ctx"] = ctx + + module = module_constructor(*module_init_args, **module_init_kwargs) + + # lift module + filters = ( + *options.variable_axes.keys(), + options.variable_broadcast, + options.variable_carry, + ) + + states, moduledef = module.partition(*filters) + + return states + + if ctxdef is not None or options.variable_axes: + init_out_axes = (*options.variable_axes.values(), None, None) + _init_state = jax.vmap( + _init_state, + in_axes=key_axes, + out_axes=init_out_axes, + axis_size=options.length, + ) + + *axes_states, broadcast_state, carry_state = _init_state(*key_values) + moduledef = tp.cast(ModuleDef[M], moduledef) + + # add additional axis name to Variable.sharding + if spmd.PARTITION_NAME in options.metadata_params: + axes_states = [ + spmd.add_axis(state, index, options.metadata_params) + for state, index in zip(axes_states, options.variable_axes.values()) + ] + + module = moduledef.merge(*axes_states, broadcast_state, carry_state) + + return module + + +def scan_apply( + options: ScanOptions, + f: ScanCall[C, A, B], + module: Module, + carry_arg: C, + axes_arg: A, + broadcast_args: tuple[tp.Any, ...], + broadcast_kwargs: dict[str, tp.Any], +) -> tp.Tuple[C, B]: + # split module state + filters = ( + *options.variable_axes.keys(), + options.variable_broadcast, + options.variable_carry, + ) + ( + *axes_states, + broadcast_state, + carry_state, + ), moduledef = module.partition(*filters) + + # transpose axes state + axes_states = tuple( + jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state) + for axes_state, axis in zip(axes_states, options.variable_axes.values()) + ) + # transpose axes arg + axes_arg = tree_map_upto_left( + lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), node), + options.in_axes, + axes_arg, + ) + + # infer length + lengths: tp.Set[int] = set( + x.shape[0] for x in jax.tree_util.tree_leaves((axes_states, axes_arg)) + ) + + if len(lengths) > 1: + raise ValueError( + "Inconsistent lengths between variable_axes states and " + f"axes_arg: {lengths}" + ) + elif len(lengths) == 0: + if options.length is None: + raise ValueError( + "Cannot infer length from variable_axes states or axes_arg, " + "please specify `length`" + ) + length = options.length + else: + length = lengths.pop() + if options.length is not None and options.length != length: + raise ValueError( + f"Specified length {options.length} is the same as the inferred " + f"length {length}" + ) + + # split rng state + axes_keys: tp.Optional[tp.Dict[str, jax.Array]] + broadcast_keys: tp.Optional[tp.Dict[str, jax.Array]] + + ctx = broadcast_kwargs.pop("ctx", None) + if ctx is not None: + if not isinstance(ctx, contextlib.Context): + raise TypeError(f"Expected a Context, got {type(ctx).__name__}") + + axes_keys = {} + broadcast_keys = {} + + keys, ctxdef = ctx.partition() + split_predicate = contextlib.to_rng_predicate(options.split_rngs) + + for name, key in keys.items(): + if split_predicate(name): + axes_keys[name] = jax.random.split(key, length) + else: + broadcast_keys[name] = key + else: + ctxdef = None + axes_keys = None + broadcast_keys = None + + def scan_fn( + carry: tp.Tuple[State, tp.Any], + axes: tp.Tuple[ + tp.Optional[tp.Dict[str, jax.Array]], tp.Tuple[State, ...], tp.Any + ], + ): + carry_state, carry_arg = carry + axes_keys, axes_states, axes_arg = axes + + # merge rng state + if ctxdef is not None: + assert axes_keys is not None and broadcast_keys is not None + ctx = ctxdef.merge({**axes_keys, **broadcast_keys}) + broadcast_kwargs["ctx"] = ctx + + # remove metadata axis name from Variable.sharding + if spmd.PARTITION_NAME in options.metadata_params: + axes_states = [ + spmd.remove_axis(state, index, options.metadata_params) + for state, index in zip(axes_states, options.variable_axes.values()) + ] + + # merge module state + module = moduledef.merge(*axes_states, broadcast_state, carry_state) + + (carry_out, axes_out) = f( + module, carry_arg, axes_arg, *broadcast_args, **broadcast_kwargs + ) + + # split module state + (*axes_states_out, broadcast_state_out, carry_state_out), _ = ( + module.partition(*filters) + ) + + carry_state_new = carry_state_out - carry_state + broadcast_state_new = broadcast_state - broadcast_state_out + + # remove new carry state + carry_state_out = carry_state_out - carry_state_new + + # add metadata axis name to Variable.sharding + if spmd.PARTITION_NAME in options.metadata_params: + axes_states_out = [ + spmd.add_axis(state, index, options.metadata_params) + for state, index in zip( + axes_states_out, options.variable_axes.values() + ) + ] + + carry = (carry_state_out, carry_out) + out = (axes_states_out, broadcast_state_new, carry_state_new, axes_out) + + return carry, out + + scan_partial = lambda length, unroll: lambda carry, axes: jax.lax.scan( + scan_fn, + carry, + axes, + length=length, + reverse=options.reverse, + unroll=unroll, + ) + + carry = (carry_state, carry_arg) + axes = (axes_keys, axes_states, axes_arg) + + abstract_output = jax.eval_shape( + scan_partial(length, options.unroll), carry, axes + ) + carry_state_new = abstract_output[1][2] + has_new_carry_state = len(jax.tree_util.tree_leaves(carry_state_new)) > 0 + + if has_new_carry_state: + # run scan for 1 step + axes1 = jax.tree_map(lambda x: x[:1], axes) + carry, scan_out = scan_partial(1, 1)(carry, axes1) + carry_state, carry_arg = carry + axes_states1, broadcast_state_new, carry_state_new, out1 = scan_out + + # slice new broadcast state and carry state + broadcast_state_new, carry_state_new = jax.tree_map( + lambda x: x[0], (broadcast_state_new, carry_state_new) + ) + # merge states + broadcast_state = State.merge(broadcast_state, broadcast_state_new) + carry_state = State.merge(carry_state, carry_state_new) + # udpate carry + carry = (carry_state, carry_arg) + + # run scan for the rest of the steps + axes_rest = jax.tree_map(lambda x: x[1:], axes) + carry, scan_out = scan_partial(length - 1, options.unroll)(carry, axes_rest) + carry_state, carry_out = carry + axes_states_rest, broadcast_state_new, carry_state_new, out_rest = scan_out + + # concatenate outputs + axes_states, out = jax.tree_map( + lambda x, y: jnp.concatenate((x, y), axis=0), + (axes_states1, out1), + (axes_states_rest, out_rest), + ) + else: + carry, scan_out = scan_partial(length, options.unroll)(carry, axes) + carry_state, carry_out = carry + axes_states, broadcast_state_new, carry_state_new, out = scan_out + + # transpose axes state + axes_states = tuple( + jax.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state) + for axes_state, axis in zip(axes_states, options.variable_axes.values()) + ) + # transpose axes arg + out = tree_map_upto_left( + lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, 0, axis), node), + options.out_axes, + out, + ) + + module.update_state( + (*axes_states, carry_state, broadcast_state_new, carry_state_new) + ) + + return carry_out, out + + +def scan( + f: F, + *, + variable_axes: tp.Mapping[partitioning.Filter, int] = MappingProxyType({}), + variable_broadcast: partitioning.Filter = None, + variable_carry: partitioning.Filter = ..., + split_rngs: contextlib.RngFilter = None, + in_axes: tp.Any = 0, + out_axes: tp.Any = 0, + length: tp.Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + data_transform: tp.Optional[tp.Callable[..., tp.Any]] = None, + metadata_params: tp.Mapping[tp.Any, tp.Any] = {}, + is_init: tp.Optional[bool] = None, +) -> F: + if is_init is None: + is_init = f.__name__ == "__init__" + + options = ScanOptions( + variable_axes=variable_axes, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + split_rngs=split_rngs, + in_axes=in_axes, + out_axes=out_axes, + length=length, + reverse=reverse, + unroll=unroll, + data_transform=data_transform, + metadata_params=metadata_params, + ) + + if is_init: + + @functools.wraps(f) + def init_wrapper(module: Module, *args, **kwargs): + def module_constructor(*args, **kwargs): + f(module, *args, **kwargs) + return module + + lifted_module = scan_init(options, module_constructor, args, kwargs) + module.update_state(lifted_module) + + wrapper = init_wrapper + + else: + + @functools.wraps(f) + def apply_wrapper( + module: Module, carry_arg: C, axes_arg, *args, **kwargs + ) -> tuple[C, tp.Any]: + return scan_apply(options, f, module, carry_arg, axes_arg, args, kwargs) + + wrapper = apply_wrapper + + return wrapper # type: ignore + + +class RematMeta(ModuleMeta): + + def __call__( + self, + module_constructor: tp.Callable[..., M], + # variables: lift.CollectionFilter = True, + # rngs: lift.PRNGSequenceFilter = True, + prevent_cse: bool = True, + static_argnums: tp.Union[int, tuple[int, ...]] = (), + policy: tp.Optional[tp.Callable[..., bool]] = None, + ) -> tp.Callable[..., "Remat[M]"]: + super_call = super().__call__ + + def create_remat(*args, **kwargs) -> Remat[M]: + return super_call( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + + return create_remat + + +@dataclasses.dataclass +class RematOptions: + # variables: lift.CollectionFilter, + # rngs: lift.PRNGSequenceFilter, + prevent_cse: bool + static_argnums: tp.Union[int, tuple[int, ...]] + policy: tp.Optional[tp.Callable[..., bool]] + + def __post_init__(self): + if isinstance(self.static_argnums, int): + self.static_argnums = (self.static_argnums,) + + # add 2 as an offset to account for state and keys + self.static_argnums = tuple( + x + 2 if x >= 0 else x for x in self.static_argnums + ) + + +class Remat(Module, tp.Generic[M], metaclass=RematMeta): + + def __init__( + self, + *, + module_constructor: tp.Callable[..., M], + # variables: lift.CollectionFilter, + # rngs: lift.PRNGSequenceFilter, + prevent_cse: bool = True, + static_argnums: tp.Union[int, tuple[int, ...]] = (), + policy: tp.Optional[tp.Callable[..., bool]] = None, + module_init_args: tp.Tuple[tp.Any, ...], + module_init_kwargs: tp.Dict[str, tp.Any], + ): + self.options = RematOptions( + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + self.module_constructor = module_constructor + self.remat_module = self.module_constructor( + *module_init_args, **module_init_kwargs + ) + + def __call__( + self, + *args, + ctx: tp.Optional[contextlib.Context] = None, + ): + return self.call(*args, ctx=ctx) # type: ignore + + @property + def call(self) -> M: + accessesor = DelayedAccessor() + + def _call( + accessesor, *args, ctx: tp.Optional[contextlib.Context] = None + ) -> tp.Tuple[tp.Any]: + def _apply(module, *args, **kwargs): + return accessesor(module)(*args, **kwargs) + + return remat_apply( + self.options, + _apply, + self.remat_module, + args, + ctx, + ) + + return CallableProxy(_call, accessesor) # type: ignore + + +class RematCall(tp.Protocol): + + def __call__(self, *args, ctx: tp.Optional[contextlib.Context]) -> tp.Any: + ... + + +def remat_apply( + options: RematOptions, + f: RematCall, + module: Module, + args: tp.Tuple[tp.Any, ...], + ctx: tp.Optional[contextlib.Context], +): + state, moduledef = module.partition() + + if ctx is not None: + keys, ctxdef = ctx.partition() + else: + keys = None + ctxdef = None + + def _remat_fn( + state: State, + keys: tp.Optional[tp.Dict[str, jax.Array]], + *args, + ) -> tp.Tuple[State, tp.Any]: + kwargs = {} + if keys is not None: + assert ctxdef is not None + kwargs["ctx"] = ctxdef.merge(keys) + + module = moduledef.merge(state) + out = f(module, *args, **kwargs) + + state, _ = module.partition() + + return state, out + + state, out = jax.checkpoint( + _remat_fn, + prevent_cse=options.prevent_cse, + static_argnums=options.static_argnums, + policy=options.policy, + )(state, keys, *args) + + module.update_state(state) + + return out + + +def remat( + f: F, + *, + # variables: lift.CollectionFilter, + # rngs: lift.PRNGSequenceFilter, + prevent_cse: bool = True, + static_argnums: tp.Union[int, tuple[int, ...]] = (), + policy: tp.Optional[tp.Callable[..., bool]] = None, + is_init: tp.Optional[bool] = None, +) -> F: + if is_init is None: + is_init = f.__name__ == "__init__" + + options = RematOptions( + # variables=variables, + # rngs=rngs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + + if is_init: + return f + else: + + @functools.wraps(f) + def wrapper( + module: Module, *args, ctx: tp.Optional[contextlib.Context] = None + ): + return remat_apply(options, f, module, args, ctx) + + return wrapper # type: ignore + + +def tree_map_upto_left( + f: tp.Callable[[tp.Any, tp.Any], tp.Any], left: tp.Any, right: tp.Any +) -> tp.Any: + leaves_left, treedef = jtu.tree_flatten(left) + leaves_right = treedef.flatten_up_to(right) + + return treedef.unflatten( + f(left_leaf, right_leaf) + for left_leaf, right_leaf in zip(leaves_left, leaves_right) + ) diff --git a/flax/experimental/nnx/poetry.lock b/flax/experimental/nnx/poetry.lock new file mode 100644 index 0000000000..17d2f22f60 --- /dev/null +++ b/flax/experimental/nnx/poetry.lock @@ -0,0 +1,3216 @@ +# This file is automatically @generated by Poetry and should not be changed by hand. + +[[package]] +name = "absl-py" +version = "1.4.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "absl-py-1.4.0.tar.gz", hash = "sha256:d2c244d01048ba476e7c080bd2c6df5e141d211de80223460d5b3b8a2a58433d"}, + {file = "absl_py-1.4.0-py3-none-any.whl", hash = "sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47"}, +] + +[[package]] +name = "aiohttp" +version = "3.8.5" +description = "Async http client/server framework (asyncio)" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"}, + {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"}, + {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"}, + {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"}, + {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"}, + {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"}, + {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"}, + {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"}, + {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"}, + {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"}, + {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"}, +] + +[package.dependencies] +aiosignal = ">=1.1.2" +async-timeout = ">=4.0.0a3,<5.0" +attrs = ">=17.3.0" +charset-normalizer = ">=2.0,<4.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns", "cchardet"] + +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + +[[package]] +name = "appnope" +version = "0.1.3" +description = "Disable App Nap on macOS >= 10.9" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, + {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, +] + +[[package]] +name = "asttokens" +version = "2.2.1" +description = "Annotate AST trees with source code positions" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.2.1-py2.py3-none-any.whl", hash = "sha256:6b0ac9e93fb0335014d382b8fa9b3afa7df546984258005da0b9e7095b3deb1c"}, + {file = "asttokens-2.2.1.tar.gz", hash = "sha256:4622110b2a6f30b77e1473affaa97e711bc2f07d3f10848420ff1898edbe94f3"}, +] + +[package.dependencies] +six = "*" + +[package.extras] +test = ["astroid", "pytest"] + +[[package]] +name = "async-timeout" +version = "4.0.2" +description = "Timeout context manager for asyncio programs" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, + {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"}, +] + +[[package]] +name = "attrs" +version = "23.1.0" +description = "Classes Without Boilerplate" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, + {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[docs,tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] + +[[package]] +name = "backcall" +version = "0.2.0" +description = "Specifications for callback functions passed in to an API" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"}, + {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, +] + +[[package]] +name = "black" +version = "23.3.0" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, + {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, + {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, + {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, + {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, + {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, + {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, + {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, + {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, + {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, + {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, + {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, + {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, + {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, +] + +[package.dependencies] +click = ">=8.0.0" +ipython = {version = ">=7.8.0", optional = true, markers = "extra == \"jupyter\""} +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tokenize-rt = {version = ">=3.2.0", optional = true, markers = "extra == \"jupyter\""} +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "certifi" +version = "2023.5.7" +description = "Python package for providing Mozilla's CA Bundle." +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, + {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, +] + +[[package]] +name = "cffi" +version = "1.15.1" +description = "Foreign Function Interface for Python calling C code." +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "cffi-1.15.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a66d3508133af6e8548451b25058d5812812ec3798c886bf38ed24a98216fab2"}, + {file = "cffi-1.15.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:470c103ae716238bbe698d67ad020e1db9d9dba34fa5a899b5e21577e6d52ed2"}, + {file = "cffi-1.15.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:9ad5db27f9cabae298d151c85cf2bad1d359a1b9c686a275df03385758e2f914"}, + {file = "cffi-1.15.1-cp27-cp27m-win32.whl", hash = "sha256:b3bbeb01c2b273cca1e1e0c5df57f12dce9a4dd331b4fa1635b8bec26350bde3"}, + {file = "cffi-1.15.1-cp27-cp27m-win_amd64.whl", hash = "sha256:e00b098126fd45523dd056d2efba6c5a63b71ffe9f2bbe1a4fe1716e1d0c331e"}, + {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:d61f4695e6c866a23a21acab0509af1cdfd2c013cf256bbf5b6b5e2695827162"}, + {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:ed9cb427ba5504c1dc15ede7d516b84757c3e3d7868ccc85121d9310d27eed0b"}, + {file = "cffi-1.15.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d39875251ca8f612b6f33e6b1195af86d1b3e60086068be9cc053aa4376e21"}, + {file = "cffi-1.15.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:285d29981935eb726a4399badae8f0ffdff4f5050eaa6d0cfc3f64b857b77185"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3eb6971dcff08619f8d91607cfc726518b6fa2a9eba42856be181c6d0d9515fd"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21157295583fe8943475029ed5abdcf71eb3911894724e360acff1d61c1d54bc"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5635bd9cb9731e6d4a1132a498dd34f764034a8ce60cef4f5319c0541159392f"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2012c72d854c2d03e45d06ae57f40d78e5770d252f195b93f581acf3ba44496e"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86c085fae2efd48ac91dd7ccffcfc0571387fe1193d33b6394db7ef31fe2a4"}, + {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fa6693661a4c91757f4412306191b6dc88c1703f780c8234035eac011922bc01"}, + {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59c0b02d0a6c384d453fece7566d1c7e6b7bae4fc5874ef2ef46d56776d61c9e"}, + {file = "cffi-1.15.1-cp310-cp310-win32.whl", hash = "sha256:cba9d6b9a7d64d4bd46167096fc9d2f835e25d7e4c121fb2ddfc6528fb0413b2"}, + {file = "cffi-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:ce4bcc037df4fc5e3d184794f27bdaab018943698f4ca31630bc7f84a7b69c6d"}, + {file = "cffi-1.15.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d08afd128ddaa624a48cf2b859afef385b720bb4b43df214f85616922e6a5ac"}, + {file = "cffi-1.15.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3799aecf2e17cf585d977b780ce79ff0dc9b78d799fc694221ce814c2c19db83"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a591fe9e525846e4d154205572a029f653ada1a78b93697f3b5a8f1f2bc055b9"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3548db281cd7d2561c9ad9984681c95f7b0e38881201e157833a2342c30d5e8c"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91fc98adde3d7881af9b59ed0294046f3806221863722ba7d8d120c575314325"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94411f22c3985acaec6f83c6df553f2dbe17b698cc7f8ae751ff2237d96b9e3c"}, + {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:03425bdae262c76aad70202debd780501fabeaca237cdfddc008987c0e0f59ef"}, + {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc4d65aeeaa04136a12677d3dd0b1c0c94dc43abac5860ab33cceb42b801c1e8"}, + {file = "cffi-1.15.1-cp311-cp311-win32.whl", hash = "sha256:a0f100c8912c114ff53e1202d0078b425bee3649ae34d7b070e9697f93c5d52d"}, + {file = "cffi-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:04ed324bda3cda42b9b695d51bb7d54b680b9719cfab04227cdd1e04e5de3104"}, + {file = "cffi-1.15.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50a74364d85fd319352182ef59c5c790484a336f6db772c1a9231f1c3ed0cbd7"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e263d77ee3dd201c3a142934a086a4450861778baaeeb45db4591ef65550b0a6"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cec7d9412a9102bdc577382c3929b337320c4c4c4849f2c5cdd14d7368c5562d"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4289fc34b2f5316fbb762d75362931e351941fa95fa18789191b33fc4cf9504a"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:173379135477dc8cac4bc58f45db08ab45d228b3363adb7af79436135d028405"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6975a3fac6bc83c4a65c9f9fcab9e47019a11d3d2cf7f3c0d03431bf145a941e"}, + {file = "cffi-1.15.1-cp36-cp36m-win32.whl", hash = "sha256:2470043b93ff09bf8fb1d46d1cb756ce6132c54826661a32d4e4d132e1977adf"}, + {file = "cffi-1.15.1-cp36-cp36m-win_amd64.whl", hash = "sha256:30d78fbc8ebf9c92c9b7823ee18eb92f2e6ef79b45ac84db507f52fbe3ec4497"}, + {file = "cffi-1.15.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:198caafb44239b60e252492445da556afafc7d1e3ab7a1fb3f0584ef6d742375"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ef34d190326c3b1f822a5b7a45f6c4535e2f47ed06fec77d3d799c450b2651e"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8102eaf27e1e448db915d08afa8b41d6c7ca7a04b7d73af6514df10a3e74bd82"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5df2768244d19ab7f60546d0c7c63ce1581f7af8b5de3eb3004b9b6fc8a9f84b"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8c4917bd7ad33e8eb21e9a5bbba979b49d9a97acb3a803092cbc1133e20343c"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2642fe3142e4cc4af0799748233ad6da94c62a8bec3a6648bf8ee68b1c7426"}, + {file = "cffi-1.15.1-cp37-cp37m-win32.whl", hash = "sha256:e229a521186c75c8ad9490854fd8bbdd9a0c9aa3a524326b55be83b54d4e0ad9"}, + {file = "cffi-1.15.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a0b71b1b8fbf2b96e41c4d990244165e2c9be83d54962a9a1d118fd8657d2045"}, + {file = "cffi-1.15.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:320dab6e7cb2eacdf0e658569d2575c4dad258c0fcc794f46215e1e39f90f2c3"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e74c6b51a9ed6589199c787bf5f9875612ca4a8a0785fb2d4a84429badaf22a"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5c84c68147988265e60416b57fc83425a78058853509c1b0629c180094904a5"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b926aa83d1edb5aa5b427b4053dc420ec295a08e40911296b9eb1b6170f6cca"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87c450779d0914f2861b8526e035c5e6da0a3199d8f1add1a665e1cbc6fc6d02"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2c9f67e9821cad2e5f480bc8d83b8742896f1242dba247911072d4fa94c192"}, + {file = "cffi-1.15.1-cp38-cp38-win32.whl", hash = "sha256:8b7ee99e510d7b66cdb6c593f21c043c248537a32e0bedf02e01e9553a172314"}, + {file = "cffi-1.15.1-cp38-cp38-win_amd64.whl", hash = "sha256:00a9ed42e88df81ffae7a8ab6d9356b371399b91dbdf0c3cb1e84c03a13aceb5"}, + {file = "cffi-1.15.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:54a2db7b78338edd780e7ef7f9f6c442500fb0d41a5a4ea24fff1c929d5af585"}, + {file = "cffi-1.15.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fcd131dd944808b5bdb38e6f5b53013c5aa4f334c5cad0c72742f6eba4b73db0"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7473e861101c9e72452f9bf8acb984947aa1661a7704553a9f6e4baa5ba64415"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c9a799e985904922a4d207a94eae35c78ebae90e128f0c4e521ce339396be9d"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3bcde07039e586f91b45c88f8583ea7cf7a0770df3a1649627bf598332cb6984"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33ab79603146aace82c2427da5ca6e58f2b3f2fb5da893ceac0c42218a40be35"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d598b938678ebf3c67377cdd45e09d431369c3b1a5b331058c338e201f12b27"}, + {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db0fbb9c62743ce59a9ff687eb5f4afbe77e5e8403d6697f7446e5f609976f76"}, + {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d85c6a2bef81588d9227dde12db8a7f47f639f4a17c9ae08e773aa9c697bf3"}, + {file = "cffi-1.15.1-cp39-cp39-win32.whl", hash = "sha256:40f4774f5a9d4f5e344f31a32b5096977b5d48560c5592e2f3d2c4374bd543ee"}, + {file = "cffi-1.15.1-cp39-cp39-win_amd64.whl", hash = "sha256:70df4e3b545a17496c9b3f41f5115e69a4f2e77e94e1d2a8e1070bc0c38c8a3c"}, + {file = "cffi-1.15.1.tar.gz", hash = "sha256:d400bfb9a37b1351253cb402671cea7e89bdecc294e8016a707f6d1d8ac934f9"}, +] + +[package.dependencies] +pycparser = "*" + +[[package]] +name = "cfgv" +version = "3.3.1" +description = "Validate configuration and produce human readable error messages." +category = "dev" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "cfgv-3.3.1-py2.py3-none-any.whl", hash = "sha256:c6a0883f3917a037485059700b9e75da2464e6c27051014ad85ba6aaa5884426"}, + {file = "cfgv-3.3.1.tar.gz", hash = "sha256:f5a830efb9ce7a445376bb66ec94c638a9787422f96264c98edc6bdeed8ab736"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.2.0" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "dev" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, + {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, +] + +[[package]] +name = "chex" +version = "0.1.81" +description = "Chex: Testing made fun, in JAX!" +category = "main" +optional = false +python-versions = ">=3.9" +files = [ + {file = "chex-0.1.81-py3-none-any.whl", hash = "sha256:b78b60d440b9172cbf7a15f2c126eac21b3761752bc9f901ab4e3c63530e1554"}, + {file = "chex-0.1.81.tar.gz", hash = "sha256:22801735acb0402453e59cf9d6f71c3bfededce77c450230eba49ec0755afdbc"}, +] + +[package.dependencies] +absl-py = ">=0.9.0" +dm-tree = ">=0.1.5" +jax = ">=0.4.6" +jaxlib = ">=0.1.37" +numpy = ">=1.25.0" +toolz = ">=0.9.0" +typing-extensions = ">=4.2.0" + +[[package]] +name = "click" +version = "8.1.6" +description = "Composable command line interface toolkit" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.6-py3-none-any.whl", hash = "sha256:fa244bb30b3b5ee2cae3da8f55c9e5e0c0e86093306301fb418eb9dc40fbded5"}, + {file = "click-8.1.6.tar.gz", hash = "sha256:48ee849951919527a045bfe3bf7baa8a959c423134e1a5b98c05c20ba75a1cbd"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "comm" +version = "0.1.3" +description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "comm-0.1.3-py3-none-any.whl", hash = "sha256:16613c6211e20223f215fc6d3b266a247b6e2641bf4e0a3ad34cb1aff2aa3f37"}, + {file = "comm-0.1.3.tar.gz", hash = "sha256:a61efa9daffcfbe66fd643ba966f846a624e4e6d6767eda9cf6e993aadaab93e"}, +] + +[package.dependencies] +traitlets = ">=5.3" + +[package.extras] +lint = ["black (>=22.6.0)", "mdformat (>0.7)", "mdformat-gfm (>=0.3.5)", "ruff (>=0.0.156)"] +test = ["pytest"] +typing = ["mypy (>=0.990)"] + +[[package]] +name = "contourpy" +version = "1.1.0" +description = "Python library for calculating contours of 2D quadrilateral grids" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "contourpy-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89f06eff3ce2f4b3eb24c1055a26981bffe4e7264acd86f15b97e40530b794bc"}, + {file = "contourpy-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dffcc2ddec1782dd2f2ce1ef16f070861af4fb78c69862ce0aab801495dda6a3"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25ae46595e22f93592d39a7eac3d638cda552c3e1160255258b695f7b58e5655"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:17cfaf5ec9862bc93af1ec1f302457371c34e688fbd381f4035a06cd47324f48"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa"}, + {file = "contourpy-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa"}, + {file = "contourpy-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2"}, + {file = "contourpy-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882"}, + {file = "contourpy-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:052cc634bf903c604ef1a00a5aa093c54f81a2612faedaa43295809ffdde885e"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9382a1c0bc46230fb881c36229bfa23d8c303b889b788b939365578d762b5c18"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a"}, + {file = "contourpy-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e"}, + {file = "contourpy-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"}, + {file = "contourpy-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed"}, + {file = "contourpy-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62013a2cf68abc80dadfd2307299bfa8f5aa0dcaec5b2954caeb5fa094171103"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0b6616375d7de55797d7a66ee7d087efe27f03d336c27cf1f32c02b8c1a5ac70"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f"}, + {file = "contourpy-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1"}, + {file = "contourpy-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4"}, + {file = "contourpy-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9"}, + {file = "contourpy-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f2931ed4741f98f74b410b16e5213f71dcccee67518970c42f64153ea9313b9"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:30f511c05fab7f12e0b1b7730ebdc2ec8deedcfb505bc27eb570ff47c51a8f15"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a"}, + {file = "contourpy-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002"}, + {file = "contourpy-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d"}, + {file = "contourpy-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493"}, + {file = "contourpy-1.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439"}, + {file = "contourpy-1.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:a67259c2b493b00e5a4d0f7bfae51fb4b3371395e47d079a4446e9b0f4d70e76"}, + {file = "contourpy-1.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2b836d22bd2c7bb2700348e4521b25e077255ebb6ab68e351ab5aa91ca27e027"}, + {file = "contourpy-1.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084eaa568400cfaf7179b847ac871582199b1b44d5699198e9602ecbbb5f6104"}, + {file = "contourpy-1.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:911ff4fd53e26b019f898f32db0d4956c9d227d51338fb3b03ec72ff0084ee5f"}, + {file = "contourpy-1.1.0.tar.gz", hash = "sha256:e53046c3863828d21d531cc3b53786e6580eb1ba02477e8681009b6aa0870b21"}, +] + +[package.dependencies] +numpy = ">=1.16" + +[package.extras] +bokeh = ["bokeh", "selenium"] +docs = ["furo", "sphinx-copybutton"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.2.0)", "types-Pillow"] +test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] +test-no-images = ["pytest", "pytest-cov", "wurlitzer"] + +[[package]] +name = "coverage" +version = "7.2.7" +description = "Code coverage measurement for Python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"}, + {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"}, + {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"}, + {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"}, + {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"}, + {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"}, + {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"}, + {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"}, + {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"}, + {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"}, + {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"}, + {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"}, + {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"}, + {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"}, + {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"}, + {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"}, + {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"}, + {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"}, + {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"}, + {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"}, + {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"}, + {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"}, + {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"}, + {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"}, + {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"}, + {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"}, + {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"}, + {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"}, + {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"}, + {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"}, + {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"}, + {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"}, + {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"}, + {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"}, + {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"}, + {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + +[[package]] +name = "cycler" +version = "0.11.0" +description = "Composable style cycles" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"}, + {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, +] + +[[package]] +name = "datasets" +version = "2.13.1" +description = "HuggingFace community-driven open-source library of datasets" +category = "dev" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "datasets-2.13.1-py3-none-any.whl", hash = "sha256:844d8dbc1759e0b6b8e5063af019dc95d6af07ea075002b03323a280bf8d53f6"}, + {file = "datasets-2.13.1.tar.gz", hash = "sha256:bacb7750b1a434417312b4281a55225a3f7e0163abdd12a2a3e2d700310d5221"}, +] + +[package.dependencies] +aiohttp = "*" +dill = ">=0.3.0,<0.3.7" +fsspec = {version = ">=2021.11.1", extras = ["http"]} +huggingface-hub = ">=0.11.0,<1.0.0" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +pyarrow = ">=8.0.0" +pyyaml = ">=5.1" +requests = ">=2.19.0" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] +audio = ["librosa", "soundfile (>=0.12.1)"] +benchmarks = ["numpy (==1.18.5)", "protobuf (==3.20.3)", "tensorflow (==2.3.0)", "torch (==1.7.1)", "transformers (==3.0.2)"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +docs = ["s3fs"] +jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] +quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] +s3 = ["s3fs"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +torch = ["torch"] +vision = ["Pillow (>=6.2.1)"] + +[[package]] +name = "debugpy" +version = "1.6.7" +description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "debugpy-1.6.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b3e7ac809b991006ad7f857f016fa92014445085711ef111fdc3f74f66144096"}, + {file = "debugpy-1.6.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3876611d114a18aafef6383695dfc3f1217c98a9168c1aaf1a02b01ec7d8d1e"}, + {file = "debugpy-1.6.7-cp310-cp310-win32.whl", hash = "sha256:33edb4afa85c098c24cc361d72ba7c21bb92f501104514d4ffec1fb36e09c01a"}, + {file = "debugpy-1.6.7-cp310-cp310-win_amd64.whl", hash = "sha256:ed6d5413474e209ba50b1a75b2d9eecf64d41e6e4501977991cdc755dc83ab0f"}, + {file = "debugpy-1.6.7-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:38ed626353e7c63f4b11efad659be04c23de2b0d15efff77b60e4740ea685d07"}, + {file = "debugpy-1.6.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:279d64c408c60431c8ee832dfd9ace7c396984fd7341fa3116aee414e7dcd88d"}, + {file = "debugpy-1.6.7-cp37-cp37m-win32.whl", hash = "sha256:dbe04e7568aa69361a5b4c47b4493d5680bfa3a911d1e105fbea1b1f23f3eb45"}, + {file = "debugpy-1.6.7-cp37-cp37m-win_amd64.whl", hash = "sha256:f90a2d4ad9a035cee7331c06a4cf2245e38bd7c89554fe3b616d90ab8aab89cc"}, + {file = "debugpy-1.6.7-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:5224eabbbeddcf1943d4e2821876f3e5d7d383f27390b82da5d9558fd4eb30a9"}, + {file = "debugpy-1.6.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae1123dff5bfe548ba1683eb972329ba6d646c3a80e6b4c06cd1b1dd0205e9b"}, + {file = "debugpy-1.6.7-cp38-cp38-win32.whl", hash = "sha256:9cd10cf338e0907fdcf9eac9087faa30f150ef5445af5a545d307055141dd7a4"}, + {file = "debugpy-1.6.7-cp38-cp38-win_amd64.whl", hash = "sha256:aaf6da50377ff4056c8ed470da24632b42e4087bc826845daad7af211e00faad"}, + {file = "debugpy-1.6.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:0679b7e1e3523bd7d7869447ec67b59728675aadfc038550a63a362b63029d2c"}, + {file = "debugpy-1.6.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de86029696e1b3b4d0d49076b9eba606c226e33ae312a57a46dca14ff370894d"}, + {file = "debugpy-1.6.7-cp39-cp39-win32.whl", hash = "sha256:d71b31117779d9a90b745720c0eab54ae1da76d5b38c8026c654f4a066b0130a"}, + {file = "debugpy-1.6.7-cp39-cp39-win_amd64.whl", hash = "sha256:c0ff93ae90a03b06d85b2c529eca51ab15457868a377c4cc40a23ab0e4e552a3"}, + {file = "debugpy-1.6.7-py2.py3-none-any.whl", hash = "sha256:53f7a456bc50706a0eaabecf2d3ce44c4d5010e46dfc65b6b81a518b42866267"}, + {file = "debugpy-1.6.7.zip", hash = "sha256:c4c2f0810fa25323abfdfa36cbbbb24e5c3b1a42cb762782de64439c575d67f2"}, +] + +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + +[[package]] +name = "dill" +version = "0.3.6" +description = "serialize all of python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "dill-0.3.6-py3-none-any.whl", hash = "sha256:a07ffd2351b8c678dfc4a856a3005f8067aea51d6ba6c700796a4d9e280f39f0"}, + {file = "dill-0.3.6.tar.gz", hash = "sha256:e5db55f3687856d8fbdab002ed78544e1c4559a130302693d839dfe8f93f2373"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] + +[[package]] +name = "distlib" +version = "0.3.7" +description = "Distribution utilities" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, + {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, +] + +[[package]] +name = "dm-tree" +version = "0.1.8" +description = "Tree is a library for working with nested data structures." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, + {file = "dm_tree-0.1.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60"}, + {file = "dm_tree-0.1.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f"}, + {file = "dm_tree-0.1.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2869228d9c619074de501a3c10dc7f07c75422f8fab36ecdcb859b6f1b1ec3ef"}, + {file = "dm_tree-0.1.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d20f2faa3672b52e5013f4077117bfb99c4cfc0b445d3bde1584c34032b57436"}, + {file = "dm_tree-0.1.8-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410"}, + {file = "dm_tree-0.1.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d7c26e431fc93cc7e0cba867eb000db6a05f6f2b25af11ac4e9dada88fc5bca"}, + {file = "dm_tree-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144"}, + {file = "dm_tree-0.1.8-cp310-cp310-win_amd64.whl", hash = "sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee"}, + {file = "dm_tree-0.1.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7"}, + {file = "dm_tree-0.1.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b"}, + {file = "dm_tree-0.1.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5"}, + {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1607ce49aa42f010d1e5e616d92ce899d66835d4d8bea49679582435285515de"}, + {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:343a4a4ebaa127451ff971254a4be4084eb4bdc0b2513c32b46f6f728fd03f9e"}, + {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, + {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, + {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, + {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, + {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, + {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, + {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6"}, + {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f7915660f59c09068e428613c480150180df1060561fd0d1470684ae7007bd1"}, + {file = "dm_tree-0.1.8-cp37-cp37m-win_amd64.whl", hash = "sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6"}, + {file = "dm_tree-0.1.8-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0e9620ccf06393eb6b613b5e366469304622d4ea96ae6540b28a33840e6c89cf"}, + {file = "dm_tree-0.1.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b095ba4f8ca1ba19350fd53cf1f8f3eb0bd406aa28af64a6dfc86707b32a810a"}, + {file = "dm_tree-0.1.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d"}, + {file = "dm_tree-0.1.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d3172394079a86c3a759179c65f64c48d1a42b89495fcf38976d11cc3bb952c"}, + {file = "dm_tree-0.1.8-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8"}, + {file = "dm_tree-0.1.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68"}, + {file = "dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134"}, + {file = "dm_tree-0.1.8-cp38-cp38-win_amd64.whl", hash = "sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5"}, + {file = "dm_tree-0.1.8-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f"}, + {file = "dm_tree-0.1.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:250b692fb75f45f02e2f58fbef9ab338904ef334b90557565621fa251df267cf"}, + {file = "dm_tree-0.1.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81fce77f22a302d7a5968aebdf4efafef4def7ce96528719a354e6990dcd49c7"}, + {file = "dm_tree-0.1.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb"}, + {file = "dm_tree-0.1.8-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe962015b2fe1282892b28ebe962faed53c7f98d942da9a4625cbf27baef913"}, + {file = "dm_tree-0.1.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c52cbf4f8b3dbd0beaedf44f69fa85eec5e9dede612e08035e06ada6ec9426"}, + {file = "dm_tree-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:181c35521d480d0365f39300542cb6cd7fd2b77351bb43d7acfda15aef63b317"}, + {file = "dm_tree-0.1.8-cp39-cp39-win_amd64.whl", hash = "sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368"}, +] + +[[package]] +name = "etils" +version = "1.3.0" +description = "Collection of common python utils" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "etils-1.3.0-py3-none-any.whl", hash = "sha256:809a92ff72f12149441492cf4d9a26b56a4741dffb4dfb9c4c7b7afe055c2d28"}, + {file = "etils-1.3.0.tar.gz", hash = "sha256:0a695ec45a982ae7c9deb437f1f251346d88b43ca59be67e961f61fe8bc8cae4"}, +] + +[package.dependencies] +importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""} +typing_extensions = {version = "*", optional = true, markers = "extra == \"epath\""} +zipp = {version = "*", optional = true, markers = "extra == \"epath\""} + +[package.extras] +all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] +array-types = ["etils[enp]"] +dev = ["chex", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] +eapp = ["absl-py", "etils[epy]", "simple_parsing"] +ecolab = ["etils[enp]", "etils[epy]", "jupyter", "mediapy", "numpy"] +edc = ["etils[epy]"] +enp = ["etils[epy]", "numpy"] +epath = ["etils[epy]", "importlib_resources", "typing_extensions", "zipp"] +epy = ["typing_extensions"] +etqdm = ["absl-py", "etils[epy]", "tqdm"] +etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"] +etree-dm = ["dm-tree", "etils[etree]"] +etree-jax = ["etils[etree]", "jax[cpu]"] +etree-tf = ["etils[etree]", "tensorflow"] +lazy-imports = ["etils[ecolab]"] + +[[package]] +name = "exceptiongroup" +version = "1.1.2" +description = "Backport of PEP 654 (exception groups)" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.2-py3-none-any.whl", hash = "sha256:e346e69d186172ca7cf029c8c1d16235aa0e04035e5750b4b95039e65204328f"}, + {file = "exceptiongroup-1.1.2.tar.gz", hash = "sha256:12c3e887d6485d16943a309616de20ae5582633e0a2eda17f4e10fd61c1e8af5"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "executing" +version = "1.2.0" +description = "Get the currently executing AST node of a frame, and other information" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "executing-1.2.0-py2.py3-none-any.whl", hash = "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc"}, + {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"}, +] + +[package.extras] +tests = ["asttokens", "littleutils", "pytest", "rich"] + +[[package]] +name = "filelock" +version = "3.12.2" +description = "A platform independent file lock." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, + {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, +] + +[package.extras] +docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] + +[[package]] +name = "flax" +version = "0.7.0" +description = "Flax: A neural network library for JAX designed for flexibility" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "flax-0.7.0-py3-none-any.whl", hash = "sha256:c63e64124be8011b3d2f65a866d98627a5879f243e18351e85bcd0ab29228fc4"}, + {file = "flax-0.7.0.tar.gz", hash = "sha256:171715d7df050eb748867f14a6d42338adba060edaa1e3b4d3e978a3483db8c5"}, +] + +[package.dependencies] +jax = ">=0.4.2" +msgpack = "*" +numpy = ">=1.12" +optax = "*" +orbax-checkpoint = "*" +PyYAML = ">=5.4.1" +rich = ">=11.1" +tensorstore = "*" +typing-extensions = ">=4.1.1" + +[package.extras] +all = ["matplotlib"] +testing = ["clu", "einops", "gymnasium[accept-rom-license,atari]", "jaxlib", "jraph (>=0.0.6dev0)", "ml-collections", "mypy", "nbstripout", "opencv-python", "pytest", "pytest-cov", "pytest-custom-exit-code", "pytest-xdist (==1.34.0)", "pytype", "sentencepiece", "tensorflow", "tensorflow-datasets", "tensorflow-text (>=2.11.0)", "torch"] + +[[package]] +name = "fonttools" +version = "4.41.0" +description = "Tools to manipulate font files" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fonttools-4.41.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ba2a367ff478cd108d5319c0dc4fd4eb4ea3476dbfc45b00c45718e889cd9463"}, + {file = "fonttools-4.41.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:69178674505ec81adf4af2a3bbacd0cb9a37ba7831bc3fca307f80e48ab2767b"}, + {file = "fonttools-4.41.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86edb95c4d1fe4fae2111d7e0c10c6e42b7790b377bcf1952303469eee5b52bb"}, + {file = "fonttools-4.41.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50f8bdb421270f71b54695c62785e300fab4bb6127be40bf9f3084962a0c3adb"}, + {file = "fonttools-4.41.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c890061915e95b619c1d3cc3c107c6fb021406b701c0c24b03e74830d522f210"}, + {file = "fonttools-4.41.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b329ae7ce971b5c4148d6cdb8119c0ce4587265b2330d4f2f3776ef851bee020"}, + {file = "fonttools-4.41.0-cp310-cp310-win32.whl", hash = "sha256:bc9e7b1e268be7a23fc66471b615c324e99c5db39ce8c49dd6dd8e962c7bc1b8"}, + {file = "fonttools-4.41.0-cp310-cp310-win_amd64.whl", hash = "sha256:f3fe90dfb297bd8265238c06787911cd81c2cb89ac5b13e1c911928bdabfce0f"}, + {file = "fonttools-4.41.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e38bd91eae257f36c2b7245c0278e9cd9d754f3a66b8d2b548c623ba66e387b6"}, + {file = "fonttools-4.41.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:415cf7c806a3f56fb280dadcf3c92c85c0415e75665ca957b4a2a2e39c17a5c9"}, + {file = "fonttools-4.41.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:381558eafffc1432d08ca58063e71c7376ecaae48e9318354a90a1049a644845"}, + {file = "fonttools-4.41.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ee75b8ca48f6c48af25e967dce995ef94e46872b35c7d454b983c62c9c7006d"}, + {file = "fonttools-4.41.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d45f28c20bb67dee0f4a4caae709f40b0693d764b7b2bf2d58890f36b1bfcef0"}, + {file = "fonttools-4.41.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5448a87f6ed57ed844b64a05d3792827af584a8584613f6289867f4e77eb603b"}, + {file = "fonttools-4.41.0-cp311-cp311-win32.whl", hash = "sha256:69dbe0154e15b68dd671441ea8f23dad87488b24a6e650d45958f4722819a443"}, + {file = "fonttools-4.41.0-cp311-cp311-win_amd64.whl", hash = "sha256:ea879afd1d6189fca02a85a7868560c9bb8415dccff6b7ae6d81e4f06b3ab30d"}, + {file = "fonttools-4.41.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8f602dd5bcde7e4241419924f23c6f0d66723dd5408a58c3a2f781745c693f45"}, + {file = "fonttools-4.41.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:06eac087ea55b3ebb2207d93b5ac56c847163899f05f5a77e1910f688fe10030"}, + {file = "fonttools-4.41.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e22d0144d735f6c7df770509b8c0c33414bf460df0d5dddc98a159e5dbb10eb"}, + {file = "fonttools-4.41.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19d461c801b8904d201c6c38a99bfcfef673bfdfe0c7f026f582ef78896434e0"}, + {file = "fonttools-4.41.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:72d40a32d6443871ea0d147813caad58394b48729dfa4fbc45dcaac54f9506f2"}, + {file = "fonttools-4.41.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0614b6348866092d00df3dfb37e037fc06412ca67087de361a2777ea5ed62c16"}, + {file = "fonttools-4.41.0-cp38-cp38-win32.whl", hash = "sha256:e43f6c7f9ba4f9d29edee530e45f9aa162872ec9549398b85971477a99f2a806"}, + {file = "fonttools-4.41.0-cp38-cp38-win_amd64.whl", hash = "sha256:eb9dfa87152bd97019adc387b2f29ef6af601de4386f36570ca537ace96d96ed"}, + {file = "fonttools-4.41.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d2dae84a3d0f76884a6102c62f2795b2d6602c2c95cfcce74c8a590b6200e533"}, + {file = "fonttools-4.41.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cc3324e4159e6d1f55c3615b4c1c211f87cc96cc0cc7c946c8447dc1319f2e9d"}, + {file = "fonttools-4.41.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c654b1facf1f3b742e4d9b2dcdf0fa867b1f007b1b4981cc58a75ef5dca2a3c"}, + {file = "fonttools-4.41.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:560ea1a604c927399f36742abf342a4c5f3fee8e8e8a484b774dfe9630bd9a91"}, + {file = "fonttools-4.41.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9387b09694fbf8ac7dcf887069068f81fb4124d05e09557ac7daabfbec1744bd"}, + {file = "fonttools-4.41.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:465d0f24bf4f75160f441793b55076b7a080a57d3a1f738390af2c20bee24fbb"}, + {file = "fonttools-4.41.0-cp39-cp39-win32.whl", hash = "sha256:841c491fa3e9c54e8f9cd5dae059e88f45e086aea090c28be9d42f59c8b99e01"}, + {file = "fonttools-4.41.0-cp39-cp39-win_amd64.whl", hash = "sha256:efd59e83223cb77952997fb850c7a7c2a958c9af0642060f536722c2a9e9d53b"}, + {file = "fonttools-4.41.0-py3-none-any.whl", hash = "sha256:5b1c2b21b40229166a864f2b0aec06d37f0a204066deb1734c93370e0c76339d"}, + {file = "fonttools-4.41.0.tar.gz", hash = "sha256:6faff25991dec48f8cac882055a09ae1a29fd15bc160bc3d663e789e994664c2"}, +] + +[package.extras] +all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.0.0)", "xattr", "zopfli (>=0.1.4)"] +graphite = ["lz4 (>=1.7.4.2)"] +interpolatable = ["munkres", "scipy"] +lxml = ["lxml (>=4.0,<5)"] +pathops = ["skia-pathops (>=0.5.0)"] +plot = ["matplotlib"] +repacker = ["uharfbuzz (>=0.23.0)"] +symfont = ["sympy"] +type1 = ["xattr"] +ufo = ["fs (>=2.2.0,<3)"] +unicode = ["unicodedata2 (>=15.0.0)"] +woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] + +[[package]] +name = "frozenlist" +version = "1.4.0" +description = "A list-like structure which implements collections.abc.MutableSequence" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, + {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, + {file = "frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ac08e601308e41eb533f232dbf6b7e4cea762f9f84f6357136eed926c15d12c"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d081f13b095d74b67d550de04df1c756831f3b83dc9881c38985834387487f1b"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71932b597f9895f011f47f17d6428252fc728ba2ae6024e13c3398a087c2cdea"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:981b9ab5a0a3178ff413bca62526bb784249421c24ad7381e39d67981be2c326"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e41f3de4df3e80de75845d3e743b3f1c4c8613c3997a912dbf0229fc61a8b963"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6918d49b1f90821e93069682c06ffde41829c346c66b721e65a5c62b4bab0300"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e5c8764c7829343d919cc2dfc587a8db01c4f70a4ebbc49abde5d4b158b007b"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8d0edd6b1c7fb94922bf569c9b092ee187a83f03fb1a63076e7774b60f9481a8"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e29cda763f752553fa14c68fb2195150bfab22b352572cb36c43c47bedba70eb"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:0c7c1b47859ee2cac3846fde1c1dc0f15da6cec5a0e5c72d101e0f83dcb67ff9"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:901289d524fdd571be1c7be054f48b1f88ce8dddcbdf1ec698b27d4b8b9e5d62"}, + {file = "frozenlist-1.4.0-cp310-cp310-win32.whl", hash = "sha256:1a0848b52815006ea6596c395f87449f693dc419061cc21e970f139d466dc0a0"}, + {file = "frozenlist-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:b206646d176a007466358aa21d85cd8600a415c67c9bd15403336c331a10d956"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de343e75f40e972bae1ef6090267f8260c1446a1695e77096db6cfa25e759a95"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad2a9eb6d9839ae241701d0918f54c51365a51407fd80f6b8289e2dfca977cc3"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7bd3b3830247580de99c99ea2a01416dfc3c34471ca1298bccabf86d0ff4dc"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdf1847068c362f16b353163391210269e4f0569a3c166bc6a9f74ccbfc7e839"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38461d02d66de17455072c9ba981d35f1d2a73024bee7790ac2f9e361ef1cd0c"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5a32087d720c608f42caed0ef36d2b3ea61a9d09ee59a5142d6070da9041b8f"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd65632acaf0d47608190a71bfe46b209719bf2beb59507db08ccdbe712f969b"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261b9f5d17cac914531331ff1b1d452125bf5daa05faf73b71d935485b0c510b"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b89ac9768b82205936771f8d2eb3ce88503b1556324c9f903e7156669f521472"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:008eb8b31b3ea6896da16c38c1b136cb9fec9e249e77f6211d479db79a4eaf01"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:490132667476f6781b4c9458298b0c1cddf237488abd228b0b3650e5ecba7467"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:76d4711f6f6d08551a7e9ef28c722f4a50dd0fc204c56b4bcd95c6cc05ce6fbb"}, + {file = "frozenlist-1.4.0-cp311-cp311-win32.whl", hash = "sha256:a02eb8ab2b8f200179b5f62b59757685ae9987996ae549ccf30f983f40602431"}, + {file = "frozenlist-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:515e1abc578dd3b275d6a5114030b1330ba044ffba03f94091842852f806f1c1"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ca265542ca427bf97aed183c1676e2a9c66942e822b14dc6e5f42e038f92a503"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:491e014f5c43656da08958808588cc6c016847b4360e327a62cb308c791bd2d9"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ae5cd0f333f94f2e03aaf140bb762c64783935cc764ff9c82dff626089bebf"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e78fb68cf9c1a6aa4a9a12e960a5c9dfbdb89b3695197aa7064705662515de2"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5655a942f5f5d2c9ed93d72148226d75369b4f6952680211972a33e59b1dfdc"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11b0746f5d946fecf750428a95f3e9ebe792c1ee3b1e96eeba145dc631a9672"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e66d2a64d44d50d2543405fb183a21f76b3b5fd16f130f5c99187c3fb4e64919"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f7bc0fcca81f985f78dd0fa68d2c75abf8272b1f5c323ea4a01a4d7a614efc"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5833593c25ac59ede40ed4de6d67eb42928cca97f26feea219f21d0ed0959b79"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:b826d97e4276750beca7c8f0f1a4938892697a6bcd8ec8217b3312dad6982781"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ceb6ec0a10c65540421e20ebd29083c50e6d1143278746a4ef6bcf6153171eb8"}, + {file = "frozenlist-1.4.0-cp38-cp38-win32.whl", hash = "sha256:2b8bcf994563466db019fab287ff390fffbfdb4f905fc77bc1c1d604b1c689cc"}, + {file = "frozenlist-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:a6c8097e01886188e5be3e6b14e94ab365f384736aa1fca6a0b9e35bd4a30bc7"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6c38721585f285203e4b4132a352eb3daa19121a035f3182e08e437cface44bf"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0c6da9aee33ff0b1a451e867da0c1f47408112b3391dd43133838339e410963"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93ea75c050c5bb3d98016b4ba2497851eadf0ac154d88a67d7a6816206f6fa7f"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa384489fefeb62321b238e64c07ef48398fe80f9e1e6afeff22e140e0850eef"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10ff5faaa22786315ef57097a279b833ecab1a0bfb07d604c9cbb1c4cdc2ed87"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:007df07a6e3eb3e33e9a1fe6a9db7af152bbd8a185f9aaa6ece10a3529e3e1c6"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4f399d28478d1f604c2ff9119907af9726aed73680e5ed1ca634d377abb087"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5374b80521d3d3f2ec5572e05adc94601985cc526fb276d0c8574a6d749f1b3"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ce31ae3e19f3c902de379cf1323d90c649425b86de7bbdf82871b8a2a0615f3d"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7211ef110a9194b6042449431e08c4d80c0481e5891e58d429df5899690511c2"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:556de4430ce324c836789fa4560ca62d1591d2538b8ceb0b4f68fb7b2384a27a"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7645a8e814a3ee34a89c4a372011dcd817964ce8cb273c8ed6119d706e9613e3"}, + {file = "frozenlist-1.4.0-cp39-cp39-win32.whl", hash = "sha256:19488c57c12d4e8095a922f328df3f179c820c212940a498623ed39160bc3c2f"}, + {file = "frozenlist-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:6221d84d463fb110bdd7619b69cb43878a11d51cbb9394ae3105d082d5199167"}, + {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"}, +] + +[[package]] +name = "fsspec" +version = "2023.6.0" +description = "File-system specification" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2023.6.0-py3-none-any.whl", hash = "sha256:1cbad1faef3e391fba6dc005ae9b5bdcbf43005c9167ce78c915549c352c869a"}, + {file = "fsspec-2023.6.0.tar.gz", hash = "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"}, +] + +[package.dependencies] +aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} +requests = {version = "*", optional = true, markers = "extra == \"http\""} + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + +[[package]] +name = "huggingface-hub" +version = "0.16.4" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +category = "dev" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"}, + {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic"] +quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["torch"] +typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] + +[[package]] +name = "identify" +version = "2.5.25" +description = "File identification library for Python" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.25-py2.py3-none-any.whl", hash = "sha256:9df2489842707d431b38ce3410ef8df40da5b10a3e28a3fcac1a42523e956409"}, + {file = "identify-2.5.25.tar.gz", hash = "sha256:db4de0e758c0db8f81996816cd2f3f2f8c5c8d49a7fd02f3b4109aac6fd80e29"}, +] + +[package.extras] +license = ["ukkonen"] + +[[package]] +name = "idna" +version = "3.4" +description = "Internationalized Domain Names in Applications (IDNA)" +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, + {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, +] + +[[package]] +name = "importlib-metadata" +version = "6.8.0" +description = "Read metadata from Python packages" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, + {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + +[[package]] +name = "importlib-resources" +version = "6.0.0" +description = "Read resources from Python packages" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_resources-6.0.0-py3-none-any.whl", hash = "sha256:d952faee11004c045f785bb5636e8f885bed30dc3c940d5d42798a2a4541c185"}, + {file = "importlib_resources-6.0.0.tar.gz", hash = "sha256:4cf94875a8368bd89531a756df9a9ebe1f150e0f885030b461237bc7f2d905f2"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "ipykernel" +version = "6.24.0" +description = "IPython Kernel for Jupyter" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ipykernel-6.24.0-py3-none-any.whl", hash = "sha256:2f5fffc7ad8f1fd5aadb4e171ba9129d9668dbafa374732cf9511ada52d6547f"}, + {file = "ipykernel-6.24.0.tar.gz", hash = "sha256:29cea0a716b1176d002a61d0b0c851f34536495bc4ef7dd0222c88b41b816123"}, +] + +[package.dependencies] +appnope = {version = "*", markers = "platform_system == \"Darwin\""} +comm = ">=0.1.1" +debugpy = ">=1.6.5" +ipython = ">=7.23.1" +jupyter-client = ">=6.1.12" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +matplotlib-inline = ">=0.1" +nest-asyncio = "*" +packaging = "*" +psutil = "*" +pyzmq = ">=20" +tornado = ">=6.1" +traitlets = ">=5.4.0" + +[package.extras] +cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"] +pyqt5 = ["pyqt5"] +pyside6 = ["pyside6"] +test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "ipython" +version = "8.14.0" +description = "IPython: Productive Interactive Computing" +category = "dev" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ipython-8.14.0-py3-none-any.whl", hash = "sha256:248aca623f5c99a6635bc3857677b7320b9b8039f99f070ee0d20a5ca5a8e6bf"}, + {file = "ipython-8.14.0.tar.gz", hash = "sha256:1d197b907b6ba441b692c48cf2a3a2de280dc0ac91a3405b39349a50272ca0a1"}, +] + +[package.dependencies] +appnope = {version = "*", markers = "sys_platform == \"darwin\""} +backcall = "*" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} +pickleshare = "*" +prompt-toolkit = ">=3.0.30,<3.0.37 || >3.0.37,<3.1.0" +pygments = ">=2.4.0" +stack-data = "*" +traitlets = ">=5" +typing-extensions = {version = "*", markers = "python_version < \"3.10\""} + +[package.extras] +all = ["black", "curio", "docrepr", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.21)", "pandas", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] +black = ["black"] +doc = ["docrepr", "ipykernel", "matplotlib", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"] +kernel = ["ipykernel"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["pytest (<7.1)", "pytest-asyncio", "testpath"] +test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"] + +[[package]] +name = "isort" +version = "5.12.0" +description = "A Python utility / library to sort Python imports." +category = "dev" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"}, + {file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"}, +] + +[package.extras] +colors = ["colorama (>=0.4.3)"] +pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"] +plugins = ["setuptools"] +requirements-deprecated-finder = ["pip-api", "pipreqs"] + +[[package]] +name = "jax" +version = "0.4.13" +description = "Differentiate, compile, and transform Numpy code." +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, +] + +[package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} +ml_dtypes = ">=0.1.0" +numpy = ">=1.21" +opt_einsum = "*" +scipy = ">=1.7" + +[package.extras] +australis = ["protobuf (>=3.13,<4)"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] + +[[package]] +name = "jaxlib" +version = "0.4.13" +description = "XLA library for JAX" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jaxlib-0.4.13-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac"}, + {file = "jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8"}, + {file = "jaxlib-0.4.13-cp310-cp310-win_amd64.whl", hash = "sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649"}, + {file = "jaxlib-0.4.13-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89"}, + {file = "jaxlib-0.4.13-cp311-cp311-win_amd64.whl", hash = "sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0"}, + {file = "jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc"}, + {file = "jaxlib-0.4.13-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5"}, + {file = "jaxlib-0.4.13-cp39-cp39-win_amd64.whl", hash = "sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e"}, +] + +[package.dependencies] +ml-dtypes = ">=0.1.0" +numpy = ">=1.21" +scipy = ">=1.7" + +[package.extras] +cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] + +[[package]] +name = "jedi" +version = "0.18.2" +description = "An autocompletion tool for Python that can be used for text editors." +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"}, + {file = "jedi-0.18.2.tar.gz", hash = "sha256:bae794c30d07f6d910d32a7048af09b5a39ed740918da923c6b780790ebac612"}, +] + +[package.dependencies] +parso = ">=0.8.0,<0.9.0" + +[package.extras] +docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] +qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] + +[[package]] +name = "jupyter-client" +version = "8.3.0" +description = "Jupyter protocol implementation and client libraries" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_client-8.3.0-py3-none-any.whl", hash = "sha256:7441af0c0672edc5d28035e92ba5e32fadcfa8a4e608a434c228836a89df6158"}, + {file = "jupyter_client-8.3.0.tar.gz", hash = "sha256:3af69921fe99617be1670399a0b857ad67275eefcfa291e2c81a160b7b650f5f"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +python-dateutil = ">=2.8.2" +pyzmq = ">=23.0" +tornado = ">=6.2" +traitlets = ">=5.3" + +[package.extras] +docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] +test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] + +[[package]] +name = "jupyter-core" +version = "5.3.1" +description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_core-5.3.1-py3-none-any.whl", hash = "sha256:ae9036db959a71ec1cac33081eeb040a79e681f08ab68b0883e9a676c7a90dce"}, + {file = "jupyter_core-5.3.1.tar.gz", hash = "sha256:5ba5c7938a7f97a6b0481463f7ff0dbac7c15ba48cf46fa4035ca6e838aa1aba"}, +] + +[package.dependencies] +platformdirs = ">=2.5" +pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\""} +traitlets = ">=5.3" + +[package.extras] +docs = ["myst-parser", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] +test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "kiwisolver" +version = "1.4.4" +description = "A fast implementation of the Cassowary constraint solver" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "kiwisolver-1.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2f5e60fabb7343a836360c4f0919b8cd0d6dbf08ad2ca6b9cf90bf0c76a3c4f6"}, + {file = "kiwisolver-1.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:10ee06759482c78bdb864f4109886dff7b8a56529bc1609d4f1112b93fe6423c"}, + {file = "kiwisolver-1.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c79ebe8f3676a4c6630fd3f777f3cfecf9289666c84e775a67d1d358578dc2e3"}, + {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:abbe9fa13da955feb8202e215c4018f4bb57469b1b78c7a4c5c7b93001699938"}, + {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7577c1987baa3adc4b3c62c33bd1118c3ef5c8ddef36f0f2c950ae0b199e100d"}, + {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8ad8285b01b0d4695102546b342b493b3ccc6781fc28c8c6a1bb63e95d22f09"}, + {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ed58b8acf29798b036d347791141767ccf65eee7f26bde03a71c944449e53de"}, + {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a68b62a02953b9841730db7797422f983935aeefceb1679f0fc85cbfbd311c32"}, + {file = "kiwisolver-1.4.4-cp310-cp310-win32.whl", hash = "sha256:e92a513161077b53447160b9bd8f522edfbed4bd9759e4c18ab05d7ef7e49408"}, + {file = "kiwisolver-1.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:3fe20f63c9ecee44560d0e7f116b3a747a5d7203376abeea292ab3152334d004"}, + {file = "kiwisolver-1.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ea21f66820452a3f5d1655f8704a60d66ba1191359b96541eaf457710a5fc6"}, + {file = "kiwisolver-1.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bc9db8a3efb3e403e4ecc6cd9489ea2bac94244f80c78e27c31dcc00d2790ac2"}, + {file = "kiwisolver-1.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d5b61785a9ce44e5a4b880272baa7cf6c8f48a5180c3e81c59553ba0cb0821ca"}, + {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c2dbb44c3f7e6c4d3487b31037b1bdbf424d97687c1747ce4ff2895795c9bf69"}, + {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295ecd49304dcf3bfbfa45d9a081c96509e95f4b9d0eb7ee4ec0530c4a96514"}, + {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4bd472dbe5e136f96a4b18f295d159d7f26fd399136f5b17b08c4e5f498cd494"}, + {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf7d9fce9bcc4752ca4a1b80aabd38f6d19009ea5cbda0e0856983cf6d0023f5"}, + {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78d6601aed50c74e0ef02f4204da1816147a6d3fbdc8b3872d263338a9052c51"}, + {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:877272cf6b4b7e94c9614f9b10140e198d2186363728ed0f701c6eee1baec1da"}, + {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:db608a6757adabb32f1cfe6066e39b3706d8c3aa69bbc353a5b61edad36a5cb4"}, + {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:5853eb494c71e267912275e5586fe281444eb5e722de4e131cddf9d442615626"}, + {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f0a1dbdb5ecbef0d34eb77e56fcb3e95bbd7e50835d9782a45df81cc46949750"}, + {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:283dffbf061a4ec60391d51e6155e372a1f7a4f5b15d59c8505339454f8989e4"}, + {file = "kiwisolver-1.4.4-cp311-cp311-win32.whl", hash = "sha256:d06adcfa62a4431d404c31216f0f8ac97397d799cd53800e9d3efc2fbb3cf14e"}, + {file = "kiwisolver-1.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:e7da3fec7408813a7cebc9e4ec55afed2d0fd65c4754bc376bf03498d4e92686"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:62ac9cc684da4cf1778d07a89bf5f81b35834cb96ca523d3a7fb32509380cbf6"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41dae968a94b1ef1897cb322b39360a0812661dba7c682aa45098eb8e193dbdf"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02f79693ec433cb4b5f51694e8477ae83b3205768a6fb48ffba60549080e295b"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0611a0a2a518464c05ddd5a3a1a0e856ccc10e67079bb17f265ad19ab3c7597"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:db5283d90da4174865d520e7366801a93777201e91e79bacbac6e6927cbceede"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1041feb4cda8708ce73bb4dcb9ce1ccf49d553bf87c3954bdfa46f0c3f77252c"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-win32.whl", hash = "sha256:a553dadda40fef6bfa1456dc4be49b113aa92c2a9a9e8711e955618cd69622e3"}, + {file = "kiwisolver-1.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:03baab2d6b4a54ddbb43bba1a3a2d1627e82d205c5cf8f4c924dc49284b87166"}, + {file = "kiwisolver-1.4.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:841293b17ad704d70c578f1f0013c890e219952169ce8a24ebc063eecf775454"}, + {file = "kiwisolver-1.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f4f270de01dd3e129a72efad823da90cc4d6aafb64c410c9033aba70db9f1ff0"}, + {file = "kiwisolver-1.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f9f39e2f049db33a908319cf46624a569b36983c7c78318e9726a4cb8923b26c"}, + {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c97528e64cb9ebeff9701e7938653a9951922f2a38bd847787d4a8e498cc83ae"}, + {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d1573129aa0fd901076e2bfb4275a35f5b7aa60fbfb984499d661ec950320b0"}, + {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad881edc7ccb9d65b0224f4e4d05a1e85cf62d73aab798943df6d48ab0cd79a1"}, + {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b428ef021242344340460fa4c9185d0b1f66fbdbfecc6c63eff4b7c29fad429d"}, + {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:2e407cb4bd5a13984a6c2c0fe1845e4e41e96f183e5e5cd4d77a857d9693494c"}, + {file = "kiwisolver-1.4.4-cp38-cp38-win32.whl", hash = "sha256:75facbe9606748f43428fc91a43edb46c7ff68889b91fa31f53b58894503a191"}, + {file = "kiwisolver-1.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:5bce61af018b0cb2055e0e72e7d65290d822d3feee430b7b8203d8a855e78766"}, + {file = "kiwisolver-1.4.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8c808594c88a025d4e322d5bb549282c93c8e1ba71b790f539567932722d7bd8"}, + {file = "kiwisolver-1.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f0a71d85ecdd570ded8ac3d1c0f480842f49a40beb423bb8014539a9f32a5897"}, + {file = "kiwisolver-1.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b533558eae785e33e8c148a8d9921692a9fe5aa516efbdff8606e7d87b9d5824"}, + {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:efda5fc8cc1c61e4f639b8067d118e742b812c930f708e6667a5ce0d13499e29"}, + {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7c43e1e1206cd421cd92e6b3280d4385d41d7166b3ed577ac20444b6995a445f"}, + {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc8d3bd6c72b2dd9decf16ce70e20abcb3274ba01b4e1c96031e0c4067d1e7cd"}, + {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ea39b0ccc4f5d803e3337dd46bcce60b702be4d86fd0b3d7531ef10fd99a1ac"}, + {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968f44fdbf6dd757d12920d63b566eeb4d5b395fd2d00d29d7ef00a00582aac9"}, + {file = "kiwisolver-1.4.4-cp39-cp39-win32.whl", hash = "sha256:da7e547706e69e45d95e116e6939488d62174e033b763ab1496b4c29b76fabea"}, + {file = "kiwisolver-1.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:ba59c92039ec0a66103b1d5fe588fa546373587a7d68f5c96f743c3396afc04b"}, + {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:91672bacaa030f92fc2f43b620d7b337fd9a5af28b0d6ed3f77afc43c4a64b5a"}, + {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:787518a6789009c159453da4d6b683f468ef7a65bbde796bcea803ccf191058d"}, + {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da152d8cdcab0e56e4f45eb08b9aea6455845ec83172092f09b0e077ece2cf7a"}, + {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ecb1fa0db7bf4cff9dac752abb19505a233c7f16684c5826d1f11ebd9472b871"}, + {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:28bc5b299f48150b5f822ce68624e445040595a4ac3d59251703779836eceff9"}, + {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:81e38381b782cc7e1e46c4e14cd997ee6040768101aefc8fa3c24a4cc58e98f8"}, + {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2a66fdfb34e05b705620dd567f5a03f239a088d5a3f321e7b6ac3239d22aa286"}, + {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:872b8ca05c40d309ed13eb2e582cab0c5a05e81e987ab9c521bf05ad1d5cf5cb"}, + {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:70e7c2e7b750585569564e2e5ca9845acfaa5da56ac46df68414f29fea97be9f"}, + {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9f85003f5dfa867e86d53fac6f7e6f30c045673fa27b603c397753bebadc3008"}, + {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e307eb9bd99801f82789b44bb45e9f541961831c7311521b13a6c85afc09767"}, + {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1792d939ec70abe76f5054d3f36ed5656021dcad1322d1cc996d4e54165cef9"}, + {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6cb459eea32a4e2cf18ba5fcece2dbdf496384413bc1bae15583f19e567f3b2"}, + {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:36dafec3d6d6088d34e2de6b85f9d8e2324eb734162fba59d2ba9ed7a2043d5b"}, + {file = "kiwisolver-1.4.4.tar.gz", hash = "sha256:d41997519fcba4a1e46eb4a2fe31bc12f0ff957b2b81bac28db24744f333e955"}, +] + +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + +[[package]] +name = "matplotlib" +version = "3.7.2" +description = "Python plotting package" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "matplotlib-3.7.2-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:2699f7e73a76d4c110f4f25be9d2496d6ab4f17345307738557d345f099e07de"}, + {file = "matplotlib-3.7.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a8035ba590658bae7562786c9cc6ea1a84aa49d3afab157e414c9e2ea74f496d"}, + {file = "matplotlib-3.7.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f8e4a49493add46ad4a8c92f63e19d548b2b6ebbed75c6b4c7f46f57d36cdd1"}, + {file = "matplotlib-3.7.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71667eb2ccca4c3537d9414b1bc00554cb7f91527c17ee4ec38027201f8f1603"}, + {file = "matplotlib-3.7.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:152ee0b569a37630d8628534c628456b28686e085d51394da6b71ef84c4da201"}, + {file = "matplotlib-3.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:070f8dddd1f5939e60aacb8fa08f19551f4b0140fab16a3669d5cd6e9cb28fc8"}, + {file = "matplotlib-3.7.2-cp310-cp310-win32.whl", hash = "sha256:fdbb46fad4fb47443b5b8ac76904b2e7a66556844f33370861b4788db0f8816a"}, + {file = "matplotlib-3.7.2-cp310-cp310-win_amd64.whl", hash = "sha256:23fb1750934e5f0128f9423db27c474aa32534cec21f7b2153262b066a581fd1"}, + {file = "matplotlib-3.7.2-cp311-cp311-macosx_10_12_universal2.whl", hash = "sha256:30e1409b857aa8a747c5d4f85f63a79e479835f8dffc52992ac1f3f25837b544"}, + {file = "matplotlib-3.7.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:50e0a55ec74bf2d7a0ebf50ac580a209582c2dd0f7ab51bc270f1b4a0027454e"}, + {file = "matplotlib-3.7.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ac60daa1dc83e8821eed155796b0f7888b6b916cf61d620a4ddd8200ac70cd64"}, + {file = "matplotlib-3.7.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:305e3da477dc8607336ba10bac96986d6308d614706cae2efe7d3ffa60465b24"}, + {file = "matplotlib-3.7.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c308b255efb9b06b23874236ec0f10f026673ad6515f602027cc8ac7805352d"}, + {file = "matplotlib-3.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60c521e21031632aa0d87ca5ba0c1c05f3daacadb34c093585a0be6780f698e4"}, + {file = "matplotlib-3.7.2-cp311-cp311-win32.whl", hash = "sha256:26bede320d77e469fdf1bde212de0ec889169b04f7f1179b8930d66f82b30cbc"}, + {file = "matplotlib-3.7.2-cp311-cp311-win_amd64.whl", hash = "sha256:af4860132c8c05261a5f5f8467f1b269bf1c7c23902d75f2be57c4a7f2394b3e"}, + {file = "matplotlib-3.7.2-cp38-cp38-macosx_10_12_universal2.whl", hash = "sha256:a1733b8e84e7e40a9853e505fe68cc54339f97273bdfe6f3ed980095f769ddc7"}, + {file = "matplotlib-3.7.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d9881356dc48e58910c53af82b57183879129fa30492be69058c5b0d9fddf391"}, + {file = "matplotlib-3.7.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f081c03f413f59390a80b3e351cc2b2ea0205839714dbc364519bcf51f4b56ca"}, + {file = "matplotlib-3.7.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cd120fca3407a225168238b790bd5c528f0fafde6172b140a2f3ab7a4ea63e9"}, + {file = "matplotlib-3.7.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a2c1590b90aa7bd741b54c62b78de05d4186271e34e2377e0289d943b3522273"}, + {file = "matplotlib-3.7.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d2ff3c984b8a569bc1383cd468fc06b70d7b59d5c2854ca39f1436ae8394117"}, + {file = "matplotlib-3.7.2-cp38-cp38-win32.whl", hash = "sha256:5dea00b62d28654b71ca92463656d80646675628d0828e08a5f3b57e12869e13"}, + {file = "matplotlib-3.7.2-cp38-cp38-win_amd64.whl", hash = "sha256:0f506a1776ee94f9e131af1ac6efa6e5bc7cb606a3e389b0ccb6e657f60bb676"}, + {file = "matplotlib-3.7.2-cp39-cp39-macosx_10_12_universal2.whl", hash = "sha256:6515e878f91894c2e4340d81f0911857998ccaf04dbc1bba781e3d89cbf70608"}, + {file = "matplotlib-3.7.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:71f7a8c6b124e904db550f5b9fe483d28b896d4135e45c4ea381ad3b8a0e3256"}, + {file = "matplotlib-3.7.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:12f01b92ecd518e0697da4d97d163b2b3aa55eb3eb4e2c98235b3396d7dad55f"}, + {file = "matplotlib-3.7.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7e28d6396563955f7af437894a36bf2b279462239a41028323e04b85179058b"}, + {file = "matplotlib-3.7.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbcf59334ff645e6a67cd5f78b4b2cdb76384cdf587fa0d2dc85f634a72e1a3e"}, + {file = "matplotlib-3.7.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:318c89edde72ff95d8df67d82aca03861240512994a597a435a1011ba18dbc7f"}, + {file = "matplotlib-3.7.2-cp39-cp39-win32.whl", hash = "sha256:ce55289d5659b5b12b3db4dc9b7075b70cef5631e56530f14b2945e8836f2d20"}, + {file = "matplotlib-3.7.2-cp39-cp39-win_amd64.whl", hash = "sha256:2ecb5be2b2815431c81dc115667e33da0f5a1bcf6143980d180d09a717c4a12e"}, + {file = "matplotlib-3.7.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fdcd28360dbb6203fb5219b1a5658df226ac9bebc2542a9e8f457de959d713d0"}, + {file = "matplotlib-3.7.2-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c3cca3e842b11b55b52c6fb8bd6a4088693829acbfcdb3e815fa9b7d5c92c1b"}, + {file = "matplotlib-3.7.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebf577c7a6744e9e1bd3fee45fc74a02710b214f94e2bde344912d85e0c9af7c"}, + {file = "matplotlib-3.7.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:936bba394682049919dda062d33435b3be211dc3dcaa011e09634f060ec878b2"}, + {file = "matplotlib-3.7.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bc221ffbc2150458b1cd71cdd9ddd5bb37962b036e41b8be258280b5b01da1dd"}, + {file = "matplotlib-3.7.2-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35d74ebdb3f71f112b36c2629cf32323adfbf42679e2751252acd468f5001c07"}, + {file = "matplotlib-3.7.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:717157e61b3a71d3d26ad4e1770dc85156c9af435659a25ee6407dc866cb258d"}, + {file = "matplotlib-3.7.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:20f844d6be031948148ba49605c8b96dfe7d3711d1b63592830d650622458c11"}, + {file = "matplotlib-3.7.2.tar.gz", hash = "sha256:a8cdb91dddb04436bd2f098b8fdf4b81352e68cf4d2c6756fcc414791076569b"}, +] + +[package.dependencies] +contourpy = ">=1.0.1" +cycler = ">=0.10" +fonttools = ">=4.22.0" +importlib-resources = {version = ">=3.2.0", markers = "python_version < \"3.10\""} +kiwisolver = ">=1.0.1" +numpy = ">=1.20" +packaging = ">=20.0" +pillow = ">=6.2.0" +pyparsing = ">=2.3.1,<3.1" +python-dateutil = ">=2.7" + +[[package]] +name = "matplotlib-inline" +version = "0.1.6" +description = "Inline Matplotlib backend for Jupyter" +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, + {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, +] + +[package.dependencies] +traitlets = "*" + +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + +[[package]] +name = "ml-dtypes" +version = "0.2.0" +description = "" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ml_dtypes-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd"}, + {file = "ml_dtypes-0.2.0.tar.gz", hash = "sha256:6488eb642acaaf08d8020f6de0a38acee7ac324c1e6e92ee0c0fea42422cb797"}, +] + +[package.dependencies] +numpy = [ + {version = ">1.20", markers = "python_version <= \"3.9\""}, + {version = ">=1.23.3", markers = "python_version > \"3.10\""}, + {version = ">=1.21.2", markers = "python_version > \"3.9\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + +[[package]] +name = "msgpack" +version = "1.0.5" +description = "MessagePack serializer" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "msgpack-1.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:525228efd79bb831cf6830a732e2e80bc1b05436b086d4264814b4b2955b2fa9"}, + {file = "msgpack-1.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4f8d8b3bf1ff2672567d6b5c725a1b347fe838b912772aa8ae2bf70338d5a198"}, + {file = "msgpack-1.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cdc793c50be3f01106245a61b739328f7dccc2c648b501e237f0699fe1395b81"}, + {file = "msgpack-1.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cb47c21a8a65b165ce29f2bec852790cbc04936f502966768e4aae9fa763cb7"}, + {file = "msgpack-1.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e42b9594cc3bf4d838d67d6ed62b9e59e201862a25e9a157019e171fbe672dd3"}, + {file = "msgpack-1.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:55b56a24893105dc52c1253649b60f475f36b3aa0fc66115bffafb624d7cb30b"}, + {file = "msgpack-1.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:1967f6129fc50a43bfe0951c35acbb729be89a55d849fab7686004da85103f1c"}, + {file = "msgpack-1.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:20a97bf595a232c3ee6d57ddaadd5453d174a52594bf9c21d10407e2a2d9b3bd"}, + {file = "msgpack-1.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d25dd59bbbbb996eacf7be6b4ad082ed7eacc4e8f3d2df1ba43822da9bfa122a"}, + {file = "msgpack-1.0.5-cp310-cp310-win32.whl", hash = "sha256:382b2c77589331f2cb80b67cc058c00f225e19827dbc818d700f61513ab47bea"}, + {file = "msgpack-1.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:4867aa2df9e2a5fa5f76d7d5565d25ec76e84c106b55509e78c1ede0f152659a"}, + {file = "msgpack-1.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9f5ae84c5c8a857ec44dc180a8b0cc08238e021f57abdf51a8182e915e6299f0"}, + {file = "msgpack-1.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9e6ca5d5699bcd89ae605c150aee83b5321f2115695e741b99618f4856c50898"}, + {file = "msgpack-1.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5494ea30d517a3576749cad32fa27f7585c65f5f38309c88c6d137877fa28a5a"}, + {file = "msgpack-1.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ab2f3331cb1b54165976a9d976cb251a83183631c88076613c6c780f0d6e45a"}, + {file = "msgpack-1.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28592e20bbb1620848256ebc105fc420436af59515793ed27d5c77a217477705"}, + {file = "msgpack-1.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe5c63197c55bce6385d9aee16c4d0641684628f63ace85f73571e65ad1c1e8d"}, + {file = "msgpack-1.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ed40e926fa2f297e8a653c954b732f125ef97bdd4c889f243182299de27e2aa9"}, + {file = "msgpack-1.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b2de4c1c0538dcb7010902a2b97f4e00fc4ddf2c8cda9749af0e594d3b7fa3d7"}, + {file = "msgpack-1.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bf22a83f973b50f9d38e55c6aade04c41ddda19b00c4ebc558930d78eecc64ed"}, + {file = "msgpack-1.0.5-cp311-cp311-win32.whl", hash = "sha256:c396e2cc213d12ce017b686e0f53497f94f8ba2b24799c25d913d46c08ec422c"}, + {file = "msgpack-1.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c4c68d87497f66f96d50142a2b73b97972130d93677ce930718f68828b382e2"}, + {file = "msgpack-1.0.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a2b031c2e9b9af485d5e3c4520f4220d74f4d222a5b8dc8c1a3ab9448ca79c57"}, + {file = "msgpack-1.0.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f837b93669ce4336e24d08286c38761132bc7ab29782727f8557e1eb21b2080"}, + {file = "msgpack-1.0.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1d46dfe3832660f53b13b925d4e0fa1432b00f5f7210eb3ad3bb9a13c6204a6"}, + {file = "msgpack-1.0.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:366c9a7b9057e1547f4ad51d8facad8b406bab69c7d72c0eb6f529cf76d4b85f"}, + {file = "msgpack-1.0.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:4c075728a1095efd0634a7dccb06204919a2f67d1893b6aa8e00497258bf926c"}, + {file = "msgpack-1.0.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:f933bbda5a3ee63b8834179096923b094b76f0c7a73c1cfe8f07ad608c58844b"}, + {file = "msgpack-1.0.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:36961b0568c36027c76e2ae3ca1132e35123dcec0706c4b7992683cc26c1320c"}, + {file = "msgpack-1.0.5-cp36-cp36m-win32.whl", hash = "sha256:b5ef2f015b95f912c2fcab19c36814963b5463f1fb9049846994b007962743e9"}, + {file = "msgpack-1.0.5-cp36-cp36m-win_amd64.whl", hash = "sha256:288e32b47e67f7b171f86b030e527e302c91bd3f40fd9033483f2cacc37f327a"}, + {file = "msgpack-1.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:137850656634abddfb88236008339fdaba3178f4751b28f270d2ebe77a563b6c"}, + {file = "msgpack-1.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c05a4a96585525916b109bb85f8cb6511db1c6f5b9d9cbcbc940dc6b4be944b"}, + {file = "msgpack-1.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a62ec00b636583e5cb6ad313bbed36bb7ead5fa3a3e38938503142c72cba4f"}, + {file = "msgpack-1.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef8108f8dedf204bb7b42994abf93882da1159728a2d4c5e82012edd92c9da9f"}, + {file = "msgpack-1.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1835c84d65f46900920b3708f5ba829fb19b1096c1800ad60bae8418652a951d"}, + {file = "msgpack-1.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e57916ef1bd0fee4f21c4600e9d1da352d8816b52a599c46460e93a6e9f17086"}, + {file = "msgpack-1.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:17358523b85973e5f242ad74aa4712b7ee560715562554aa2134d96e7aa4cbbf"}, + {file = "msgpack-1.0.5-cp37-cp37m-win32.whl", hash = "sha256:cb5aaa8c17760909ec6cb15e744c3ebc2ca8918e727216e79607b7bbce9c8f77"}, + {file = "msgpack-1.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:ab31e908d8424d55601ad7075e471b7d0140d4d3dd3272daf39c5c19d936bd82"}, + {file = "msgpack-1.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b72d0698f86e8d9ddf9442bdedec15b71df3598199ba33322d9711a19f08145c"}, + {file = "msgpack-1.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:379026812e49258016dd84ad79ac8446922234d498058ae1d415f04b522d5b2d"}, + {file = "msgpack-1.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:332360ff25469c346a1c5e47cbe2a725517919892eda5cfaffe6046656f0b7bb"}, + {file = "msgpack-1.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:476a8fe8fae289fdf273d6d2a6cb6e35b5a58541693e8f9f019bfe990a51e4ba"}, + {file = "msgpack-1.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9985b214f33311df47e274eb788a5893a761d025e2b92c723ba4c63936b69b1"}, + {file = "msgpack-1.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48296af57cdb1d885843afd73c4656be5c76c0c6328db3440c9601a98f303d87"}, + {file = "msgpack-1.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:addab7e2e1fcc04bd08e4eb631c2a90960c340e40dfc4a5e24d2ff0d5a3b3edb"}, + {file = "msgpack-1.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:916723458c25dfb77ff07f4c66aed34e47503b2eb3188b3adbec8d8aa6e00f48"}, + {file = "msgpack-1.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:821c7e677cc6acf0fd3f7ac664c98803827ae6de594a9f99563e48c5a2f27eb0"}, + {file = "msgpack-1.0.5-cp38-cp38-win32.whl", hash = "sha256:1c0f7c47f0087ffda62961d425e4407961a7ffd2aa004c81b9c07d9269512f6e"}, + {file = "msgpack-1.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:bae7de2026cbfe3782c8b78b0db9cbfc5455e079f1937cb0ab8d133496ac55e1"}, + {file = "msgpack-1.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:20c784e66b613c7f16f632e7b5e8a1651aa5702463d61394671ba07b2fc9e025"}, + {file = "msgpack-1.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:266fa4202c0eb94d26822d9bfd7af25d1e2c088927fe8de9033d929dd5ba24c5"}, + {file = "msgpack-1.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18334484eafc2b1aa47a6d42427da7fa8f2ab3d60b674120bce7a895a0a85bdd"}, + {file = "msgpack-1.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57e1f3528bd95cc44684beda696f74d3aaa8a5e58c816214b9046512240ef437"}, + {file = "msgpack-1.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:586d0d636f9a628ddc6a17bfd45aa5b5efaf1606d2b60fa5d87b8986326e933f"}, + {file = "msgpack-1.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a740fa0e4087a734455f0fc3abf5e746004c9da72fbd541e9b113013c8dc3282"}, + {file = "msgpack-1.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3055b0455e45810820db1f29d900bf39466df96ddca11dfa6d074fa47054376d"}, + {file = "msgpack-1.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a61215eac016f391129a013c9e46f3ab308db5f5ec9f25811e811f96962599a8"}, + {file = "msgpack-1.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:362d9655cd369b08fda06b6657a303eb7172d5279997abe094512e919cf74b11"}, + {file = "msgpack-1.0.5-cp39-cp39-win32.whl", hash = "sha256:ac9dd47af78cae935901a9a500104e2dea2e253207c924cc95de149606dc43cc"}, + {file = "msgpack-1.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:06f5174b5f8ed0ed919da0e62cbd4ffde676a374aba4020034da05fab67b9164"}, + {file = "msgpack-1.0.5.tar.gz", hash = "sha256:c075544284eadc5cddc70f4757331d99dcbc16b2bbd4849d15f8aae4cf36d31c"}, +] + +[[package]] +name = "multidict" +version = "6.0.4" +description = "multidict implementation" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, + {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"}, + {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"}, + {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"}, + {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"}, + {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"}, + {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"}, + {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"}, + {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"}, + {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"}, + {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"}, + {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"}, + {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"}, + {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"}, + {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"}, + {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"}, + {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"}, + {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"}, + {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"}, + {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"}, + {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"}, + {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"}, + {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"}, + {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"}, + {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"}, + {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"}, + {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"}, + {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, +] + +[[package]] +name = "multiprocess" +version = "0.70.14" +description = "better multiprocessing and multithreading in python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multiprocess-0.70.14-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:560a27540daef4ce8b24ed3cc2496a3c670df66c96d02461a4da67473685adf3"}, + {file = "multiprocess-0.70.14-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:bfbbfa36f400b81d1978c940616bc77776424e5e34cb0c94974b178d727cfcd5"}, + {file = "multiprocess-0.70.14-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:89fed99553a04ec4f9067031f83a886d7fdec5952005551a896a4b6a59575bb9"}, + {file = "multiprocess-0.70.14-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:40a5e3685462079e5fdee7c6789e3ef270595e1755199f0d50685e72523e1d2a"}, + {file = "multiprocess-0.70.14-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:44936b2978d3f2648727b3eaeab6d7fa0bedf072dc5207bf35a96d5ee7c004cf"}, + {file = "multiprocess-0.70.14-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e628503187b5d494bf29ffc52d3e1e57bb770ce7ce05d67c4bbdb3a0c7d3b05f"}, + {file = "multiprocess-0.70.14-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0d5da0fc84aacb0e4bd69c41b31edbf71b39fe2fb32a54eaedcaea241050855c"}, + {file = "multiprocess-0.70.14-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:6a7b03a5b98e911a7785b9116805bd782815c5e2bd6c91c6a320f26fd3e7b7ad"}, + {file = "multiprocess-0.70.14-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cea5bdedd10aace3c660fedeac8b087136b4366d4ee49a30f1ebf7409bce00ae"}, + {file = "multiprocess-0.70.14-py310-none-any.whl", hash = "sha256:7dc1f2f6a1d34894c8a9a013fbc807971e336e7cc3f3ff233e61b9dc679b3b5c"}, + {file = "multiprocess-0.70.14-py37-none-any.whl", hash = "sha256:93a8208ca0926d05cdbb5b9250a604c401bed677579e96c14da3090beb798193"}, + {file = "multiprocess-0.70.14-py38-none-any.whl", hash = "sha256:6725bc79666bbd29a73ca148a0fb5f4ea22eed4a8f22fce58296492a02d18a7b"}, + {file = "multiprocess-0.70.14-py39-none-any.whl", hash = "sha256:63cee628b74a2c0631ef15da5534c8aedbc10c38910b9c8b18dcd327528d1ec7"}, + {file = "multiprocess-0.70.14.tar.gz", hash = "sha256:3eddafc12f2260d27ae03fe6069b12570ab4764ab59a75e81624fac453fbf46a"}, +] + +[package.dependencies] +dill = ">=0.3.6" + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +name = "nest-asyncio" +version = "1.5.6" +description = "Patch asyncio to allow nested event loops" +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.5.6-py3-none-any.whl", hash = "sha256:b9a953fb40dceaa587d109609098db21900182b16440652454a146cffb06e8b8"}, + {file = "nest_asyncio-1.5.6.tar.gz", hash = "sha256:d267cc1ff794403f7df692964d1d2a3fa9418ffea2a3f6859a439ff482fef290"}, +] + +[[package]] +name = "nodeenv" +version = "1.8.0" +description = "Node.js virtual environment builder" +category = "dev" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +files = [ + {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, + {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, +] + +[package.dependencies] +setuptools = "*" + +[[package]] +name = "numpy" +version = "1.25.1" +description = "Fundamental package for array computing in Python" +category = "main" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.25.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d339465dff3eb33c701430bcb9c325b60354698340229e1dff97745e6b3efa"}, + {file = "numpy-1.25.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d736b75c3f2cb96843a5c7f8d8ccc414768d34b0a75f466c05f3a739b406f10b"}, + {file = "numpy-1.25.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a90725800caeaa160732d6b31f3f843ebd45d6b5f3eec9e8cc287e30f2805bf"}, + {file = "numpy-1.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c6c9261d21e617c6dc5eacba35cb68ec36bb72adcff0dee63f8fbc899362588"}, + {file = "numpy-1.25.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0def91f8af6ec4bb94c370e38c575855bf1d0be8a8fbfba42ef9c073faf2cf19"}, + {file = "numpy-1.25.1-cp310-cp310-win32.whl", hash = "sha256:fd67b306320dcadea700a8f79b9e671e607f8696e98ec255915c0c6d6b818503"}, + {file = "numpy-1.25.1-cp310-cp310-win_amd64.whl", hash = "sha256:c1516db588987450b85595586605742879e50dcce923e8973f79529651545b57"}, + {file = "numpy-1.25.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6b82655dd8efeea69dbf85d00fca40013d7f503212bc5259056244961268b66e"}, + {file = "numpy-1.25.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e8f6049c4878cb16960fbbfb22105e49d13d752d4d8371b55110941fb3b17800"}, + {file = "numpy-1.25.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41a56b70e8139884eccb2f733c2f7378af06c82304959e174f8e7370af112e09"}, + {file = "numpy-1.25.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5154b1a25ec796b1aee12ac1b22f414f94752c5f94832f14d8d6c9ac40bcca6"}, + {file = "numpy-1.25.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38eb6548bb91c421261b4805dc44def9ca1a6eef6444ce35ad1669c0f1a3fc5d"}, + {file = "numpy-1.25.1-cp311-cp311-win32.whl", hash = "sha256:791f409064d0a69dd20579345d852c59822c6aa087f23b07b1b4e28ff5880fcb"}, + {file = "numpy-1.25.1-cp311-cp311-win_amd64.whl", hash = "sha256:c40571fe966393b212689aa17e32ed905924120737194b5d5c1b20b9ed0fb171"}, + {file = "numpy-1.25.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d7abcdd85aea3e6cdddb59af2350c7ab1ed764397f8eec97a038ad244d2d105"}, + {file = "numpy-1.25.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1a180429394f81c7933634ae49b37b472d343cccb5bb0c4a575ac8bbc433722f"}, + {file = "numpy-1.25.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d412c1697c3853c6fc3cb9751b4915859c7afe6a277c2bf00acf287d56c4e625"}, + {file = "numpy-1.25.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20e1266411120a4f16fad8efa8e0454d21d00b8c7cee5b5ccad7565d95eb42dd"}, + {file = "numpy-1.25.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f76aebc3358ade9eacf9bc2bb8ae589863a4f911611694103af05346637df1b7"}, + {file = "numpy-1.25.1-cp39-cp39-win32.whl", hash = "sha256:247d3ffdd7775bdf191f848be8d49100495114c82c2bd134e8d5d075fb386a1c"}, + {file = "numpy-1.25.1-cp39-cp39-win_amd64.whl", hash = "sha256:1d5d3c68e443c90b38fdf8ef40e60e2538a27548b39b12b73132456847f4b631"}, + {file = "numpy-1.25.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:35a9527c977b924042170a0887de727cd84ff179e478481404c5dc66b4170009"}, + {file = "numpy-1.25.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d3fe3dd0506a28493d82dc3cf254be8cd0d26f4008a417385cbf1ae95b54004"}, + {file = "numpy-1.25.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:012097b5b0d00a11070e8f2e261128c44157a8689f7dedcf35576e525893f4fe"}, + {file = "numpy-1.25.1.tar.gz", hash = "sha256:9a3a9f3a61480cc086117b426a8bd86869c213fc4072e606f01c4e4b66eb92bf"}, +] + +[[package]] +name = "opt-einsum" +version = "3.3.0" +description = "Optimizing numpys einsum function" +category = "main" +optional = false +python-versions = ">=3.5" +files = [ + {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, + {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, +] + +[package.dependencies] +numpy = ">=1.7" + +[package.extras] +docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] +tests = ["pytest", "pytest-cov", "pytest-pep8"] + +[[package]] +name = "optax" +version = "0.1.5" +description = "A gradient processing and optimisation library in JAX." +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "optax-0.1.5-py3-none-any.whl", hash = "sha256:4057461448abd1fccdefd5e6c7ebc6ea8daa3105041f2631d6efd506544ecde0"}, + {file = "optax-0.1.5.tar.gz", hash = "sha256:0aa379b56f51dbd525562f5ee6805a180a2616f3e9fe8080582352bcbb520f2e"}, +] + +[package.dependencies] +absl-py = ">=0.7.1" +chex = ">=0.1.5" +jax = ">=0.1.55" +jaxlib = ">=0.1.37" +numpy = ">=1.18.0" + +[[package]] +name = "orbax-checkpoint" +version = "0.2.7" +description = "Orbax Checkpoint" +category = "dev" +optional = false +python-versions = ">=3.9" +files = [ + {file = "orbax_checkpoint-0.2.7-py3-none-any.whl", hash = "sha256:8cbb09f82334aecb6f06d05c72ce2e78ecd114c7d59f086b24af99b83a8fea96"}, + {file = "orbax_checkpoint-0.2.7.tar.gz", hash = "sha256:8f23a301641ec33ea094c756629855daccab263e114c6f27cebcf6c4e3f0e90e"}, +] + +[package.dependencies] +absl-py = "*" +etils = {version = "*", extras = ["epath", "epy"]} +jax = ">=0.4.9" +jaxlib = "*" +msgpack = "*" +nest_asyncio = "*" +numpy = "*" +protobuf = "*" +pyyaml = "*" +tensorstore = ">=0.1.35" +typing_extensions = "*" + +[package.extras] +testing = ["flax", "pytest", "pytest-xdist"] + +[[package]] +name = "packaging" +version = "23.1" +description = "Core utilities for Python packages" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, + {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, +] + +[[package]] +name = "pandas" +version = "2.0.3" +description = "Powerful data structures for data analysis, time series, and statistics" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, + {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"}, + {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"}, + {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"}, + {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"}, + {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"}, + {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"}, + {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"}, + {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"}, + {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"}, + {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.1" + +[package.extras] +all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] +aws = ["s3fs (>=2021.08.0)"] +clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] +compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] +computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"] +feather = ["pyarrow (>=7.0.0)"] +fss = ["fsspec (>=2021.07.0)"] +gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"] +hdf5 = ["tables (>=3.6.1)"] +html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"] +mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"] +parquet = ["pyarrow (>=7.0.0)"] +performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"] +plot = ["matplotlib (>=3.6.1)"] +postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"] +spss = ["pyreadstat (>=1.1.2)"] +sql-other = ["SQLAlchemy (>=1.4.16)"] +test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.6.3)"] + +[[package]] +name = "parso" +version = "0.8.3" +description = "A Python Parser" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, + {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, +] + +[package.extras] +qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +testing = ["docopt", "pytest (<6.0.0)"] + +[[package]] +name = "pathspec" +version = "0.11.1" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, + {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, +] + +[[package]] +name = "pexpect" +version = "4.8.0" +description = "Pexpect allows easy control of interactive console applications." +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, + {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, +] + +[package.dependencies] +ptyprocess = ">=0.5" + +[[package]] +name = "pickleshare" +version = "0.7.5" +description = "Tiny 'shelve'-like database with concurrency support" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, + {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, +] + +[[package]] +name = "pillow" +version = "10.0.0" +description = "Python Imaging Library (Fork)" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "Pillow-10.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1f62406a884ae75fb2f818694469519fb685cc7eaff05d3451a9ebe55c646891"}, + {file = "Pillow-10.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d5db32e2a6ccbb3d34d87c87b432959e0db29755727afb37290e10f6e8e62614"}, + {file = "Pillow-10.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edf4392b77bdc81f36e92d3a07a5cd072f90253197f4a52a55a8cec48a12483b"}, + {file = "Pillow-10.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:520f2a520dc040512699f20fa1c363eed506e94248d71f85412b625026f6142c"}, + {file = "Pillow-10.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:8c11160913e3dd06c8ffdb5f233a4f254cb449f4dfc0f8f4549eda9e542c93d1"}, + {file = "Pillow-10.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a74ba0c356aaa3bb8e3eb79606a87669e7ec6444be352870623025d75a14a2bf"}, + {file = "Pillow-10.0.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d5d0dae4cfd56969d23d94dc8e89fb6a217be461c69090768227beb8ed28c0a3"}, + {file = "Pillow-10.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:22c10cc517668d44b211717fd9775799ccec4124b9a7f7b3635fc5386e584992"}, + {file = "Pillow-10.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:dffe31a7f47b603318c609f378ebcd57f1554a3a6a8effbc59c3c69f804296de"}, + {file = "Pillow-10.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:9fb218c8a12e51d7ead2a7c9e101a04982237d4855716af2e9499306728fb485"}, + {file = "Pillow-10.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d35e3c8d9b1268cbf5d3670285feb3528f6680420eafe35cccc686b73c1e330f"}, + {file = "Pillow-10.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ed64f9ca2f0a95411e88a4efbd7a29e5ce2cea36072c53dd9d26d9c76f753b3"}, + {file = "Pillow-10.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b6eb5502f45a60a3f411c63187db83a3d3107887ad0d036c13ce836f8a36f1d"}, + {file = "Pillow-10.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c1fbe7621c167ecaa38ad29643d77a9ce7311583761abf7836e1510c580bf3dd"}, + {file = "Pillow-10.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:cd25d2a9d2b36fcb318882481367956d2cf91329f6892fe5d385c346c0649629"}, + {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3b08d4cc24f471b2c8ca24ec060abf4bebc6b144cb89cba638c720546b1cf538"}, + {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d737a602fbd82afd892ca746392401b634e278cb65d55c4b7a8f48e9ef8d008d"}, + {file = "Pillow-10.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:3a82c40d706d9aa9734289740ce26460a11aeec2d9c79b7af87bb35f0073c12f"}, + {file = "Pillow-10.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:bc2ec7c7b5d66b8ec9ce9f720dbb5fa4bace0f545acd34870eff4a369b44bf37"}, + {file = "Pillow-10.0.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:d80cf684b541685fccdd84c485b31ce73fc5c9b5d7523bf1394ce134a60c6883"}, + {file = "Pillow-10.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76de421f9c326da8f43d690110f0e79fe3ad1e54be811545d7d91898b4c8493e"}, + {file = "Pillow-10.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81ff539a12457809666fef6624684c008e00ff6bf455b4b89fd00a140eecd640"}, + {file = "Pillow-10.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce543ed15570eedbb85df19b0a1a7314a9c8141a36ce089c0a894adbfccb4568"}, + {file = "Pillow-10.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:685ac03cc4ed5ebc15ad5c23bc555d68a87777586d970c2c3e216619a5476223"}, + {file = "Pillow-10.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d72e2ecc68a942e8cf9739619b7f408cc7b272b279b56b2c83c6123fcfa5cdff"}, + {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d50b6aec14bc737742ca96e85d6d0a5f9bfbded018264b3b70ff9d8c33485551"}, + {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:00e65f5e822decd501e374b0650146063fbb30a7264b4d2744bdd7b913e0cab5"}, + {file = "Pillow-10.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:f31f9fdbfecb042d046f9d91270a0ba28368a723302786c0009ee9b9f1f60199"}, + {file = "Pillow-10.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:1ce91b6ec08d866b14413d3f0bbdea7e24dfdc8e59f562bb77bc3fe60b6144ca"}, + {file = "Pillow-10.0.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:349930d6e9c685c089284b013478d6f76e3a534e36ddfa912cde493f235372f3"}, + {file = "Pillow-10.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a684105f7c32488f7153905a4e3015a3b6c7182e106fe3c37fbb5ef3e6994c3"}, + {file = "Pillow-10.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4f69b3700201b80bb82c3a97d5e9254084f6dd5fb5b16fc1a7b974260f89f43"}, + {file = "Pillow-10.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f07ea8d2f827d7d2a49ecf1639ec02d75ffd1b88dcc5b3a61bbb37a8759ad8d"}, + {file = "Pillow-10.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:040586f7d37b34547153fa383f7f9aed68b738992380ac911447bb78f2abe530"}, + {file = "Pillow-10.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f88a0b92277de8e3ca715a0d79d68dc82807457dae3ab8699c758f07c20b3c51"}, + {file = "Pillow-10.0.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c7cf14a27b0d6adfaebb3ae4153f1e516df54e47e42dcc073d7b3d76111a8d86"}, + {file = "Pillow-10.0.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:3400aae60685b06bb96f99a21e1ada7bc7a413d5f49bce739828ecd9391bb8f7"}, + {file = "Pillow-10.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:dbc02381779d412145331789b40cc7b11fdf449e5d94f6bc0b080db0a56ea3f0"}, + {file = "Pillow-10.0.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:9211e7ad69d7c9401cfc0e23d49b69ca65ddd898976d660a2fa5904e3d7a9baa"}, + {file = "Pillow-10.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:faaf07ea35355b01a35cb442dd950d8f1bb5b040a7787791a535de13db15ed90"}, + {file = "Pillow-10.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9f72a021fbb792ce98306ffb0c348b3c9cb967dce0f12a49aa4c3d3fdefa967"}, + {file = "Pillow-10.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f7c16705f44e0504a3a2a14197c1f0b32a95731d251777dcb060aa83022cb2d"}, + {file = "Pillow-10.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:76edb0a1fa2b4745fb0c99fb9fb98f8b180a1bbceb8be49b087e0b21867e77d3"}, + {file = "Pillow-10.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:368ab3dfb5f49e312231b6f27b8820c823652b7cd29cfbd34090565a015e99ba"}, + {file = "Pillow-10.0.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:608bfdee0d57cf297d32bcbb3c728dc1da0907519d1784962c5f0c68bb93e5a3"}, + {file = "Pillow-10.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5c6e3df6bdd396749bafd45314871b3d0af81ff935b2d188385e970052091017"}, + {file = "Pillow-10.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:7be600823e4c8631b74e4a0d38384c73f680e6105a7d3c6824fcf226c178c7e6"}, + {file = "Pillow-10.0.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:92be919bbc9f7d09f7ae343c38f5bb21c973d2576c1d45600fce4b74bafa7ac0"}, + {file = "Pillow-10.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8182b523b2289f7c415f589118228d30ac8c355baa2f3194ced084dac2dbba"}, + {file = "Pillow-10.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:38250a349b6b390ee6047a62c086d3817ac69022c127f8a5dc058c31ccef17f3"}, + {file = "Pillow-10.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:88af2003543cc40c80f6fca01411892ec52b11021b3dc22ec3bc9d5afd1c5334"}, + {file = "Pillow-10.0.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:c189af0545965fa8d3b9613cfdb0cd37f9d71349e0f7750e1fd704648d475ed2"}, + {file = "Pillow-10.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce7b031a6fc11365970e6a5686d7ba8c63e4c1cf1ea143811acbb524295eabed"}, + {file = "Pillow-10.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:db24668940f82321e746773a4bc617bfac06ec831e5c88b643f91f122a785684"}, + {file = "Pillow-10.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:efe8c0681042536e0d06c11f48cebe759707c9e9abf880ee213541c5b46c5bf3"}, + {file = "Pillow-10.0.0.tar.gz", hash = "sha256:9c82b5b3e043c7af0d95792d0d20ccf68f61a1fec6b3530e718b688422727396"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "platformdirs" +version = "3.9.1" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "platformdirs-3.9.1-py3-none-any.whl", hash = "sha256:ad8291ae0ae5072f66c16945166cb11c63394c7a3ad1b1bc9828ca3162da8c2f"}, + {file = "platformdirs-3.9.1.tar.gz", hash = "sha256:1b42b450ad933e981d56e59f1b97495428c9bd60698baab9f3eb3d00d5822421"}, +] + +[package.extras] +docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"] + +[[package]] +name = "pluggy" +version = "1.2.0" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, + {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pre-commit" +version = "3.3.3" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pre_commit-3.3.3-py2.py3-none-any.whl", hash = "sha256:10badb65d6a38caff29703362271d7dca483d01da88f9d7e05d0b97171c136cb"}, + {file = "pre_commit-3.3.3.tar.gz", hash = "sha256:a2256f489cd913d575c145132ae196fe335da32d91a8294b7afe6622335dd023"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + +[[package]] +name = "prompt-toolkit" +version = "3.0.39" +description = "Library for building powerful interactive command lines in Python" +category = "dev" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, + {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, +] + +[package.dependencies] +wcwidth = "*" + +[[package]] +name = "protobuf" +version = "4.23.4" +description = "" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "protobuf-4.23.4-cp310-abi3-win32.whl", hash = "sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b"}, + {file = "protobuf-4.23.4-cp310-abi3-win_amd64.whl", hash = "sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12"}, + {file = "protobuf-4.23.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd"}, + {file = "protobuf-4.23.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a"}, + {file = "protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597"}, + {file = "protobuf-4.23.4-cp37-cp37m-win32.whl", hash = "sha256:c3e0939433c40796ca4cfc0fac08af50b00eb66a40bbbc5dee711998fb0bbc1e"}, + {file = "protobuf-4.23.4-cp37-cp37m-win_amd64.whl", hash = "sha256:9053df6df8e5a76c84339ee4a9f5a2661ceee4a0dab019e8663c50ba324208b0"}, + {file = "protobuf-4.23.4-cp38-cp38-win32.whl", hash = "sha256:e1c915778d8ced71e26fcf43c0866d7499891bca14c4368448a82edc61fdbc70"}, + {file = "protobuf-4.23.4-cp38-cp38-win_amd64.whl", hash = "sha256:351cc90f7d10839c480aeb9b870a211e322bf05f6ab3f55fcb2f51331f80a7d2"}, + {file = "protobuf-4.23.4-cp39-cp39-win32.whl", hash = "sha256:6dd9b9940e3f17077e820b75851126615ee38643c2c5332aa7a359988820c720"}, + {file = "protobuf-4.23.4-cp39-cp39-win_amd64.whl", hash = "sha256:0a5759f5696895de8cc913f084e27fd4125e8fb0914bb729a17816a33819f474"}, + {file = "protobuf-4.23.4-py3-none-any.whl", hash = "sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff"}, + {file = "protobuf-4.23.4.tar.gz", hash = "sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9"}, +] + +[[package]] +name = "psutil" +version = "5.9.5" +description = "Cross-platform lib for process and system monitoring in Python." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"}, + {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"}, + {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"}, + {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"}, + {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"}, + {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"}, + {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, + {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +description = "Run a subprocess in a pseudo terminal" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.2" +description = "Safely evaluate AST nodes without side effects" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, + {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, +] + +[package.extras] +tests = ["pytest"] + +[[package]] +name = "pyarrow" +version = "12.0.1" +description = "Python library for Apache Arrow" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyarrow-12.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:6d288029a94a9bb5407ceebdd7110ba398a00412c5b0155ee9813a40d246c5df"}, + {file = "pyarrow-12.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:345e1828efdbd9aa4d4de7d5676778aba384a2c3add896d995b23d368e60e5af"}, + {file = "pyarrow-12.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d6009fdf8986332b2169314da482baed47ac053311c8934ac6651e614deacd6"}, + {file = "pyarrow-12.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d3c4cbbf81e6dd23fe921bc91dc4619ea3b79bc58ef10bce0f49bdafb103daf"}, + {file = "pyarrow-12.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:cdacf515ec276709ac8042c7d9bd5be83b4f5f39c6c037a17a60d7ebfd92c890"}, + {file = "pyarrow-12.0.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:749be7fd2ff260683f9cc739cb862fb11be376de965a2a8ccbf2693b098db6c7"}, + {file = "pyarrow-12.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6895b5fb74289d055c43db3af0de6e16b07586c45763cb5e558d38b86a91e3a7"}, + {file = "pyarrow-12.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1887bdae17ec3b4c046fcf19951e71b6a619f39fa674f9881216173566c8f718"}, + {file = "pyarrow-12.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2c9cb8eeabbadf5fcfc3d1ddea616c7ce893db2ce4dcef0ac13b099ad7ca082"}, + {file = "pyarrow-12.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:ce4aebdf412bd0eeb800d8e47db854f9f9f7e2f5a0220440acf219ddfddd4f63"}, + {file = "pyarrow-12.0.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:e0d8730c7f6e893f6db5d5b86eda42c0a130842d101992b581e2138e4d5663d3"}, + {file = "pyarrow-12.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43364daec02f69fec89d2315f7fbfbeec956e0d991cbbef471681bd77875c40f"}, + {file = "pyarrow-12.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:051f9f5ccf585f12d7de836e50965b3c235542cc896959320d9776ab93f3b33d"}, + {file = "pyarrow-12.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:be2757e9275875d2a9c6e6052ac7957fbbfc7bc7370e4a036a9b893e96fedaba"}, + {file = "pyarrow-12.0.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:cf812306d66f40f69e684300f7af5111c11f6e0d89d6b733e05a3de44961529d"}, + {file = "pyarrow-12.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:459a1c0ed2d68671188b2118c63bac91eaef6fc150c77ddd8a583e3c795737bf"}, + {file = "pyarrow-12.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85e705e33eaf666bbe508a16fd5ba27ca061e177916b7a317ba5a51bee43384c"}, + {file = "pyarrow-12.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9120c3eb2b1f6f516a3b7a9714ed860882d9ef98c4b17edcdc91d95b7528db60"}, + {file = "pyarrow-12.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:c780f4dc40460015d80fcd6a6140de80b615349ed68ef9adb653fe351778c9b3"}, + {file = "pyarrow-12.0.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:a3c63124fc26bf5f95f508f5d04e1ece8cc23a8b0af2a1e6ab2b1ec3fdc91b24"}, + {file = "pyarrow-12.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b13329f79fa4472324f8d32dc1b1216616d09bd1e77cfb13104dec5463632c36"}, + {file = "pyarrow-12.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb656150d3d12ec1396f6dde542db1675a95c0cc8366d507347b0beed96e87ca"}, + {file = "pyarrow-12.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6251e38470da97a5b2e00de5c6a049149f7b2bd62f12fa5dbb9ac674119ba71a"}, + {file = "pyarrow-12.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:3de26da901216149ce086920547dfff5cd22818c9eab67ebc41e863a5883bac7"}, + {file = "pyarrow-12.0.1.tar.gz", hash = "sha256:cce317fc96e5b71107bf1f9f184d5e54e2bd14bbf3f9a3d62819961f0af86fec"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + +[[package]] +name = "pycparser" +version = "2.21" +description = "C parser in Python" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, + {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, +] + +[[package]] +name = "pygments" +version = "2.15.1" +description = "Pygments is a syntax highlighting package written in Python." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "Pygments-2.15.1-py3-none-any.whl", hash = "sha256:db2db3deb4b4179f399a09054b023b6a586b76499d36965813c71aa8ed7b5fd1"}, + {file = "Pygments-2.15.1.tar.gz", hash = "sha256:8ace4d3c1dd481894b2005f560ead0f9f19ee64fe983366be1a21e171d12775c"}, +] + +[package.extras] +plugins = ["importlib-metadata"] + +[[package]] +name = "pyink" +version = "23.3.0" +description = "Pyink is a python formatter, forked from Black with slightly different behavior." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyink-23.3.0-py3-none-any.whl", hash = "sha256:514d234bcea150b43530061ce20b1f6e3b82e847cfb6635fb13aebb4d57628b8"}, + {file = "pyink-23.3.0.tar.gz", hash = "sha256:2ee26ad46a26dfee79e748019f3d452be26009464383c82f054a0fb71890e417"}, +] + +[package.dependencies] +black = ">=23.3.0" +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "pyparsing" +version = "3.0.9" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "dev" +optional = false +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, + {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + +[[package]] +name = "pytest" +version = "7.4.0" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, + {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytz" +version = "2023.3" +description = "World timezone definitions, modern and historical" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2023.3-py2.py3-none-any.whl", hash = "sha256:a151b3abb88eda1d4e34a9814df37de2a80e301e68ba0fd856fb9b46bfbbbffb"}, + {file = "pytz-2023.3.tar.gz", hash = "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588"}, +] + +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] +name = "pyzmq" +version = "25.1.0" +description = "Python bindings for 0MQ" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pyzmq-25.1.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:1a6169e69034eaa06823da6a93a7739ff38716142b3596c180363dee729d713d"}, + {file = "pyzmq-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:19d0383b1f18411d137d891cab567de9afa609b214de68b86e20173dc624c101"}, + {file = "pyzmq-25.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1e931d9a92f628858a50f5bdffdfcf839aebe388b82f9d2ccd5d22a38a789dc"}, + {file = "pyzmq-25.1.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:97d984b1b2f574bc1bb58296d3c0b64b10e95e7026f8716ed6c0b86d4679843f"}, + {file = "pyzmq-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:154bddda2a351161474b36dba03bf1463377ec226a13458725183e508840df89"}, + {file = "pyzmq-25.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:cb6d161ae94fb35bb518b74bb06b7293299c15ba3bc099dccd6a5b7ae589aee3"}, + {file = "pyzmq-25.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:90146ab578931e0e2826ee39d0c948d0ea72734378f1898939d18bc9c823fcf9"}, + {file = "pyzmq-25.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:831ba20b660b39e39e5ac8603e8193f8fce1ee03a42c84ade89c36a251449d80"}, + {file = "pyzmq-25.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3a522510e3434e12aff80187144c6df556bb06fe6b9d01b2ecfbd2b5bfa5c60c"}, + {file = "pyzmq-25.1.0-cp310-cp310-win32.whl", hash = "sha256:be24a5867b8e3b9dd5c241de359a9a5217698ff616ac2daa47713ba2ebe30ad1"}, + {file = "pyzmq-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:5693dcc4f163481cf79e98cf2d7995c60e43809e325b77a7748d8024b1b7bcba"}, + {file = "pyzmq-25.1.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:13bbe36da3f8aaf2b7ec12696253c0bf6ffe05f4507985a8844a1081db6ec22d"}, + {file = "pyzmq-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:69511d604368f3dc58d4be1b0bad99b61ee92b44afe1cd9b7bd8c5e34ea8248a"}, + {file = "pyzmq-25.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a983c8694667fd76d793ada77fd36c8317e76aa66eec75be2653cef2ea72883"}, + {file = "pyzmq-25.1.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:332616f95eb400492103ab9d542b69d5f0ff628b23129a4bc0a2fd48da6e4e0b"}, + {file = "pyzmq-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58416db767787aedbfd57116714aad6c9ce57215ffa1c3758a52403f7c68cff5"}, + {file = "pyzmq-25.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:cad9545f5801a125f162d09ec9b724b7ad9b6440151b89645241d0120e119dcc"}, + {file = "pyzmq-25.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d6128d431b8dfa888bf51c22a04d48bcb3d64431caf02b3cb943269f17fd2994"}, + {file = "pyzmq-25.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:2b15247c49d8cbea695b321ae5478d47cffd496a2ec5ef47131a9e79ddd7e46c"}, + {file = "pyzmq-25.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:442d3efc77ca4d35bee3547a8e08e8d4bb88dadb54a8377014938ba98d2e074a"}, + {file = "pyzmq-25.1.0-cp311-cp311-win32.whl", hash = "sha256:65346f507a815a731092421d0d7d60ed551a80d9b75e8b684307d435a5597425"}, + {file = "pyzmq-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:8b45d722046fea5a5694cba5d86f21f78f0052b40a4bbbbf60128ac55bfcc7b6"}, + {file = "pyzmq-25.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f45808eda8b1d71308c5416ef3abe958f033fdbb356984fabbfc7887bed76b3f"}, + {file = "pyzmq-25.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b697774ea8273e3c0460cf0bba16cd85ca6c46dfe8b303211816d68c492e132"}, + {file = "pyzmq-25.1.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b324fa769577fc2c8f5efcd429cef5acbc17d63fe15ed16d6dcbac2c5eb00849"}, + {file = "pyzmq-25.1.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:5873d6a60b778848ce23b6c0ac26c39e48969823882f607516b91fb323ce80e5"}, + {file = "pyzmq-25.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:f0d9e7ba6a815a12c8575ba7887da4b72483e4cfc57179af10c9b937f3f9308f"}, + {file = "pyzmq-25.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:414b8beec76521358b49170db7b9967d6974bdfc3297f47f7d23edec37329b00"}, + {file = "pyzmq-25.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:01f06f33e12497dca86353c354461f75275a5ad9eaea181ac0dc1662da8074fa"}, + {file = "pyzmq-25.1.0-cp36-cp36m-win32.whl", hash = "sha256:b5a07c4f29bf7cb0164664ef87e4aa25435dcc1f818d29842118b0ac1eb8e2b5"}, + {file = "pyzmq-25.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:968b0c737797c1809ec602e082cb63e9824ff2329275336bb88bd71591e94a90"}, + {file = "pyzmq-25.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:47b915ba666c51391836d7ed9a745926b22c434efa76c119f77bcffa64d2c50c"}, + {file = "pyzmq-25.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5af31493663cf76dd36b00dafbc839e83bbca8a0662931e11816d75f36155897"}, + {file = "pyzmq-25.1.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5489738a692bc7ee9a0a7765979c8a572520d616d12d949eaffc6e061b82b4d1"}, + {file = "pyzmq-25.1.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1fc56a0221bdf67cfa94ef2d6ce5513a3d209c3dfd21fed4d4e87eca1822e3a3"}, + {file = "pyzmq-25.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:75217e83faea9edbc29516fc90c817bc40c6b21a5771ecb53e868e45594826b0"}, + {file = "pyzmq-25.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:3830be8826639d801de9053cf86350ed6742c4321ba4236e4b5568528d7bfed7"}, + {file = "pyzmq-25.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3575699d7fd7c9b2108bc1c6128641a9a825a58577775ada26c02eb29e09c517"}, + {file = "pyzmq-25.1.0-cp37-cp37m-win32.whl", hash = "sha256:95bd3a998d8c68b76679f6b18f520904af5204f089beebb7b0301d97704634dd"}, + {file = "pyzmq-25.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:dbc466744a2db4b7ca05589f21ae1a35066afada2f803f92369f5877c100ef62"}, + {file = "pyzmq-25.1.0-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:3bed53f7218490c68f0e82a29c92335daa9606216e51c64f37b48eb78f1281f4"}, + {file = "pyzmq-25.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:eb52e826d16c09ef87132c6e360e1879c984f19a4f62d8a935345deac43f3c12"}, + {file = "pyzmq-25.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ddbef8b53cd16467fdbfa92a712eae46dd066aa19780681a2ce266e88fbc7165"}, + {file = "pyzmq-25.1.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9301cf1d7fc1ddf668d0abbe3e227fc9ab15bc036a31c247276012abb921b5ff"}, + {file = "pyzmq-25.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e23a8c3b6c06de40bdb9e06288180d630b562db8ac199e8cc535af81f90e64b"}, + {file = "pyzmq-25.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4a82faae00d1eed4809c2f18b37f15ce39a10a1c58fe48b60ad02875d6e13d80"}, + {file = "pyzmq-25.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c8398a1b1951aaa330269c35335ae69744be166e67e0ebd9869bdc09426f3871"}, + {file = "pyzmq-25.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d40682ac60b2a613d36d8d3a0cd14fbdf8e7e0618fbb40aa9fa7b796c9081584"}, + {file = "pyzmq-25.1.0-cp38-cp38-win32.whl", hash = "sha256:33d5c8391a34d56224bccf74f458d82fc6e24b3213fc68165c98b708c7a69325"}, + {file = "pyzmq-25.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:c66b7ff2527e18554030319b1376d81560ca0742c6e0b17ff1ee96624a5f1afd"}, + {file = "pyzmq-25.1.0-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:af56229ea6527a849ac9fb154a059d7e32e77a8cba27e3e62a1e38d8808cb1a5"}, + {file = "pyzmq-25.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bdca18b94c404af6ae5533cd1bc310c4931f7ac97c148bbfd2cd4bdd62b96253"}, + {file = "pyzmq-25.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0b6b42f7055bbc562f63f3df3b63e3dd1ebe9727ff0f124c3aa7bcea7b3a00f9"}, + {file = "pyzmq-25.1.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4c2fc7aad520a97d64ffc98190fce6b64152bde57a10c704b337082679e74f67"}, + {file = "pyzmq-25.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be86a26415a8b6af02cd8d782e3a9ae3872140a057f1cadf0133de685185c02b"}, + {file = "pyzmq-25.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:851fb2fe14036cfc1960d806628b80276af5424db09fe5c91c726890c8e6d943"}, + {file = "pyzmq-25.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2a21fec5c3cea45421a19ccbe6250c82f97af4175bc09de4d6dd78fb0cb4c200"}, + {file = "pyzmq-25.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bad172aba822444b32eae54c2d5ab18cd7dee9814fd5c7ed026603b8cae2d05f"}, + {file = "pyzmq-25.1.0-cp39-cp39-win32.whl", hash = "sha256:4d67609b37204acad3d566bb7391e0ecc25ef8bae22ff72ebe2ad7ffb7847158"}, + {file = "pyzmq-25.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:71c7b5896e40720d30cd77a81e62b433b981005bbff0cb2f739e0f8d059b5d99"}, + {file = "pyzmq-25.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4cb27ef9d3bdc0c195b2dc54fcb8720e18b741624686a81942e14c8b67cc61a6"}, + {file = "pyzmq-25.1.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0c4fc2741e0513b5d5a12fe200d6785bbcc621f6f2278893a9ca7bed7f2efb7d"}, + {file = "pyzmq-25.1.0-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fc34fdd458ff77a2a00e3c86f899911f6f269d393ca5675842a6e92eea565bae"}, + {file = "pyzmq-25.1.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8751f9c1442624da391bbd92bd4b072def6d7702a9390e4479f45c182392ff78"}, + {file = "pyzmq-25.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:6581e886aec3135964a302a0f5eb68f964869b9efd1dbafdebceaaf2934f8a68"}, + {file = "pyzmq-25.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5482f08d2c3c42b920e8771ae8932fbaa0a67dff925fc476996ddd8155a170f3"}, + {file = "pyzmq-25.1.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7fbcafa3ea16d1de1f213c226005fea21ee16ed56134b75b2dede5a2129e62"}, + {file = "pyzmq-25.1.0-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:adecf6d02b1beab8d7c04bc36f22bb0e4c65a35eb0b4750b91693631d4081c70"}, + {file = "pyzmq-25.1.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6d39e42a0aa888122d1beb8ec0d4ddfb6c6b45aecb5ba4013c27e2f28657765"}, + {file = "pyzmq-25.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7018289b402ebf2b2c06992813523de61d4ce17bd514c4339d8f27a6f6809492"}, + {file = "pyzmq-25.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9e68ae9864d260b18f311b68d29134d8776d82e7f5d75ce898b40a88df9db30f"}, + {file = "pyzmq-25.1.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e21cc00e4debe8f54c3ed7b9fcca540f46eee12762a9fa56feb8512fd9057161"}, + {file = "pyzmq-25.1.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f666ae327a6899ff560d741681fdcdf4506f990595201ed39b44278c471ad98"}, + {file = "pyzmq-25.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f5efcc29056dfe95e9c9db0dfbb12b62db9c4ad302f812931b6d21dd04a9119"}, + {file = "pyzmq-25.1.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:48e5e59e77c1a83162ab3c163fc01cd2eebc5b34560341a67421b09be0891287"}, + {file = "pyzmq-25.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:108c96ebbd573d929740d66e4c3d1bdf31d5cde003b8dc7811a3c8c5b0fc173b"}, + {file = "pyzmq-25.1.0.tar.gz", hash = "sha256:80c41023465d36280e801564a69cbfce8ae85ff79b080e1913f6e90481fb8957"}, +] + +[package.dependencies] +cffi = {version = "*", markers = "implementation_name == \"pypy\""} + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "rich" +version = "13.4.2" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "dev" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.4.2-py3-none-any.whl", hash = "sha256:8f87bc7ee54675732fa66a05ebfe489e27264caeeff3728c945d25971b6485ec"}, + {file = "rich-13.4.2.tar.gz", hash = "sha256:d653d6bccede5844304c605d5aac802c7cf9621efd700b46c7ec2b51ea914898"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + +[[package]] +name = "scipy" +version = "1.11.1" +description = "Fundamental algorithms for scientific computing in Python" +category = "main" +optional = false +python-versions = "<3.13,>=3.9" +files = [ + {file = "scipy-1.11.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:aec8c62fbe52914f9cf28d846cf0401dd80ab80788bbab909434eb336ed07c04"}, + {file = "scipy-1.11.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3b9963798df1d8a52db41a6fc0e6fa65b1c60e85d73da27ae8bb754de4792481"}, + {file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e8eb42db36526b130dfbc417609498a6192381abc1975b91e3eb238e0b41c1a"}, + {file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:366a6a937110d80dca4f63b3f5b00cc89d36f678b2d124a01067b154e692bab1"}, + {file = "scipy-1.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:08d957ca82d3535b3b9ba6c8ff355d78fe975271874e2af267cb5add5bd78625"}, + {file = "scipy-1.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:e866514bc2d660608447b6ba95c8900d591f2865c07cca0aa4f7ff3c4ca70f30"}, + {file = "scipy-1.11.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba94eeef3c9caa4cea7b402a35bb02a5714ee1ee77eb98aca1eed4543beb0f4c"}, + {file = "scipy-1.11.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:512fdc18c65f76dadaca139348e525646d440220d8d05f6d21965b8d4466bccd"}, + {file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cce154372f0ebe88556ed06d7b196e9c2e0c13080ecb58d0f35062dc7cc28b47"}, + {file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4bb943010203465ac81efa392e4645265077b4d9e99b66cf3ed33ae12254173"}, + {file = "scipy-1.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:249cfa465c379c9bb2c20123001e151ff5e29b351cbb7f9c91587260602c58d0"}, + {file = "scipy-1.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:ffb28e3fa31b9c376d0fb1f74c1f13911c8c154a760312fbee87a21eb21efe31"}, + {file = "scipy-1.11.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:39154437654260a52871dfde852adf1b93b1d1bc5dc0ffa70068f16ec0be2624"}, + {file = "scipy-1.11.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b588311875c58d1acd4ef17c983b9f1ab5391755a47c3d70b6bd503a45bfaf71"}, + {file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d51565560565a0307ed06fa0ec4c6f21ff094947d4844d6068ed04400c72d0c3"}, + {file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b41a0f322b4eb51b078cb3441e950ad661ede490c3aca66edef66f4b37ab1877"}, + {file = "scipy-1.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:396fae3f8c12ad14c5f3eb40499fd06a6fef8393a6baa352a652ecd51e74e029"}, + {file = "scipy-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:be8c962a821957fdde8c4044efdab7a140c13294997a407eaee777acf63cbf0c"}, + {file = "scipy-1.11.1.tar.gz", hash = "sha256:fb5b492fa035334fd249f0973cc79ecad8b09c604b42a127a677b45a9a3d4289"}, +] + +[package.dependencies] +numpy = ">=1.21.6,<1.28.0" + +[package.extras] +dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] +test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + +[[package]] +name = "setuptools" +version = "68.0.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "setuptools-68.0.0-py3-none-any.whl", hash = "sha256:11e52c67415a381d10d6b462ced9cfb97066179f0e871399e006c4ab101fc85f"}, + {file = "setuptools-68.0.0.tar.gz", hash = "sha256:baf1fdb41c6da4cd2eae722e135500da913332ab3f2f5c7d33af9b492acb5235"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "stack-data" +version = "0.6.2" +description = "Extract data from python stack frames and tracebacks for informative displays" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "stack_data-0.6.2-py3-none-any.whl", hash = "sha256:cbb2a53eb64e5785878201a97ed7c7b94883f48b87bfb0bbe8b623c74679e4a8"}, + {file = "stack_data-0.6.2.tar.gz", hash = "sha256:32d2dd0376772d01b6cb9fc996f3c8b57a357089dec328ed4b6553d037eaf815"}, +] + +[package.dependencies] +asttokens = ">=2.1.0" +executing = ">=1.2.0" +pure-eval = "*" + +[package.extras] +tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] + +[[package]] +name = "tensorstore" +version = "0.1.40" +description = "Read and write large, multi-dimensional arrays" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tensorstore-0.1.40-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:9a184f8b3a6392e1b7f209bd8932328a2df086625f75d8bc9d21ddff6e909614"}, + {file = "tensorstore-0.1.40-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e93b5ae2bbb54e7a1b15ae055d6cfc7e6f6b0311741b4f58d8f8ff29a40f5671"}, + {file = "tensorstore-0.1.40-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66ecb6cc12d5f79d0231b76da4b369216aa75c7c368e48eb646e3c786b66e36c"}, + {file = "tensorstore-0.1.40-cp310-cp310-win_amd64.whl", hash = "sha256:69e3f3352843fe4c3334a7634aea8b60e53263639d06383edbd9fc63f86293cd"}, + {file = "tensorstore-0.1.40-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:79ef100f26e731c12fc4872efaaa4bee505b918104731b3521a55a2974281a51"}, + {file = "tensorstore-0.1.40-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:49bd2dbfd64f409e27f5bd326e98c95a9ade40f66f482e5277b097b94c070a7d"}, + {file = "tensorstore-0.1.40-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:060e63f10079cfc26ee1b110b190e88c91d3278528405d6991de266bcdd273fb"}, + {file = "tensorstore-0.1.40-cp311-cp311-win_amd64.whl", hash = "sha256:26ab0b2dfefbb3c5bb12ae8e27ca273f8b15069fad344bde615a5e2dbdf1171d"}, + {file = "tensorstore-0.1.40-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:dc3242c372786660066ede40340dfd7984de86f3c841ad4d8d1eb82a1d80affb"}, + {file = "tensorstore-0.1.40-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81224f28022dfa186e8fbcc5d0b9859c2b2b3f284c5f3db42d6cc958cb83dab5"}, + {file = "tensorstore-0.1.40-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7751e74b9eb571d827d4bbce6f7c5b349343e45ee1962d934019ef57bc79e73"}, + {file = "tensorstore-0.1.40-cp38-cp38-win_amd64.whl", hash = "sha256:f32ce021112fba4efb096e0ebe16733b8786d33a031e3f0ec289b86d426e25c3"}, + {file = "tensorstore-0.1.40-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:33a84480f6f462fb456a6b4a5ad42e08cc030ece719d9b305c6cb74aa4a70717"}, + {file = "tensorstore-0.1.40-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d9e25a1a4a32fc6dc99f1b1f04d2a46a3e62234de176f160a59cfa09cab158eb"}, + {file = "tensorstore-0.1.40-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ee0b3c07583de7908f451ea7ede6d8ec37b9e0b4947ddbc9e83e600bdfa6ce6"}, + {file = "tensorstore-0.1.40-cp39-cp39-win_amd64.whl", hash = "sha256:9dd7027591743a19965f3b87363b0c7725d6d8e52acb5a9ea19533dbcf40fa13"}, + {file = "tensorstore-0.1.40.tar.gz", hash = "sha256:41517dbd3919e5a5ee2be69b51bdd528b57c9b35f533e6fc83f6155a378fdf8a"}, +] + +[package.dependencies] +numpy = ">=1.16.0" + +[[package]] +name = "tokenize-rt" +version = "5.1.0" +description = "A wrapper around the stdlib `tokenize` which roundtrips." +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tokenize_rt-5.1.0-py2.py3-none-any.whl", hash = "sha256:9b7bb843e77dd6ed0be5564bfaaba200083911e0497841cd3e9235a6a9794d74"}, + {file = "tokenize_rt-5.1.0.tar.gz", hash = "sha256:08f0c2daa94c4052e53c2fcaa8e32585e6ae9bdfc800974092d031401694e002"}, +] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "toolz" +version = "0.12.0" +description = "List processing tools and functional utilities" +category = "main" +optional = false +python-versions = ">=3.5" +files = [ + {file = "toolz-0.12.0-py3-none-any.whl", hash = "sha256:2059bd4148deb1884bb0eb770a3cde70e7f954cfbbdc2285f1f2de01fd21eb6f"}, + {file = "toolz-0.12.0.tar.gz", hash = "sha256:88c570861c440ee3f2f6037c4654613228ff40c93a6c25e0eba70d17282c6194"}, +] + +[[package]] +name = "tornado" +version = "6.3.2" +description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "dev" +optional = false +python-versions = ">= 3.8" +files = [ + {file = "tornado-6.3.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:c367ab6c0393d71171123ca5515c61ff62fe09024fa6bf299cd1339dc9456829"}, + {file = "tornado-6.3.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b46a6ab20f5c7c1cb949c72c1994a4585d2eaa0be4853f50a03b5031e964fc7c"}, + {file = "tornado-6.3.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2de14066c4a38b4ecbbcd55c5cc4b5340eb04f1c5e81da7451ef555859c833f"}, + {file = "tornado-6.3.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:05615096845cf50a895026f749195bf0b10b8909f9be672f50b0fe69cba368e4"}, + {file = "tornado-6.3.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b17b1cf5f8354efa3d37c6e28fdfd9c1c1e5122f2cb56dac121ac61baa47cbe"}, + {file = "tornado-6.3.2-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:29e71c847a35f6e10ca3b5c2990a52ce38b233019d8e858b755ea6ce4dcdd19d"}, + {file = "tornado-6.3.2-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:834ae7540ad3a83199a8da8f9f2d383e3c3d5130a328889e4cc991acc81e87a0"}, + {file = "tornado-6.3.2-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:6a0848f1aea0d196a7c4f6772197cbe2abc4266f836b0aac76947872cd29b411"}, + {file = "tornado-6.3.2-cp38-abi3-win32.whl", hash = "sha256:7efcbcc30b7c654eb6a8c9c9da787a851c18f8ccd4a5a3a95b05c7accfa068d2"}, + {file = "tornado-6.3.2-cp38-abi3-win_amd64.whl", hash = "sha256:0c325e66c8123c606eea33084976c832aa4e766b7dff8aedd7587ea44a604cdf"}, + {file = "tornado-6.3.2.tar.gz", hash = "sha256:4b927c4f19b71e627b13f3db2324e4ae660527143f9e1f2e2fb404f3a187e2ba"}, +] + +[[package]] +name = "tqdm" +version = "4.65.0" +description = "Fast, Extensible Progress Meter" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"}, + {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["py-make (>=0.1.0)", "twine", "wheel"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "traitlets" +version = "5.9.0" +description = "Traitlets Python configuration system" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "traitlets-5.9.0-py3-none-any.whl", hash = "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8"}, + {file = "traitlets-5.9.0.tar.gz", hash = "sha256:f6cde21a9c68cf756af02035f72d5a723bf607e862e7be33ece505abf4a3bad9"}, +] + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] + +[[package]] +name = "typing-extensions" +version = "4.7.1" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, + {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, +] + +[[package]] +name = "tzdata" +version = "2023.3" +description = "Provider of IANA time zone data" +category = "dev" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"}, + {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, +] + +[[package]] +name = "urllib3" +version = "2.0.4" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"}, + {file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "virtualenv" +version = "20.24.0" +description = "Virtual Python Environment builder" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.24.0-py3-none-any.whl", hash = "sha256:18d1b37fc75cc2670625702d76849a91ebd383768b4e91382a8d51be3246049e"}, + {file = "virtualenv-20.24.0.tar.gz", hash = "sha256:e2a7cef9da880d693b933db7654367754f14e20650dc60e8ee7385571f8593a3"}, +] + +[package.dependencies] +distlib = ">=0.3.6,<1" +filelock = ">=3.12,<4" +platformdirs = ">=3.5.1,<4" + +[package.extras] +docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.3.1)", "pytest-env (>=0.8.1)", "pytest-freezer (>=0.4.6)", "pytest-mock (>=3.10)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=67.8)", "time-machine (>=2.9)"] + +[[package]] +name = "wcwidth" +version = "0.2.6" +description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"}, + {file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"}, +] + +[[package]] +name = "xxhash" +version = "3.2.0" +description = "Python binding for xxHash" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "xxhash-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:af44b9e59c4b2926a4e3c7f9d29949ff42fcea28637ff6b8182e654461932be8"}, + {file = "xxhash-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1bdd57973e2b802ef32553d7bebf9402dac1557874dbe5c908b499ea917662cd"}, + {file = "xxhash-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b7c9aa77bbce61a5e681bd39cb6a804338474dcc90abe3c543592aa5d6c9a9b"}, + {file = "xxhash-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11bf87dc7bb8c3b0b5e24b7b941a9a19d8c1f88120b6a03a17264086bc8bb023"}, + {file = "xxhash-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2783d41487ce6d379fdfaa7332fca5187bf7010b9bddcf20cafba923bc1dc665"}, + {file = "xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:561076ca0dcef2fbc20b2bc2765bff099e002e96041ae9dbe910a863ca6ee3ea"}, + {file = "xxhash-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a26eeb4625a6e61cedc8c1b39b89327c9c7e1a8c2c4d786fe3f178eb839ede6"}, + {file = "xxhash-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d93a44d0104d1b9b10de4e7aadf747f6efc1d7ec5ed0aa3f233a720725dd31bd"}, + {file = "xxhash-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:89585adc73395a10306d2e2036e50d6c4ac0cf8dd47edf914c25488871b64f6d"}, + {file = "xxhash-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:a892b4b139126a86bfdcb97cd912a2f8c4e8623869c3ef7b50871451dd7afeb0"}, + {file = "xxhash-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e998efb190653f70e0f30d92b39fc645145369a4823bee46af8ddfc244aa969d"}, + {file = "xxhash-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e8ed3bd2b8bb3277710843ca63e4f5c3ee6f8f80b083be5b19a7a9905420d11e"}, + {file = "xxhash-3.2.0-cp310-cp310-win32.whl", hash = "sha256:20181cbaed033c72cb881b2a1d13c629cd1228f113046133469c9a48cfcbcd36"}, + {file = "xxhash-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:a0f7a16138279d707db778a63264d1d6016ac13ffd3f1e99f54b2855d6c0d8e1"}, + {file = "xxhash-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5daff3fb5bfef30bc5a2cb143810d376d43461445aa17aece7210de52adbe151"}, + {file = "xxhash-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75bb5be3c5de702a547715f320ecf5c8014aeca750ed5147ca75389bd22e7343"}, + {file = "xxhash-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01f36b671ff55cb1d5c2f6058b799b697fd0ae4b4582bba6ed0999678068172a"}, + {file = "xxhash-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d4d4519123aac73c93159eb8f61db9682393862dd669e7eae034ecd0a35eadac"}, + {file = "xxhash-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:994e4741d5ed70fc2a335a91ef79343c6b1089d7dfe6e955dd06f8ffe82bede6"}, + {file = "xxhash-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919bc1b010aa6ff0eb918838ff73a435aed9e9a19c3202b91acecd296bf75607"}, + {file = "xxhash-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17b65454c5accbb079c45eca546c27c4782f5175aa320758fafac896b1549d27"}, + {file = "xxhash-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b0c094d5e65a46dbf3fe0928ff20873a747e6abfd2ed4b675beeb2750624bc2e"}, + {file = "xxhash-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f94163ebe2d5546e6a5977e96d83621f4689c1054053428cf8d4c28b10f92f69"}, + {file = "xxhash-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:cead7c0307977a00b3f784cff676e72c147adbcada19a2e6fc2ddf54f37cf387"}, + {file = "xxhash-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a0e1bd0260c1da35c1883321ce2707ceea07127816ab625e1226ec95177b561a"}, + {file = "xxhash-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc8878935671490efe9275fb4190a6062b73277bd273237179b9b5a2aa436153"}, + {file = "xxhash-3.2.0-cp311-cp311-win32.whl", hash = "sha256:a433f6162b18d52f7068175d00bd5b1563b7405f926a48d888a97b90a160c40d"}, + {file = "xxhash-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:a32d546a1752e4ee7805d6db57944f7224afa7428d22867006b6486e4195c1f3"}, + {file = "xxhash-3.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:82daaab720866bf690b20b49de5640b0c27e3b8eea2d08aa75bdca2b0f0cfb63"}, + {file = "xxhash-3.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3126df6520cbdbaddd87ce74794b2b6c45dd2cf6ac2b600a374b8cdb76a2548c"}, + {file = "xxhash-3.2.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e172c1ee40507ae3b8d220f4048aaca204f203e1e4197e8e652f5c814f61d1aa"}, + {file = "xxhash-3.2.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5384f1d9f30876f5d5b618464fb19ff7ce6c0fe4c690fbaafd1c52adc3aae807"}, + {file = "xxhash-3.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26cb52174a7e96a17acad27a3ca65b24713610ac479c99ac9640843822d3bebf"}, + {file = "xxhash-3.2.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbcd613a5e76b1495fc24db9c37a6b7ee5f214fd85979187ec4e032abfc12ded"}, + {file = "xxhash-3.2.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:f988daf25f31726d5b9d0be6af636ca9000898f9ea43a57eac594daea25b0948"}, + {file = "xxhash-3.2.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:bbc30c98ab006ab9fc47e5ed439c00f706bc9d4441ff52693b8b6fea335163e0"}, + {file = "xxhash-3.2.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:2408d49260b0a4a7cc6ba445aebf38e073aeaf482f8e32767ca477e32ccbbf9e"}, + {file = "xxhash-3.2.0-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:3f4152fd0bf8b03b79f2f900fd6087a66866537e94b5a11fd0fd99ef7efe5c42"}, + {file = "xxhash-3.2.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:0eea848758e4823a01abdbcccb021a03c1ee4100411cbeeb7a5c36a202a0c13c"}, + {file = "xxhash-3.2.0-cp36-cp36m-win32.whl", hash = "sha256:77709139af5123c578ab06cf999429cdb9ab211047acd0c787e098dcb3f1cb4d"}, + {file = "xxhash-3.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:91687671fd9d484a4e201ad266d366b695a45a1f2b41be93d116ba60f1b8f3b3"}, + {file = "xxhash-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e4af8bc5c3fcc2192c266421c6aa2daab1a18e002cb8e66ef672030e46ae25cf"}, + {file = "xxhash-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8be562e2ce3e481d9209b6f254c3d7c5ff920eb256aba2380d2fb5ba75d4f87"}, + {file = "xxhash-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9eba0c7c12126b12f7fcbea5513f28c950d28f33d2a227f74b50b77789e478e8"}, + {file = "xxhash-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2198c4901a0223c48f6ec0a978b60bca4f4f7229a11ca4dc96ca325dd6a29115"}, + {file = "xxhash-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50ce82a71b22a3069c02e914bf842118a53065e2ec1c6fb54786e03608ab89cc"}, + {file = "xxhash-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5019fb33711c30e54e4e57ae0ca70af9d35b589d385ac04acd6954452fa73bb"}, + {file = "xxhash-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0d54ac023eef7e3ac9f0b8841ae8a376b933043bc2ad428121346c6fa61c491c"}, + {file = "xxhash-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c55fa832fc3fe64e0d29da5dc9b50ba66ca93312107cec2709300ea3d3bab5c7"}, + {file = "xxhash-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:f4ce006215497993ae77c612c1883ca4f3973899573ce0c52fee91f0d39c4561"}, + {file = "xxhash-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1afb9b9d27fd675b436cb110c15979976d92d761ad6e66799b83756402f3a974"}, + {file = "xxhash-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:baa99cebf95c1885db21e119395f222a706a2bb75a545f0672880a442137725e"}, + {file = "xxhash-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:75aa692936942ccb2e8fd6a386c81c61630ac1b6d6e921698122db8a930579c3"}, + {file = "xxhash-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:0a2cdfb5cae9fafb9f7b65fd52ecd60cf7d72c13bb2591ea59aaefa03d5a8827"}, + {file = "xxhash-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3a68d1e8a390b660d94b9360ae5baa8c21a101bd9c4790a8b30781bada9f1fc6"}, + {file = "xxhash-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ce7c3ce28f94302df95eaea7c9c1e2c974b6d15d78a0c82142a97939d7b6c082"}, + {file = "xxhash-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0dcb419bf7b0bc77d366e5005c25682249c5521a63fd36c51f584bd91bb13bd5"}, + {file = "xxhash-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae521ed9287f86aac979eeac43af762f03d9d9797b2272185fb9ddd810391216"}, + {file = "xxhash-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0d16775094423088ffa357d09fbbb9ab48d2fb721d42c0856b801c86f616eec"}, + {file = "xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe454aeab348c42f56d6f7434ff758a3ef90787ac81b9ad5a363cd61b90a1b0b"}, + {file = "xxhash-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:052fd0efdd5525c2dbc61bebb423d92aa619c4905bba605afbf1e985a562a231"}, + {file = "xxhash-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:02badf3754e2133de254a4688798c4d80f0060635087abcb461415cb3eb82115"}, + {file = "xxhash-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:66b8a90b28c13c2aae7a71b32638ceb14cefc2a1c8cf23d8d50dfb64dfac7aaf"}, + {file = "xxhash-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:649cdf19df175925ad87289ead6f760cd840730ee85abc5eb43be326a0a24d97"}, + {file = "xxhash-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4b948a03f89f5c72d69d40975af8af241111f0643228796558dc1cae8f5560b0"}, + {file = "xxhash-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49f51fab7b762da7c2cee0a3d575184d3b9be5e2f64f26cae2dd286258ac9b3c"}, + {file = "xxhash-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1a42994f0d42b55514785356722d9031f064fd34e495b3a589e96db68ee0179d"}, + {file = "xxhash-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:0a6d58ba5865475e53d6c2c4fa6a62e2721e7875e146e2681e5337a6948f12e7"}, + {file = "xxhash-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:aabdbc082030f8df613e2d2ea1f974e7ad36a539bdfc40d36f34e55c7e4b8e94"}, + {file = "xxhash-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:498843b66b9ca416e9d03037e5875c8d0c0ab9037527e22df3b39aa5163214cd"}, + {file = "xxhash-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a910b1193cd90af17228f5d6069816646df0148f14f53eefa6b2b11a1dedfcd0"}, + {file = "xxhash-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb6d8ce31dc25faf4da92991320e211fa7f42de010ef51937b1dc565a4926501"}, + {file = "xxhash-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:883dc3d3942620f4c7dbc3fd6162f50a67f050b714e47da77444e3bcea7d91cc"}, + {file = "xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59dc8bfacf89b8f5be54d55bc3b4bd6d74d0c5320c8a63d2538ac7df5b96f1d5"}, + {file = "xxhash-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:61e6aa1d30c2af692aa88c4dd48709426e8b37bff6a574ee2de677579c34a3d6"}, + {file = "xxhash-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:314ec0bd21f0ee8d30f2bd82ed3759314bd317ddbbd8555668f3d20ab7a8899a"}, + {file = "xxhash-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dad638cde3a5357ad3163b80b3127df61fb5b5e34e9e05a87697144400ba03c7"}, + {file = "xxhash-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:eaa3ea15025b56076d806b248948612289b093e8dcda8d013776b3848dffff15"}, + {file = "xxhash-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:7deae3a312feb5c17c97cbf18129f83cbd3f1f9ec25b0f50e2bd9697befb22e7"}, + {file = "xxhash-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:add774341c09853b1612c64a526032d95ab1683053325403e1afbe3ad2f374c5"}, + {file = "xxhash-3.2.0-cp39-cp39-win32.whl", hash = "sha256:9b94749130ef3119375c599bfce82142c2500ef9ed3280089157ee37662a7137"}, + {file = "xxhash-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:e57d94a1552af67f67b27db5dba0b03783ea69d5ca2af2f40e098f0ba3ce3f5f"}, + {file = "xxhash-3.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92fd765591c83e5c5f409b33eac1d3266c03d3d11c71a7dbade36d5cdee4fbc0"}, + {file = "xxhash-3.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8970f6a411a9839a02b23b7e90bbbba4a6de52ace009274998566dc43f36ca18"}, + {file = "xxhash-3.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5f3e33fe6cbab481727f9aeb136a213aed7e33cd1ca27bd75e916ffacc18411"}, + {file = "xxhash-3.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:368265392cb696dd53907e2328b5a8c1bee81cf2142d0cc743caf1c1047abb36"}, + {file = "xxhash-3.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:3b1f3c6d67fa9f49c4ff6b25ce0e7143bab88a5bc0f4116dd290c92337d0ecc7"}, + {file = "xxhash-3.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c5e8db6e1ee7267b7c412ad0afd5863bf7a95286b8333a5958c8097c69f94cf5"}, + {file = "xxhash-3.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:761df3c7e2c5270088b691c5a8121004f84318177da1ca1db64222ec83c44871"}, + {file = "xxhash-3.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2d15a707e7f689531eb4134eccb0f8bf3844bb8255ad50823aa39708d9e6755"}, + {file = "xxhash-3.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e6b2ba4ff53dd5f57d728095e3def7375eb19c90621ce3b41b256de84ec61cfd"}, + {file = "xxhash-3.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:61b0bcf946fdfd8ab5f09179dc2b5c74d1ef47cedfc6ed0ec01fdf0ee8682dd3"}, + {file = "xxhash-3.2.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7b79f0f302396d8e0d444826ceb3d07b61977793886ebae04e82796c02e42dc"}, + {file = "xxhash-3.2.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0773cd5c438ffcd5dbff91cdd503574f88a4b960e70cedeb67736583a17a918"}, + {file = "xxhash-3.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ec1f57127879b419a2c8d2db9d9978eb26c61ae17e5972197830430ae78d25b"}, + {file = "xxhash-3.2.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d4b15c00e807b1d3d0b612338c814739dec310b80fb069bd732b98ddc709ad7"}, + {file = "xxhash-3.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9d3f686e3d1c8900c5459eee02b60c7399e20ec5c6402364068a343c83a61d90"}, + {file = "xxhash-3.2.0.tar.gz", hash = "sha256:1afd47af8955c5db730f630ad53ae798cf7fae0acb64cebb3cf94d35c47dd088"}, +] + +[[package]] +name = "yarl" +version = "1.9.2" +description = "Yet another URL library" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, + {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, + {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, + {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, + {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, + {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, + {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, + {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, + {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, + {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, + {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, + {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, + {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + +[[package]] +name = "zipp" +version = "3.16.2" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.16.2-py3-none-any.whl", hash = "sha256:679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0"}, + {file = "zipp-3.16.2.tar.gz", hash = "sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + +[metadata] +lock-version = "2.0" +python-versions = ">=3.9,<3.12" +content-hash = "aa6554c9a3bc25b60de9b18ea2b74f4cda29f3e17baafa193f50776bf2cd61a7" diff --git a/flax/experimental/nnx/pyproject.toml b/flax/experimental/nnx/pyproject.toml new file mode 100644 index 0000000000..7a8132f133 --- /dev/null +++ b/flax/experimental/nnx/pyproject.toml @@ -0,0 +1,46 @@ +[tool.poetry] +name = "nnx" +version = "0.0.8" +description = "" +authors = ["Cristian Garcia "] +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.9,<3.12" +jax = "*" +jaxlib = "*" +optax = "*" +typing-extensions = "*" + + +[tool.poetry.group.test.dependencies] +pytest = ">=7.2.2" +pytest-cov = ">=4.0.0" +flax = ">=0.6.10" + + +[tool.poetry.group.dev.dependencies] +black = { version = "23.3.0", extras = ["jupyter"] } +isort = "5.12.0" +ipykernel = "^6.22.0" +pre-commit = ">=3.3.2" +pyink = "23.3.0" + +[tool.poetry.group.examples.dependencies] +matplotlib = "^3.7.1" +datasets = "^2.12.0" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "@overload", + "@tp.overload", + "@abstractmethod", +] + +[tool.pyink] +pyink-indentation = 2 diff --git a/flax/experimental/nnx/scripts/deploy-docs.sh b/flax/experimental/nnx/scripts/deploy-docs.sh new file mode 100644 index 0000000000..f6fe53fa8b --- /dev/null +++ b/flax/experimental/nnx/scripts/deploy-docs.sh @@ -0,0 +1,2 @@ +cp README.md docs/index.md +mkdocs gh-deploy --clean \ No newline at end of file diff --git a/flax/experimental/nnx/scripts/run-all-examples.bash b/flax/experimental/nnx/scripts/run-all-examples.bash new file mode 100644 index 0000000000..a25d3534af --- /dev/null +++ b/flax/experimental/nnx/scripts/run-all-examples.bash @@ -0,0 +1,8 @@ +set -e + +for f in $(find examples -name "*.py"); do + echo -e "\n---------------------------------" + echo "$f" + echo "---------------------------------" + poetry run time python "$f" +done diff --git a/flax/experimental/nnx/scripts/update_version.py b/flax/experimental/nnx/scripts/update_version.py new file mode 100644 index 0000000000..6f617c1fe4 --- /dev/null +++ b/flax/experimental/nnx/scripts/update_version.py @@ -0,0 +1,49 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from pathlib import Path + +import typer + + +# NOTE: this script could be written bash using sed, but I'm not sure if it's worth it +def main(release_name: str): + release_name = release_name.replace("-create-release", "") + + # Update pyproject.toml + pyproject_path = Path("pyproject.toml") + pyproject_text = pyproject_path.read_text() + pyproject_text = re.sub( + r'version = ".*"', + f'version = "{release_name}"', + pyproject_text, + count=1, + ) + pyproject_path.write_text(pyproject_text) + + # Update __init__.py + init_path = Path("nnx/__init__.py") + init_text = init_path.read_text() + init_text = re.sub( + r'__version__ = "(.*?)"', + f'__version__ = "{release_name}"', + init_text, + count=1, + ) + init_path.write_text(init_text) + + +if __name__ == "__main__": + typer.run(main) diff --git a/flax/experimental/nnx/tests/__init__.py b/flax/experimental/nnx/tests/__init__.py new file mode 100644 index 0000000000..e80ba0b35f --- /dev/null +++ b/flax/experimental/nnx/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flax/experimental/nnx/tests/test_containers.py b/flax/experimental/nnx/tests/test_containers.py new file mode 100644 index 0000000000..6d70cfbf6b --- /dev/null +++ b/flax/experimental/nnx/tests/test_containers.py @@ -0,0 +1,75 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import numpy as np +import pytest + +from flax.experimental import nnx + + +class TestContainers: + + def test_node_idenpotence(self): + x = nnx.Node(1) + x = nnx.Node(x) + + assert isinstance(x, nnx.Node) + + def test_variable_idenpotence(self): + x = nnx.Variable(1) + x = nnx.Variable(x) + + assert isinstance(x, nnx.Variable) + assert x.value == 1 + + def test_variable_cannot_change_collection(self): + x = nnx.Param(1) + + with pytest.raises(ValueError, match="is not compatible with return type"): + x = nnx.BatchStat(x) + + def test_container_cannot_change_type(self): + x = nnx.Variable(1) + + with pytest.raises(ValueError, match="is not compatible with return type"): + x = nnx.Node(x) + + x = nnx.Node(2) + + with pytest.raises(ValueError, match="is not compatible with return type"): + x = nnx.Variable(x) + + def test_static_is_empty(self): + leaves = jax.tree_util.tree_leaves(nnx.Static(1)) + + assert len(leaves) == 0 + + def test_static_empty_pytree(self): + static = nnx.Static(2) + + static = jax.tree_map(lambda x: x + 1, static) + + assert static.value == 2 + + def test_static_array_not_jitable(self): + @jax.jit + def f(x): + return x + + # first time you don't get an error due to a bug in jax + f(nnx.Static(np.random.uniform(size=(10, 10)))) + + with pytest.raises(ValueError): + f(nnx.Static(np.random.uniform(size=(10, 10)))) diff --git a/flax/experimental/nnx/tests/test_context.py b/flax/experimental/nnx/tests/test_context.py new file mode 100644 index 0000000000..65a76347e8 --- /dev/null +++ b/flax/experimental/nnx/tests/test_context.py @@ -0,0 +1,114 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import jax +import numpy as np +import pytest + +from flax.experimental import nnx +from flax.experimental.nnx.nnx.contextlib import _stable_hash + + +class TestContext: + + def test_hash(self): + _hash = _stable_hash("hi") + assert isinstance(_hash, int) + + def test_rng_stream(self): + key0 = jax.random.PRNGKey(0) + ctx = nnx.context(key0) + assert ctx._rngs["params"].count == 0 + + key1 = ctx.make_rng("params") + assert ctx._rngs["params"].count == 1 + assert ctx._rngs["params"].key is key0 + assert not np.equal(key0, key1).all() + + key2 = ctx.make_rng("params") + assert ctx._rngs["params"].count == 2 + assert ctx._rngs["params"].key is key0 + assert not np.equal(key1, key2).all() + + def test_rng_fork(self): + key0 = jax.random.PRNGKey(0) + ctx1 = nnx.context(key0) + ctx2 = ctx1.partition().merge() + + assert ctx2._rngs["params"].count == 0 + assert ctx2._rngs["params"].count_path == (0,) + + key1 = ctx1.make_rng("params") + key2 = ctx2.make_rng("params") + + assert not np.equal(key1, key2).all() + + def test_rng_trace_level_constraints(self): + ctx = nnx.context(0) + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match="Cannot use Context from a different trace level", + ): + ctx.make_rng("params") + + f() + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match="Cannot use Context from a different trace level", + ): + ctx.partition() + + f() + + ctx1: Any = None + + @jax.jit + def g(): + nonlocal ctx1 + ctx1 = nnx.context(1) + + g() + + assert isinstance(ctx1, nnx.Context) + with pytest.raises( + nnx.TraceContextError, + match="Cannot use Context from a different trace level", + ): + ctx1.make_rng("params") + + def test_partition_merge(self): + ctx = nnx.context(dropout=0) + + keys, ctxdef = ctx.partition() + + assert "dropout" in keys + assert ctxdef._rng_counts == (("dropout", (0,)),) + + ctx2 = ctxdef.merge(keys) + + key1 = ctx.make_rng("dropout") + key2 = ctx2.make_rng("dropout") + assert not np.equal(key1, key2).all() + + ctx3 = ctxdef.merge(keys) + key3 = ctx3.make_rng("dropout") + assert np.equal(key2, key3).all() diff --git a/flax/experimental/nnx/tests/test_helpers.py b/flax/experimental/nnx/tests/test_helpers.py new file mode 100644 index 0000000000..6dbe65df7e --- /dev/null +++ b/flax/experimental/nnx/tests/test_helpers.py @@ -0,0 +1,79 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import optax + +from flax.experimental import nnx + + +class TestHelpers: + + def test_train_state(self): + m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) + + (params, batch_stats), moduledef = m.partition(nnx.Param, nnx.BatchStat) + + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(1.0), + batch_stats=batch_stats, + other=nnx.Node(100), + int=200, + static=nnx.Static(300), + ) + + leaves = jax.tree_util.tree_leaves(state) + + assert 1 in leaves + assert 2 in leaves + assert 100 in leaves + assert 200 not in leaves + assert 300 not in leaves + + def test_train_state_methods(self): + class Foo(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(2, 4, ctx=ctx) + self.batch_norm = nnx.BatchNorm(4, ctx=ctx) + + def __call__(self, x: jax.Array, train: bool) -> jax.Array: + x = self.linear(x) + x = self.batch_norm(x, use_running_average=not train) + return x + + module = Foo(ctx=nnx.context(0)) + (params, batch_stats), moduledef = module.partition( + nnx.Param, nnx.BatchStat + ) + + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(1.0), + batch_stats=batch_stats, + ) + + x = jax.numpy.ones((1, 2)) + y, _updates = state.apply("params", "batch_stats")(x, train=True) + + assert y.shape == (1, 4) + + # fake gradient + grads = jax.tree_map(jnp.ones_like, state.params) + # test apply_gradients + state = state.apply_gradients(grads) diff --git a/flax/experimental/nnx/tests/test_ids.py b/flax/experimental/nnx/tests/test_ids.py new file mode 100644 index 0000000000..78172ad8e0 --- /dev/null +++ b/flax/experimental/nnx/tests/test_ids.py @@ -0,0 +1,31 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from flax.experimental.nnx.nnx import ids + + +class TestIds: + + def test_hashable(self): + id1 = ids.uuid() + id2 = ids.uuid() + assert id1 == id1 + assert id1 != id2 + assert hash(id1) != hash(id2) + id1c = copy.copy(id1) + id1dc = copy.deepcopy(id1) + assert hash(id1) != hash(id1c) + assert hash(id1) != hash(id1dc) diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/experimental/nnx/tests/test_integration.py new file mode 100644 index 0000000000..177c4dd50e --- /dev/null +++ b/flax/experimental/nnx/tests/test_integration.py @@ -0,0 +1,258 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np + +from flax.experimental import nnx + +A = tp.TypeVar("A") + + +class TestIntegration: + + def test_shared_modules(self): + class Block(nnx.Module): + + def __init__(self, linear: nnx.Linear, *, ctx): + self.linear = linear + self.bn = nnx.BatchNorm(2, ctx=ctx) + + def __call__(self, x, *, ctx): + x = self.linear(x) + x = self.bn(x, ctx=ctx) + return nnx.relu(x) + + class Model(nnx.Module): + + def __init__(self, *, ctx): + shared = nnx.Linear(2, 2, ctx=ctx) + self.block1 = Block(shared, ctx=ctx) + self.block2 = Block(shared, ctx=ctx) + + def __call__(self, x, *, ctx): + x = self.block1(x, ctx=ctx) + x = self.block2(x, ctx=ctx) + return x + + @nnx.jit + def train_step(model: Model, x, y): + @nnx.grad + def loss_fn(model: Model): + ctx = nnx.context(flags=dict(use_running_average=False)) + y_pred = model(x, ctx=ctx) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(model) + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + + model = Model(ctx=nnx.context(0)) + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + + for _i in range(3): + train_step(model, x, y) + + assert model.block1.linear is model.block2.linear + assert model.block1.linear.bias is not None + assert model.block1.bn is not model.block2.bn + + def test_shared_modules_pure(self): + class Block(nnx.Module): + + def __init__(self, linear: nnx.Linear, *, ctx: nnx.Context): + self.linear = linear + self.bn = nnx.BatchNorm(2, ctx=ctx) + + def __call__(self, x, *, ctx: nnx.Context): + x = self.linear(x) + x = self.bn(x, ctx=ctx) + return nnx.relu(x) + + class Model(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + shared = nnx.Linear(2, 2, ctx=ctx) + self.block1 = Block(shared, ctx=ctx) + self.block2 = Block(shared, ctx=ctx) + + def __call__(self, x, *, ctx: nnx.Context): + x = self.block1(x, ctx=ctx) + x = self.block2(x, ctx=ctx) + return x + + @jax.jit + def train_step(pure_module: nnx.PureModule[Model], x, y): + model = pure_module.merge() + + @nnx.grad + def loss_fn(model: Model): + ctx = nnx.context(flags=dict(use_running_average=False)) + y_pred = model(x, ctx=ctx) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(model) + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + + return model.partition() + + pure_module = Model(ctx=nnx.context(0)).partition() + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + + for _i in range(3): + pure_module = train_step(pure_module, x, y) + + model = pure_module.merge() + + assert model.block1.linear.bias is not None + assert model.block2.linear.bias is not None + assert model.block1.linear.kernel is model.block2.linear.kernel + assert model.block1.linear.bias is model.block2.linear.bias + assert model.block1.bn is not model.block2.bn + + def test_stateful_example(self): + class State(nnx.Variable[A]): + pass + + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = State(0) + + def __call__(self, x): + self.count += 1 + return x @ self.w + self.b[None] + + model = Linear(din=12, dout=2, ctx=nnx.context(0)) + # forward pass + x = jnp.ones((8, 12)) + y = model(x) + assert model.count == 1 + + @nnx.jit + def train_step(model, x, y): + def loss_fn(model): + y_pred = model(x) + return jax.numpy.mean((y_pred - y) ** 2) + + # compute gradient + grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # SGD update + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + + # execute the training step + train_step(model, x, y) + assert model.count == 2 + + def test_functional_example(self): + class Count(nnx.Variable[A]): + pass + + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) + + def __call__(self, x): + self.count += 1 + return x @ self.w + self.b[None] + + model = Linear(din=12, dout=2, ctx=nnx.context(0)) + # forward pass + x = jnp.ones((8, 12)) + y = model(x) + assert model.count == 1 + + (params, counts), moduledef = model.partition(nnx.Param, Count) + + @jax.jit + def train_step(params, counts, x, y): + def loss_fn(params): + y_pred, (updates, _) = moduledef.apply(params, counts)(x) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, updates.filter(Count) + + # compute gradient + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + # SGD update + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + + return params, counts + + # execute the training step + params, counts = train_step(params, counts, x, y) + model = moduledef.merge(params, counts) + assert model.count == 2 + + def test_intermediates_example(self): + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b[None] + self.y = nnx.Intermediate(y) + return y + + model = Linear(12, 2, ctx=nnx.context(0)) + + y = model(jnp.ones((8, 12))) + + intermediates = model.pop_state(nnx.Intermediate) + + assert "y" in intermediates + + def test_intermediates_example_functional(self): + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b[None] + self.y = nnx.Intermediate(y) + return y + + model = Linear(12, 2, ctx=nnx.context(0)) + + state, moduledef = model.partition() + + y, (state, _) = moduledef.apply(state)(jnp.ones((8, 12))) + + intermediates, state = state.partition(nnx.Intermediate, ...) + + assert "y" in intermediates diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py new file mode 100644 index 0000000000..b03a744033 --- /dev/null +++ b/flax/experimental/nnx/tests/test_module.py @@ -0,0 +1,587 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import jax +import jax.numpy as jnp +import pytest + +from flax.experimental import nnx + + +class TestModule: + + def test_has_module_state(self): + class Foo(nnx.Module): + ... + + foo = Foo() + + assert hasattr(foo, "_module__state") + + def test_trace_level(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match="Cannot mutate Module from different trace level", + ): + m.a = 2 + + f() + + def test_split_merge(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(pure_module: nnx.PureModule[nnx.Dict[int]]): + m = pure_module.merge() + m.a = 2 + return m.partition() + + m2 = g(m.partition()).merge() + + assert m2.a == 2 + + def test_no_trace_level_error_on_grad(self): + # No trace level error occurs because jax doesn't update + # its top trace for grad. + m = nnx.Dict(a=nnx.Param(1.0)) + + @jax.grad + def f(_): + m.a = 2.0 + return 1.0 + + f(1.0) + + def test_trace_level_error_on_nnx_grad(self): + # error occurs because nnx updates its nnx_trace + # in nnx.grad. + m = nnx.Dict(a=nnx.Param(1.0)) + + @nnx.grad + def f(_): + with pytest.raises( + nnx.TraceContextError, + match="Cannot mutate Module from different trace level", + ): + m.a = 2.0 + return 1.0 + + f(m) + + def test_call(self): + class Foo(nnx.Module): + + def __init__(self, c: float, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, ())) + self.c = jnp.asarray(c) + + def __call__(self, x, *, ctx: nnx.Context): + key = ctx.make_rng("e") + return self.w * x + jax.random.normal(key, ()) + self.c + + foo = Foo(c=1.0, ctx=nnx.context(0)) + + y = foo(x=2.0, ctx=nnx.context(e=1)) + + assert isinstance(y, jax.Array) + + def test_shared_module(self): + m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) + m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) + + m3 = m2.partition().merge() + + assert m3["x"] is m3["y"] + assert m3["x"]["a"] is m3["y"]["a"] + assert m3["x"]["b"] is m3["y"]["b"] + + def test_module_graph(self): + class Foo(nnx.Module): + + def __init__(self): + self.a = nnx.Param(1) + self.sub = self + + m = Foo() + + state, moduledef = m.partition() + assert len(state) == 1 + + m2 = moduledef.merge(state) + assert m2 is m2.sub + + def test_deref_through_jit(self): + r1 = nnx.Node(1) + r2 = nnx.Node(2) + + m = m0 = nnx.Dict({"a": nnx.Sequence([r1, r2]), "b": r1}) + + @jax.jit + def f(pure_module: nnx.PureModule[nnx.Dict[Any]]): + m = pure_module.merge() + + assert m["a"][0] is not m["b"] + assert m["a"][1] is not m["b"] + + return m.partition() + + m = f(m.partition()).merge() + + assert m["a"][0] is not m["b"] + assert m["a"][1] is not m["b"] + + # compare with pytree0 + assert m["a"][0] is not m0["a"][0] + assert m["a"][1] is not m0["a"][1] + assert m["b"] is not m0["b"] + + def test_cross_barrier(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(pure_module: nnx.PureModule[nnx.Dict[int]]): + m = pure_module.merge() + m.a += 1 + return m.partition() + + m2 = g(m.partition()).merge() + assert m2 is not m + assert m.a == 1 + assert m2.a == 2 + + def test_no_rejit(self): + n = 0 + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(pure_module): + nonlocal n + n += 1 + m = pure_module.merge() + m.a += 1 + return m.partition() + + m2 = g(m.partition()).merge() + + assert n == 1 + assert m2 is not m + assert m.a == 1 + assert m2.a == 2 + + g(m.partition()) + assert n == 1 + + g(m2.partition()) + assert n == 1 + + m2.b = nnx.Param(10) + g(m2.partition()) + + assert n == 2 + + def test_deref_number_of_fields(self): + r1 = nnx.Node(1) + r2 = nnx.Node(2) + v1 = 3 + m = nnx.Dict({ + "a": nnx.Sequence([r1, r2, v1]), + "b": nnx.Dict({"c": r1, "d": r2}), + }) + + p, moduledef = m.partition() + assert len(p) == 4 + assert len(jax.tree_util.tree_leaves(p)) == 4 + + def test_deref_arrays_are_nodes(self): + # test arrays are nodes + r1 = nnx.Node(1) + r2 = nnx.Node(2) + v1 = jax.numpy.array(3) + m = nnx.Dict({ + "a": nnx.Sequence([r1, r2, v1]), + "b": nnx.Dict({"c": r1, "d": r2}), + }) + + p, moduledef = m.partition() + assert len(p) == 5 + assert len(jax.tree_util.tree_leaves(p)) == 5 + + def test_clone(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), 3]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), + ) + + m2 = m.clone() + + assert m is not m2 + assert m2.a[0] == m2.b.c + assert m2.a[1] == m2.b.d + + assert m.a[0] == m2.a[0] + assert m.a[1] == m2.a[1] + assert m.b.c == m2.b.c + assert m.b.d == m2.b.d + + def test_sow_basic(self): + class Foo(nnx.Module): + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + y1 = m(2) + y2 = m(10) + + assert y1 == 3 + assert y2 == 11 + assert m.y == (3, 11) + + intermediates = m.pop_state(nnx.Intermediate) + + assert isinstance(intermediates["y"], nnx.Intermediate) + assert intermediates["y"].value == (3, 11) + + assert not hasattr(m, "y") + + def test_sow_existing_non_variable_field(self): + class Foo(nnx.Module): + + def __init__(self) -> None: + self.y = 10 + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + + with pytest.raises(ValueError, match="to be a Variable, got"): + m(2) + + def test_sow_wrong_collection(self): + class Foo(nnx.Module): + + def __init__(self) -> None: + self.y = nnx.Param(10) + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + + with pytest.raises(ValueError, match="to be of type"): + m(2) + + def test_sow_non_tuple(self): + class Foo(nnx.Module): + + def __init__(self) -> None: + self.y = nnx.Intermediate(10) + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + + with pytest.raises(ValueError, match="to be a tuple,"): + m(2) + + +class TestModuleDataclass: + + def test_basic(self): + @nnx.dataclass + class Foo(nnx.Module): + a: int = nnx.static_field() + b: int = nnx.node_field() + c: int = nnx.param_field() + d: int = nnx.var_field(nnx.BatchStat) + e: int + f: int + + m = Foo( + a=1, # static + b=2, # node + c=3, # param + d=4, # var + e=5, # static int + f=nnx.Node(6), # test that we can pass in a node + ) + + state, moduledef = m.partition() + + assert len(state) == 4 + assert state["b"] == nnx.Node(2) + assert state["c"] == nnx.Param(3) + assert state["d"] == nnx.BatchStat(4) + assert state["f"] == nnx.Node(6) + + def test_no_override(self): + @nnx.dataclass + class Foo(nnx.Module): + a: int = nnx.node_field() + + with pytest.raises(ValueError, match="is not compatible with return type"): + _m = Foo(a=nnx.Param(1)) + + _m = Foo(a=nnx.Node(1)) + + +class TestModuleDef: + + def test_apply(self): + class Foo(nnx.Module): + + def __init__(self, c: float, *, ctx: nnx.Context): + self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), ())) + self.c = jnp.asarray(c) + + def __call__(self, x, *, ctx: nnx.Context): + key = ctx.make_rng("e") + return self.w * x + jax.random.normal(key, ()) + self.c + + ctx = nnx.context(0) + foo = Foo(c=1.0, ctx=ctx) + + states, moduledef = foo.partition() + + assert isinstance(states, nnx.State) + assert isinstance(states["w"], nnx.Param) + assert isinstance(states["c"], jax.Array) + + y, _updates = moduledef.apply(states)(x=2.0, ctx=nnx.context(e=1)) + + assert isinstance(y, jax.Array) + + def test_derefed_mod_apply(self): + class Foo(nnx.Module): + + def __init__(self, c: float, *, ctx: nnx.Context): + self.w = nnx.Param( + jax.random.uniform(ctx.make_rng("params"), ()), + ) + self.c = jnp.asarray(c) + + def __call__(self, x, *, ctx: nnx.Context): + key = ctx.make_rng("e") + return self.w * x + jax.random.normal(key, ()) + self.c + + foo = Foo(c=1.0, ctx=nnx.context(0)) + + pure_module = foo.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + assert isinstance(pure_module.states["w"], nnx.Param) + assert isinstance(pure_module.states["c"], jax.Array) + + y, states = pure_module.apply(x=2.0, ctx=nnx.context(e=1)) + + assert isinstance(y, jax.Array) + + +class TestPureModule: + + def test_partition_merge(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = state, moduledef = m.partition() + + m2 = pure_module.merge() + + assert isinstance(state, nnx.State) + assert isinstance(moduledef, nnx.ModuleDef) + assert isinstance(m2, nnx.Dict) + assert isinstance(m2.a, nnx.Sequence) + assert isinstance(m2.b, nnx.Dict) + assert len(m.get_state()) == 5 + assert len(m2.get_state()) == 5 + + def test_partition_merge_with_filters(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = (params, batch_stats, rest), moduledef = m.partition( + nnx.Param, nnx.BatchStat, ... + ) + + m2 = pure_module.merge() + + assert isinstance(params, nnx.State) + assert isinstance(batch_stats, nnx.State) + assert isinstance(rest, nnx.State) + assert isinstance(moduledef, nnx.ModuleDef) + assert isinstance(m2, nnx.Dict) + assert isinstance(m2.a, nnx.Sequence) + assert isinstance(m2.b, nnx.Dict) + assert len(m.get_state()) == 5 + assert len(m2.get_state()) == 5 + + def test_filter(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition() + + params = pure_module.filter(nnx.Param) + batch_stats = pure_module.filter(nnx.BatchStat) + rest = pure_module.filter(nnx.Not(nnx.Variable)) + + assert len(params) == 3 + assert len(batch_stats) == 1 + assert len(rest) == 1 + + def test_filter_with_filters(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition(nnx.Param, ...) + + params = pure_module.filter(nnx.Param) + batch_stats = pure_module.filter(nnx.BatchStat) + rest = pure_module.filter(nnx.Not(nnx.Variable)) + + assert len(params) == 3 + assert len(batch_stats) == 1 + assert len(rest) == 1 + + def test_partition_partition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + + pure_module = pure_module.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + + def test_partition_with_filters_partition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition(nnx.Param, ...) + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, tuple) + + pure_module = pure_module.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + + def test_partition_with_filters_partition_with_filters(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition(nnx.Param, ...) + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, tuple) + + pure_module = pure_module.partition(nnx.BatchStat, ...) + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, tuple) + + def test_pop(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition() + + params, pure_module2 = pure_module.pop_state(nnx.Param) + + assert isinstance(params, nnx.State) + assert isinstance(pure_module2, nnx.Pure) + assert isinstance(pure_module2.states, nnx.State) + assert len(params) == 3 + assert len(pure_module2.states) == 2 + + (params, batch_stats), pure_module2 = pure_module.pop_state( + nnx.Param, nnx.BatchStat + ) + + assert isinstance(params, nnx.State) + assert isinstance(batch_stats, nnx.State) + assert isinstance(pure_module2, nnx.Pure) + assert isinstance(pure_module2.states, nnx.State) + assert len(params) == 3 + assert len(batch_stats) == 1 + assert len(pure_module2.states) == 1 + + def test_on_all(self): + class Bar(nnx.Module): + + def __init__(self): + self.a = nnx.Param(1) + + class Foo(nnx.Module): + + def __init__(self, bar): + self.bar1 = bar + self.bar2 = bar + self.b = nnx.Param(2) + + foo = Foo(Bar()) + + def f(bar: Bar): + bar.a += 1 + + foo.for_each(Bar, f) + + assert foo.bar1.a == 2 + assert foo.bar2.a == 2 + + def g(foo: Foo): + foo.b += 1 + + foo.for_each(Foo, g) + + assert foo.b == 3 diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/experimental/nnx/tests/test_partitioning.py new file mode 100644 index 0000000000..7fb44918e7 --- /dev/null +++ b/flax/experimental/nnx/tests/test_partitioning.py @@ -0,0 +1,162 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import numpy as np +import pytest + +from flax.experimental import nnx + + +class TestPartitioning: + + def test_partition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(2)]), + b=nnx.Param(2), + c=100, + ) + + (params, rest), moduledef = m.partition(nnx.Param, ...) + + assert len(params) == 2 + assert len(rest) == 1 + + # check params + assert params["a/0"].value == m.a[0] + assert params["b"].value == m.b + + # check rest + assert rest["a/1"].value == m.a[1] + + m2 = moduledef.merge(params, rest) + + assert m2.a[0] == m.a[0] + assert m2.a[1] == m.a[1] + assert m2.b == m.b + assert m2.c == 100 + + def test_complete_partitioning(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + # no error + m.partition(nnx.Param, nnx.BatchStat, nnx.Node) + + def test_complete_partitioning_plus_ellipsis(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + # no error if additional ... is passed at the end + m.partition(nnx.Param, nnx.BatchStat, nnx.Node, ...) + + def test_inclomplete_partition_error(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + with pytest.raises( + ValueError, match="Non-exhaustive filters, got a non-empty remainder" + ): + m.partition(nnx.Param) + + def test_ellipsis_not_last_error(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + with pytest.raises( + ValueError, match="Ellipsis `...` can only be used as the last filter," + ): + m.partition(..., nnx.Param) + + def test_update_from(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), + b=nnx.Param(2), + c=100, + ) + + state = m.partition()[0] + state = jax.tree_map(lambda x: x * 2, state) + + m.update_state(state) + + assert m.a[0] == 2 + assert m.a[1] == 6 + assert m.b == 4 + assert m.c == 100 + + def test_update_from_with_array_leaf(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), + b=nnx.Param(2), + c=jax.numpy.array(100), + ) + + pure_module: nnx.Pure = m.partition() + pure_module = jax.tree_map(lambda x: x * 2, pure_module) + + m.update_state(pure_module.states) + + assert m.a[0] == 2 + assert m.a[1] == 6 + assert m.b == 4 + assert m.c == 200 + + def test_grad_example(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1.0), nnx.BatchStat(-10)]), + b=nnx.Param(2.0), + c=100, + ) + + params = m.filter(nnx.Param) + + def loss(params): + return sum(2 * p for p in jax.tree_util.tree_leaves(params)) + + grads = jax.grad(loss)(params) + m.update_state(grads) + + assert m.a[0] == 2.0 + assert m.a[1] == -10 + assert m.b == 2.0 + assert m.c == 100 + + def test_get_paritition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.Param(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + # test Variables not shared + assert vars(m.a)["0"] is not vars(m)["b"] + + state = m.filter(nnx.Node) + assert state["a/0"].value == m.a[0] + assert state["a/1"].value == m.a[1] + assert state["b"].value == m.b + assert state["b"] is not state["a/0"] + assert len(state) == 3 diff --git a/flax/experimental/nnx/tests/test_pytree.py b/flax/experimental/nnx/tests/test_pytree.py new file mode 100644 index 0000000000..88078c37e0 --- /dev/null +++ b/flax/experimental/nnx/tests/test_pytree.py @@ -0,0 +1,273 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Generic, TypeVar + +import jax +import pytest +from flax import serialization + +from flax.experimental import nnx + + +class TestPytree: + + def test_immutable_pytree(self): + class Foo(nnx.Pytree): + + def __init__(self, y) -> None: + self.x = 2 + self.y = nnx.Node(y) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises( + AttributeError, match="is immutable, trying to update field" + ): + pytree.x = 4 + + def test_immutable_pytree_dataclass(self): + @nnx.dataclass(frozen=True) + class Foo(nnx.Pytree): + y: int = nnx.node_field() + x: int = nnx.field(default=2) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises(AttributeError, match="cannot assign to field"): + pytree.x = 4 + + def test_jit(self): + @nnx.dataclass + class Foo(nnx.Pytree): + a: int = nnx.node_field() + b: int = nnx.field() + + module = Foo(a=1, b=2) + + @jax.jit + def f(m: Foo): + return m.a + m.b + + assert f(module) == 3 + + def test_flax_serialization(self): + class Bar(nnx.Pytree): + + def __init__(self, a, b): + self.a = a + self.b = nnx.Node(b) + + @nnx.dataclass + class Foo(nnx.Pytree): + bar: Bar + c: int = nnx.node_field() + d: int = nnx.field() + + foo: Foo = Foo(bar=Bar(a=1, b=2), c=3, d=4) + + state_dict = serialization.to_state_dict(foo) + + assert state_dict == { + "bar": { + "b": 2, + }, + "c": 3, + } + + state_dict["bar"]["b"] = 5 + + foo = serialization.from_state_dict(foo, state_dict) + + assert foo.bar.b == 5 + + del state_dict["bar"]["b"] + + with pytest.raises(ValueError, match="Missing field"): + serialization.from_state_dict(foo, state_dict) + + state_dict["bar"]["b"] = 5 + + # add unknown field + state_dict["x"] = 6 + + with pytest.raises(ValueError, match="Unknown field"): + serialization.from_state_dict(foo, state_dict) + + def test_generics(self): + T = TypeVar("T") + + class MyClass(nnx.Pytree, Generic[T]): + + def __init__(self, x: T): + self.x = x + + MyClass[int] + + def test_key_paths(self): + @nnx.dataclass + class Bar(nnx.Pytree): + a: int = nnx.node_field(default=1) + b: int = nnx.field(default=2) + + @nnx.dataclass + class Foo(nnx.Pytree): + x: int = nnx.node_field(default=3) + y: int = nnx.field(default=4) + z: Bar = nnx.node_field(default_factory=Bar) + + foo = Foo() + + path_values, treedef = jax.tree_util.tree_flatten_with_path(foo) + path_values = [(list(map(str, path)), value) for path, value in path_values] + + assert path_values[0] == ([".x", ".value"], 3) + assert path_values[1] == ([".z", ".value", ".a", ".value"], 1) + + def test_replace_unknown_fields_error(self): + class Foo(nnx.Pytree): + pass + + with pytest.raises(ValueError, match="Trying to replace unknown fields"): + Foo().replace(y=1) + + def test_dataclass_inheritance(self): + @nnx.dataclass + class A(nnx.Pytree): + a: int = nnx.node_field(default=1) + b: int = nnx.field(default=2) + + @nnx.dataclass + class B(A): + c: int = nnx.node_field(default=3) + + pytree = B() + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [1, 3] + + def test_pytree_with_new(self): + class A(nnx.Pytree): + + def __init__(self, a): + self.a = a + + def __new__(cls, a): + return super().__new__(cls) + + pytree = A(a=1) + + pytree = jax.tree_map(lambda x: x * 2, pytree) + + def test_deterministic_order(self): + class A(nnx.Pytree): + + def __init__(self, order: bool): + if order: + self.a = 1 + self.b = 2 + else: + self.b = 2 + self.a = 1 + + p1 = A(order=True) + p2 = A(order=False) + + leaves1 = jax.tree_util.tree_leaves(p1) + leaves2 = jax.tree_util.tree_leaves(p2) + + assert leaves1 == leaves2 + + +class TestMutablePytree: + + def test_pytree(self): + class Foo(nnx.Pytree, mutable=True): + + def __init__(self, y) -> None: + self.x = 2 + self.y = nnx.Node(y) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 + + def test_no_new_fields_after_init(self): + class Foo(nnx.Pytree, mutable=True): + + def __init__(self, x): + self.x = nnx.Node(x) + + foo = Foo(x=1) + foo.x = 2 + + with pytest.raises(AttributeError, match=r"Cannot add new fields to"): + foo.y = 2 + + def test_pytree_dataclass(self): + @nnx.dataclass + class Foo(nnx.Pytree, mutable=True): + y: int = nnx.node_field() + x: int = nnx.field(default=2) + + pytree: Foo = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/experimental/nnx/tests/test_spmd.py new file mode 100644 index 0000000000..63c2adab55 --- /dev/null +++ b/flax/experimental/nnx/tests/test_spmd.py @@ -0,0 +1,89 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import optax +from jax._src import test_util as jtu +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec + +from flax.experimental import nnx + + +class TestSPMD: + + @jtu.skip_on_devices("cpu", "gpu") + def test_init(self): + class Foo(nnx.Module): + + def __init__(self): + self.w = nnx.Param( + nnx.with_logical_partitioning( + lambda: jnp.ones((8, 2)), + sharding=("row", "col"), + )() + ) + + def __call__(self, x): + return x @ self.w + + @jax.jit + def create_module(): + return Foo().partition() + + mesh = Mesh(mesh_utils.create_device_mesh((2, 2)), ("model", "data")) + + with mesh, nnx.logical_axis_rules([("row", "model"), ("col", "data")]): + m: Foo = create_module().merge() + + assert m.w.shape == (8, 2) + assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) + + def test_get_partition_spec(self): + class Foo(nnx.Module): + + def __init__(self): + self.w = nnx.Param( + nnx.with_logical_partitioning( + lambda: jnp.ones((8, 2)), + sharding=("row", "col"), + )() + ) + + def __call__(self, x): + return x @ self.w + + params, moduledef = Foo().partition() + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.adam(1e-3), + ) + logical_state_spec = nnx.get_partition_spec(state) + + assert logical_state_spec.params["w"] == PartitionSpec("row", "col") + assert logical_state_spec.opt_state[0].mu["w"] == PartitionSpec( + "row", "col" + ) + assert logical_state_spec.opt_state[0].nu["w"] == PartitionSpec( + "row", "col" + ) + + with nnx.logical_axis_rules([("row", "model"), ("col", "data")]): + state_spec = nnx.logical_to_mesh(logical_state_spec) + + assert state_spec.params["w"] == PartitionSpec("model", "data") + assert state_spec.opt_state[0].mu["w"] == PartitionSpec("model", "data") + assert state_spec.opt_state[0].nu["w"] == PartitionSpec("model", "data") diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py new file mode 100644 index 0000000000..72d0ea3033 --- /dev/null +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -0,0 +1,408 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp +from functools import partial + +import jax +import jax.numpy as jnp +import pytest + +from flax.experimental import nnx + + +class TestJIT: + + def test_jit(self): + m = nnx.Dict(a=nnx.Param(1)) + + @nnx.jit + def g(m: nnx.Dict): + m.a = 2 + return 1.0 + + out = g(m) + + assert m.a == 2 + assert out == 1.0 + + def test_jit_stateless(self): + m = nnx.Dict(a=nnx.Param(1)) + + @partial(nnx.jit, stateful=False) + def g(m: nnx.Dict): + m.a = 2 + return 1.0 + + out = g(m) + + assert m.a == 1 + assert out == 1.0 + + +class TestGrad: + + def test_grad(self): + p1 = nnx.Param(10.0) + p2 = nnx.Param(20.0) + + m = nnx.Dict( + a=nnx.Sequence([p1, p2]), + b=p1, + c=7, + d=5.0, + ) + + @nnx.grad + def f(m: nnx.Dict): + # sum all params + return m["a"][0] + m["a"][1] + m["b"] + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads["a/0"].value == 1.0 + assert isinstance(grads["a/0"], nnx.Node) + assert grads["a/1"].value == 1.0 + assert isinstance(grads["a/1"], nnx.Node) + assert grads["b"].value == 1.0 + assert isinstance(grads["b"], nnx.Node) + assert len(grads) == 3 + + m.update_state(grads) + + assert m["a"][0] == 1.0 + assert m["a"][1] == 1.0 + assert m["b"] == 1.0 + assert m["c"] == 7 + assert m["d"] == 5.0 + + def test_grad_with_multiple_ref_types(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + @nnx.grad + def f(m: nnx.Dict): + # sum all params + return m.a[0] + m.a[1] + m.b + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads["a/0"].value == 1.0 + assert isinstance(grads["a/0"], nnx.Param) + assert len(grads) == 2 + + m.update_state(grads) + + assert m.a[0] == 1.0 + assert m.a[1] == 20.0 + assert m.b == 1.0 + assert m.c == 7 + assert m.d == 5.0 + + def test_grad_with_type_predicate(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + @partial(nnx.grad, wrt=nnx.BatchStat) + def f(m: nnx.Dict): + # sum all params + return m.a[0] + m.a[1] + m.b + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads["a/1"].value == 1.0 + assert isinstance(grads["a/1"], nnx.BatchStat) + assert len(grads) == 1 + + m.update_state(grads) + + assert m.a[0] == 10.0 + assert m.a[1] == 1.0 + assert m.b == 10.0 + assert m.c == 7 + assert m.d == 5.0 + + +class TestScan: + + def test_basic(self): + class Block(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + self.node = jnp.ones((2,)) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + jax.debug.print("x={x}", x=x) + x = self.linear(x) + x = nnx.gelu(x) + return x, None + + MLP = nnx.Scan( + Block, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 + ) + + module = MLP(ctx=nnx.context(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, out = module.call(x, None) + + assert y.shape == (1, 3) + assert out is None + + def test_complex(self): + class Block(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + self.bn = nnx.BatchNorm(3, ctx=ctx) + self.dropout = nnx.Dropout(0.5) + self.node = jnp.ones((2,)) + + def __call__( + self, x: jax.Array, _, *, ctx: nnx.Context + ) -> tp.Tuple[jax.Array, None]: + jax.debug.print("x={x}", x=x) + x = self.linear(x) + x = self.bn(x, ctx=ctx) + x = self.dropout(x, ctx=ctx) + x = nnx.gelu(x) + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + # variable_carry="batch_stats", + split_rngs=["params", "dropout"], + length=5, + ) + + module = MLP(ctx=nnx.context(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + ctx = nnx.context( + dropout=1, flags=dict(deterministic=False, use_running_average=False) + ) + y, out = module.call(x, None, ctx=ctx) + + assert y.shape == (1, 3) + assert out is None + + def test_complex_decorator(self): + scan_over_layers = partial( + nnx.scan, + variable_axes={nnx.Param: 0}, + split_rngs=["params", "dropout"], + length=5, + ) + + class Block(nnx.Module): + + @scan_over_layers + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + self.bn = nnx.BatchNorm(3, ctx=ctx) + self.dropout = nnx.Dropout(0.5) + self.node = jnp.ones((2,)) + + @scan_over_layers + def __call__( + self, x: jax.Array, _, *, ctx: nnx.Context + ) -> tp.Tuple[jax.Array, None]: + jax.debug.print("x={x}", x=x) + x = self.linear(x) + x = self.bn(x, ctx=ctx) + x = self.dropout(x, ctx=ctx) + x = nnx.gelu(x) + return x, None + + module = Block(ctx=nnx.context(0)) + + assert module.linear.kernel.shape == (5, 3, 3) + assert module.linear.bias.shape == (5, 3) + assert module.node.shape == (2,) + + x = jnp.ones((1, 3)) + ctx = nnx.context( + dropout=1, flags=dict(deterministic=False, use_running_average=False) + ) + y, out = module(x, None, ctx=ctx) + + assert y.shape == (1, 3) + assert out is None + + def test_scan_with_sharding(self): + class Block(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear( + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), + sharding=("din", "dout"), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros(), + sharding=("dout",), + ), + ctx=ctx, + ) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + + # test sharding layer axes is not present inside scan + state = self.linear.get_state() + assert state["kernel"].value.shape == (3, 3) + assert state["kernel"].sharding == ("din", "dout") + assert state["bias"].value.shape == (3,) + assert state["bias"].sharding == ("dout",) + + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + split_rngs=["params"], + length=5, + metadata_params={nnx.PARTITION_NAME: "layers"}, + ) + + m = MLP(ctx=nnx.context(0)) + + # test sharding layers axes is set + state = m.get_state() + assert state["scan_module/linear/kernel"].value.shape == (5, 3, 3) + assert state["scan_module/linear/kernel"].sharding == ( + "layers", + "din", + "dout", + ) + assert state["scan_module/linear/bias"].value.shape == (5, 3) + assert state["scan_module/linear/bias"].sharding == ("layers", "dout") + + x = jnp.ones((1, 3)) + y, out = m.call(x, None) + + # test sharding axes is preserved + state = m.get_state() + assert state["scan_module/linear/kernel"].value.shape == (5, 3, 3) + assert state["scan_module/linear/kernel"].sharding == ( + "layers", + "din", + "dout", + ) + assert state["scan_module/linear/bias"].value.shape == (5, 3) + assert state["scan_module/linear/bias"].sharding == ("layers", "dout") + + +class TestRemat: + + def test_basic_remat(self): + RematLinear = nnx.Remat(nnx.Linear) + + module = RematLinear(2, 3, ctx=nnx.context(0)) + + y = module.call(jnp.ones((1, 2))) + + assert y.shape == (1, 3) + + def test_remat_decorator(self): + class RematLinear(nnx.Module): + + @nnx.remat + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.linear = nnx.Linear(din, dout, ctx=ctx) + + @nnx.remat + def __call__(self, x: jax.Array) -> jax.Array: + return self.linear(x) + + module = RematLinear(2, 3, ctx=nnx.context(0)) + + y = module(jnp.ones((1, 2))) + + assert y.shape == (1, 3) + + def test_remat_with_scan(self): + class LinearBlock(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + return x, None + + RematLinear = nnx.Remat(LinearBlock) + + ScanRematLinear = nnx.Scan( + RematLinear, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 + ) + + m = ScanRematLinear(ctx=nnx.context(0)) + + assert m.scan_module.remat_module.linear.kernel.shape == (5, 3, 3) + assert m.scan_module.remat_module.linear.bias.shape == (5, 3) + + y, _ = m.call.call(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) + + y, _ = m(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) + + def test_remat_with_scan_decorator(self): + scan = partial( + nnx.scan, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 + ) + + class ScanLinear(nnx.Module): + + @scan + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + + @scan + @nnx.remat + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + return x, None + + m = ScanLinear(ctx=nnx.context(0)) + + assert m.linear.kernel.shape == (5, 3, 3) + assert m.linear.bias.shape == (5, 3) + + y, _ = m(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) diff --git a/flax/experimental/nnx/tests/test_variable.py b/flax/experimental/nnx/tests/test_variable.py new file mode 100644 index 0000000000..ce84027beb --- /dev/null +++ b/flax/experimental/nnx/tests/test_variable.py @@ -0,0 +1,35 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import pytest + +from flax.experimental import nnx + +A = tp.TypeVar("A") + + +class TestVariable: + + def test_value(self): + r1 = nnx.Node(1) + assert r1.value == 1 + + r2 = jax.tree_map(lambda x: x + 1, r1) + + assert r1.value == 1 + assert r2.value == 2 + assert r1 is not r2 diff --git a/pyproject.toml b/pyproject.toml index e8fc9051b3..a0c3cb7b0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,9 @@ disable = "pyi-error" show_error_codes = true no_implicit_optional = true disable_error_code = "attr-defined" +exclude = [ + "flax/experimental/nnx", +] [[tool.mypy.overrides]] module = [ diff --git a/tests/nnx/__init__.py b/tests/nnx/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/nnx/test_containers.py b/tests/nnx/test_containers.py new file mode 100644 index 0000000000..bef884cb47 --- /dev/null +++ b/tests/nnx/test_containers.py @@ -0,0 +1,61 @@ +import jax +import numpy as np +import pytest + +from flax.experimental import nnx + + +class TestContainers: + + def test_node_idenpotence(self): + x = nnx.Node(1) + x = nnx.Node(x) + + assert isinstance(x, nnx.Node) + + def test_variable_idenpotence(self): + x = nnx.Variable(1) + x = nnx.Variable(x) + + assert isinstance(x, nnx.Variable) + assert x.value == 1 + + def test_variable_cannot_change_collection(self): + x = nnx.Param(1) + + with pytest.raises(ValueError, match="is not compatible with return type"): + x = nnx.BatchStat(x) + + def test_container_cannot_change_type(self): + x = nnx.Variable(1) + + with pytest.raises(ValueError, match="is not compatible with return type"): + x = nnx.Node(x) + + x = nnx.Node(2) + + with pytest.raises(ValueError, match="is not compatible with return type"): + x = nnx.Variable(x) + + def test_static_is_empty(self): + leaves = jax.tree_util.tree_leaves(nnx.Static(1)) + + assert len(leaves) == 0 + + def test_static_empty_pytree(self): + static = nnx.Static(2) + + static = jax.tree_map(lambda x: x + 1, static) + + assert static.value == 2 + + def test_static_array_not_jitable(self): + @jax.jit + def f(x): + return x + + # first time you don't get an error due to a bug in jax + f(nnx.Static(np.random.uniform(size=(10, 10)))) + + with pytest.raises(ValueError): + f(nnx.Static(np.random.uniform(size=(10, 10)))) diff --git a/tests/nnx/test_context.py b/tests/nnx/test_context.py new file mode 100644 index 0000000000..92986b28f3 --- /dev/null +++ b/tests/nnx/test_context.py @@ -0,0 +1,100 @@ +from typing import Any + +import jax +import numpy as np +import pytest + +from flax.experimental import nnx +from flax.experimental.nnx.nnx.contextlib import _stable_hash + + +class TestContext: + + def test_hash(self): + _hash = _stable_hash("hi") + assert isinstance(_hash, int) + + def test_rng_stream(self): + key0 = jax.random.PRNGKey(0) + ctx = nnx.context(key0) + assert ctx._rngs["params"].count == 0 + + key1 = ctx.make_rng("params") + assert ctx._rngs["params"].count == 1 + assert ctx._rngs["params"].key is key0 + assert not np.equal(key0, key1).all() + + key2 = ctx.make_rng("params") + assert ctx._rngs["params"].count == 2 + assert ctx._rngs["params"].key is key0 + assert not np.equal(key1, key2).all() + + def test_rng_fork(self): + key0 = jax.random.PRNGKey(0) + ctx1 = nnx.context(key0) + ctx2 = ctx1.partition().merge() + + assert ctx2._rngs["params"].count == 0 + assert ctx2._rngs["params"].count_path == (0,) + + key1 = ctx1.make_rng("params") + key2 = ctx2.make_rng("params") + + assert not np.equal(key1, key2).all() + + def test_rng_trace_level_constraints(self): + ctx = nnx.context(0) + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match="Cannot use Context from a different trace level", + ): + ctx.make_rng("params") + + f() + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match="Cannot use Context from a different trace level", + ): + ctx.partition() + + f() + + ctx1: Any = None + + @jax.jit + def g(): + nonlocal ctx1 + ctx1 = nnx.context(1) + + g() + + assert isinstance(ctx1, nnx.Context) + with pytest.raises( + nnx.TraceContextError, + match="Cannot use Context from a different trace level", + ): + ctx1.make_rng("params") + + def test_partition_merge(self): + ctx = nnx.context(dropout=0) + + keys, ctxdef = ctx.partition() + + assert "dropout" in keys + assert ctxdef._rng_counts == (("dropout", (0,)),) + + ctx2 = ctxdef.merge(keys) + + key1 = ctx.make_rng("dropout") + key2 = ctx2.make_rng("dropout") + assert not np.equal(key1, key2).all() + + ctx3 = ctxdef.merge(keys) + key3 = ctx3.make_rng("dropout") + assert np.equal(key2, key3).all() diff --git a/tests/nnx/test_helpers.py b/tests/nnx/test_helpers.py new file mode 100644 index 0000000000..1b96fd7ac8 --- /dev/null +++ b/tests/nnx/test_helpers.py @@ -0,0 +1,65 @@ +import jax +import jax.numpy as jnp +import optax + +from flax.experimental import nnx + + +class TestHelpers: + + def test_train_state(self): + m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) + + (params, batch_stats), moduledef = m.partition(nnx.Param, nnx.BatchStat) + + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(1.0), + batch_stats=batch_stats, + other=nnx.Node(100), + int=200, + static=nnx.Static(300), + ) + + leaves = jax.tree_util.tree_leaves(state) + + assert 1 in leaves + assert 2 in leaves + assert 100 in leaves + assert 200 not in leaves + assert 300 not in leaves + + def test_train_state_methods(self): + class Foo(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(2, 4, ctx=ctx) + self.batch_norm = nnx.BatchNorm(4, ctx=ctx) + + def __call__(self, x: jax.Array, train: bool) -> jax.Array: + x = self.linear(x) + x = self.batch_norm(x, use_running_average=not train) + return x + + module = Foo(ctx=nnx.context(0)) + (params, batch_stats), moduledef = module.partition( + nnx.Param, nnx.BatchStat + ) + + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(1.0), + batch_stats=batch_stats, + ) + + x = jax.numpy.ones((1, 2)) + y, _updates = state.apply("params", "batch_stats")(x, train=True) + + assert y.shape == (1, 4) + + # fake gradient + grads = jax.tree_map(jnp.ones_like, state.params) + # test apply_gradients + state = state.apply_gradients(grads) diff --git a/tests/nnx/test_ids.py b/tests/nnx/test_ids.py new file mode 100644 index 0000000000..977f24275c --- /dev/null +++ b/tests/nnx/test_ids.py @@ -0,0 +1,17 @@ +import copy + +from flax.experimental.nnx.nnx import ids + + +class TestIds: + + def test_hashable(self): + id1 = ids.uuid() + id2 = ids.uuid() + assert id1 == id1 + assert id1 != id2 + assert hash(id1) != hash(id2) + id1c = copy.copy(id1) + id1dc = copy.deepcopy(id1) + assert hash(id1) != hash(id1c) + assert hash(id1) != hash(id1dc) diff --git a/tests/nnx/test_integration.py b/tests/nnx/test_integration.py new file mode 100644 index 0000000000..e763d666e4 --- /dev/null +++ b/tests/nnx/test_integration.py @@ -0,0 +1,244 @@ +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np + +from flax.experimental import nnx + +A = tp.TypeVar("A") + + +class TestIntegration: + + def test_shared_modules(self): + class Block(nnx.Module): + + def __init__(self, linear: nnx.Linear, *, ctx): + self.linear = linear + self.bn = nnx.BatchNorm(2, ctx=ctx) + + def __call__(self, x, *, ctx): + x = self.linear(x) + x = self.bn(x, ctx=ctx) + return nnx.relu(x) + + class Model(nnx.Module): + + def __init__(self, *, ctx): + shared = nnx.Linear(2, 2, ctx=ctx) + self.block1 = Block(shared, ctx=ctx) + self.block2 = Block(shared, ctx=ctx) + + def __call__(self, x, *, ctx): + x = self.block1(x, ctx=ctx) + x = self.block2(x, ctx=ctx) + return x + + @nnx.jit + def train_step(model: Model, x, y): + @nnx.grad + def loss_fn(model: Model): + ctx = nnx.context(flags=dict(use_running_average=False)) + y_pred = model(x, ctx=ctx) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(model) + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + + model = Model(ctx=nnx.context(0)) + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + + for _i in range(3): + train_step(model, x, y) + + assert model.block1.linear is model.block2.linear + assert model.block1.linear.bias is not None + assert model.block1.bn is not model.block2.bn + + def test_shared_modules_pure(self): + class Block(nnx.Module): + + def __init__(self, linear: nnx.Linear, *, ctx: nnx.Context): + self.linear = linear + self.bn = nnx.BatchNorm(2, ctx=ctx) + + def __call__(self, x, *, ctx: nnx.Context): + x = self.linear(x) + x = self.bn(x, ctx=ctx) + return nnx.relu(x) + + class Model(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + shared = nnx.Linear(2, 2, ctx=ctx) + self.block1 = Block(shared, ctx=ctx) + self.block2 = Block(shared, ctx=ctx) + + def __call__(self, x, *, ctx: nnx.Context): + x = self.block1(x, ctx=ctx) + x = self.block2(x, ctx=ctx) + return x + + @jax.jit + def train_step(pure_module: nnx.PureModule[Model], x, y): + model = pure_module.merge() + + @nnx.grad + def loss_fn(model: Model): + ctx = nnx.context(flags=dict(use_running_average=False)) + y_pred = model(x, ctx=ctx) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(model) + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + + return model.partition() + + pure_module = Model(ctx=nnx.context(0)).partition() + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + + for _i in range(3): + pure_module = train_step(pure_module, x, y) + + model = pure_module.merge() + + assert model.block1.linear.bias is not None + assert model.block2.linear.bias is not None + assert model.block1.linear.kernel is model.block2.linear.kernel + assert model.block1.linear.bias is model.block2.linear.bias + assert model.block1.bn is not model.block2.bn + + def test_stateful_example(self): + class State(nnx.Variable[A]): + pass + + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = State(0) + + def __call__(self, x): + self.count += 1 + return x @ self.w + self.b[None] + + model = Linear(din=12, dout=2, ctx=nnx.context(0)) + # forward pass + x = jnp.ones((8, 12)) + y = model(x) + assert model.count == 1 + + @nnx.jit + def train_step(model, x, y): + def loss_fn(model): + y_pred = model(x) + return jax.numpy.mean((y_pred - y) ** 2) + + # compute gradient + grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # SGD update + model.update_state( + jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) + ) + + # execute the training step + train_step(model, x, y) + assert model.count == 2 + + def test_functional_example(self): + class Count(nnx.Variable[A]): + pass + + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) + + def __call__(self, x): + self.count += 1 + return x @ self.w + self.b[None] + + model = Linear(din=12, dout=2, ctx=nnx.context(0)) + # forward pass + x = jnp.ones((8, 12)) + y = model(x) + assert model.count == 1 + + (params, counts), moduledef = model.partition(nnx.Param, Count) + + @jax.jit + def train_step(params, counts, x, y): + def loss_fn(params): + y_pred, (updates, _) = moduledef.apply(params, counts)(x) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, updates.filter(Count) + + # compute gradient + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + # SGD update + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + + return params, counts + + # execute the training step + params, counts = train_step(params, counts, x, y) + model = moduledef.merge(params, counts) + assert model.count == 2 + + def test_intermediates_example(self): + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b[None] + self.y = nnx.Intermediate(y) + return y + + model = Linear(12, 2, ctx=nnx.context(0)) + + y = model(jnp.ones((8, 12))) + + intermediates = model.pop_state(nnx.Intermediate) + + assert "y" in intermediates + + def test_intermediates_example_functional(self): + class Linear(nnx.Module): + + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b[None] + self.y = nnx.Intermediate(y) + return y + + model = Linear(12, 2, ctx=nnx.context(0)) + + state, moduledef = model.partition() + + y, (state, _) = moduledef.apply(state)(jnp.ones((8, 12))) + + intermediates, state = state.partition(nnx.Intermediate, ...) + + assert "y" in intermediates diff --git a/tests/nnx/test_module.py b/tests/nnx/test_module.py new file mode 100644 index 0000000000..e73a393e1e --- /dev/null +++ b/tests/nnx/test_module.py @@ -0,0 +1,573 @@ +from typing import Any + +import jax +import jax.numpy as jnp +import pytest + +from flax.experimental import nnx + + +class TestModule: + + def test_has_module_state(self): + class Foo(nnx.Module): + ... + + foo = Foo() + + assert hasattr(foo, "_module__state") + + def test_trace_level(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match="Cannot mutate Module from different trace level", + ): + m.a = 2 + + f() + + def test_split_merge(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(pure_module: nnx.PureModule[nnx.Dict[int]]): + m = pure_module.merge() + m.a = 2 + return m.partition() + + m2 = g(m.partition()).merge() + + assert m2.a == 2 + + def test_no_trace_level_error_on_grad(self): + # No trace level error occurs because jax doesn't update + # its top trace for grad. + m = nnx.Dict(a=nnx.Param(1.0)) + + @jax.grad + def f(_): + m.a = 2.0 + return 1.0 + + f(1.0) + + def test_trace_level_error_on_nnx_grad(self): + # error occurs because nnx updates its nnx_trace + # in nnx.grad. + m = nnx.Dict(a=nnx.Param(1.0)) + + @nnx.grad + def f(_): + with pytest.raises( + nnx.TraceContextError, + match="Cannot mutate Module from different trace level", + ): + m.a = 2.0 + return 1.0 + + f(m) + + def test_call(self): + class Foo(nnx.Module): + + def __init__(self, c: float, *, ctx: nnx.Context): + key = ctx.make_rng("params") + self.w = nnx.Param(jax.random.uniform(key, ())) + self.c = jnp.asarray(c) + + def __call__(self, x, *, ctx: nnx.Context): + key = ctx.make_rng("e") + return self.w * x + jax.random.normal(key, ()) + self.c + + foo = Foo(c=1.0, ctx=nnx.context(0)) + + y = foo(x=2.0, ctx=nnx.context(e=1)) + + assert isinstance(y, jax.Array) + + def test_shared_module(self): + m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) + m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) + + m3 = m2.partition().merge() + + assert m3["x"] is m3["y"] + assert m3["x"]["a"] is m3["y"]["a"] + assert m3["x"]["b"] is m3["y"]["b"] + + def test_module_graph(self): + class Foo(nnx.Module): + + def __init__(self): + self.a = nnx.Param(1) + self.sub = self + + m = Foo() + + state, moduledef = m.partition() + assert len(state) == 1 + + m2 = moduledef.merge(state) + assert m2 is m2.sub + + def test_deref_through_jit(self): + r1 = nnx.Node(1) + r2 = nnx.Node(2) + + m = m0 = nnx.Dict({"a": nnx.Sequence([r1, r2]), "b": r1}) + + @jax.jit + def f(pure_module: nnx.PureModule[nnx.Dict[Any]]): + m = pure_module.merge() + + assert m["a"][0] is not m["b"] + assert m["a"][1] is not m["b"] + + return m.partition() + + m = f(m.partition()).merge() + + assert m["a"][0] is not m["b"] + assert m["a"][1] is not m["b"] + + # compare with pytree0 + assert m["a"][0] is not m0["a"][0] + assert m["a"][1] is not m0["a"][1] + assert m["b"] is not m0["b"] + + def test_cross_barrier(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(pure_module: nnx.PureModule[nnx.Dict[int]]): + m = pure_module.merge() + m.a += 1 + return m.partition() + + m2 = g(m.partition()).merge() + assert m2 is not m + assert m.a == 1 + assert m2.a == 2 + + def test_no_rejit(self): + n = 0 + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(pure_module): + nonlocal n + n += 1 + m = pure_module.merge() + m.a += 1 + return m.partition() + + m2 = g(m.partition()).merge() + + assert n == 1 + assert m2 is not m + assert m.a == 1 + assert m2.a == 2 + + g(m.partition()) + assert n == 1 + + g(m2.partition()) + assert n == 1 + + m2.b = nnx.Param(10) + g(m2.partition()) + + assert n == 2 + + def test_deref_number_of_fields(self): + r1 = nnx.Node(1) + r2 = nnx.Node(2) + v1 = 3 + m = nnx.Dict({ + "a": nnx.Sequence([r1, r2, v1]), + "b": nnx.Dict({"c": r1, "d": r2}), + }) + + p, moduledef = m.partition() + assert len(p) == 4 + assert len(jax.tree_util.tree_leaves(p)) == 4 + + def test_deref_arrays_are_nodes(self): + # test arrays are nodes + r1 = nnx.Node(1) + r2 = nnx.Node(2) + v1 = jax.numpy.array(3) + m = nnx.Dict({ + "a": nnx.Sequence([r1, r2, v1]), + "b": nnx.Dict({"c": r1, "d": r2}), + }) + + p, moduledef = m.partition() + assert len(p) == 5 + assert len(jax.tree_util.tree_leaves(p)) == 5 + + def test_clone(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), 3]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), + ) + + m2 = m.clone() + + assert m is not m2 + assert m2.a[0] == m2.b.c + assert m2.a[1] == m2.b.d + + assert m.a[0] == m2.a[0] + assert m.a[1] == m2.a[1] + assert m.b.c == m2.b.c + assert m.b.d == m2.b.d + + def test_sow_basic(self): + class Foo(nnx.Module): + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + y1 = m(2) + y2 = m(10) + + assert y1 == 3 + assert y2 == 11 + assert m.y == (3, 11) + + intermediates = m.pop_state(nnx.Intermediate) + + assert isinstance(intermediates["y"], nnx.Intermediate) + assert intermediates["y"].value == (3, 11) + + assert not hasattr(m, "y") + + def test_sow_existing_non_variable_field(self): + class Foo(nnx.Module): + + def __init__(self) -> None: + self.y = 10 + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + + with pytest.raises(ValueError, match="to be a Variable, got"): + m(2) + + def test_sow_wrong_collection(self): + class Foo(nnx.Module): + + def __init__(self) -> None: + self.y = nnx.Param(10) + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + + with pytest.raises(ValueError, match="to be of type"): + m(2) + + def test_sow_non_tuple(self): + class Foo(nnx.Module): + + def __init__(self) -> None: + self.y = nnx.Intermediate(10) + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, "y", y) + return y + + m = Foo() + + with pytest.raises(ValueError, match="to be a tuple,"): + m(2) + + +class TestModuleDataclass: + + def test_basic(self): + @nnx.dataclass + class Foo(nnx.Module): + a: int = nnx.static_field() + b: int = nnx.node_field() + c: int = nnx.param_field() + d: int = nnx.var_field(nnx.BatchStat) + e: int + f: int + + m = Foo( + a=1, # static + b=2, # node + c=3, # param + d=4, # var + e=5, # static int + f=nnx.Node(6), # test that we can pass in a node + ) + + state, moduledef = m.partition() + + assert len(state) == 4 + assert state["b"] == nnx.Node(2) + assert state["c"] == nnx.Param(3) + assert state["d"] == nnx.BatchStat(4) + assert state["f"] == nnx.Node(6) + + def test_no_override(self): + @nnx.dataclass + class Foo(nnx.Module): + a: int = nnx.node_field() + + with pytest.raises(ValueError, match="is not compatible with return type"): + _m = Foo(a=nnx.Param(1)) + + _m = Foo(a=nnx.Node(1)) + + +class TestModuleDef: + + def test_apply(self): + class Foo(nnx.Module): + + def __init__(self, c: float, *, ctx: nnx.Context): + self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), ())) + self.c = jnp.asarray(c) + + def __call__(self, x, *, ctx: nnx.Context): + key = ctx.make_rng("e") + return self.w * x + jax.random.normal(key, ()) + self.c + + ctx = nnx.context(0) + foo = Foo(c=1.0, ctx=ctx) + + states, moduledef = foo.partition() + + assert isinstance(states, nnx.State) + assert isinstance(states["w"], nnx.Param) + assert isinstance(states["c"], jax.Array) + + y, _updates = moduledef.apply(states)(x=2.0, ctx=nnx.context(e=1)) + + assert isinstance(y, jax.Array) + + def test_derefed_mod_apply(self): + class Foo(nnx.Module): + + def __init__(self, c: float, *, ctx: nnx.Context): + self.w = nnx.Param( + jax.random.uniform(ctx.make_rng("params"), ()), + ) + self.c = jnp.asarray(c) + + def __call__(self, x, *, ctx: nnx.Context): + key = ctx.make_rng("e") + return self.w * x + jax.random.normal(key, ()) + self.c + + foo = Foo(c=1.0, ctx=nnx.context(0)) + + pure_module = foo.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + assert isinstance(pure_module.states["w"], nnx.Param) + assert isinstance(pure_module.states["c"], jax.Array) + + y, states = pure_module.apply(x=2.0, ctx=nnx.context(e=1)) + + assert isinstance(y, jax.Array) + + +class TestPureModule: + + def test_partition_merge(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = state, moduledef = m.partition() + + m2 = pure_module.merge() + + assert isinstance(state, nnx.State) + assert isinstance(moduledef, nnx.ModuleDef) + assert isinstance(m2, nnx.Dict) + assert isinstance(m2.a, nnx.Sequence) + assert isinstance(m2.b, nnx.Dict) + assert len(m.get_state()) == 5 + assert len(m2.get_state()) == 5 + + def test_partition_merge_with_filters(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = (params, batch_stats, rest), moduledef = m.partition( + nnx.Param, nnx.BatchStat, ... + ) + + m2 = pure_module.merge() + + assert isinstance(params, nnx.State) + assert isinstance(batch_stats, nnx.State) + assert isinstance(rest, nnx.State) + assert isinstance(moduledef, nnx.ModuleDef) + assert isinstance(m2, nnx.Dict) + assert isinstance(m2.a, nnx.Sequence) + assert isinstance(m2.b, nnx.Dict) + assert len(m.get_state()) == 5 + assert len(m2.get_state()) == 5 + + def test_filter(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition() + + params = pure_module.filter(nnx.Param) + batch_stats = pure_module.filter(nnx.BatchStat) + rest = pure_module.filter(nnx.Not(nnx.Variable)) + + assert len(params) == 3 + assert len(batch_stats) == 1 + assert len(rest) == 1 + + def test_filter_with_filters(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition(nnx.Param, ...) + + params = pure_module.filter(nnx.Param) + batch_stats = pure_module.filter(nnx.BatchStat) + rest = pure_module.filter(nnx.Not(nnx.Variable)) + + assert len(params) == 3 + assert len(batch_stats) == 1 + assert len(rest) == 1 + + def test_partition_partition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + + pure_module = pure_module.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + + def test_partition_with_filters_partition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition(nnx.Param, ...) + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, tuple) + + pure_module = pure_module.partition() + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, nnx.State) + + def test_partition_with_filters_partition_with_filters(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition(nnx.Param, ...) + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, tuple) + + pure_module = pure_module.partition(nnx.BatchStat, ...) + + assert isinstance(pure_module, nnx.Pure) + assert isinstance(pure_module.states, tuple) + + def test_pop(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + pure_module = m.partition() + + params, pure_module2 = pure_module.pop_state(nnx.Param) + + assert isinstance(params, nnx.State) + assert isinstance(pure_module2, nnx.Pure) + assert isinstance(pure_module2.states, nnx.State) + assert len(params) == 3 + assert len(pure_module2.states) == 2 + + (params, batch_stats), pure_module2 = pure_module.pop_state( + nnx.Param, nnx.BatchStat + ) + + assert isinstance(params, nnx.State) + assert isinstance(batch_stats, nnx.State) + assert isinstance(pure_module2, nnx.Pure) + assert isinstance(pure_module2.states, nnx.State) + assert len(params) == 3 + assert len(batch_stats) == 1 + assert len(pure_module2.states) == 1 + + def test_on_all(self): + class Bar(nnx.Module): + + def __init__(self): + self.a = nnx.Param(1) + + class Foo(nnx.Module): + + def __init__(self, bar): + self.bar1 = bar + self.bar2 = bar + self.b = nnx.Param(2) + + foo = Foo(Bar()) + + def f(bar: Bar): + bar.a += 1 + + foo.for_each(Bar, f) + + assert foo.bar1.a == 2 + assert foo.bar2.a == 2 + + def g(foo: Foo): + foo.b += 1 + + foo.for_each(Foo, g) + + assert foo.b == 3 diff --git a/tests/nnx/test_partitioning.py b/tests/nnx/test_partitioning.py new file mode 100644 index 0000000000..f9bb552c50 --- /dev/null +++ b/tests/nnx/test_partitioning.py @@ -0,0 +1,148 @@ +import typing as tp + +import jax +import numpy as np +import pytest + +from flax.experimental import nnx + + +class TestPartitioning: + + def test_partition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(2)]), + b=nnx.Param(2), + c=100, + ) + + (params, rest), moduledef = m.partition(nnx.Param, ...) + + assert len(params) == 2 + assert len(rest) == 1 + + # check params + assert params["a/0"].value == m.a[0] + assert params["b"].value == m.b + + # check rest + assert rest["a/1"].value == m.a[1] + + m2 = moduledef.merge(params, rest) + + assert m2.a[0] == m.a[0] + assert m2.a[1] == m.a[1] + assert m2.b == m.b + assert m2.c == 100 + + def test_complete_partitioning(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + # no error + m.partition(nnx.Param, nnx.BatchStat, nnx.Node) + + def test_complete_partitioning_plus_ellipsis(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + # no error if additional ... is passed at the end + m.partition(nnx.Param, nnx.BatchStat, nnx.Node, ...) + + def test_inclomplete_partition_error(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + with pytest.raises( + ValueError, match="Non-exhaustive filters, got a non-empty remainder" + ): + m.partition(nnx.Param) + + def test_ellipsis_not_last_error(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + with pytest.raises( + ValueError, match="Ellipsis `...` can only be used as the last filter," + ): + m.partition(..., nnx.Param) + + def test_update_from(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), + b=nnx.Param(2), + c=100, + ) + + state = m.partition()[0] + state = jax.tree_map(lambda x: x * 2, state) + + m.update_state(state) + + assert m.a[0] == 2 + assert m.a[1] == 6 + assert m.b == 4 + assert m.c == 100 + + def test_update_from_with_array_leaf(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), + b=nnx.Param(2), + c=jax.numpy.array(100), + ) + + pure_module: nnx.Pure = m.partition() + pure_module = jax.tree_map(lambda x: x * 2, pure_module) + + m.update_state(pure_module.states) + + assert m.a[0] == 2 + assert m.a[1] == 6 + assert m.b == 4 + assert m.c == 200 + + def test_grad_example(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1.0), nnx.BatchStat(-10)]), + b=nnx.Param(2.0), + c=100, + ) + + params = m.filter(nnx.Param) + + def loss(params): + return sum(2 * p for p in jax.tree_util.tree_leaves(params)) + + grads = jax.grad(loss)(params) + m.update_state(grads) + + assert m.a[0] == 2.0 + assert m.a[1] == -10 + assert m.b == 2.0 + assert m.c == 100 + + def test_get_paritition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.Param(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + # test Variables not shared + assert vars(m.a)["0"] is not vars(m)["b"] + + state = m.filter(nnx.Node) + assert state["a/0"].value == m.a[0] + assert state["a/1"].value == m.a[1] + assert state["b"].value == m.b + assert state["b"] is not state["a/0"] + assert len(state) == 3 diff --git a/tests/nnx/test_pytree.py b/tests/nnx/test_pytree.py new file mode 100644 index 0000000000..f09a0e1f28 --- /dev/null +++ b/tests/nnx/test_pytree.py @@ -0,0 +1,259 @@ +from typing import Generic, TypeVar + +import jax +import pytest +from flax import serialization + +from flax.experimental import nnx + + +class TestPytree: + + def test_immutable_pytree(self): + class Foo(nnx.Pytree): + + def __init__(self, y) -> None: + self.x = 2 + self.y = nnx.Node(y) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises( + AttributeError, match="is immutable, trying to update field" + ): + pytree.x = 4 + + def test_immutable_pytree_dataclass(self): + @nnx.dataclass(frozen=True) + class Foo(nnx.Pytree): + y: int = nnx.node_field() + x: int = nnx.field(default=2) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises(AttributeError, match="cannot assign to field"): + pytree.x = 4 + + def test_jit(self): + @nnx.dataclass + class Foo(nnx.Pytree): + a: int = nnx.node_field() + b: int = nnx.field() + + module = Foo(a=1, b=2) + + @jax.jit + def f(m: Foo): + return m.a + m.b + + assert f(module) == 3 + + def test_flax_serialization(self): + class Bar(nnx.Pytree): + + def __init__(self, a, b): + self.a = a + self.b = nnx.Node(b) + + @nnx.dataclass + class Foo(nnx.Pytree): + bar: Bar + c: int = nnx.node_field() + d: int = nnx.field() + + foo: Foo = Foo(bar=Bar(a=1, b=2), c=3, d=4) + + state_dict = serialization.to_state_dict(foo) + + assert state_dict == { + "bar": { + "b": 2, + }, + "c": 3, + } + + state_dict["bar"]["b"] = 5 + + foo = serialization.from_state_dict(foo, state_dict) + + assert foo.bar.b == 5 + + del state_dict["bar"]["b"] + + with pytest.raises(ValueError, match="Missing field"): + serialization.from_state_dict(foo, state_dict) + + state_dict["bar"]["b"] = 5 + + # add unknown field + state_dict["x"] = 6 + + with pytest.raises(ValueError, match="Unknown field"): + serialization.from_state_dict(foo, state_dict) + + def test_generics(self): + T = TypeVar("T") + + class MyClass(nnx.Pytree, Generic[T]): + + def __init__(self, x: T): + self.x = x + + MyClass[int] + + def test_key_paths(self): + @nnx.dataclass + class Bar(nnx.Pytree): + a: int = nnx.node_field(default=1) + b: int = nnx.field(default=2) + + @nnx.dataclass + class Foo(nnx.Pytree): + x: int = nnx.node_field(default=3) + y: int = nnx.field(default=4) + z: Bar = nnx.node_field(default_factory=Bar) + + foo = Foo() + + path_values, treedef = jax.tree_util.tree_flatten_with_path(foo) + path_values = [(list(map(str, path)), value) for path, value in path_values] + + assert path_values[0] == ([".x", ".value"], 3) + assert path_values[1] == ([".z", ".value", ".a", ".value"], 1) + + def test_replace_unknown_fields_error(self): + class Foo(nnx.Pytree): + pass + + with pytest.raises(ValueError, match="Trying to replace unknown fields"): + Foo().replace(y=1) + + def test_dataclass_inheritance(self): + @nnx.dataclass + class A(nnx.Pytree): + a: int = nnx.node_field(default=1) + b: int = nnx.field(default=2) + + @nnx.dataclass + class B(A): + c: int = nnx.node_field(default=3) + + pytree = B() + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [1, 3] + + def test_pytree_with_new(self): + class A(nnx.Pytree): + + def __init__(self, a): + self.a = a + + def __new__(cls, a): + return super().__new__(cls) + + pytree = A(a=1) + + pytree = jax.tree_map(lambda x: x * 2, pytree) + + def test_deterministic_order(self): + class A(nnx.Pytree): + + def __init__(self, order: bool): + if order: + self.a = 1 + self.b = 2 + else: + self.b = 2 + self.a = 1 + + p1 = A(order=True) + p2 = A(order=False) + + leaves1 = jax.tree_util.tree_leaves(p1) + leaves2 = jax.tree_util.tree_leaves(p2) + + assert leaves1 == leaves2 + + +class TestMutablePytree: + + def test_pytree(self): + class Foo(nnx.Pytree, mutable=True): + + def __init__(self, y) -> None: + self.x = 2 + self.y = nnx.Node(y) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 + + def test_no_new_fields_after_init(self): + class Foo(nnx.Pytree, mutable=True): + + def __init__(self, x): + self.x = nnx.Node(x) + + foo = Foo(x=1) + foo.x = 2 + + with pytest.raises(AttributeError, match=r"Cannot add new fields to"): + foo.y = 2 + + def test_pytree_dataclass(self): + @nnx.dataclass + class Foo(nnx.Pytree, mutable=True): + y: int = nnx.node_field() + x: int = nnx.field(default=2) + + pytree: Foo = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 diff --git a/tests/nnx/test_spmd.py b/tests/nnx/test_spmd.py new file mode 100644 index 0000000000..e00fd37119 --- /dev/null +++ b/tests/nnx/test_spmd.py @@ -0,0 +1,75 @@ +import jax +import jax.numpy as jnp +import optax +from jax._src import test_util as jtu +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec + +from flax.experimental import nnx + + +class TestSPMD: + + @jtu.skip_on_devices("cpu", "gpu") + def test_init(self): + class Foo(nnx.Module): + + def __init__(self): + self.w = nnx.Param( + nnx.with_logical_partitioning( + lambda: jnp.ones((8, 2)), + sharding=("row", "col"), + )() + ) + + def __call__(self, x): + return x @ self.w + + @jax.jit + def create_module(): + return Foo().partition() + + mesh = Mesh(mesh_utils.create_device_mesh((2, 2)), ("model", "data")) + + with mesh, nnx.logical_axis_rules([("row", "model"), ("col", "data")]): + m: Foo = create_module().merge() + + assert m.w.shape == (8, 2) + assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) + + def test_get_partition_spec(self): + class Foo(nnx.Module): + + def __init__(self): + self.w = nnx.Param( + nnx.with_logical_partitioning( + lambda: jnp.ones((8, 2)), + sharding=("row", "col"), + )() + ) + + def __call__(self, x): + return x @ self.w + + params, moduledef = Foo().partition() + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.adam(1e-3), + ) + logical_state_spec = nnx.get_partition_spec(state) + + assert logical_state_spec.params["w"] == PartitionSpec("row", "col") + assert logical_state_spec.opt_state[0].mu["w"] == PartitionSpec( + "row", "col" + ) + assert logical_state_spec.opt_state[0].nu["w"] == PartitionSpec( + "row", "col" + ) + + with nnx.logical_axis_rules([("row", "model"), ("col", "data")]): + state_spec = nnx.logical_to_mesh(logical_state_spec) + + assert state_spec.params["w"] == PartitionSpec("model", "data") + assert state_spec.opt_state[0].mu["w"] == PartitionSpec("model", "data") + assert state_spec.opt_state[0].nu["w"] == PartitionSpec("model", "data") diff --git a/tests/nnx/test_transforms.py b/tests/nnx/test_transforms.py new file mode 100644 index 0000000000..65a6a9a4d6 --- /dev/null +++ b/tests/nnx/test_transforms.py @@ -0,0 +1,394 @@ +import typing as tp +from functools import partial + +import jax +import jax.numpy as jnp +import pytest + +from flax.experimental import nnx + + +class TestJIT: + + def test_jit(self): + m = nnx.Dict(a=nnx.Param(1)) + + @nnx.jit + def g(m: nnx.Dict): + m.a = 2 + return 1.0 + + out = g(m) + + assert m.a == 2 + assert out == 1.0 + + def test_jit_stateless(self): + m = nnx.Dict(a=nnx.Param(1)) + + @partial(nnx.jit, stateful=False) + def g(m: nnx.Dict): + m.a = 2 + return 1.0 + + out = g(m) + + assert m.a == 1 + assert out == 1.0 + + +class TestGrad: + + def test_grad(self): + p1 = nnx.Param(10.0) + p2 = nnx.Param(20.0) + + m = nnx.Dict( + a=nnx.Sequence([p1, p2]), + b=p1, + c=7, + d=5.0, + ) + + @nnx.grad + def f(m: nnx.Dict): + # sum all params + return m["a"][0] + m["a"][1] + m["b"] + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads["a/0"].value == 1.0 + assert isinstance(grads["a/0"], nnx.Node) + assert grads["a/1"].value == 1.0 + assert isinstance(grads["a/1"], nnx.Node) + assert grads["b"].value == 1.0 + assert isinstance(grads["b"], nnx.Node) + assert len(grads) == 3 + + m.update_state(grads) + + assert m["a"][0] == 1.0 + assert m["a"][1] == 1.0 + assert m["b"] == 1.0 + assert m["c"] == 7 + assert m["d"] == 5.0 + + def test_grad_with_multiple_ref_types(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + @nnx.grad + def f(m: nnx.Dict): + # sum all params + return m.a[0] + m.a[1] + m.b + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads["a/0"].value == 1.0 + assert isinstance(grads["a/0"], nnx.Param) + assert len(grads) == 2 + + m.update_state(grads) + + assert m.a[0] == 1.0 + assert m.a[1] == 20.0 + assert m.b == 1.0 + assert m.c == 7 + assert m.d == 5.0 + + def test_grad_with_type_predicate(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + @partial(nnx.grad, wrt=nnx.BatchStat) + def f(m: nnx.Dict): + # sum all params + return m.a[0] + m.a[1] + m.b + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads["a/1"].value == 1.0 + assert isinstance(grads["a/1"], nnx.BatchStat) + assert len(grads) == 1 + + m.update_state(grads) + + assert m.a[0] == 10.0 + assert m.a[1] == 1.0 + assert m.b == 10.0 + assert m.c == 7 + assert m.d == 5.0 + + +class TestScan: + + def test_basic(self): + class Block(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + self.node = jnp.ones((2,)) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + jax.debug.print("x={x}", x=x) + x = self.linear(x) + x = nnx.gelu(x) + return x, None + + MLP = nnx.Scan( + Block, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 + ) + + module = MLP(ctx=nnx.context(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, out = module.call(x, None) + + assert y.shape == (1, 3) + assert out is None + + def test_complex(self): + class Block(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + self.bn = nnx.BatchNorm(3, ctx=ctx) + self.dropout = nnx.Dropout(0.5) + self.node = jnp.ones((2,)) + + def __call__( + self, x: jax.Array, _, *, ctx: nnx.Context + ) -> tp.Tuple[jax.Array, None]: + jax.debug.print("x={x}", x=x) + x = self.linear(x) + x = self.bn(x, ctx=ctx) + x = self.dropout(x, ctx=ctx) + x = nnx.gelu(x) + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + # variable_carry="batch_stats", + split_rngs=["params", "dropout"], + length=5, + ) + + module = MLP(ctx=nnx.context(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + ctx = nnx.context( + dropout=1, flags=dict(deterministic=False, use_running_average=False) + ) + y, out = module.call(x, None, ctx=ctx) + + assert y.shape == (1, 3) + assert out is None + + def test_complex_decorator(self): + scan_over_layers = partial( + nnx.scan, + variable_axes={nnx.Param: 0}, + split_rngs=["params", "dropout"], + length=5, + ) + + class Block(nnx.Module): + + @scan_over_layers + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + self.bn = nnx.BatchNorm(3, ctx=ctx) + self.dropout = nnx.Dropout(0.5) + self.node = jnp.ones((2,)) + + @scan_over_layers + def __call__( + self, x: jax.Array, _, *, ctx: nnx.Context + ) -> tp.Tuple[jax.Array, None]: + jax.debug.print("x={x}", x=x) + x = self.linear(x) + x = self.bn(x, ctx=ctx) + x = self.dropout(x, ctx=ctx) + x = nnx.gelu(x) + return x, None + + module = Block(ctx=nnx.context(0)) + + assert module.linear.kernel.shape == (5, 3, 3) + assert module.linear.bias.shape == (5, 3) + assert module.node.shape == (2,) + + x = jnp.ones((1, 3)) + ctx = nnx.context( + dropout=1, flags=dict(deterministic=False, use_running_average=False) + ) + y, out = module(x, None, ctx=ctx) + + assert y.shape == (1, 3) + assert out is None + + def test_scan_with_sharding(self): + class Block(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear( + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), + sharding=("din", "dout"), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros(), + sharding=("dout",), + ), + ctx=ctx, + ) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + + # test sharding layer axes is not present inside scan + state = self.linear.get_state() + assert state["kernel"].value.shape == (3, 3) + assert state["kernel"].sharding == ("din", "dout") + assert state["bias"].value.shape == (3,) + assert state["bias"].sharding == ("dout",) + + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + split_rngs=["params"], + length=5, + metadata_params={nnx.PARTITION_NAME: "layers"}, + ) + + m = MLP(ctx=nnx.context(0)) + + # test sharding layers axes is set + state = m.get_state() + assert state["scan_module/linear/kernel"].value.shape == (5, 3, 3) + assert state["scan_module/linear/kernel"].sharding == ( + "layers", + "din", + "dout", + ) + assert state["scan_module/linear/bias"].value.shape == (5, 3) + assert state["scan_module/linear/bias"].sharding == ("layers", "dout") + + x = jnp.ones((1, 3)) + y, out = m.call(x, None) + + # test sharding axes is preserved + state = m.get_state() + assert state["scan_module/linear/kernel"].value.shape == (5, 3, 3) + assert state["scan_module/linear/kernel"].sharding == ( + "layers", + "din", + "dout", + ) + assert state["scan_module/linear/bias"].value.shape == (5, 3) + assert state["scan_module/linear/bias"].sharding == ("layers", "dout") + + +class TestRemat: + + def test_basic_remat(self): + RematLinear = nnx.Remat(nnx.Linear) + + module = RematLinear(2, 3, ctx=nnx.context(0)) + + y = module.call(jnp.ones((1, 2))) + + assert y.shape == (1, 3) + + def test_remat_decorator(self): + class RematLinear(nnx.Module): + + @nnx.remat + def __init__(self, din: int, dout: int, *, ctx: nnx.Context): + self.linear = nnx.Linear(din, dout, ctx=ctx) + + @nnx.remat + def __call__(self, x: jax.Array) -> jax.Array: + return self.linear(x) + + module = RematLinear(2, 3, ctx=nnx.context(0)) + + y = module(jnp.ones((1, 2))) + + assert y.shape == (1, 3) + + def test_remat_with_scan(self): + class LinearBlock(nnx.Module): + + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + return x, None + + RematLinear = nnx.Remat(LinearBlock) + + ScanRematLinear = nnx.Scan( + RematLinear, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 + ) + + m = ScanRematLinear(ctx=nnx.context(0)) + + assert m.scan_module.remat_module.linear.kernel.shape == (5, 3, 3) + assert m.scan_module.remat_module.linear.bias.shape == (5, 3) + + y, _ = m.call.call(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) + + y, _ = m(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) + + def test_remat_with_scan_decorator(self): + scan = partial( + nnx.scan, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 + ) + + class ScanLinear(nnx.Module): + + @scan + def __init__(self, *, ctx: nnx.Context): + self.linear = nnx.Linear(3, 3, ctx=ctx) + + @scan + @nnx.remat + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + return x, None + + m = ScanLinear(ctx=nnx.context(0)) + + assert m.linear.kernel.shape == (5, 3, 3) + assert m.linear.bias.shape == (5, 3) + + y, _ = m(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) diff --git a/tests/nnx/test_variable.py b/tests/nnx/test_variable.py new file mode 100644 index 0000000000..3269d03d19 --- /dev/null +++ b/tests/nnx/test_variable.py @@ -0,0 +1,21 @@ +import typing as tp + +import jax +import pytest + +from flax.experimental import nnx + +A = tp.TypeVar("A") + + +class TestVariable: + + def test_value(self): + r1 = nnx.Node(1) + assert r1.value == 1 + + r2 = jax.tree_map(lambda x: x + 1, r1) + + assert r1.value == 1 + assert r2.value == 2 + assert r1 is not r2