Skip to content

Commit

Permalink
call super on __init_subclas__ (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae authored Mar 14, 2023
1 parent 5bcb5d5 commit 30717af
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ dmypy.json

# Pyre type checker
.pyre/

# project specific
.vscode
/tmp
7 changes: 0 additions & 7 deletions .vscode/settings.json

This file was deleted.

1 change: 1 addition & 0 deletions simple_pytree/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class Pytree(metaclass=PytreeMeta):
_pytree__class_is_mutable: bool

def __init_subclass__(cls, mutable: bool = False):
super().__init_subclass__()
# init class variables
cls._pytree__initialized = False # initialize mutable
cls._pytree__class_is_mutable = mutable
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
from typing import Generic, TypeVar

import jax
import pytest
Expand Down Expand Up @@ -114,6 +115,15 @@ class Foo(Pytree):
with pytest.raises(ValueError, match="Unknown field"):
serialization.from_state_dict(foo, state_dict)

def test_generics(self):
T = TypeVar("T")

class MyClass(Pytree, Generic[T]):
def __init__(self, x: T):
self.x = x

MyClass[int]


class TestMutablePytree:
def test_pytree(self):
Expand Down

0 comments on commit 30717af

Please sign in to comment.