Skip to content

Commit

Permalink
Fix nn.Module typing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559093587
  • Loading branch information
Flax Team committed Aug 22, 2023
1 parent 9e82cdb commit c6c50d3
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ def __getattr__(self, name):
# Base Module definition.
# -----------------------------------------------------------------------------

def module_field(*, kw_only: bool = False, default: Optional[Any] = ...) -> Any:
...

# The ModuleBase class is created only to make static analyzers happy
# mainly pytype and pyright. Some notes:
Expand All @@ -887,14 +889,12 @@ def __getattr__(self, name):
# * Other attributes are annotated for completeness. Because we are using
# the `if typing.TYPE_CHECKING` pattern, these annotations are not present
# at runtime so they don't affect the dataclass behavior.
@dataclass_transform()
@dataclass_transform(field_specifiers=(module_field,))
class ModuleBase:
if typing.TYPE_CHECKING:
name: Optional[str] = None
scope: Optional[Scope]
_state: _ModuleInternalState
_parent_ref: Union['Module', weakref.ReferenceType['Module'], None]
parent: Union['Module', _Sentinel, None]
__dataclass_fields__: Dict[str, dataclasses.Field]


Expand Down Expand Up @@ -934,6 +934,10 @@ def __call__(self, x):
"""

if typing.TYPE_CHECKING:
name: Optional[str] = module_field(kw_only=True, default=None)
parent: Union['Module', _Sentinel, None] = module_field(
kw_only=True, default=None
)

def __init__(self, *args, **kwargs):
# this stub makes sure pytype accepts constructor arguments.
Expand Down

0 comments on commit c6c50d3

Please sign in to comment.