diff --git a/flax/core/lift.py b/flax/core/lift.py index d1485ad855..0ec7e6abbc 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -315,13 +315,13 @@ def swap(target): @dataclasses.dataclass(frozen=True) class In(Generic[T]): """Specifies a variable collection should only be lifted as input.""" - axis: Any # pytype does not support generic variable annotation + axis: T @dataclasses.dataclass(frozen=True) class Out(Generic[T]): """Specifies a variable collection should only be lifted as output.""" - axis: Any # pytype does not support generic variable annotation + axis: T def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]):