diff --git a/flax/linen/module.py b/flax/linen/module.py index a4f2b114ab..169bd22f83 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -21,6 +21,7 @@ import re import sys import threading +from types import MappingProxyType import typing import weakref from typing import (Any, Callable, Dict, Iterable, List, Sequence, NamedTuple, Mapping, @@ -1903,6 +1904,8 @@ def tabulate( show_repeated: bool = False, mutable: CollectionFilter = True, console_kwargs: Optional[Mapping[str, Any]] = None, + table_kwargs: Mapping[str, Any] = MappingProxyType({}), + column_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs) -> str: """Creates a summary of the Module represented as a table. @@ -1978,6 +1981,10 @@ def __call__(self, x): console_kwargs: An optional dictionary with additional keyword arguments that are passed to `rich.console.Console` when rendering the table. Default arguments are `{'force_terminal': True, 'force_jupyter': False}`. + table_kwargs: An optional dictionary with additional keyword arguments that + are passed to `rich.table.Table` constructor. + column_kwargs: An optional dictionary with additional keyword arguments that + are passed to `rich.table.Table.add_column` when adding columns to the table. **kwargs: keyword arguments to pass to the forward computation. Returns: @@ -1985,9 +1992,9 @@ def __call__(self, x): """ from flax.linen import summary - tabulate_fn = summary.tabulate(self, rngs, depth=depth, - show_repeated=show_repeated, mutable=mutable, - console_kwargs=console_kwargs) + tabulate_fn = summary.tabulate( + self, rngs, depth=depth, show_repeated=show_repeated, mutable=mutable, + console_kwargs=console_kwargs, table_kwargs=table_kwargs, column_kwargs=column_kwargs) return tabulate_fn(*args, **kwargs) diff --git a/flax/linen/summary.py b/flax/linen/summary.py index 57f2563eed..59d3369c63 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -16,6 +16,7 @@ from abc import ABC, abstractmethod import dataclasses import io +from types import MappingProxyType from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union from flax.core import unfreeze @@ -137,6 +138,8 @@ def tabulate( show_repeated: bool = False, mutable: CollectionFilter = True, console_kwargs: Optional[Mapping[str, Any]] = None, + table_kwargs: Mapping[str, Any] = MappingProxyType({}), + column_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs, ) -> Callable[..., str]: """Returns a function that creates a summary of the Module represented as a table. @@ -217,6 +220,10 @@ def __call__(self, x): console_kwargs: An optional dictionary with additional keyword arguments that are passed to `rich.console.Console` when rendering the table. Default arguments are `{'force_terminal': True, 'force_jupyter': False}`. + table_kwargs: An optional dictionary with additional keyword arguments that + are passed to `rich.table.Table` constructor. + column_kwargs: An optional dictionary with additional keyword arguments that + are passed to `rich.table.Table.add_column` when adding columns to the table. **kwargs: Additional arguments passed to `Module.init`. Returns: @@ -228,7 +235,7 @@ def __call__(self, x): def _tabulate_fn(*fn_args, **fn_kwargs): table_fn = _get_module_table(module, depth=depth, show_repeated=show_repeated) table = table_fn(rngs, *fn_args, mutable=mutable, **fn_kwargs, **kwargs) - return _render_table(table, console_kwargs) + return _render_table(table, console_kwargs, table_kwargs, column_kwargs) return _tabulate_fn @@ -337,7 +344,12 @@ def _process_inputs(args, kwargs) -> Any: return input_values -def _render_table(table: Table, console_extras: Optional[Mapping[str, Any]]) -> str: +def _render_table( + table: Table, + console_extras: Optional[Mapping[str, Any]], + table_kwargs: Mapping[str, Any], + column_kwargs: Mapping[str, Any], +) -> str: """A function that renders a Table to a string representation using rich.""" console_kwargs = {'force_terminal': True, 'force_jupyter': False} if console_extras is not None: @@ -345,19 +357,20 @@ def _render_table(table: Table, console_extras: Optional[Mapping[str, Any]]) -> non_params_cols = 4 rich_table = rich.table.Table( - show_header=True, - show_lines=True, - show_footer=True, - title=f'{table.module.__class__.__name__} Summary', + show_header=True, + show_lines=True, + show_footer=True, + title=f'{table.module.__class__.__name__} Summary', + **table_kwargs, ) - rich_table.add_column('path') - rich_table.add_column('module') - rich_table.add_column('inputs') - rich_table.add_column('outputs') + rich_table.add_column('path', **column_kwargs) + rich_table.add_column('module', **column_kwargs) + rich_table.add_column('inputs', **column_kwargs) + rich_table.add_column('outputs', **column_kwargs) for col in table.collections: - rich_table.add_column(col) + rich_table.add_column(col, **column_kwargs) for row in table: collections_size_repr = []