diff --git a/docs/api_reference/flax.cursor.rst b/docs/api_reference/flax.cursor.rst index 243c1cc581..56ace5bc14 100644 --- a/docs/api_reference/flax.cursor.rst +++ b/docs/api_reference/flax.cursor.rst @@ -18,7 +18,7 @@ To illustrate, consider the example below:: a = A(A(A(A(A(A(A(0))))))) -To replace the int 0 using ``dataclasses.replace``, we would have to write many nested calls:: +To replace the int ``0`` using ``dataclasses.replace``, we would have to write many nested calls:: a2 = dataclasses.replace( a, @@ -55,6 +55,6 @@ generating a new copy of the original object with the accumulated changes. .. autofunction:: cursor .. autoclass:: Cursor - :members: build, set, apply_update + :members: apply_update, build, find, find_all, set diff --git a/flax/cursor.py b/flax/cursor.py index 49561f1bb0..ea12ad0d0d 100644 --- a/flax/cursor.py +++ b/flax/cursor.py @@ -15,6 +15,7 @@ import enum from typing import Any, Callable, Dict, Generator, Generic, Mapping, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable from flax.core import FrozenDict +from flax.errors import CursorFindError, TraverseTreeError import dataclasses @@ -29,15 +30,16 @@ def __getitem__(self, key) -> Any: ... +class AccessType(enum.Enum): + ITEM = enum.auto() + ATTR = enum.auto() + + @dataclasses.dataclass class ParentKey(Generic[A]): parent: 'Cursor[A]' key: Any - - -class AccessType(enum.Enum): - GETITEM = enum.auto() - GETATTR = enum.auto() + access_type: AccessType def is_named_tuple(obj): @@ -49,97 +51,137 @@ def is_named_tuple(obj): ) -def _get_changes(path, obj, update_fn): - """Helper function for ``Cursor.apply_update``. Returns a generator of - Tuple[Tuple[Union[str, int], Any], ...], where the first element is a - tuple key path where the change was applied from the ``update_fn``, and - the second element is the newly modified value. If the generator is - non-empty, then the tuple key path will always be non-empty as well.""" +def _traverse_tree(path, obj, *, update_fn=None, cond_fn=None): + """Helper function for ``Cursor.apply_update`` and ``Cursor.find_all``. + Exactly one of ``update_fn`` and ``cond_fn`` must be not None. + + - If ``update_fn`` is not None, then ``Cursor.apply_update`` is calling + this function and ``_traverse_tree`` will return a generator where + each generated element is of type Tuple[Tuple[Union[str, int], AccessType], Any]. + The first element is a tuple of the key path and access type where the + change was applied from the ``update_fn``, and the second element is + the newly modified value. If the generator is non-empty, then the + tuple key path will always be non-empty as well. + - If ``cond_fn`` is not None, then ``Cursor.find_all`` is calling this + function and ``_traverse_tree`` will return a generator where each + generated element is of type Tuple[Union[str, int], AccessType]. The + tuple contains the key path and access type where the object was found + that fulfilled the conditions of the ``cond_fn``. + """ + if not (bool(update_fn) ^ bool(cond_fn)): + raise TraverseTreeError(update_fn, cond_fn) + if path: str_path = '/'.join(str(key) for key, _ in path) - new_obj = update_fn(str_path, obj) - if new_obj is not obj: - yield path, new_obj + if update_fn: + new_obj = update_fn(str_path, obj) + if new_obj is not obj: + yield path, new_obj + return + elif cond_fn(str_path, obj): # type: ignore + yield path return if isinstance(obj, (FrozenDict, dict)): items = obj.items() - access_type = AccessType.GETITEM + access_type = AccessType.ITEM elif is_named_tuple(obj): items = ((name, getattr(obj, name)) for name in obj._fields) # type: ignore - access_type = AccessType.GETATTR + access_type = AccessType.ATTR elif isinstance(obj, (list, tuple)): items = enumerate(obj) - access_type = AccessType.GETITEM + access_type = AccessType.ITEM elif dataclasses.is_dataclass(obj): items = ( (f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj) if f.init ) - access_type = AccessType.GETATTR + access_type = AccessType.ATTR else: - yield from () # empty generator return - for key, value in items: - yield from _get_changes(path + ((key, access_type),), value, update_fn) + if update_fn: + for key, value in items: + yield from _traverse_tree( + path + ((key, access_type),), value, update_fn=update_fn + ) + else: + for key, value in items: + yield from _traverse_tree( + path + ((key, access_type),), value, cond_fn=cond_fn + ) class Cursor(Generic[A]): - obj: A - parent_key: Optional[ParentKey[A]] - changes: Dict[Any, Union[Any, 'Cursor[A]']] + _obj: A + _parent_key: Optional[ParentKey[A]] + _changes: Dict[Any, 'Cursor[A]'] def __init__(self, obj: A, parent_key: Optional[ParentKey[A]]): # NOTE: we use `vars` here to avoid calling `__setattr__` # vars(self) = self.__dict__ - vars(self)['obj'] = obj - vars(self)['parent_key'] = parent_key - vars(self)['changes'] = {} + vars(self)['_obj'] = obj + vars(self)['_parent_key'] = parent_key + vars(self)['_changes'] = {} @property - def root(self) -> 'Cursor[A]': - if self.parent_key is None: + def _root(self) -> 'Cursor[A]': + if self._parent_key is None: return self else: - return self.parent_key.parent.root # type: ignore + return self._parent_key.parent._root # type: ignore + + @property + def _path(self) -> str: + if self._parent_key is None: + return '' + if self._parent_key.access_type == AccessType.ITEM: # type: ignore + if isinstance(self._parent_key.key, str): # type: ignore + key = "'" + self._parent_key.key + "'" # type: ignore + else: + key = str(self._parent_key.key) # type: ignore + return self._parent_key.parent._path + '[' + key + ']' # type: ignore + # self.parent_key.access_type == AccessType.ATTR: + return self._parent_key.parent._path + '.' + self._parent_key.key # type: ignore def __getitem__(self, key) -> 'Cursor[A]': - if key in self.changes: - return self.changes[key] + if key in self._changes: + return self._changes[key] - if not isinstance(self.obj, Indexable): - raise TypeError(f'Cannot index into {self.obj}') + if not isinstance(self._obj, Indexable): + raise TypeError(f'Cannot index into {self._obj}') - if isinstance(self.obj, Mapping) and key not in self.obj: - raise KeyError(f'Key {key} not found in {self.obj}') + if isinstance(self._obj, Mapping) and key not in self._obj: + raise KeyError(f'Key {key} not found in {self._obj}') - if is_named_tuple(self.obj): - return getattr(self, self.obj._fields[key]) # type: ignore + if is_named_tuple(self._obj): + return getattr(self, self._obj._fields[key]) # type: ignore - child = Cursor(self.obj[key], ParentKey(self, key)) - self.changes[key] = child + child = Cursor(self._obj[key], ParentKey(self, key, AccessType.ITEM)) + self._changes[key] = child return child def __getattr__(self, name) -> 'Cursor[A]': - if name in self.changes: - return self.changes[name] + if name in self._changes: + return self._changes[name] - if not hasattr(self.obj, name): - raise AttributeError(f'Attribute {name} not found in {self.obj}') + if not hasattr(self._obj, name): + raise AttributeError(f'Attribute {name} not found in {self._obj}') - child = Cursor(getattr(self.obj, name), ParentKey(self, name)) - self.changes[name] = child + child = Cursor( + getattr(self._obj, name), ParentKey(self, name, AccessType.ATTR) + ) + self._changes[name] = child return child def __setitem__(self, key, value): - if is_named_tuple(self.obj): - return setattr(self, self.obj._fields[key], value) # type: ignore - self.changes[key] = Cursor(value, ParentKey(self, key)) + if is_named_tuple(self._obj): + return setattr(self, self._obj._fields[key], value) # type: ignore + self._changes[key] = Cursor(value, ParentKey(self, key, AccessType.ITEM)) def __setattr__(self, name, value): - self.changes[name] = Cursor(value, ParentKey(self, name)) + self._changes[name] = Cursor(value, ParentKey(self, name, AccessType.ATTR)) def set(self, value) -> A: """Set a new value for an attribute, property, element or entry @@ -169,11 +211,11 @@ def set(self, value) -> A: Returns: A copy of the original object with the new set value. """ - if self.parent_key is None: + if self._parent_key is None: return value - parent, key = self.parent_key.parent, self.parent_key.key # type: ignore - parent.changes[key] = value - return parent.root.build() + parent, key = self._parent_key.parent, self._parent_key.key # type: ignore + parent._changes[key] = Cursor(value, self._parent_key) + return parent._root.build() def build(self) -> A: """Create and return a copy of the original object with accumulated changes. @@ -213,70 +255,56 @@ def build(self) -> A: """ changes = { key: child.build() if isinstance(child, Cursor) else child - for key, child in self.changes.items() + for key, child in self._changes.items() } - if isinstance(self.obj, FrozenDict): - obj = self.obj.copy(changes) # type: ignore - elif isinstance(self.obj, (dict, list)): - obj = self.obj.copy() # type: ignore + if isinstance(self._obj, FrozenDict): + obj = self._obj.copy(changes) # type: ignore + elif isinstance(self._obj, (dict, list)): + obj = self._obj.copy() # type: ignore for key, value in changes.items(): obj[key] = value - elif is_named_tuple(self.obj): - obj = self.obj._replace(**changes) # type: ignore - elif isinstance(self.obj, tuple): - obj = list(self.obj) # type: ignore + elif is_named_tuple(self._obj): + obj = self._obj._replace(**changes) # type: ignore + elif isinstance(self._obj, tuple): + obj = list(self._obj) # type: ignore for key, value in changes.items(): obj[key] = value obj = tuple(obj) # type: ignore - elif dataclasses.is_dataclass(self.obj): - obj = dataclasses.replace(self.obj, **changes) # type: ignore + elif dataclasses.is_dataclass(self._obj): + obj = dataclasses.replace(self._obj, **changes) # type: ignore else: - obj = self.obj # type: ignore - - # NOTE: There is a way to try to do a general replace for pytrees, but it requires - # the key of `changes` to store the type of access (getattr, getitem, etc.) - # in order to access those value from the original object and try to replace them - # with the new value. For simplicity, this is not implemented for now. - # ---------------------- - # changed_values = tuple(changes.values()) - # result = flatten_until_found(self.obj, changed_values) - - # if result is None: - # raise ValueError('Cannot find object in parent') - - # leaves, treedef = result - # leaves = [leaf if leaf is not self.obj else value for leaf in leaves] - # obj = jax.tree_util.tree_unflatten(treedef, leaves) - + obj = self._obj # type: ignore return obj # type: ignore def apply_update( self, update_fn: Callable[[str, Any], Any], ) -> 'Cursor[A]': - """Traverse the Cursor object and apply conditional changes recursively via an ``update_fn``. + """Traverse the Cursor object and record conditional changes recursively via an ``update_fn``. + The changes are recorded in the Cursor object's ``._changes`` dictionary. To generate a copy + of the original object with the accumulated changes, call the ``.build`` method after calling + ``.apply_update``. + The ``update_fn`` has a function signature of ``(str, Any) -> Any``: - The input arguments are the current key path (in the form of a string delimited - by '/') and value at that current key path + by ``'/'``) and value at that current key path - The output is the new value (either modified by the ``update_fn`` or same as the input value if the condition wasn't fulfilled) - To generate a copy of the original object with the accumulated changes, call the ``.build`` method. - NOTES: - - If the ``update_fn`` returns a modified value, this function will not recurse any further - down that branch to apply changes. For example, if we intend to replace an attribute that points + - If the ``update_fn`` returns a modified value, this method will not recurse any further + down that branch to record changes. For example, if we intend to replace an attribute that points to a dictionary with an int, we don't need to look for further changes inside the dictionary, since the dictionary will be replaced anyways. - The ``is`` operator is used to determine whether the return value is modified (by comparing it to the input value). Therefore if the ``update_fn`` modifies a mutable container (e.g. lists, dicts, etc.) and returns the same container, ``.apply_update`` will treat the returned value as unmodified as it contains the same ``id``. To avoid this, return a copy of the modified value. - - The ``.apply_update`` WILL NOT apply the ``update_fn`` to the value at the top-most level of - the pytree (i.e. the root node). The ``update_fn`` will be applied recursively, starting at the - root node's children. + - ``.apply_update`` WILL NOT call the ``update_fn`` to the value at the top-most level of + the pytree (i.e. the root node). The ``update_fn`` will first be called on the root node's + children, and then the pytree traversal will continue recursively from there. Example:: @@ -312,7 +340,7 @@ def update_fn(path, value): for layer in ('Dense_0', 'Dense_1', 'Dense_2'): assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all() if layer == 'Dense_1': - assert (new_params[layer]['bias'] == jnp.array([-1, -1, -1])).all() + assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all() else: assert (new_params[layer]['bias'] == params[layer]['bias']).all() @@ -327,33 +355,375 @@ def update_fn(path, value): ) # make sure original params are unchanged Args: - update_fn: the function that will conditionally apply changes to the Cursor object + update_fn: the function that will conditionally record changes to the Cursor object Returns: - The current Cursor object with the updates applied by the ``update_fn``. + The current Cursor object with the recorded conditional changes specified by the + ``update_fn``. To generate a copy of the original object with the accumulated + changes, call the ``.build`` method after calling ``.apply_update``. """ - for path, value in _get_changes((), self.obj, update_fn): + for path, value in _traverse_tree((), self._obj, update_fn=update_fn): child = self for key, access_type in path[:-1]: - if access_type is AccessType.GETITEM: + if access_type is AccessType.ITEM: child = child[key] - else: # access_type is AccessType.GETATTR + else: # access_type is AccessType.ATTR child = getattr(child, key) key, access_type = path[-1] - if access_type is AccessType.GETITEM: + if access_type is AccessType.ITEM: child[key] = value - else: # access_type is AccessType.GETATTR + else: # access_type is AccessType.ATTR setattr(child, key, value) return self + def find(self, cond_fn: Callable[[str, Any], bool]) -> 'Cursor[A]': + """Traverse the Cursor object and return a child Cursor object that fulfill the + conditions in the ``cond_fn``. The ``cond_fn`` has a function signature of ``(str, Any) -> bool``: + + - The input arguments are the current key path (in the form of a string delimited + by ``'/'``) and value at that current key path + - The output is a boolean, denoting whether to return the child Cursor object at this path + + Raises a :meth:`CursorFindError ` if no object or more + than one object is found that fulfills the condition of the ``cond_fn``. We raise an + error because the user should always expect this method to return the only object whose + corresponding key path and value fulfill the condition of the ``cond_fn``. + + NOTES: + + - If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse + any further down that branch; i.e. this method will find and return the "earliest" child node + that fulfills the condition in ``cond_fn`` in a particular key path + - ``.find`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root + node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children. + + Example:: + + import flax.linen as nn + from flax.cursor import cursor + import jax + import jax.numpy as jnp + + class Model(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + return x + + params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] + + def cond_fn(path, value): + '''Find the second dense layer params.''' + return 'Dense_1' in path + + new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1) + + for layer in ('Dense_0', 'Dense_1', 'Dense_2'): + if layer == 'Dense_1': + assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() + else: + assert (new_params[layer]['bias'] == params[layer]['bias']).all() + + c = cursor(params) + c2 = c.find(cond_fn) + c2['kernel'] += 2 + c2['bias'] += 2 + new_params = c.build() + + for layer in ('Dense_0', 'Dense_1', 'Dense_2'): + if layer == 'Dense_1': + assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all() + assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all() + else: + assert (new_params[layer]['kernel'] == params[layer]['kernel']).all() + assert (new_params[layer]['bias'] == params[layer]['bias']).all() + + assert jax.tree_util.tree_all( + jax.tree_util.tree_map( + lambda x, y: (x == y).all(), + params, + Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ + 'params' + ], + ) + ) # make sure original params are unchanged + + Args: + cond_fn: the function that will conditionally find child Cursor objects + Returns: + A child Cursor object that fulfills the condition in the ``cond_fn``. + """ + generator = self.find_all(cond_fn) + try: + cursor = next(generator) + except StopIteration: + raise CursorFindError() + try: + cursor2 = next(generator) + raise CursorFindError(cursor, cursor2) + except StopIteration: + return cursor + + def find_all( + self, cond_fn: Callable[[str, Any], bool] + ) -> Generator['Cursor[A]', None, None]: + """Traverse the Cursor object and return a generator of child Cursor objects that fulfill the + conditions in the ``cond_fn``. The ``cond_fn`` has a function signature of ``(str, Any) -> bool``: + + - The input arguments are the current key path (in the form of a string delimited + by ``'/'``) and value at that current key path + - The output is a boolean, denoting whether to return the child Cursor object at this path + + NOTES: + + - If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse + any further down that branch; i.e. this method will find and return the "earliest" child nodes + that fulfill the condition in ``cond_fn`` in a particular key path + - ``.find_all`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root + node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children. + + Example:: + + import flax.linen as nn + from flax.cursor import cursor + import jax + import jax.numpy as jnp + + class Model(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + return x + + params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] + + def cond_fn(path, value): + '''Find all dense layer params.''' + return 'Dense' in path + + c = cursor(params) + for dense_params in c.find_all(cond_fn): + dense_params['bias'] += 1 + new_params = c.build() + + for layer in ('Dense_0', 'Dense_1', 'Dense_2'): + assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() + + assert jax.tree_util.tree_all( + jax.tree_util.tree_map( + lambda x, y: (x == y).all(), + params, + Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ + 'params' + ], + ) + ) # make sure original params are unchanged + + Args: + cond_fn: the function that will conditionally find child Cursor objects + Returns: + A generator of child Cursor objects that fulfill the condition in the ``cond_fn``. + """ + for path in _traverse_tree((), self._obj, cond_fn=cond_fn): + child = self + for key, access_type in path: + if access_type is AccessType.ITEM: + child = child[key] + else: # access_type is AccessType.ATTR + child = getattr(child, key) + yield child + + def __str__(self): + return str(self._obj) + + def __repr__(self): + return self._pretty_repr() + + def _pretty_repr(self, indent=2, _prefix_indent=0): + s = 'Cursor(\n' + obj_str = repr(self._obj).replace( + '\n', '\n' + ' ' * (_prefix_indent + indent) + ) + s += ' ' * (_prefix_indent + indent) + f'_obj={obj_str},\n' + s += ' ' * (_prefix_indent + indent) + '_changes={' + if self._changes: + s += '\n' + for key in self._changes: + str_key = repr(key) + prefix = ' ' * (_prefix_indent + 2 * indent) + str_key + ': ' + s += ( + prefix + + self._changes[key]._pretty_repr( + indent=indent, _prefix_indent=len(prefix) + ) + + ',\n' + ) + s = s[ + :-2 + ] # remove comma and newline character for last element in self._changes + s += '\n' + ' ' * (_prefix_indent + indent) + '}\n' + else: + s += '}\n' + s += ' ' * _prefix_indent + ')' + return s + + def __len__(self): + return len(self._obj) + + def __iter__(self): + if isinstance(self._obj, (tuple, list)): + return (self[i] for i in range(len(self._obj))) + else: + raise NotImplementedError( + '__iter__ method only implemented for tuples and lists, not type' + f' {type(self._obj)}' + ) + + def __reversed__(self): + if isinstance(self._obj, (tuple, list)): + return (self[i] for i in range(len(self._obj) - 1, -1, -1)) + else: + raise NotImplementedError( + '__reversed__ method only implemented for tuples and lists, not type' + f' {type(self._obj)}' + ) + + def __add__(self, other): + return self._obj + other + + def __sub__(self, other): + return self._obj - other + + def __mul__(self, other): + return self._obj * other + + def __matmul__(self, other): + return self._obj @ other + + def __truediv__(self, other): + return self._obj / other + + def __floordiv__(self, other): + return self._obj // other + + def __mod__(self, other): + return self._obj % other + + def __divmod__(self, other): + return divmod(self._obj, other) + + def __pow__(self, other): + return pow(self._obj, other) + + def __lshift__(self, other): + return self._obj << other + + def __rshift__(self, other): + return self._obj >> other + + def __and__(self, other): + return self._obj & other + + def __xor__(self, other): + return self._obj ^ other + + def __or__(self, other): + return self._obj | other + + def __radd__(self, other): + return other + self._obj + + def __rsub__(self, other): + return other - self._obj + + def __rmul__(self, other): + return other * self._obj + + def __rmatmul__(self, other): + return other @ self._obj + + def __rtruediv__(self, other): + return other / self._obj + + def __rfloordiv__(self, other): + return other // self._obj + + def __rmod__(self, other): + return other % self._obj + + def __rdivmod__(self, other): + return divmod(other, self._obj) + + def __rpow__(self, other): + return pow(other, self._obj) + + def __rlshift__(self, other): + return other << self._obj + + def __rrshift__(self, other): + return other >> self._obj + + def __rand__(self, other): + return other & self._obj + + def __rxor__(self, other): + return other ^ self._obj + + def __ror__(self, other): + return other | self._obj + + def __neg__(self): + return -self._obj + + def __pos__(self): + return +self._obj + + def __abs__(self): + return abs(self._obj) + + def __invert__(self): + return ~self._obj + + def __round__(self, ndigits=None): + return round(self._obj, ndigits) + + def __lt__(self, other): + return self._obj < other + + def __le__(self, other): + return self._obj <= other + + def __eq__(self, other): + return self._obj == other + + def __ne__(self, other): + return self._obj != other + + def __gt__(self, other): + return self._obj > other + + def __ge__(self, other): + return self._obj >= other + def cursor(obj: A) -> Cursor[A]: - """Wrap Cursor over obj and return it. + """Wrap :class:`Cursor ` over ``obj`` and return it. Changes can then be applied to the Cursor object in the following ways: - single-line change via the ``.set`` method - multiple changes, and then calling the ``.build`` method - - multiple changes conditioned on the tree path and node value, via the ``.apply_update`` method + - multiple changes conditioned on the pytree path and node value via the + ``.apply_update`` method, and then calling the ``.build`` method ``.set`` example:: @@ -396,6 +766,16 @@ def update_fn(path, value): assert state2.params == {} assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged + If the underlying ``obj`` is a ``list`` or ``tuple``, iterating over the Cursor object + to get the child Cursors is also possible:: + + from flax.cursor import cursor + + c = cursor(((1, 2), (3, 4))) + for child_c in c: + child_c[1] *= -1 + assert c.build() == ((1, -2), (3, -4)) + View the docstrings for each method to see more examples of their usage. Args: diff --git a/flax/errors.py b/flax/errors.py index df28287728..80e03b9091 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -895,3 +895,53 @@ class AlreadyExistsError(FlaxError): def __init__(self, path): super().__init__(f'Trying overwrite an existing file: "{path}".') + + +################################################# +# cursor.py errors # +################################################# + + +class CursorFindError(FlaxError): + """Error when calling :meth:`Cursor.find() `. + + This error occurs if no object or more than one object is found, given + the conditions of the ``cond_fn``. + """ + + def __init__(self, cursor=None, cursor2=None): + if cursor and cursor2: + super().__init__( + 'More than one object found given the conditions of the cond_fn. ' + 'The first two objects found have the following paths: ' + f'{cursor._path} and {cursor2._path}' + ) + else: + super().__init__('No object found given the conditions of the cond_fn.') + + +class TraverseTreeError(FlaxError): + """Error when calling ``Cursor._traverse_tree()``. This function has two + modes: + + - if ``update_fn`` is not None, it will traverse the tree and return a + generator of tuples containing the path where the ``update_fn`` was + applied and the newly modified value. + - if ``cond_fn`` is not None, it will traverse the tree and return a + generator of tuple paths that fulfilled the conditions of the ``cond_fn``. + + This error occurs if either both ``update_fn`` and ``cond_fn`` are None, + or both are not None. + """ + + def __init__(self, update_fn, cond_fn): + if update_fn is None and cond_fn is None: + super().__init__( + 'Both update_fn and cond_fn are None. Exactly one of them must be' + ' None.' + ) + else: + super().__init__( + 'Both update_fn and cond_fn are not None. Exactly one of them must be' + ' not None.' + ) diff --git a/tests/cursor_test.py b/tests/cursor_test.py index 72c6672fae..f1c0e55a36 100644 --- a/tests/cursor_test.py +++ b/tests/cursor_test.py @@ -16,14 +16,17 @@ from absl.testing import absltest +import dataclasses +import flax import jax import jax.numpy as jnp import optax -from typing import Any, Generic, NamedTuple +from typing import Any, NamedTuple import flax.linen as nn from flax.core import freeze -from flax.cursor import cursor +from flax.cursor import cursor, _traverse_tree, AccessType +from flax.errors import CursorFindError, TraverseTreeError from flax.training import train_state # Parse absl flags test_srcdir and test_tmpdir. @@ -33,10 +36,370 @@ class GenericTuple(NamedTuple): x: Any y: Any = None + z: Any = None + + +@dataclasses.dataclass +class GenericDataClass: + x: Any + y: Any = None + z: Any = None class CursorTest(absltest.TestCase): + def test_repr(self): + g = GenericTuple(1, 'a', (2, 'b')) + c = cursor( + {'a': {1: {(2, 3): 'z', 4: g, '6': (7, 8)}, 'b': [1, 2, 3]}, 'z': -1} + ) + self.assertEqual( + repr(c), + """Cursor( + _obj={'a': {1: {(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, 'b': [1, 2, 3]}, 'z': -1}, + _changes={} +)""", + ) + + # test overwriting + c['z'] = -2 + c['z'] = -3 + c['a']['b'][1] = -2 + c['a']['b'] = None + + # test deep mutation + c['a'][1][4].x = (2, 4, 6) + c['a'][1][4].z[0] = flax.core.freeze({'a': 1, 'b': {'c': 2, 'd': 3}}) + + self.assertEqual( + repr(c), + """Cursor( + _obj={'a': {1: {(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, 'b': [1, 2, 3]}, 'z': -1}, + _changes={ + 'z': Cursor( + _obj=-3, + _changes={} + ), + 'a': Cursor( + _obj={1: {(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, 'b': [1, 2, 3]}, + _changes={ + 'b': Cursor( + _obj=None, + _changes={} + ), + 1: Cursor( + _obj={(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, + _changes={ + 4: Cursor( + _obj=GenericTuple(x=1, y='a', z=(2, 'b')), + _changes={ + 'x': Cursor( + _obj=(2, 4, 6), + _changes={} + ), + 'z': Cursor( + _obj=(2, 'b'), + _changes={ + 0: Cursor( + _obj=FrozenDict({ + a: 1, + b: { + c: 2, + d: 3, + }, + }), + _changes={} + ) + } + ) + } + ) + } + ) + } + ) + } +)""", + ) + + def test_magic_methods(self): + def same_value(v1, v2): + if isinstance(v1, tuple): + return all( + [ + jnp.all(jax.tree_map(lambda x, y: x == y, e1, e2)) + for e1, e2 in zip(v1, v2) + ] + ) + return jnp.all(jax.tree_map(lambda x, y: x == y, v1, v2)) + + list_obj = [(1, 2), (3, 4)] + for l, tuple_wrap in ((list_obj, lambda x: x), (tuple(list_obj), tuple)): + c = cursor(l) + # test __len__ + self.assertTrue(same_value(len(c), len(l))) + # test __iter__ + for i, child_c in enumerate(c): + child_c[1] += i + 1 + self.assertEqual(c.build(), tuple_wrap([(1, 3), (3, 6)])) + # test __reversed__ + for i, child_c in enumerate(reversed(c)): + child_c[1] += i + 1 + self.assertEqual(c.build(), tuple_wrap([(1, 5), (3, 7)])) + # test __iter__ error + with self.assertRaisesRegex( + NotImplementedError, + '__iter__ method only implemented for tuples and lists, not type ", + ): + c = cursor({'a': 1, 'b': 2}) + for key in c: + c[key] *= -1 + # test __iter__ error + with self.assertRaisesRegex( + NotImplementedError, + '__reversed__ method only implemented for tuples and lists, not type' + " ", + ): + c = cursor({'a': 1, 'b': 2}) + for key in reversed(c): + c[key] *= -1 + + for obj_value in (2, jnp.array([[1, -2], [3, 4]])): + for c in ( + cursor(obj_value), + cursor([obj_value])[0], + cursor((obj_value,))[0], + cursor({0: obj_value})[0], + cursor(flax.core.freeze({0: obj_value}))[0], + cursor(GenericTuple(x=obj_value)).x, + cursor(GenericDataClass(x=obj_value)).x, + ): + # test __neg__ + self.assertTrue(same_value(-c, -obj_value)) + # test __pos__ + self.assertTrue(same_value(+c, +obj_value)) + # test __abs__ + self.assertTrue(same_value(abs(-c), abs(-obj_value))) + # test __invert__ + self.assertTrue(same_value(~c, ~obj_value)) + # test __round__ + self.assertTrue(same_value(round(c + 0.123), round(obj_value + 0.123))) + self.assertTrue( + same_value(round(c + 0.123, 2), round(obj_value + 0.123, 2)) + ) + + for other_value in (3, jnp.array([[5, 6], [7, 8]])): + # test __add__ + self.assertTrue(same_value(c + other_value, obj_value + other_value)) + # test __radd__ + self.assertTrue(same_value(other_value + c, other_value + obj_value)) + # test __sub__ + self.assertTrue(same_value(c - other_value, obj_value - other_value)) + # test __rsub__ + self.assertTrue(same_value(other_value - c, other_value - obj_value)) + # test __mul__ + self.assertTrue(same_value(c * other_value, obj_value * other_value)) + # test __rmul__ + self.assertTrue(same_value(other_value * c, other_value * obj_value)) + # test __truediv__ + self.assertTrue(same_value(c / other_value, obj_value / other_value)) + # test __rtruediv__ + self.assertTrue(same_value(other_value / c, other_value / obj_value)) + # test __floordiv__ + self.assertTrue( + same_value(c // other_value, obj_value // other_value) + ) + # test __rfloordiv__ + self.assertTrue( + same_value(other_value // c, other_value // obj_value) + ) + # test __mod__ + self.assertTrue(same_value(c % other_value, obj_value % other_value)) + # test __rmod__ + self.assertTrue(same_value(other_value % c, other_value % obj_value)) + # test __divmod__ + self.assertTrue( + same_value(divmod(c, other_value), divmod(obj_value, other_value)) + ) + # test __rdivmod__ + self.assertTrue( + same_value(divmod(other_value, c), divmod(other_value, obj_value)) + ) + # test __pow__ + self.assertTrue( + same_value(pow(c, other_value), pow(obj_value, other_value)) + ) + # test __rpow__ + self.assertTrue( + same_value(pow(other_value, c), pow(other_value, obj_value)) + ) + # test __lshift__ + self.assertTrue( + same_value(c << other_value, obj_value << other_value) + ) + # test __rlshift__ + self.assertTrue( + same_value(other_value << c, other_value << obj_value) + ) + # test __rshift__ + self.assertTrue( + same_value(c >> other_value, obj_value >> other_value) + ) + # test __rrshift__ + self.assertTrue( + same_value(other_value >> c, other_value >> obj_value) + ) + # test __and__ + self.assertTrue(same_value(c & other_value, obj_value & other_value)) + # test __rand__ + self.assertTrue(same_value(other_value & c, other_value & obj_value)) + # test __xor__ + self.assertTrue(same_value(c ^ other_value, obj_value ^ other_value)) + # test __rxor__ + self.assertTrue(same_value(other_value ^ c, other_value ^ obj_value)) + # test __or__ + self.assertTrue(same_value(c | other_value, obj_value | other_value)) + # test __ror__ + self.assertTrue(same_value(other_value | c, other_value | obj_value)) + + if isinstance(obj_value, jax.Array) and isinstance( + other_value, jax.Array + ): + # test __matmul__ + self.assertTrue( + same_value(c @ other_value, obj_value @ other_value) + ) + # test __rmatmul__ + self.assertTrue( + same_value(other_value @ c, other_value @ obj_value) + ) + + # test __lt__ + self.assertTrue(same_value(c < other_value, obj_value < other_value)) + self.assertTrue(same_value(other_value < c, other_value < obj_value)) + # test __le__ + self.assertTrue( + same_value(c <= other_value, obj_value <= other_value) + ) + self.assertTrue( + same_value(other_value <= c, other_value <= obj_value) + ) + # test __eq__ + self.assertTrue( + same_value(c == other_value, obj_value == other_value) + ) + self.assertTrue( + same_value(other_value == c, other_value == obj_value) + ) + # test __ne__ + self.assertTrue( + same_value(c != other_value, obj_value != other_value) + ) + self.assertTrue( + same_value(other_value != c, other_value != obj_value) + ) + # test __gt__ + self.assertTrue(same_value(c > other_value, obj_value > other_value)) + self.assertTrue(same_value(other_value > c, other_value > obj_value)) + # test __ge__ + self.assertTrue( + same_value(c >= other_value, obj_value >= other_value) + ) + self.assertTrue( + same_value(other_value >= c, other_value >= obj_value) + ) + + def test_path(self): + c = cursor( + GenericTuple( + x=[ + 0, + {'a': 1, 'b': (2, 3), ('c', 'd'): [4, 5]}, + (100, 200), + [3, 4, 5], + ], + y=train_state.TrainState.create( + apply_fn=lambda x: x, + params=freeze({'a': 1, 'b': (2, 3), 'c': [4, 5]}), + tx=optax.adam(1e-3), + ), + ) + ) + self.assertEqual(c.x[1][('c', 'd')][0]._path, ".x[1][('c', 'd')][0]") + self.assertEqual(c.x[2][1]._path, '.x[2][1]') + self.assertEqual(c.y.params['b'][1]._path, ".y.params['b'][1]") + + # test path when first access type is item access + c = cursor([1, GenericTuple('a', 2), (3, 4)]) + self.assertEqual(c[1].x._path, '[1].x') + self.assertEqual(c[2][0]._path, '[2][0]') + + def test_traverse_tree(self): + c = cursor( + GenericTuple( + x=[ + 0, + {'a': 1, 'b': (2, 3), ('c', 'd'): [4, 5]}, + (100, 200), + [3, 4, 5], + ], + y=3, + ) + ) + + def update_fn(path, value): + if value == 4: + return -4 + return value + + def cond_fn(path, value): + return value == 3 + + with self.assertRaisesRegex( + TraverseTreeError, + 'Both update_fn and cond_fn are None. Exactly one of them must be' + ' None.', + ): + next(_traverse_tree((), c._obj)) + with self.assertRaisesRegex( + TraverseTreeError, + 'Both update_fn and cond_fn are not None. Exactly one of them must be' + ' not None.', + ): + next(_traverse_tree((), c._obj, update_fn=update_fn, cond_fn=cond_fn)) + + (p, v), (p2, v2) = _traverse_tree((), c._obj, update_fn=update_fn) + self.assertEqual( + p, + ( + ('x', AccessType.ATTR), + (1, AccessType.ITEM), + (('c', 'd'), AccessType.ITEM), + (0, AccessType.ITEM), + ), + ) + self.assertEqual(v, -4) + self.assertEqual( + p2, (('x', AccessType.ATTR), (3, AccessType.ITEM), (1, AccessType.ITEM)) + ) + self.assertEqual(v2, -4) + + p, p2, p3 = _traverse_tree((), c._obj, cond_fn=cond_fn) + self.assertEqual( + p, + ( + ('x', AccessType.ATTR), + (1, AccessType.ITEM), + ('b', AccessType.ITEM), + (1, AccessType.ITEM), + ), + ) + self.assertEqual( + p2, (('x', AccessType.ATTR), (3, AccessType.ITEM), (0, AccessType.ITEM)) + ) + self.assertEqual(p3, (('y', AccessType.ATTR),)) + def test_set_and_build(self): # test regular dict and FrozenDict dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} @@ -222,7 +585,7 @@ def update_fn(path, value): # test NamedTuple def update_fn(path, value): - """Add 5 to all x-attribute values that are ints""" + """Add 5 to all x-attribute values that are ints.""" if path[-1] == 'x' and isinstance(value, int): return value + 5 return value @@ -327,6 +690,157 @@ def test_named_tuple_multi_access(self): c.build(), GenericTuple(GenericTuple(4, -8), GenericTuple(6, 7)) ) + def test_find(self): + c = cursor( + GenericTuple( + x=[ + 0, + {'a': 1, 'b': (2, 3), ('c', 'd'): [4, 5]}, + (100, 200), + [3, 4, 5], + ], + y=train_state.TrainState.create( + apply_fn=lambda x: x, + params=freeze({'a': 1, 'b': (2, 3), 'c': [4, 5]}), + tx=optax.adam(1e-3), + ), + ) + ) + + with self.assertRaisesRegex( + CursorFindError, + 'More than one object found given the conditions of the cond_fn\\. ' + 'The first two objects found have the following paths: ' + "\\.x\\[1]\\['b'] and \\.y\\.params\\['b'] ", + ): + c.find(lambda path, value: 'b' in path and isinstance(value, tuple)) + with self.assertRaisesRegex( + CursorFindError, + 'No object found given the conditions of the cond_fn\\.', + ): + c.find(lambda path, value: 'b' in path and isinstance(value, str)) + + self.assertEqual( + c.find(lambda path, value: path.endswith('params/b'))[1] + .set(30) + .y.params, + freeze({'a': 1, 'b': (2, 30), 'c': [4, 5]}), + ) + + def test_find_all(self): + # test list and tuple + def cond_fn(path, value): + """Get all lists that are not the first element in its parent.""" + return path[-1] != '0' and isinstance(value, (tuple, list)) + + for tuple_wrap in (lambda x: x, tuple): + l = tuple_wrap( + [tuple_wrap([1, 2]), tuple_wrap([3, 4]), tuple_wrap([5, 6])] + ) + c = cursor(l) + c2, c3 = c.find_all(cond_fn) + c2[0] *= -1 + c3[1] *= -2 + self.assertEqual( + c.build(), + tuple_wrap( + [tuple_wrap([1, 2]), tuple_wrap([-3, 4]), tuple_wrap([5, -12])] + ), + ) + self.assertEqual( + l, + tuple_wrap( + [tuple_wrap([1, 2]), tuple_wrap([3, 4]), tuple_wrap([5, 6])] + ), + ) # make sure the original object is unchanged + + # test regular dict and FrozenDict + def cond_fn(path, value): + """Get the second and third dense params.""" + return 'Dense_1' in path or 'Dense_2' in path + + class Model(nn.Module): + + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + x = nn.Dense(3)(x) + x = nn.relu(x) + return x + + for freeze_wrap in (lambda x: x, freeze): + params = freeze_wrap( + Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] + ) + c = cursor(params) + for i, c2 in enumerate(c.find_all(cond_fn)): + self.assertEqual( + c2['kernel'].set(123)[f'Dense_{i+1}'], + freeze_wrap( + {'kernel': 123, 'bias': params[f'Dense_{i+1}']['bias']} + ), + ) + self.assertTrue( + jax.tree_util.tree_all( + jax.tree_util.tree_map( + lambda x, y: (x == y).all(), + params, + freeze_wrap( + Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ + 'params' + ] + ), + ) + ) + ) # make sure original params are unchanged + + # test TrainState + def cond_fn(path, value): + """Find TrainState params.""" + return 'params' in path + + state = train_state.TrainState.create( + apply_fn=lambda x: x, + params={'a': 1, 'b': 2}, + tx=optax.adam(1e-3), + ) + c = cursor(state) + c2 = list(c.find_all(cond_fn)) + self.assertEqual(len(c2), 1) + c2 = c2[0] + self.assertEqual(c2['b'].set(-1).params, {'a': 1, 'b': -1}) + self.assertEqual( + state.params, {'a': 1, 'b': 2} + ) # make sure original params are unchanged + + # test NamedTuple + def cond_fn(path, value): + """Get all GenericTuples that have int x-attribute values.""" + return isinstance(value, GenericTuple) and isinstance(value.x, int) + + t = GenericTuple( + GenericTuple(0, 'a'), GenericTuple(1, 'b'), GenericTuple('c', 2) + ) + c = cursor(t) + c2, c3 = c.find_all(cond_fn) + c2.x += 5 + c3.x += 6 + self.assertEqual( + c.build(), + GenericTuple( + GenericTuple(5, 'a'), GenericTuple(7, 'b'), GenericTuple('c', 2) + ), + ) + self.assertEqual( + t, + GenericTuple( + GenericTuple(0, 'a'), GenericTuple(1, 'b'), GenericTuple('c', 2) + ), + ) # make sure original object is unchanged + if __name__ == '__main__': absltest.main()