Skip to content

Commit

Permalink
Merge pull request #3293 from google:fix-kw-only-dataclass
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560158617
  • Loading branch information
Flax Authors committed Aug 25, 2023
2 parents ba9e24a + 836e0e4 commit 4879b4c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 19 deletions.
50 changes: 44 additions & 6 deletions flax/linen/kw_only_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ class that defines a field with a default, and a subclass that defines a field
"""

import dataclasses
import inspect
import functools
from types import MappingProxyType
from typing import Any, TypeVar
from typing_extensions import dataclass_transform
import flax

M = TypeVar('M', bound='flax.linen.Module')
FieldName = str
Annotation = Any
Default = Any


class _KwOnlyType:
Expand Down Expand Up @@ -88,6 +99,7 @@ def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs):
return dataclasses.field(metadata=metadata, **kwargs)


@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
def dataclass(cls=None, extra_fields=None, **kwargs):
"""Wrapper for dataclasses.dataclass that adds support for kw_only fields.
Expand All @@ -111,7 +123,7 @@ def wrap(cls):
return wrap if cls is None else wrap(cls)


def _process_class(cls, extra_fields=None, **kwargs):
def _process_class(cls: type[M], extra_fields=None, **kwargs):
"""Transforms `cls` into a dataclass that supports kw_only fields."""
if '__annotations__' not in cls.__dict__:
cls.__annotations__ = {}
Expand All @@ -122,7 +134,7 @@ def _process_class(cls, extra_fields=None, **kwargs):
base_dataclass_fields = {} # dict[cls, cls.__dataclass_fields__.copy()]

# The keyword only fields from `cls` or any of its base classes.
kw_only_fields = {} # dict[field_name, tuple[annotation, default]]
kw_only_fields: dict[FieldName, tuple[Annotation, Default]] = {}

# Scan for KW_ONLY marker.
kw_only_name = None
Expand All @@ -138,7 +150,7 @@ def _process_class(cls, extra_fields=None, **kwargs):
)
default = getattr(cls, name)
if isinstance(default, dataclasses.Field):
default.metadata = {**default.metadata, **{KW_ONLY: True}}
default.metadata = MappingProxyType({**default.metadata, KW_ONLY: True})
else:
default = field(default=default, kw_only=True)
setattr(cls, name, default)
Expand Down Expand Up @@ -190,12 +202,38 @@ def _process_class(cls, extra_fields=None, **kwargs):
cls_annotations.pop(name, None)
cls_annotations[name] = annotation

create_init = '__init__' not in vars(cls) and kwargs.get('init', True)

# Apply the dataclass transform.
transformed_cls = dataclasses.dataclass(cls, **kwargs)
transformed_cls: type[M] = dataclasses.dataclass(cls, **kwargs)

# Restore the base classes' __dataclass_fields__.
for cls, dataclass_fields in base_dataclass_fields.items():
cls.__dataclass_fields__ = dataclass_fields
for _cls, fields in base_dataclass_fields.items():
_cls.__dataclass_fields__ = fields

if create_init:
dataclass_init = transformed_cls.__init__
# use sum to count the number of init fields that are not keyword-only
expected_num_args = sum(
f.init and not f.metadata.get(KW_ONLY, False)
for f in dataclasses.fields(transformed_cls)
)

@functools.wraps(dataclass_init)
def init_wrapper(self, *args, **kwargs):
num_args = len(args)
if num_args > expected_num_args:
# we add + 1 to each to account for `self`, matching python's
# default error message
raise TypeError(
f'__init__() takes {expected_num_args + 1} positional '
f'arguments but {num_args + 1} were given'
)

dataclass_init(self, *args, **kwargs)

init_wrapper.__signature__ = inspect.signature(dataclass_init) # type: ignore
transformed_cls.__init__ = init_wrapper # type: ignore[method-assign]

# Return the transformed dataclass
return transformed_cls
2 changes: 1 addition & 1 deletion tests/linen/kw_only_dataclasses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Child(Parent):
v1 = Child(4)
self.assertDictEqual(dataclasses.asdict(v1), dict(a=2, b=4))

v2 = Child(4, 5) # pylint: disable=too-many-function-args
v2 = Child(4, a=5) # pylint: disable=too-many-function-args
self.assertDictEqual(dataclasses.asdict(v2), dict(a=5, b=4))

def test_subclass_overrides_base(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,6 +2214,18 @@ def __call__(self, x):
# take positional arg. It takes BaseLayer's default kwargs though.
np.testing.assert_equal(ChildLayer(8)(np.ones(10)), -8 * np.ones(10))

def test_positional_cannot_be_kw_only(self):
class Foo(nn.Module):
a: int

Foo(1) # ok
Foo(a=1) # ok
with self.assertRaisesRegex(
TypeError, r'takes 2 positional arguments but 3 were'
):
Foo(1, None)
Foo(a=1, parent=None) # type: ignore[call-arg]

def test_intercept_methods(self):
mod = IdentityModule(parent=None)
x = jnp.ones([])
Expand Down
24 changes: 12 additions & 12 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ def __call__(self, x):

x = jnp.ones((10, 10))
rngs = random.PRNGKey(0)
init_vars = Test(None).init(rngs, x)
_, new_vars = Test(None).apply(init_vars, x, mutable=['counter'])
init_vars = Test(parent=None).init(rngs, x)
_, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter'])
self.assertEqual(
init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32)
)
Expand Down Expand Up @@ -462,8 +462,8 @@ def __call__(self, x):

x = jnp.ones((1, 1))
rngs = random.PRNGKey(0)
init_vars = Test(None).init(rngs, x)
_, new_vars = Test(None).apply(init_vars, x, mutable=['counter'])
init_vars = Test(parent=None).init(rngs, x)
_, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter'])
self.assertEqual(
init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32)
)
Expand Down Expand Up @@ -508,8 +508,8 @@ def __call__(self, x):

x = jnp.ones((1, 1))
rngs = random.PRNGKey(0)
init_vars = Test(None).init(rngs, x)
_, new_vars = Test(None).apply(init_vars, x, mutable=['counter'])
init_vars = Test(parent=None).init(rngs, x)
_, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter'])
self.assertEqual(
init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32)
)
Expand Down Expand Up @@ -563,8 +563,8 @@ def __call__(self, x):

x = jnp.ones((1, 1))
rngs = random.PRNGKey(0)
init_vars = Test(None).init(rngs, x)
_, new_vars = Test(None).apply(init_vars, x, mutable=['counter'])
init_vars = Test(parent=None).init(rngs, x)
_, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter'])
self.assertEqual(
init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32)
)
Expand Down Expand Up @@ -619,8 +619,8 @@ def __call__(self, x):

x = jnp.ones((1, 1))
rngs = random.PRNGKey(0)
init_vars = Test(None).init(rngs, x)
_, new_vars = Test(None).apply(init_vars, x, mutable=['counter'])
init_vars = Test(parent=None).init(rngs, x)
_, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter'])
self.assertEqual(
init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32)
)
Expand Down Expand Up @@ -662,8 +662,8 @@ def __call__(self, x):

x = jnp.ones((3, 1, 2))
rngs = random.PRNGKey(0)
init_vars = Test(None).init(rngs, x)
y = Test(None).apply(init_vars, x)
init_vars = Test(parent=None).init(rngs, x)
y = Test(parent=None).apply(init_vars, x)
self.assertEqual(
init_vars['params']['outer']['Dense_0']['kernel'].shape, (3, 2, 5)
)
Expand Down

0 comments on commit 4879b4c

Please sign in to comment.