diff --git a/flax/linen/module.py b/flax/linen/module.py index 406912d22c..50fef7a978 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -111,6 +111,9 @@ def _attr_repr(value: Any): def _module_repr(module: 'Module', num_spaces: int = 4): """Returns a pretty printed representation of the module.""" cls = type(module) + if not hasattr(cls, '_FIELDS'): + # Edge case with no fields e.g. module = nn.Module() causes error later. + return object.__repr__(module) cls_name = cls.__name__ rep = ''