Skip to content

Commit

Permalink
find static fields from dataclass parents (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae authored Sep 6, 2023
1 parent 9ee3fe6 commit 359a623
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
13 changes: 7 additions & 6 deletions simple_pytree/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ def __setattr__(self: P, field: str, value: tp.Any):
def _inherited_static_fields(cls: type) -> tp.Set[str]:
static_fields = set()
for parent_class in cls.mro():
if (
parent_class is not cls
and parent_class is not Pytree
and issubclass(parent_class, Pytree)
):
static_fields.update(parent_class._pytree__static_fields)
if parent_class is not cls and parent_class is not Pytree:
if issubclass(parent_class, Pytree):
static_fields.update(parent_class._pytree__static_fields)
elif dataclasses.is_dataclass(parent_class):
for field in dataclasses.fields(parent_class):
if not field.metadata.get("pytree_node", True):
static_fields.add(field.name)
return static_fields
19 changes: 19 additions & 0 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from typing import Generic, TypeVar

import jax
Expand Down Expand Up @@ -270,3 +271,21 @@ class Foo(Pytree, mutable=True):
# test mutation
pytree.x = 4
assert pytree.x == 4

def test_dataclass_inheritance(self):
A = dataclasses.make_dataclass(
"A",
[("x", int), "y", ("z", int, static_field(default=5))],
)

@dataclass
class B(Pytree, A):
...

b = B(1, 2)

assert b.x == 1
assert b.y == 2
assert b.z == 5

assert jax.tree_util.tree_leaves(b) == [1, 2]

0 comments on commit 359a623

Please sign in to comment.