From 836e0e439bc3dc6f9fe91446a1e50bc3efe89e70 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 25 Aug 2023 15:53:59 +0000 Subject: [PATCH] check number of positional args on Module __init__ --- flax/linen/kw_only_dataclasses.py | 50 ++++++++++++++++++++++--- tests/linen/kw_only_dataclasses_test.py | 2 +- tests/linen/linen_module_test.py | 12 ++++++ tests/linen/linen_transforms_test.py | 24 ++++++------ 4 files changed, 69 insertions(+), 19 deletions(-) diff --git a/flax/linen/kw_only_dataclasses.py b/flax/linen/kw_only_dataclasses.py index d45eb59a01..84d2ec0032 100644 --- a/flax/linen/kw_only_dataclasses.py +++ b/flax/linen/kw_only_dataclasses.py @@ -51,6 +51,17 @@ class that defines a field with a default, and a subclass that defines a field """ import dataclasses +import inspect +import functools +from types import MappingProxyType +from typing import Any, TypeVar +from typing_extensions import dataclass_transform +import flax + +M = TypeVar('M', bound='flax.linen.Module') +FieldName = str +Annotation = Any +Default = Any class _KwOnlyType: @@ -88,6 +99,7 @@ def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): return dataclasses.field(metadata=metadata, **kwargs) +@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] def dataclass(cls=None, extra_fields=None, **kwargs): """Wrapper for dataclasses.dataclass that adds support for kw_only fields. @@ -111,7 +123,7 @@ def wrap(cls): return wrap if cls is None else wrap(cls) -def _process_class(cls, extra_fields=None, **kwargs): +def _process_class(cls: type[M], extra_fields=None, **kwargs): """Transforms `cls` into a dataclass that supports kw_only fields.""" if '__annotations__' not in cls.__dict__: cls.__annotations__ = {} @@ -122,7 +134,7 @@ def _process_class(cls, extra_fields=None, **kwargs): base_dataclass_fields = {} # dict[cls, cls.__dataclass_fields__.copy()] # The keyword only fields from `cls` or any of its base classes. - kw_only_fields = {} # dict[field_name, tuple[annotation, default]] + kw_only_fields: dict[FieldName, tuple[Annotation, Default]] = {} # Scan for KW_ONLY marker. kw_only_name = None @@ -138,7 +150,7 @@ def _process_class(cls, extra_fields=None, **kwargs): ) default = getattr(cls, name) if isinstance(default, dataclasses.Field): - default.metadata = {**default.metadata, **{KW_ONLY: True}} + default.metadata = MappingProxyType({**default.metadata, KW_ONLY: True}) else: default = field(default=default, kw_only=True) setattr(cls, name, default) @@ -190,12 +202,38 @@ def _process_class(cls, extra_fields=None, **kwargs): cls_annotations.pop(name, None) cls_annotations[name] = annotation + create_init = '__init__' not in vars(cls) and kwargs.get('init', True) + # Apply the dataclass transform. - transformed_cls = dataclasses.dataclass(cls, **kwargs) + transformed_cls: type[M] = dataclasses.dataclass(cls, **kwargs) # Restore the base classes' __dataclass_fields__. - for cls, dataclass_fields in base_dataclass_fields.items(): - cls.__dataclass_fields__ = dataclass_fields + for _cls, fields in base_dataclass_fields.items(): + _cls.__dataclass_fields__ = fields + + if create_init: + dataclass_init = transformed_cls.__init__ + # use sum to count the number of init fields that are not keyword-only + expected_num_args = sum( + f.init and not f.metadata.get(KW_ONLY, False) + for f in dataclasses.fields(transformed_cls) + ) + + @functools.wraps(dataclass_init) + def init_wrapper(self, *args, **kwargs): + num_args = len(args) + if num_args > expected_num_args: + # we add + 1 to each to account for `self`, matching python's + # default error message + raise TypeError( + f'__init__() takes {expected_num_args + 1} positional ' + f'arguments but {num_args + 1} were given' + ) + + dataclass_init(self, *args, **kwargs) + + init_wrapper.__signature__ = inspect.signature(dataclass_init) # type: ignore + transformed_cls.__init__ = init_wrapper # type: ignore[method-assign] # Return the transformed dataclass return transformed_cls diff --git a/tests/linen/kw_only_dataclasses_test.py b/tests/linen/kw_only_dataclasses_test.py index 9d3e455e39..44cf33c40a 100644 --- a/tests/linen/kw_only_dataclasses_test.py +++ b/tests/linen/kw_only_dataclasses_test.py @@ -62,7 +62,7 @@ class Child(Parent): v1 = Child(4) self.assertDictEqual(dataclasses.asdict(v1), dict(a=2, b=4)) - v2 = Child(4, 5) # pylint: disable=too-many-function-args + v2 = Child(4, a=5) # pylint: disable=too-many-function-args self.assertDictEqual(dataclasses.asdict(v2), dict(a=5, b=4)) def test_subclass_overrides_base(self): diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index da49524e24..5a7e053a2d 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2214,6 +2214,18 @@ def __call__(self, x): # take positional arg. It takes BaseLayer's default kwargs though. np.testing.assert_equal(ChildLayer(8)(np.ones(10)), -8 * np.ones(10)) + def test_positional_cannot_be_kw_only(self): + class Foo(nn.Module): + a: int + + Foo(1) # ok + Foo(a=1) # ok + with self.assertRaisesRegex( + TypeError, r'takes 2 positional arguments but 3 were' + ): + Foo(1, None) + Foo(a=1, parent=None) # type: ignore[call-arg] + def test_intercept_methods(self): mod = IdentityModule(parent=None) x = jnp.ones([]) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 6c9b5c03ff..079bb00491 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -414,8 +414,8 @@ def __call__(self, x): x = jnp.ones((10, 10)) rngs = random.PRNGKey(0) - init_vars = Test(None).init(rngs, x) - _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) + init_vars = Test(parent=None).init(rngs, x) + _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) ) @@ -462,8 +462,8 @@ def __call__(self, x): x = jnp.ones((1, 1)) rngs = random.PRNGKey(0) - init_vars = Test(None).init(rngs, x) - _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) + init_vars = Test(parent=None).init(rngs, x) + _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) ) @@ -508,8 +508,8 @@ def __call__(self, x): x = jnp.ones((1, 1)) rngs = random.PRNGKey(0) - init_vars = Test(None).init(rngs, x) - _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) + init_vars = Test(parent=None).init(rngs, x) + _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32) ) @@ -563,8 +563,8 @@ def __call__(self, x): x = jnp.ones((1, 1)) rngs = random.PRNGKey(0) - init_vars = Test(None).init(rngs, x) - _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) + init_vars = Test(parent=None).init(rngs, x) + _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32) ) @@ -619,8 +619,8 @@ def __call__(self, x): x = jnp.ones((1, 1)) rngs = random.PRNGKey(0) - init_vars = Test(None).init(rngs, x) - _, new_vars = Test(None).apply(init_vars, x, mutable=['counter']) + init_vars = Test(parent=None).init(rngs, x) + _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) ) @@ -662,8 +662,8 @@ def __call__(self, x): x = jnp.ones((3, 1, 2)) rngs = random.PRNGKey(0) - init_vars = Test(None).init(rngs, x) - y = Test(None).apply(init_vars, x) + init_vars = Test(parent=None).init(rngs, x) + y = Test(parent=None).apply(init_vars, x) self.assertEqual( init_vars['params']['outer']['Dense_0']['kernel'].shape, (3, 2, 5) )