Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added cursor api #3246

Merged
merged 1 commit into from
Aug 10, 2023
Merged

added cursor api #3246

merged 1 commit into from
Aug 10, 2023

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Aug 1, 2023

Continuing from #3236.

Added the Cursor API, which allows for mutability of pytrees. This API provides a more ergonomic solution to making partial-updates of deeply nested immutable data structures, compared to making many nested dataclasses.replace calls.

To illustrate, consider the example below:

from flax.cursor import cursor
import dataclasses
from typing import Any

@dataclasses.dataclass
class A:
  x: Any

a = A(A(A(A(A(A(A(0)))))))

To replace the int 0 using dataclasses.replace, we would have to write many nested calls:

a2 = dataclasses.replace(
  a,
  x=dataclasses.replace(
    a.x,
    x=dataclasses.replace(
      a.x.x,
      x=dataclasses.replace(
        a.x.x.x,
        x=dataclasses.replace(
          a.x.x.x.x,
          x=dataclasses.replace(
            a.x.x.x.x.x,
            x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
          ),
        ),
      ),
    ),
  ),
)

The equivalent can be achieved much more simply using the Cursor API:

a3 = cursor(a).x.x.x.x.x.x.x.set(1)
assert a2 == a3

The Cursor object keeps tracks of changes made to it and when .build is called, generates a new object with the accumulated changes. Basic usage involves wrapping the object in a Cursor, making changes to the Cursor object and generating a new copy of the original object with the accumulated changes.

There are three ways to use the Cursor API:

  • multiple changes using the .build method
  • single line change using the .set method
  • multiple conditional changes using the .apply_update method

Once you wrap the object in a Cursor and make changes to it, calling the .build method will generate a new object with the accumulated changes. For example:

from flax.cursor import cursor
from flax.training import train_state
import optax

dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
c = cursor(dict_obj)          # wrap the object in a Cursor
c['b'][0] = 10                # change 1
c['a'] = (100, 200)           # change 2
modified_dict_obj = c.build() # generate new object with the above changes
assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

state = train_state.TrainState.create(
    apply_fn=lambda x: x,
    params=dict_obj,
    tx=optax.adam(1e-3),
)
new_fn = lambda x: x + 1
c = cursor(state)          # wrap the object in a Cursor
c.params['b'][1] = 10      # change 1
c.apply_fn = new_fn        # change 2
modified_state = c.build() # generate new object with the above changes
assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
assert modified_state.apply_fn == new_fn

Calling the .set method will apply a single-line change and call .build immediately after. Therefore, you don't need to manually call .build after .set. For example:

from flax.cursor import cursor
from flax.training import train_state
import optax

dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
modified_dict_obj = cursor(dict_obj)['b'][0].set(10) # apply change and generate new object with the change
assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

state = train_state.TrainState.create(
    apply_fn=lambda x: x,
    params=dict_obj,
    tx=optax.adam(1e-3),
)
modified_state = cursor(state).params['b'][1].set(10) # apply change and generate new object with the change
assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}

Calling .apply_update will traverse the Cursor object and apply conditional changes recursively via a filter function.
The filter function has a function signature of (str, Any) -> Any:

  • The input arguments are the current key path (in the form of a string delimited by ‘/’) and value at that current key path
  • The output is the new value (either modified by the update_fn or same as the input value if the condition wasn’t fulfilled)

To generate a copy of the original object with the accumulated changes, call the .build method. For example:

import flax.linen as nn
from flax.cursor import cursor
import jax
import jax.numpy as jnp

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    return x

params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

def update_fn(path, value):
  '''Multiply all dense kernel params by 2 and add 1.
  Subtract the Dense_1 bias param by 1.'''
  if 'kernel' in path:
    return value * 2 + 1
  elif 'Dense_1' in path and 'bias' in path:
    return value - 1
  return value

c = cursor(params)
new_params = c.apply_update(update_fn).build()
for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
  assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
  if layer == 'Dense_1':
    assert (new_params[layer]['bias'] == jnp.array([-1, -1, -1])).all()
  else:
    assert (new_params[layer]['bias'] == params[layer]['bias']).all()

@chiamp chiamp self-assigned this Aug 1, 2023
@chiamp chiamp marked this pull request as draft August 1, 2023 05:37
@chiamp chiamp force-pushed the cursor branch 2 times, most recently from a15ae0b to a0f910d Compare August 1, 2023 21:18
@cgarciae cgarciae mentioned this pull request Aug 2, 2023
@chiamp chiamp force-pushed the cursor branch 4 times, most recently from 19027fa to 5acc1cb Compare August 3, 2023 04:48
@chiamp chiamp marked this pull request as ready for review August 3, 2023 05:08
@ASEM000
Copy link

ASEM000 commented Aug 3, 2023

Nice API, would you take a look at https://github.com/ASEM000/PyTreeClass, I attempted a slightly different API with the same objective as this API.

Specifically https://pytreeclass.readthedocs.io/en/latest/notebooks/common_recipes.html#[7]-Use-PyTreeClass-with-Flax/Equinox

flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
@chiamp chiamp force-pushed the cursor branch 13 times, most recently from 93ce3e1 to bd0762d Compare August 4, 2023 04:35
@chiamp chiamp force-pushed the cursor branch 7 times, most recently from 3f33ba1 to f79473a Compare August 4, 2023 05:26
@codecov-commenter
Copy link

codecov-commenter commented Aug 4, 2023

Codecov Report

Merging #3246 (8d4d3eb) into main (675e34d) will increase coverage by 0.16%.
Report is 4 commits behind head on main.
The diff coverage is 93.68%.

@@            Coverage Diff             @@
##             main    #3246      +/-   ##
==========================================
+ Coverage   82.35%   82.51%   +0.16%     
==========================================
  Files          54       55       +1     
  Lines        6087     6183      +96     
==========================================
+ Hits         5013     5102      +89     
- Misses       1074     1081       +7     
Files Changed Coverage Δ
flax/cursor.py 93.68% <93.68%> (ø)

... and 1 file with indirect coverage changes

Copy link
Collaborator

@levskaya levskaya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a few questions

flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Show resolved Hide resolved
flax/cursor.py Show resolved Hide resolved
@chiamp chiamp force-pushed the cursor branch 5 times, most recently from e9a7493 to cbd7efc Compare August 9, 2023 00:33
Copy link
Collaborator

@levskaya levskaya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a good start, we might iterate on a few details/extensions soon.

@copybara-service copybara-service bot merged commit 3ea6381 into google:main Aug 10, 2023
16 checks passed
@chiamp chiamp deleted the cursor branch September 25, 2023 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants