diff --git a/flax/linen/summary.py b/flax/linen/summary.py index d6676729f0..b04f8961de 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -242,11 +242,6 @@ def tabulate( Total Parameters: 50 (200 B) - - **Note**: rows order in the table does not represent execution order, - instead it aligns with the order of keys in `variables` which are sorted - alphabetically. - **Note**: `vjp_flops` returns `0` if the module is not differentiable. Args: diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index fcb15f0608..d85271c465 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -167,4 +167,5 @@ from .extract import to_tree as to_tree from .extract import from_tree as from_tree from .extract import NodeStates as NodeStates +from .summary import tabulate as tabulate from . import traversals as traversals \ No newline at end of file diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index 63ed371be9..1028efb2b1 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate: else: raise TypeError(f'Invalid collection filter: {filter:!r}. ') -def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]: +def filters_to_predicates( + filters: tp.Sequence[Filter], +) -> tuple[Predicate, ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py new file mode 100644 index 0000000000..f1c92e5537 --- /dev/null +++ b/flax/nnx/summary.py @@ -0,0 +1,301 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pytype: skip-file + + +import io +import typing as tp +from itertools import groupby +from types import MappingProxyType + +import jax +import rich.console +import rich.table +import rich.text +import yaml +import jax.numpy as jnp + +from flax.nnx import graph, rnglib, variablelib + + +def tabulate( + obj, + depth: int | None = None, + table_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + column_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + console_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), +) -> str: + """Creates a summary of the graph object represented as a table. + + The table summarizes the object's state and metadata. The table is + structured as follows: + + - The first column represents the path of the object in the graph. + - The second column represents the type of the object. + - The following columns provide information about the object's state, + grouped by Variable types. + + Example: + + >>> from flax import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.bn = nnx.BatchNorm(dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.2, rngs=rngs) + ... + ... def __call__(self, x): + ... return nnx.relu(self.dropout(self.bn(self.linear(x)))) + ... + >>> class Foo(nnx.Module): + ... def __init__(self, rngs: nnx.Rngs): + ... self.block1 = Block(32, 128, rngs=rngs) + ... self.block2 = Block(128, 10, rngs=rngs) + ... + ... def __call__(self, x): + ... return self.block2(self.block1(x)) + ... + >>> foo = Foo(nnx.Rngs(0)) + >>> # print(nnx.tabulate(foo)) + + Foo Summary + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ + ┃ path ┃ type ┃ BatchStat ┃ Param ┃ RngState ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ + │ block1/bn │ BatchNorm │ mean: float32[128] │ bias: float32[128] │ │ + │ │ │ var: float32[128] │ scale: float32[128] │ │ + │ │ │ │ │ │ + │ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block1/dropout/rngs/default │ RngStream │ │ │ count: │ + │ │ │ │ │ value: uint32[] │ + │ │ │ │ │ tag: default │ + │ │ │ │ │ key: │ + │ │ │ │ │ value: key[] │ + │ │ │ │ │ tag: default │ + │ │ │ │ │ │ + │ │ │ │ │ 2 (12 B) │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block1/linear │ Linear │ │ bias: float32[128] │ │ + │ │ │ │ kernel: float32[32,128] │ │ + │ │ │ │ │ │ + │ │ │ │ 4,224 (16.9 KB) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block2/bn │ BatchNorm │ mean: float32[10] │ bias: float32[10] │ │ + │ │ │ var: float32[10] │ scale: float32[10] │ │ + │ │ │ │ │ │ + │ │ │ 20 (80 B) │ 20 (80 B) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block2/linear │ Linear │ │ bias: float32[10] │ │ + │ │ │ │ kernel: float32[128,10] │ │ + │ │ │ │ │ │ + │ │ │ │ 1,290 (5.2 KB) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ + └─────────────────────────────┴───────────┴────────────────────┴─────────────────────────┴─────────────────────┘ + + Total Parameters: 6,068 (24.3 KB) + + + Note that ``block2/dropout`` is not shown in the table because it shares the + same ``RngState`` with ``block1/dropout``. + + Args: + obj: A object to summarize. It can a pytree or a graph objects + such as nnx.Module or nnx.Optimizer. + depth: The depth of the table. + table_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table`` constructor. + column_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table.add_column`` when adding columns to + the table. + console_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.console.Console`` when rendering the table. + Default arguments are ``{'force_terminal': True, 'force_jupyter': + False}``. + + Returns: + A string summarizing the object. + """ + _console_kwargs = {'force_terminal': True, 'force_jupyter': False} + _console_kwargs.update(console_kwargs) + state = graph.state(obj) + graph_map = dict(graph.iter_graph(obj)) + flat_state = sorted(state.flat_state()) + + def key_fn( + path_state: tuple[graph.PathParts, variablelib.VariableState[tp.Any]], + ): + path, _ = path_state + if depth is None or len(path) <= depth: + return path[:-1] + else: + return path[:depth] + + rows = groupby(flat_state, key_fn) + table = sorted((path, list(flat_states)) for path, flat_states in rows) + + state_types_set = {variable_state.type for _, variable_state in flat_state} + # replace RngKey and RngCount with RngState + if rnglib.RngKey in state_types_set: + state_types_set.remove(rnglib.RngKey) + state_types_set.add(rnglib.RngState) + if rnglib.RngCount in state_types_set: + state_types_set.remove(rnglib.RngCount) + state_types_set.add(rnglib.RngState) + # sort based on MRO + state_types = _sort_variable_types(state_types_set) + + rich_table = rich.table.Table( + show_header=True, + show_lines=True, + show_footer=True, + title=f'{type(obj).__name__} Summary', + **table_kwargs, + ) + + rich_table.add_column('path', **column_kwargs) + rich_table.add_column('type', **column_kwargs) + + for state_type in state_types: + rich_table.add_column(state_type.__name__, **column_kwargs) + + for key_path, row_states in table: + row: list[str] = [] + node = graph_map[key_path] + type_state_groups = variablelib.split_flat_state(row_states, state_types) + path_str = '/'.join(map(str, key_path)) + node_type = type(node).__name__ + row.extend([path_str, node_type]) + + for state_type, type_path_and_states in zip(state_types, type_state_groups): + attributes = {} + for state_path, variable_state in type_path_and_states: + if len(state_path) == len(key_path) + 1: + name = str(state_path[-1]) + value = variable_state.value + value_repr = _render_array(value) if _has_shape_dtype(value) else '' + metadata = variable_state.get_metadata() + + if metadata: + attributes[name] = { + 'value': value_repr, + **metadata, + } + elif value_repr: + attributes[name] = value_repr + + if attributes: + col_repr = _as_yaml_str(attributes) + '\n\n' + else: + col_repr = '' + + type_states = [state for _, state in type_path_and_states] + size_, bytes_ = _size_and_bytes(type_states) + col_repr += f'[bold]{_size_and_bytes_repr(size_, bytes_)}[/bold]' + row.append(col_repr) + + rich_table.add_row(*row) + + rich_table.columns[1].footer = rich.text.Text.from_markup( + 'Total', justify='right' + ) + flat_states = variablelib.split_flat_state(flat_state, state_types) + + for i, (state_type, type_path_and_states) in enumerate( + zip(state_types, flat_states) + ): + type_states = [state for _, state in type_path_and_states] + size_, bytes_ = _size_and_bytes(type_states) + size_repr = _size_and_bytes_repr(size_, bytes_) + rich_table.columns[i + 2].footer = size_repr + + rich_table.caption_style = 'bold' + rich_table.caption = ( + f'\nTotal Parameters: {_size_and_bytes_repr(*_size_and_bytes(state))}' + ) + + return '\n' + _get_rich_repr(rich_table, _console_kwargs) + '\n' + + +def _get_rich_repr(obj, console_kwargs): + f = io.StringIO() + console = rich.console.Console(file=f, **console_kwargs) + console.print(obj) + return f.getvalue() + + +def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]: + leaves = jax.tree.leaves(pytree) + size = sum(x.size for x in leaves if hasattr(x, 'size')) + num_bytes = sum( + x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') + ) + return size, num_bytes + + +def _size_and_bytes_repr(size: int, num_bytes: int) -> str: + if not size: + return '' + bytes_repr = _bytes_repr(num_bytes) + return f'{size:,} [dim]({bytes_repr})[/dim]' + + +def _bytes_repr(num_bytes): + count, units = ( + (f'{num_bytes / 1e9 :,.1f}', 'GB') + if num_bytes > 1e9 + else (f'{num_bytes / 1e6 :,.1f}', 'MB') + if num_bytes > 1e6 + else (f'{num_bytes / 1e3 :,.1f}', 'KB') + if num_bytes > 1e3 + else (f'{num_bytes:,}', 'B') + ) + + return f'{count} {units}' + + +def _has_shape_dtype(value): + return hasattr(value, 'shape') and hasattr(value, 'dtype') + + +def _as_yaml_str(value) -> str: + if (hasattr(value, '__len__') and len(value) == 0) or value is None: + return '' + + file = io.StringIO() + yaml.safe_dump( + value, + file, + default_flow_style=False, + indent=2, + sort_keys=False, + explicit_end=False, + ) + return file.getvalue().replace('\n...', '').replace("'", '').strip() + + +def _render_array(x): + shape, dtype = jnp.shape(x), jnp.result_type(x) + shape_repr = ','.join(str(x) for x in shape) + return f'[dim]{dtype}[/dim][{shape_repr}]' + + +def _sort_variable_types(types: tp.Iterable[type]) -> list[type]: + def _variable_parents_count(t: type): + return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) + + type_sort_key = {t: (-_variable_parents_count(t), t.__name__) for t in types} + return sorted(types, key=lambda t: type_sort_key[t]) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 4752a9b7bd..85b441711b 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -911,7 +911,7 @@ def wrapper(*args): def split_flat_state( flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]], - filters: tuple[filterlib.Filter, ...], + filters: tp.Sequence[filterlib.Filter], ) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]: predicates = filterlib.filters_to_predicates(filters) # we have n + 1 states, where n is the number of predicates diff --git a/uv.lock b/uv.lock index e08e2dbf53..6307908fc7 100644 --- a/uv.lock +++ b/uv.lock @@ -3,13 +3,13 @@ requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] [[package]] @@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 } wheels = [ @@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 } wheels = [ @@ -1202,7 +1202,7 @@ name = "ipython" version = "8.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "decorator" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "jedi" }, @@ -1246,7 +1246,7 @@ wheels = [ [[package]] name = "jax" -version = "0.4.37" +version = "0.4.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -1255,14 +1255,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/30/ad7617a960c86782587540a179cef676962322d1e5411415b1aa24f02ce0/jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b", size = 1915966 } +sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/3f/6c5553baaa7faa3fa8bae8279b1e46cb54c7ce52360139eae53498786ea5/jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e", size = 2221192 }, + { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864 }, ] [[package]] name = "jaxlib" -version = "0.4.36" +version = "0.4.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -1270,26 +1270,26 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/23/8d/8a44618f3493f29d769b2b40778d24075689cc8697b98e2c43bafbe50edf/jaxlib-0.4.36-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d69f991833b6dca794767049843462805936c89553b136a8ebb8485334204457", size = 98648230 }, - { url = "https://files.pythonhosted.org/packages/78/b8/207485eab566dcfbc29bb833714ac1ca47a1665ca605b1ff7d3d5dd2afbe/jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:807814c1ba3ec69cffaa93d3f90651c694a9b8a750b43832cc167ed590c821dd", size = 78553787 }, - { url = "https://files.pythonhosted.org/packages/26/42/3c2b0dc86a17aafd8f46ba0e4388f39f55706ee25f6c463c3dadea7a71e2/jaxlib-0.4.36-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1bc27d9ae09549d7652eafe1fdb10c21546cd2fd02bb24a49a7e6208b69163b0", size = 84008742 }, - { url = "https://files.pythonhosted.org/packages/b9/b2/29be712098342df10075fe085c0b39d783a579bd3325fb0d69c22712cf27/jaxlib-0.4.36-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3379f03a794d6a30b75765d2786f6e31052f364196fcd49aaae292a3c16f12ec", size = 100263041 }, - { url = "https://files.pythonhosted.org/packages/63/a9/93404a2f1d59647749d4d6dbab7bee9f5a7bfaeb9ade25b7e66c0ca0949a/jaxlib-0.4.36-cp310-cp310-win_amd64.whl", hash = "sha256:63e575ac8a515dee8171dd4a88c460d538bbcc9d959cabc9781e961763678f84", size = 63270658 }, - { url = "https://files.pythonhosted.org/packages/e4/7d/9394ff39af5c23bb98a241c33742a328df5a43c21d569855ea7e096aaf5e/jaxlib-0.4.36-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:213792db3b876206b45f6a9fbea15e4dd22a9e80be25b03136f20c94784fecfa", size = 98669744 }, - { url = "https://files.pythonhosted.org/packages/34/5a/9f3c9e5cec23e60f78bb3c3da108a5ef664601862dbc4e84fc4be3654f5d/jaxlib-0.4.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d7a89adf4c9d3cddd20482931dedc7a9e2669e904196a9599d9a605b3d9e552", size = 78574312 }, - { url = "https://files.pythonhosted.org/packages/ff/5c/bf78ed9b8d0f174a562f6496049a4872e14a3bb3a80de09c4292d04be5f0/jaxlib-0.4.36-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c395fe8cc5bd6558dd2fbce78e24172b6f27762e17628720ae03d693001283f3", size = 84038323 }, - { url = "https://files.pythonhosted.org/packages/67/af/6a9dd26e8a6bedd4c9fe702059767256b0d9ed18c29a180a4598d5795bb4/jaxlib-0.4.36-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc324c6b1c64fe68400934c653e4e622f12576120dcdb451c3b4ea4dcaba2ae9", size = 100285487 }, - { url = "https://files.pythonhosted.org/packages/b7/46/31c3a519a94e84c672ca264c4151998e3e3fd11c481d8fa5af5885b91a1e/jaxlib-0.4.36-cp311-cp311-win_amd64.whl", hash = "sha256:c9e0c45a79e63aea65447f82bd0fa21c17b9afe884aa18dd5362b9965abe9d72", size = 63308064 }, - { url = "https://files.pythonhosted.org/packages/e3/0e/3b4a99c09431ee5820624d4dcf4efa7becd3c83b56ff0f09a078f4c421a2/jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d", size = 98718357 }, - { url = "https://files.pythonhosted.org/packages/d3/46/05e70a1236ec3782333b3e9469f971c9d45af2aa0aebf602acd9d76292eb/jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e", size = 78596060 }, - { url = "https://files.pythonhosted.org/packages/8e/76/6b969cbf197b8c53c84c2642069722e84a3a260af084a8acbbf90ca444ea/jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442", size = 84053202 }, - { url = "https://files.pythonhosted.org/packages/fe/f2/7624a304426daa7b135b85caf1b8eccf879e7cb10bc074656ce628309cb0/jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357", size = 100325610 }, - { url = "https://files.pythonhosted.org/packages/bb/8b/ded8420cd9198eb677869ffd557d9880af5833c7bf39e604e80b56550e09/jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3", size = 63338518 }, - { url = "https://files.pythonhosted.org/packages/5d/22/b72811c61e8b594951d3ee03245cb0932c723ac35e75569005c3c976eec2/jaxlib-0.4.36-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:02df9c0e1323dde01e966c22eb12432905d2d4de8aac7b603cad2083101b0e6b", size = 98719384 }, - { url = "https://files.pythonhosted.org/packages/f1/66/3f4a97097983914899100db9e5312493fe1d6adc924e47a0e47e15c553f5/jaxlib-0.4.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ec980e85983f41999c4dc84137dec70507d958e23d7eefa104da93053d135f", size = 78596150 }, - { url = "https://files.pythonhosted.org/packages/3a/6f/cf02f56d1532962d8ca77a6548acab8204294b96b5a153ca4a2caf4971fc/jaxlib-0.4.36-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7ce9368515348d869d6c59d9904c3cb3c81f22ff3e9e969eae0e3563fe472080", size = 84055851 }, - { url = "https://files.pythonhosted.org/packages/28/10/4fc4e9719c065c6455491730011e87fe4b5120a9a008161cc32663feb9ce/jaxlib-0.4.36-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:93f1c502d08e517f842fe7b18428bb086cfd077db0ea9a2418fb21e5b4e06d3d", size = 100325986 }, - { url = "https://files.pythonhosted.org/packages/ba/28/fece5385e736ef2f1b5bed133f8001f0fc66dd0104707381343e047b341a/jaxlib-0.4.36-cp313-cp313-win_amd64.whl", hash = "sha256:bddf436a243e83ec6bc16bcbb74d15b1960a69318c9ea796fb2109492bc52575", size = 63338694 }, + { url = "https://files.pythonhosted.org/packages/ee/d4/e6a0881a88b8f17491c2ee271fd77c348b0221d9e2ec92dad23a2c9e41bc/jaxlib-0.4.38-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:55c19b9d3f33a6fc59f644aa5a21fba02639ccdd776cb4a9b5526625f57839ff", size = 99663603 }, + { url = "https://files.pythonhosted.org/packages/b6/6d/11569ce873f04c82ec22e58d822f4187dccae1d400c0d6dd05ed314d5328/jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:30b2f52cb50d74734af2f477c2533a7a583e3bb7b2c8acdeb361ee77d940577a", size = 79475708 }, + { url = "https://files.pythonhosted.org/packages/72/61/1de2405d13089c83b1ad87ec0266479c9d00080659dae2474892ae356306/jaxlib-0.4.38-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ee19c163a8fdf0839d4c18b88a5fbfb4e731ba7c437416d3e5483e570bb764e4", size = 93219045 }, + { url = "https://files.pythonhosted.org/packages/9c/24/0829decf233c6af9efe7c53888ae8ac72395e0979869cd9cee487e35dac3/jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:61aeccb9a27c67fdb8450f6357240019cd4511cb9d62a44e4764756d384853ad", size = 101732107 }, + { url = "https://files.pythonhosted.org/packages/0d/04/120c4caac6151f7297fedf9dd776362aa2d417d3f87bda826050b4da45e8/jaxlib-0.4.38-cp310-cp310-win_amd64.whl", hash = "sha256:d6ab745a89d0fb737a36fe1d8b86659e3fffe6ee8303b20651b26193d5edc0ef", size = 64223924 }, + { url = "https://files.pythonhosted.org/packages/b0/6a/b9fba73eb5e758e40a514919e096a039d27dc0ab4776a6cc977f5153a55f/jaxlib-0.4.38-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:b67fdeabd6dfed08b7768f3bdffb521160085f8305669bd197beef61d08de08b", size = 99679916 }, + { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377 }, + { url = "https://files.pythonhosted.org/packages/94/96/7d9a0b9f35af4727df44b68ade4c6f15163840727d1cb47251b1ea515e30/jaxlib-0.4.38-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:43db58c4c427627296366a56c10318e1f00f503690e17f94bb4344293e1995e0", size = 93241543 }, + { url = "https://files.pythonhosted.org/packages/a3/2d/68f85037e60c981b37b18b23ace458c677199dea4722ddce541b48ddfc63/jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2751ff7037d6a997d0be0e77cc4be381c5a9f9bb8b314edb755c13a6fd969f45", size = 101751923 }, + { url = "https://files.pythonhosted.org/packages/cc/24/a9c571c8a189f58e0b54b14d53fc7f5a0a06e4f1d7ab9edcf8d1d91d07e7/jaxlib-0.4.38-cp311-cp311-win_amd64.whl", hash = "sha256:35226968fc9de6873d1571670eac4117f5ed80e955f7a1775204d1044abe16c6", size = 64255189 }, + { url = "https://files.pythonhosted.org/packages/49/df/08b94c593c0867c7eaa334592807ba74495de4be90580f360db8b96221dc/jaxlib-0.4.38-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:3fefea985f0415816f3bbafd3f03a437050275ef9bac9a72c1314e1644ac57c1", size = 99737849 }, + { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242 }, + { url = "https://files.pythonhosted.org/packages/53/25/dd670d8bdf3799ece76d12cfe6a6a250ea256057aa4b0fcace4753a99d2d/jaxlib-0.4.38-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:496f45b0e001a2341309cd0c74af0b670537dced79c168cb230cfcc773f0aa86", size = 93251503 }, + { url = "https://files.pythonhosted.org/packages/f9/cc/37fce5162f6b9070203fd76cc0f298d9b3bfdf01939a78935a6078d63621/jaxlib-0.4.38-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dad6c0a96567c06d083c0469fec40f201210b099365bd698be31a6d2ec88fd59", size = 101792792 }, + { url = "https://files.pythonhosted.org/packages/6f/7a/8515950a60a4ea5b13cc98fc0a42e36553b2db5a6eedc00d3bd7836f77b5/jaxlib-0.4.38-cp312-cp312-win_amd64.whl", hash = "sha256:966cdec36cfa978f5b4582bcb4147fe511725b94c1a752dac3a5f52ce46b6fa3", size = 64288223 }, + { url = "https://files.pythonhosted.org/packages/91/03/aee503c7077c6dbbd568842303426c6ec1cef9bff330c418c9e71906cccd/jaxlib-0.4.38-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:41e55ae5818a882e5789e848f6f16687ac132bcfbb5a5fa114a5d18b78d05f2d", size = 99739026 }, + { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735 }, + { url = "https://files.pythonhosted.org/packages/e4/0b/8cbff0b6d62a4694351c49baf53b7ed8deb8a6854d129408c38158e11676/jaxlib-0.4.38-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:248cca3771ebf24b070f49701364ceada33e6139445b06c782cca5ac5ad92bf4", size = 93251882 }, + { url = "https://files.pythonhosted.org/packages/15/57/7f0283273b69c417071bcd2f4c2ed076479ec5ffc22a647f13c21da8d071/jaxlib-0.4.38-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:2ce77ba8cda9259a4bca97afc1c722e4291a6c463a63f8d372c6edc85117d625", size = 101791137 }, + { url = "https://files.pythonhosted.org/packages/de/de/d6c4d234cd426b97459cb070af90792b48643967a0d28641379ee9e10fc9/jaxlib-0.4.38-cp313-cp313-win_amd64.whl", hash = "sha256:4103db0b3a38a5dc132741237453c24d8547290a22079ba1b577d6c88c95300a", size = 64288459 }, ] [[package]] @@ -1431,7 +1431,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, + { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -2095,7 +2095,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -2122,9 +2122,9 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -2135,7 +2135,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -2262,7 +2262,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.10.2" +version = "0.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2280,9 +2280,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/06/c42e2f1563dbaaf5ed1464d7b634324fb9a2da04021073c45777e61af78d/orbax_checkpoint-0.10.2.tar.gz", hash = "sha256:e575ebe1f94e5cb6353ab8c9df81de0ca7cddc118645c3bfc17b8344f19d42f1", size = 248170 } +sdist = { url = "https://files.pythonhosted.org/packages/de/b3/a9a8a6bc08ded7634a9d85ba440400172f0a11f9341897b8fd3389fad245/orbax_checkpoint-0.11.0.tar.gz", hash = "sha256:d4a0dcc81edd29191cf5a4feb9cf2a4edd31fc5da79d7be616a04f11f2a4d484", size = 253035 } wheels = [ - { url = "https://files.pythonhosted.org/packages/61/19/ed366f8894923f3c8db0370e4bdd57ef843d68011dafa00d8175f4a66e1a/orbax_checkpoint-0.10.2-py3-none-any.whl", hash = "sha256:dcfc425674bd8d4934986143bd22a37cd634d034652c5d30d83c539ef8587941", size = 354306 }, + { url = "https://files.pythonhosted.org/packages/87/32/3779fa524a2272f408ab51d869fde9ff1c0ca731eedd01e40436bcf7ba2c/orbax_checkpoint-0.11.0-py3-none-any.whl", hash = "sha256:892a124fce71f3e7c71451a2b2090c0251db1097803a119a00baa377113bc9ba", size = 360423 }, ] [[package]] @@ -2436,7 +2436,7 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version < '3.11' and platform_system == 'Darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 } wheels = [ @@ -2454,10 +2454,10 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*' and platform_system == 'Darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", "python_full_version >= '3.12' and platform_system == 'Darwin'", "python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 } wheels = [ @@ -2606,7 +2606,7 @@ name = "pytest" version = "8.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig" }, { name = "packaging" }, @@ -3195,7 +3195,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -3684,7 +3684,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },