Skip to content

Commit

Permalink
Merge pull request #3242 from google:improve-typing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555853597
  • Loading branch information
Flax Authors committed Aug 11, 2023
2 parents 3ea6381 + 2dbb44a commit 5030149
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 24 deletions.
28 changes: 15 additions & 13 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@

import abc
import functools
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generic, Mapping, Optional, Tuple, TypeVar, Union

from flax import errors
from flax import struct
import jax
from jax.experimental import maps


TAxisMetadata = Any # TypeVar('TAxisMetadata', bound='AxisMetadata')
A = TypeVar('A')
B = TypeVar('B')
TAxisMetadata = TypeVar('TAxisMetadata', bound='AxisMetadata[Any]')


class AxisMetadata(metaclass=abc.ABCMeta):
class AxisMetadata(Generic[A], metaclass=abc.ABCMeta):
"""Abstract base class for boxed Metadata.
``AxisMetadata`` enables arbitrary, per axis metadata for variables.
Expand All @@ -53,7 +55,7 @@ class AxisMetadata(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def unbox(self) -> Any:
def unbox(self) -> A:
"""Returns the content of the AxisMetadata box.
Note that unlike ``meta.unbox`` the unbox call should recursively unbox
Expand All @@ -70,7 +72,7 @@ def unbox(self) -> Any:
pass

@abc.abstractmethod
def replace_boxed(self, val: Any) -> TAxisMetadata:
def replace_boxed(self, val: B) -> 'AxisMetadata[B]':
"""Replaces the boxed value with the provided value.
Args:
Expand Down Expand Up @@ -129,7 +131,7 @@ def is_axis_metadata(val: Any) -> bool:
return isinstance(val, AxisMetadata)


def map_axis_meta(fn: Callable[[AxisMetadata], Any], tree: Any) -> Any:
def map_axis_meta(fn: Callable[[AxisMetadata[Any]], Any], tree: Any) -> Any:
"""Maps over all PyTree nodes that are AxisMetadata instances."""

def wrapper(x):
Expand Down Expand Up @@ -178,7 +180,7 @@ def _global_mesh_defined() -> bool:
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison


class Partitioned(struct.PyTreeNode, AxisMetadata):
class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
"""Wrapper for partitioning metadata.
``Partitioned`` is used to extend variables with partitioning information
Expand Down Expand Up @@ -241,7 +243,7 @@ def body(mdl, c):
default=None, pytree_node=False
)

def unbox(self, apply_constraint=True) -> Any:
def unbox(self, apply_constraint=True) -> A:
"""Returns the wrapped value with the partitioning applied as a sharding constraint."""
if apply_constraint and (_global_mesh_defined() or self.mesh is not None):
axis_resource = self.get_partition_spec()
Expand All @@ -252,23 +254,23 @@ def unbox(self, apply_constraint=True) -> Any:
else:
return self.value

def replace_boxed(self, val: Any) -> TAxisMetadata:
return self.replace(value=val)
def replace_boxed(self, val: B) -> 'Partitioned[B]':
return self.replace(value=val) # type: ignore

def _get_partition_name(self, params: Dict[Any, Any]) -> str:
if PARTITION_NAME not in params:
raise errors.PartitioningUnspecifiedError(self)
return params[PARTITION_NAME]

def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
def add_axis(self, index: int, params: Dict[Any, Any]) -> 'Partitioned[A]':
axis_name = self._get_partition_name(params)
names = list(self.names)
while len(names) < index:
names.append(None) # type: ignore
names.insert(index, axis_name) # type: ignore
return self.replace(names=tuple(names))

def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
def remove_axis(self, index: int, params: Dict[Any, Any]) -> 'Partitioned[A]':
axis_name = self._get_partition_name(params)
names = list(self.names)
assert names.pop(index) == axis_name
Expand All @@ -287,7 +289,7 @@ def with_partitioning(
fn: Callable[..., Any],
names: LogicalNames,
mesh: Optional[jax.sharding.Mesh] = None,
) -> Callable[..., Partitioned]:
) -> Callable[..., Partitioned[Any]]:
"""Wraps a function's return value with Partitioned.
Example::
Expand Down
86 changes: 83 additions & 3 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
Dict,
Generic,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)

from flax import config as config
Expand Down Expand Up @@ -816,14 +819,57 @@ def put(target, key, val):

put(variables, name, value)

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[True],
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[False],
) -> Variable[meta.AxisMetadata[T]]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
...

def variable(
self,
col: str,
name: str, # pylint: disable=keyword-arg-before-vararg
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Variable[T]:
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
"""Creates a variable if it doesn't exist yet in this scope and returns it.
Args:
Expand All @@ -847,11 +893,45 @@ def variable(
raise errors.ScopeVariableNotFoundError(name, col, self.path_text)
init_value = init_fn(*init_args)
self.put_variable(col, name, init_value)
return Variable(self, col, name, unbox=unbox)
# cast to make static analyzers happy
return cast(
Union[Variable[T], Variable[meta.AxisMetadata[T]]],
Variable(self, col, name, unbox=unbox),
)

@overload
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[True],
) -> T:
...

@overload
def param(
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[False],
) -> meta.AxisMetadata[T]:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool
) -> Union[T, meta.AxisMetadata[T]]:
...

def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
) -> Union[T, meta.AxisMetadata[T]]:
"""Creates a parameter if it doesn't exist yet in this scope and returns it.
If the parameter exists already, the existing value is simply returned.
Expand Down
8 changes: 6 additions & 2 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Attention core modules for Flax."""

import functools
from typing import (Any, Callable, Optional, Tuple)
from typing import (Any, Callable, Optional, Tuple, Union)
from flax.linen.dtypes import promote_dtype

from flax.linen import initializers
Expand Down Expand Up @@ -334,7 +334,11 @@ def __call__(
)
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
indices: tuple[Union[int, jax.Array], ...] = (0,) * len(batch_dims) + (
cur_index,
0,
0,
)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
Expand Down
84 changes: 78 additions & 6 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@
Dict,
Iterable,
List,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand All @@ -53,6 +51,7 @@
)
from flax.core import partial_eval
from flax.core import Scope
from flax.core import meta
from flax.core.frozen_dict import FrozenDict
from flax.core.scope import ( # pylint: disable=g-multiple-import
CollectionFilter,
Expand Down Expand Up @@ -1332,14 +1331,57 @@ def clone_fn(m: Module) -> Module:

return module

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[True],
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[False],
) -> Variable[meta.AxisMetadata[T]]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
...

def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., Any]] = None,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Variable:
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
"""Declares and returns a variable in this Module.
See :mod:`flax.core.variables` for more information. See also :meth:`param`
Expand Down Expand Up @@ -1383,9 +1425,39 @@ def variable(
self._state.children[name] = col
return v

@overload
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[True],
) -> T:
...

@overload
def param(
self,
name: str,
init_fn: Callable[..., T],
*init_args,
unbox: Literal[False],
) -> meta.AxisMetadata[T]:
...

@overload
def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool
) -> Union[T, meta.AxisMetadata[T]]:
...

def param(
self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True
) -> Union[T, meta.AxisMetadata[T]]:
"""Declares and returns a parameter in this Module.
Parameters are read-only variables in the collection named "params". See
Expand Down

0 comments on commit 5030149

Please sign in to comment.