Skip to content

Commit

Permalink
added cursor api
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Aug 1, 2023
1 parent 675e34d commit a15ae0b
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 0 deletions.
213 changes: 213 additions & 0 deletions flax/cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import enum
from sqlite3 import Cursor
from typing import Any, Callable, Generic, Mapping, NamedTuple, Optional, Protocol, Sequence, Tuple, TypeVar, Union, runtime_checkable
import jax
from jax._src.tree_util import SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey
from flax.core import freeze, FrozenDict
import dataclasses
from flax.training.train_state import TrainState
import optax

A = TypeVar('A')
Key = Any


@runtime_checkable
class Indexable(Protocol):

def __getitem__(self, key) -> Any:
...


class AccessType(enum.Enum):
GETATTR = enum.auto()
GETITEM = enum.auto()


# def flatten_until_found(tree: Any, targets: Sequence[Any]):
# remaining = len(targets)

# def is_leaf(x):
# nonlocal remaining
# leaf = False
# if x in targets:
# remaining -= 1
# leaf = True
# return leaf or remaining <= 0

# value = jax.tree_util.tree_flatten(tree, is_leaf=is_leaf)

# if remaining > 0:
# return None
# else:
# return value


@dataclasses.dataclass
class ParentKey(Generic[A]):
parent: 'Cursor[A]'
key: Any


class Cursor(Generic[A]):
obj: A
parent_key: Optional[ParentKey[A]]
changes: dict[Any, Union[Any, 'Cursor[A]']]

def __init__(self, obj: A, parent_key: Optional[ParentKey[A]]):
# NOTE: we use `vars` here to avoid calling `__setattr__`
# vars(self) = self.__dict__
vars(self)['obj'] = obj
vars(self)['parent_key'] = parent_key
vars(self)['changes'] = {}

@property
def root(self) -> 'Cursor[A]':
if self.parent_key is None:
return self
else:
return self.parent_key.parent.root

def __getitem__(self, key) -> 'Cursor[A]':
if key in self.changes:
return self.changes[key]

if not isinstance(self.obj, Indexable):
raise TypeError(f'Cannot index into {self.obj}')

if isinstance(self.obj, Mapping) and key not in self.obj:
raise KeyError(f'Key {key} not found in {self.obj}')

child = Cursor(self.obj[key], ParentKey(self, key))
self.changes[key] = child
return child

def __getattr__(self, name) -> 'Cursor[A]':
if name in self.changes:
return self.changes[name]

if not hasattr(self.obj, name):
raise AttributeError(f'Attribute {name} not found in {self.obj}')

child = Cursor(getattr(self.obj, name), ParentKey(self, name))
self.changes[name] = child
return child

def __setitem__(self, key, value):
self.changes[key] = value

def __setattr__(self, name, value):
self.changes[name] = value

def apply_filter(
self, filter_fn: Callable[[Sequence[str], Any], Tuple[bool, Any]]
):
"""
def increment_ints_at_layer1(path: str, value):
if 'layer1' in path and isinstance(value, int)
return True, value + 1
else:
return False, value
c = cursor(config)
c.apply_filter(increment_ints_at_layer1)
config = c.build()
"""

def get_child(obj, key):
if isinstance(key, SequenceKey):
return obj[key.idx]
if isinstance(key, (DictKey, FlattenedIndexKey)):
return obj[key.key]
if isinstance(key, GetAttrKey):
return getattr(obj, key.name)
raise KeyError(
f'Key {key} of type {type(key)} is not a valid key type. Must be one'
' of [SequenceKey, GetAttrKey, DictKey, FlattenedIndexKey].'
)

def set_child(obj, key, value):
if isinstance(key, SequenceKey):
obj[key.idx] = value
elif isinstance(key, (DictKey, FlattenedIndexKey)):
obj[key.key] = value
elif isinstance(key, GetAttrKey):
setattr(obj, key.name, value)
else:
raise KeyError(
f'Key {key} of type {type(key)} is not a valid key type. Must be'
' one of [SequenceKey, GetAttrKey, DictKey, FlattenedIndexKey].'
)

def get_key_str(key):
if isinstance(key, SequenceKey):
return str(key.idx)
elif isinstance(key, (DictKey, FlattenedIndexKey)):
return str(key.key)
elif isinstance(key, GetAttrKey):
return str(key.name)
else:
raise KeyError(
f'Key {key} of type {type(key)} is not a valid key type. Must be'
' one of [SequenceKey, GetAttrKey, DictKey, FlattenedIndexKey].'
)

flattened, _ = jax.tree_util.tree_flatten_with_path(self.obj)
for key_path, value in flattened:
valid, new_value = filter_fn(
[get_key_str(key) for key in key_path], value
)
if valid:
obj = self
for key in key_path[:-1]:
obj = get_child(obj, key)
set_child(obj, key_path[-1], new_value)

