diff --git a/flax/cursor.py b/flax/cursor.py new file mode 100644 index 0000000000..d38e24a0c6 --- /dev/null +++ b/flax/cursor.py @@ -0,0 +1,213 @@ +import enum +from sqlite3 import Cursor +from typing import Any, Callable, Generic, Mapping, NamedTuple, Optional, Protocol, Sequence, Tuple, TypeVar, Union, runtime_checkable +import jax +from jax._src.tree_util import SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey +from flax.core import freeze, FrozenDict +import dataclasses +from flax.training.train_state import TrainState +import optax + +A = TypeVar('A') +Key = Any + + +@runtime_checkable +class Indexable(Protocol): + + def __getitem__(self, key) -> Any: + ... + + +class AccessType(enum.Enum): + GETATTR = enum.auto() + GETITEM = enum.auto() + + +# def flatten_until_found(tree: Any, targets: Sequence[Any]): +# remaining = len(targets) + +# def is_leaf(x): +# nonlocal remaining +# leaf = False +# if x in targets: +# remaining -= 1 +# leaf = True +# return leaf or remaining <= 0 + +# value = jax.tree_util.tree_flatten(tree, is_leaf=is_leaf) + +# if remaining > 0: +# return None +# else: +# return value + + +@dataclasses.dataclass +class ParentKey(Generic[A]): + parent: 'Cursor[A]' + key: Any + + +class Cursor(Generic[A]): + obj: A + parent_key: Optional[ParentKey[A]] + changes: dict[Any, Union[Any, 'Cursor[A]']] + + def __init__(self, obj: A, parent_key: Optional[ParentKey[A]]): + # NOTE: we use `vars` here to avoid calling `__setattr__` + # vars(self) = self.__dict__ + vars(self)['obj'] = obj + vars(self)['parent_key'] = parent_key + vars(self)['changes'] = {} + + @property + def root(self) -> 'Cursor[A]': + if self.parent_key is None: + return self + else: + return self.parent_key.parent.root + + def __getitem__(self, key) -> 'Cursor[A]': + if key in self.changes: + return self.changes[key] + + if not isinstance(self.obj, Indexable): + raise TypeError(f'Cannot index into {self.obj}') + + if isinstance(self.obj, Mapping) and key not in self.obj: + raise KeyError(f'Key {key} not found in {self.obj}') + + child = Cursor(self.obj[key], ParentKey(self, key)) + self.changes[key] = child + return child + + def __getattr__(self, name) -> 'Cursor[A]': + if name in self.changes: + return self.changes[name] + + if not hasattr(self.obj, name): + raise AttributeError(f'Attribute {name} not found in {self.obj}') + + child = Cursor(getattr(self.obj, name), ParentKey(self, name)) + self.changes[name] = child + return child + + def __setitem__(self, key, value): + self.changes[key] = value + + def __setattr__(self, name, value): + self.changes[name] = value + + def apply_filter( + self, filter_fn: Callable[[Sequence[str], Any], Tuple[bool, Any]] + ): + """ + def increment_ints_at_layer1(path: str, value): + if 'layer1' in path and isinstance(value, int) + return True, value + 1 + else: + return False, value + + c = cursor(config) + c.apply_filter(increment_ints_at_layer1) + config = c.build() + """ + + def get_child(obj, key): + if isinstance(key, SequenceKey): + return obj[key.idx] + if isinstance(key, (DictKey, FlattenedIndexKey)): + return obj[key.key] + if isinstance(key, GetAttrKey): + return getattr(obj, key.name) + raise KeyError( + f'Key {key} of type {type(key)} is not a valid key type. Must be one' + ' of [SequenceKey, GetAttrKey, DictKey, FlattenedIndexKey].' + ) + + def set_child(obj, key, value): + if isinstance(key, SequenceKey): + obj[key.idx] = value + elif isinstance(key, (DictKey, FlattenedIndexKey)): + obj[key.key] = value + elif isinstance(key, GetAttrKey): + setattr(obj, key.name, value) + else: + raise KeyError( + f'Key {key} of type {type(key)} is not a valid key type. Must be' + ' one of [SequenceKey, GetAttrKey, DictKey, FlattenedIndexKey].' + ) + + def get_key_str(key): + if isinstance(key, SequenceKey): + return str(key.idx) + elif isinstance(key, (DictKey, FlattenedIndexKey)): + return str(key.key) + elif isinstance(key, GetAttrKey): + return str(key.name) + else: + raise KeyError( + f'Key {key} of type {type(key)} is not a valid key type. Must be' + ' one of [SequenceKey, GetAttrKey, DictKey, FlattenedIndexKey].' + ) + + flattened, _ = jax.tree_util.tree_flatten_with_path(self.obj) + for key_path, value in flattened: + valid, new_value = filter_fn( + [get_key_str(key) for key in key_path], value + ) + if valid: + obj = self + for key in key_path[:-1]: + obj = get_child(obj, key) + set_child(obj, key_path[-1], new_value) + + def build(self) -> A: + changes = { + key: child.build() if isinstance(child, Cursor) else child + for key, child in self.changes.items() + } + if isinstance(self.obj, FrozenDict): + obj = self.obj.copy(changes) + elif isinstance(self.obj, (dict, list)): + obj = self.obj.copy() + for key, value in changes.items(): + obj[key] = value + elif isinstance(self.obj, tuple): + obj = list(self.obj) + for key, value in changes.items(): + obj[key] = value + obj = tuple(obj) + elif dataclasses.is_dataclass(self.obj): + obj = dataclasses.replace(self.obj, **changes) + else: + raise ValueError(f'Cannot build object of type {type(self.obj).__name__}') + + # NOTE: There is a way to try to do a general replace for pytrees, but it requires + # the key of `changes` to store the type of access (getattr, getitem, etc.) + # in order to access those value from the original object and try to replace them + # with the new value. For simplicity, this is not implemented for now. + # ---------------------- + # changed_values = tuple(changes.values()) + # result = flatten_until_found(self.obj, changed_values) + + # if result is None: + # raise ValueError('Cannot find object in parent') + + # leaves, treedef = result + # leaves = [leaf if leaf is not self.obj else value for leaf in leaves] + # obj = jax.tree_util.tree_unflatten(treedef, leaves) + + return obj # type: ignore + + def set(self, value) -> A: + if self.parent_key is None: + return value + parent, key = self.parent_key.parent, self.parent_key.key + parent.changes[key] = value + return parent.root.build() + + +def cursor(obj: A) -> Cursor[A]: + return Cursor(obj, None) diff --git a/tests/cursor_test.py b/tests/cursor_test.py new file mode 100644 index 0000000000..56851e7a95 --- /dev/null +++ b/tests/cursor_test.py @@ -0,0 +1,137 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flax.struct.""" + + +from absl.testing import absltest +import jax +import jax.numpy as jnp +import optax + +import flax.linen as nn +from flax.core import freeze +from flax.cursor import cursor +from flax.training import train_state + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + + +class CursorTest(absltest.TestCase): + + def test_set_and_build(self): + # test regular dicts and FrozenDicts + dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} + for d, freeze_wrap in ((dict_obj, lambda x: x), (freeze(dict_obj), freeze)): + # set API + self.assertEqual( + cursor(d)['b'][0].set(10), + freeze_wrap({'a': 1, 'b': (10, 3), 'c': [4, 5]}), + ) + # build API + c = cursor(d) + c['b'][0] = 20 + c['a'] = (100, 200) + d2 = c.build() + self.assertEqual( + d2, freeze_wrap({'a': (100, 200), 'b': (20, 3), 'c': [4, 5]}) + ) + + # test lists and tuples + list_obj = [0, dict_obj, (1, 2), [3, 4, 5]] + for l, tuple_wrap in ((list_obj, lambda x: x), (tuple(list_obj), tuple)): + # set API + self.assertEqual( + cursor(l)[1]['b'][0].set(10), + tuple_wrap( + [0, {'a': 1, 'b': (10, 3), 'c': [4, 5]}, (1, 2), [3, 4, 5]] + ), + ) + # build API + c = cursor(l) + c[1]['b'][0] = 20 + c[2] = (100, 200) + l2 = c.build() + self.assertEqual( + l2, + tuple_wrap( + [0, {'a': 1, 'b': (20, 3), 'c': [4, 5]}, (100, 200), [3, 4, 5]] + ), + ) + + # test TrainState + state = train_state.TrainState.create( + apply_fn=lambda x: x, + params=dict_obj, + tx=optax.adam(1e-3), + ) + # set API + self.assertEqual( + cursor(state).params['b'][0].set(10).params, + {'a': 1, 'b': (10, 3), 'c': [4, 5]}, + ) + # build API + new_fn = lambda x: x + 1 + c = cursor(state) + c.apply_fn = new_fn + c.params['b'][0] = 20 + c.params['a'] = (100, 200) + state2 = c.build() + self.assertEqual(state2.apply_fn, new_fn) + self.assertEqual( + state2.params, {'a': (100, 200), 'b': (20, 3), 'c': [4, 5]} + ) + + def test_apply_filter(self): + class Model(nn.Module): + + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + return x + + params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] + + def filter_fn(path, value): + if 'kernel' in path: + return True, value * 2 + 1 + elif 'Dense_1' in path and 'bias' in path: + return True, value - 1 + return False, value + + c = cursor(params) + c.apply_filter(filter_fn) + new_params = c.build() + for layer in ('Dense_0', 'Dense_1', 'Dense_2'): + self.assertTrue( + (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all() + ) + if layer == 'Dense_1': + self.assertTrue( + (new_params[layer]['bias'] == jnp.array([-1, -1, -1])).all() + ) + else: + self.assertTrue( + (new_params[layer]['bias'] == params[layer]['bias']).all() + ) + + +if __name__ == '__main__': + absltest.main()