Skip to content

Commit

Permalink
adds table_kwargs and column_kwargs arguments to tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 13, 2023
1 parent 6739a07 commit c5f4152
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
13 changes: 10 additions & 3 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import re
import sys
import threading
from types import MappingProxyType
import typing
import weakref
from typing import (Any, Callable, Dict, Iterable, List, Sequence, NamedTuple, Mapping,
Expand Down Expand Up @@ -1903,6 +1904,8 @@ def tabulate(
show_repeated: bool = False,
mutable: CollectionFilter = True,
console_kwargs: Optional[Mapping[str, Any]] = None,
table_kwargs: Mapping[str, Any] = MappingProxyType({}),
column_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs) -> str:
"""Creates a summary of the Module represented as a table.
Expand Down Expand Up @@ -1978,16 +1981,20 @@ def __call__(self, x):
console_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.console.Console` when rendering the table. Default arguments
are `{'force_terminal': True, 'force_jupyter': False}`.
table_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table` constructor.
column_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table.add_column` when adding columns to the table.
**kwargs: keyword arguments to pass to the forward computation.
Returns:
A string summarizing the Module.
"""
from flax.linen import summary

tabulate_fn = summary.tabulate(self, rngs, depth=depth,
show_repeated=show_repeated, mutable=mutable,
console_kwargs=console_kwargs)
tabulate_fn = summary.tabulate(
self, rngs, depth=depth, show_repeated=show_repeated, mutable=mutable,
console_kwargs=console_kwargs, table_kwargs=table_kwargs, column_kwargs=column_kwargs)
return tabulate_fn(*args, **kwargs)


Expand Down
35 changes: 24 additions & 11 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABC, abstractmethod
import dataclasses
import io
from types import MappingProxyType
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union
from flax.core import unfreeze

Expand Down Expand Up @@ -137,6 +138,8 @@ def tabulate(
show_repeated: bool = False,
mutable: CollectionFilter = True,
console_kwargs: Optional[Mapping[str, Any]] = None,
table_kwargs: Mapping[str, Any] = MappingProxyType({}),
column_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs,
) -> Callable[..., str]:
"""Returns a function that creates a summary of the Module represented as a table.
Expand Down Expand Up @@ -217,6 +220,10 @@ def __call__(self, x):
console_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.console.Console` when rendering the table. Default arguments
are `{'force_terminal': True, 'force_jupyter': False}`.
table_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table` constructor.
column_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table.add_column` when adding columns to the table.
**kwargs: Additional arguments passed to `Module.init`.
Returns:
Expand All @@ -228,7 +235,7 @@ def __call__(self, x):
def _tabulate_fn(*fn_args, **fn_kwargs):
table_fn = _get_module_table(module, depth=depth, show_repeated=show_repeated)
table = table_fn(rngs, *fn_args, mutable=mutable, **fn_kwargs, **kwargs)
return _render_table(table, console_kwargs)
return _render_table(table, console_kwargs, table_kwargs, column_kwargs)

return _tabulate_fn

Expand Down Expand Up @@ -337,27 +344,33 @@ def _process_inputs(args, kwargs) -> Any:

return input_values

def _render_table(table: Table, console_extras: Optional[Mapping[str, Any]]) -> str:
def _render_table(
table: Table,
console_extras: Optional[Mapping[str, Any]],
table_kwargs: Mapping[str, Any],
column_kwargs: Mapping[str, Any],
) -> str:
"""A function that renders a Table to a string representation using rich."""
console_kwargs = {'force_terminal': True, 'force_jupyter': False}
if console_extras is not None:
console_kwargs.update(console_extras)

non_params_cols = 4
rich_table = rich.table.Table(
show_header=True,
show_lines=True,
show_footer=True,
title=f'{table.module.__class__.__name__} Summary',
show_header=True,
show_lines=True,
show_footer=True,
title=f'{table.module.__class__.__name__} Summary',
**table_kwargs,
)

rich_table.add_column('path')
rich_table.add_column('module')
rich_table.add_column('inputs')
rich_table.add_column('outputs')
rich_table.add_column('path', **column_kwargs)
rich_table.add_column('module', **column_kwargs)
rich_table.add_column('inputs', **column_kwargs)
rich_table.add_column('outputs', **column_kwargs)

for col in table.collections:
rich_table.add_column(col)
rich_table.add_column(col, **column_kwargs)

for row in table:
collections_size_repr = []
Expand Down

0 comments on commit c5f4152

Please sign in to comment.