def build(self) -> A:
changes = {
key: child.build() if isinstance(child, Cursor) else child
for key, child in self.changes.items()
}
if isinstance(self.obj, FrozenDict):
obj = self.obj.copy(changes)
elif isinstance(self.obj, (dict, list)):
obj = self.obj.copy()
for key, value in changes.items():
obj[key] = value
elif isinstance(self.obj, tuple):
obj = list(self.obj)
for key, value in changes.items():
obj[key] = value
obj = tuple(obj)
elif dataclasses.is_dataclass(self.obj):
obj = dataclasses.replace(self.obj, **changes)
else:
raise ValueError(f'Cannot build object of type {type(self.obj).__name__}')

# NOTE: There is a way to try to do a general replace for pytrees, but it requires
# the key of `changes` to store the type of access (getattr, getitem, etc.)
# in order to access those value from the original object and try to replace them
# with the new value. For simplicity, this is not implemented for now.
# ----------------------
# changed_values = tuple(changes.values())
# result = flatten_until_found(self.obj, changed_values)

# if result is None:
# raise ValueError('Cannot find object in parent')

# leaves, treedef = result
# leaves = [leaf if leaf is not self.obj else value for leaf in leaves]
# obj = jax.tree_util.tree_unflatten(treedef, leaves)

return obj # type: ignore

def set(self, value) -> A:
if self.parent_key is None:
return value
parent, key = self.parent_key.parent, self.parent_key.key
parent.changes[key] = value
return parent.root.build()


def cursor(obj: A) -> Cursor[A]:
return Cursor(obj, None)
137 changes: 137 additions & 0 deletions tests/cursor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for flax.struct."""


from absl.testing import absltest
import jax
import jax.numpy as jnp
import optax

import flax.linen as nn
from flax.core import freeze
from flax.cursor import cursor
from flax.training import train_state

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()


class CursorTest(absltest.TestCase):

def test_set_and_build(self):
# test regular dicts and FrozenDicts
dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
for d, freeze_wrap in ((dict_obj, lambda x: x), (freeze(dict_obj), freeze)):
# set API
self.assertEqual(
cursor(d)['b'][0].set(10),
freeze_wrap({'a': 1, 'b': (10, 3), 'c': [4, 5]}),
)
# build API
c = cursor(d)
c['b'][0] = 20
c['a'] = (100, 200)
d2 = c.build()
self.assertEqual(
d2, freeze_wrap({'a': (100, 200), 'b': (20, 3), 'c': [4, 5]})
)

# test lists and tuples
list_obj = [0, dict_obj, (1, 2), [3, 4, 5]]
for l, tuple_wrap in ((list_obj, lambda x: x), (tuple(list_obj), tuple)):
# set API
self.assertEqual(
cursor(l)[1]['b'][0].set(10),
tuple_wrap(
[0, {'a': 1, 'b': (10, 3), 'c': [4, 5]}, (1, 2), [3, 4, 5]]
),
)
# build API
c = cursor(l)
c[1]['b'][0] = 20
c[2] = (100, 200)
l2 = c.build()
self.assertEqual(
l2,
tuple_wrap(
[0, {'a': 1, 'b': (20, 3), 'c': [4, 5]}, (100, 200), [3, 4, 5]]
),
)

# test TrainState
state = train_state.TrainState.create(
apply_fn=lambda x: x,
params=dict_obj,
tx=optax.adam(1e-3),
)
# set API
self.assertEqual(
cursor(state).params['b'][0].set(10).params,
{'a': 1, 'b': (10, 3), 'c': [4, 5]},
)
# build API
new_fn = lambda x: x + 1
c = cursor(state)
c.apply_fn = new_fn
c.params['b'][0] = 20
c.params['a'] = (100, 200)
state2 = c.build()
self.assertEqual(state2.apply_fn, new_fn)
self.assertEqual(
state2.params, {'a': (100, 200), 'b': (20, 3), 'c': [4, 5]}
)

def test_apply_filter(self):
class Model(nn.Module):

@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = nn.relu(x)
x = nn.Dense(3)(x)
x = nn.relu(x)
x = nn.Dense(3)(x)
x = nn.relu(x)
return x

params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

def filter_fn(path, value):
if 'kernel' in path:
return True, value * 2 + 1
elif 'Dense_1' in path and 'bias' in path:
return True, value - 1
return False, value

c = cursor(params)
c.apply_filter(filter_fn)
new_params = c.build()
for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
self.assertTrue(
(new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
)
if layer == 'Dense_1':
self.assertTrue(
(new_params[layer]['bias'] == jnp.array([-1, -1, -1])).all()
)
else:
self.assertTrue(
(new_params[layer]['bias'] == params[layer]['bias']).all()
)


if __name__ == '__main__':
absltest.main()

0 comments on commit a15ae0b

Please sign in to comment.