Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Typing support #3242

Merged
merged 5 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@

import abc
import functools
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generic, Mapping, Optional, Tuple, TypeVar, Union

from flax import errors
from flax import struct
import jax
from jax.experimental import maps


TAxisMetadata = Any # TypeVar('TAxisMetadata', bound='AxisMetadata')
A = TypeVar('A')
B = TypeVar('B')
TAxisMetadata = TypeVar('TAxisMetadata', bound='AxisMetadata[Any]')


class AxisMetadata(metaclass=abc.ABCMeta):
class AxisMetadata(Generic[A], metaclass=abc.ABCMeta):
"""Abstract base class for boxed Metadata.

``AxisMetadata`` enables arbitrary, per axis metadata for variables.
Expand All @@ -53,7 +55,7 @@ class AxisMetadata(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def unbox(self) -> Any:
def unbox(self) -> A:
"""Returns the content of the AxisMetadata box.

Note that unlike ``meta.unbox`` the unbox call should recursively unbox
Expand All @@ -70,7 +72,7 @@ def unbox(self) -> Any:
pass

@abc.abstractmethod
def replace_boxed(self, val: Any) -> TAxisMetadata:
def replace_boxed(self, val: B) -> 'AxisMetadata[B]':
"""Replaces the boxed value with the provided value.

Args:
Expand Down Expand Up @@ -129,7 +131,7 @@ def is_axis_metadata(val: Any) -> bool:
return isinstance(val, AxisMetadata)


def map_axis_meta(fn: Callable[[AxisMetadata], Any], tree: Any) -> Any:
def map_axis_meta(fn: Callable[[AxisMetadata[Any]], Any], tree: Any) -> Any:
"""Maps over all PyTree nodes that are AxisMetadata instances."""

def wrapper(x):
Expand Down Expand Up @@ -178,7 +180,7 @@ def _global_mesh_defined() -> bool:
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison


class Partitioned(struct.PyTreeNode, AxisMetadata):
class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
"""Wrapper for partitioning metadata.

``Partitioned`` is used to extend variables with partitioning information
Expand Down Expand Up @@ -241,7 +243,7 @@ def body(mdl, c):
default=None, pytree_node=False
)

def unbox(self, apply_constraint=True) -> Any:
def unbox(self, apply_constraint=True) -> A:
"""Returns the wrapped value with the partitioning applied as a sharding constraint."""
if apply_constraint and (_global_mesh_defined() or self.mesh is not None):
axis_resource = self.get_partition_spec()
Expand All @@ -252,23 +254,23 @@ def unbox(self, apply_constraint=True) -> Any:
else:
return self.value

def replace_boxed(self, val: Any) -> TAxisMetadata:
return self.replace(value=val)
def replace_boxed(self, val: B) -> 'Partitioned[B]':
return self.replace(value=val) # type: ignore

def _get_partition_name(self, params: Dict[Any, Any]) -> str:
if PARTITION_NAME not in params:
raise errors.PartitioningUnspecifiedError(self)
return params[PARTITION_NAME]

def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
def add_axis(self, index: int, params: Dict[Any, Any]) -> 'Partitioned[A]':
axis_name = self._get_partition_name(params)
names = list(self.names)
while len(names) < index:
names.append(None) # type: ignore
names.insert(index, axis_name) # type: ignore
return self.replace(names=tuple(names))

def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
def remove_axis(self, index: int, params: Dict[Any, Any]) -> 'Partitioned[A]':
axis_name = self._get_partition_name(params)
names = list(self.names)
assert names.pop(index) == axis_name
Expand All @@ -287,7 +289,7 @@ def with_partitioning(
fn: Callable[..., Any],
names: LogicalNames,
mesh: Optional[jax.sharding.Mesh] = None,
) -> Callable[..., Partitioned]:
) -> Callable[..., Partitioned[Any]]:
"""Wraps a function's return value with Partitioned.

Example::
Expand Down
34 changes: 33 additions & 1 deletion flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
Dict,
Generic,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
overload,
)

from flax import config as config
Expand Down Expand Up @@ -849,9 +851,39 @@ def variable(
self.put_variable(col, name, init_value)
return Variable(self, col, name, unbox=unbox)

@overload
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[True],
) -> T:
...

@overload
def param(
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[False],
) -> meta.AxisMetadata[T]:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool
) -> Union[T, meta.AxisMetadata[T]]:
...

def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
) -> Union[T, meta.AxisMetadata[T]]:
"""Creates a parameter if it doesn't exist yet in this scope and returns it.

If the parameter exists already, the existing value is simply returned.
Expand Down
82 changes: 78 additions & 4 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@
Dict,
Iterable,
List,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand All @@ -53,6 +51,7 @@
)
from flax.core import partial_eval
from flax.core import Scope
from flax.core import meta
from flax.core.frozen_dict import FrozenDict
from flax.core.scope import ( # pylint: disable=g-multiple-import
CollectionFilter,
Expand Down Expand Up @@ -1382,9 +1381,39 @@ def variable(
self._state.children[name] = col
return v

@overload
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[True],
) -> T:
...

@overload
def param(
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[False],
) -> meta.AxisMetadata[T]:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool
) -> Union[T, meta.AxisMetadata[T]]:
...

def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
) -> Union[T, meta.AxisMetadata[T]]:
"""Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named "params". See
Expand Down Expand Up @@ -1594,6 +1623,51 @@ def __call__(self, x):
module = self.clone()
return module, variables

@overload
def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: Literal[False],
capture_intermediates: Union[
bool, Callable[['Module', str], bool]
] = False,
**kwargs,
) -> Any:
...

@overload
def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter,
capture_intermediates: Union[
bool, Callable[['Module', str], bool]
] = False,
**kwargs,
) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]:
...

@overload
def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = False,
capture_intermediates: Union[
bool, Callable[['Module', str], bool]
] = False,
**kwargs,
) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]:
...

@traceback_util.api_boundary
def apply(
self,
Expand Down
Loading