diff --git a/flax/linen/module.py b/flax/linen/module.py index 406912d22c..9de568b874 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -111,12 +111,17 @@ def _attr_repr(value: Any): def _module_repr(module: 'Module', num_spaces: int = 4): """Returns a pretty printed representation of the module.""" cls = type(module) + try: + fields = dataclasses.fields(cls) + except TypeError: + # Edge case with no fields e.g. module = nn.Module() causes error later. + return object.__repr__(module) cls_name = cls.__name__ rep = '' attributes = { f.name: f.type - for f in dataclasses.fields(cls) + for f in fields if f.name not in ('parent', 'name') and f.repr } child_modules = {