From 70d845a81aadecef6ac7894aa308ff456e1ec651 Mon Sep 17 00:00:00 2001 From: jheek Date: Wed, 7 Sep 2022 12:47:53 +0200 Subject: [PATCH] Remove pytype generic workaround --- flax/core/lift.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/core/lift.py b/flax/core/lift.py index d1485ad855..0ec7e6abbc 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -315,13 +315,13 @@ def swap(target): @dataclasses.dataclass(frozen=True) class In(Generic[T]): """Specifies a variable collection should only be lifted as input.""" - axis: Any # pytype does not support generic variable annotation + axis: T @dataclasses.dataclass(frozen=True) class Out(Generic[T]): """Specifies a variable collection should only be lifted as output.""" - axis: Any # pytype does not support generic variable annotation + axis: T def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]):