Skip to content

Commit

Permalink
Merge pull request #2446 from jheek:remove-pytype-generic-workaround
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 473235778
  • Loading branch information
Flax Authors committed Sep 9, 2022
2 parents 8687673 + 70d845a commit e320e11
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ def swap(target):
@dataclasses.dataclass(frozen=True)
class In(Generic[T]):
"""Specifies a variable collection should only be lifted as input."""
axis: Any # pytype does not support generic variable annotation
axis: T


@dataclasses.dataclass(frozen=True)
class Out(Generic[T]):
"""Specifies a variable collection should only be lifted as output."""
axis: Any # pytype does not support generic variable annotation
axis: T


def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]):
Expand Down

0 comments on commit e320e11

Please sign in to comment.