Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add find methods and magic methods for Cursor API #3306

Merged
merged 1 commit into from
Sep 19, 2023

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Sep 2, 2023

Initial Cursor PR: #3246
Resolves #3309 and resolves #3311.

The PR adds the following features to the Cursor API:

  • Moved all Cursor attributes and properties to private (e.g. Cursor.obj -> Cursor._obj)
  • Added magic methods to support more intuitive usage and updates of Cursor objects
c = cursor(((0, 1), (1, 2)))
for child_c in c: # iterating over child Cursors is now possible if the underlying object is a tuple or list
  child_c[1] *= -1 # Cursor objects now support in-place updates
assert c.build() == (((0, -1), (1, -2)))
assert len(c) == 2
  • Added a .find_all method which finds all child Cursor objects that fulfill a cond_fn that is passed into the method
c = cursor({'a1': {'b': 1}, 'a2': {'b': 2}, 'c': 3})
for child_c in c.find_all(lambda path, value: path.endswith('a1') or path.endswith('a2')):
  child_c['b'] += 5
assert c.build() == {'a1': {'b': 6}, 'a2': {'b': 7}, 'c': 3}
  • Added a .find method which finds a single child Cursor object that fulfills a cond_fn passed into the method. An error is raised if no object or more than one object is found.
c = cursor({'a1': {'b': 1}, 'a2': {'b': 2}, 'c': 3})
child_c = c.find(lambda path, value: path.endswith('a1'))
child_c['b'] += 5
assert c.build() == {'a1': {'b': 6}, 'a2': {'b': 2}, 'c': 3}
  • Added ._path property for Cursor objects
@dataclasses.dataclass
class GenericDataClass:
    x: Any
    y: Any

c = cursor({'a': ['x', ('y', 'z')], 'b': (1, GenericDataClass(x=2, y=3))})
assert c.find(lambda path, value: value == 2)._path == "['b'][1].x"
  • Added a more informative representation for Cursor objects
@dataclasses.dataclass
class GenericDataClass:
    x: Any
    y: Any

c = cursor({'a': ['x', ('y', 'z')], 'b': (1, GenericDataClass(x=2, y=3))})
print(repr(c))
# Cursor(
#   _obj={'a': ['x', ('y', 'z')], 'b': (1, GenericDataClass(x=2, y=3))},
#   _changes={}
# )

c['a'][1][0] += 'yy'
c['b'][0] = []
c['b'][1].x *= -1
print(repr(c))
# Cursor(
#   _obj={'a': ['x', ('y', 'z')], 'b': (1, GenericDataClass(x=2, y=3))},
#   _changes={
#     'a': Cursor(
#            _obj=['x', ('y', 'z')],
#            _changes={
#              1: Cursor(
#                   _obj=('y', 'z'),
#                   _changes={
#                     0: Cursor(
#                          _obj='yyy',
#                          _changes={}
#                        )
#                   }
#                 )
#            }
#          ),
#     'b': Cursor(
#            _obj=(1, GenericDataClass(x=2, y=3)),
#            _changes={
#              0: Cursor(
#                   _obj=[],
#                   _changes={}
#                 ),
#              1: Cursor(
#                   _obj=GenericDataClass(x=2, y=3),
#                   _changes={
#                     'x': Cursor(
#                            _obj=-2,
#                            _changes={}
#                          )
#                   }
#                 )
#            }
#          )
#   }
# )

@chiamp chiamp self-assigned this Sep 2, 2023
@chiamp chiamp marked this pull request as draft September 2, 2023 00:44
@codecov-commenter
Copy link

codecov-commenter commented Sep 2, 2023

Codecov Report

Merging #3306 (eb2200f) into main (1127475) will increase coverage by 0.36%.
Report is 15 commits behind head on main.
The diff coverage is 97.68%.

@@            Coverage Diff             @@
##             main    #3306      +/-   ##
==========================================
+ Coverage   82.65%   83.01%   +0.36%     
==========================================
  Files          55       55              
  Lines        6359     6518     +159     
==========================================
+ Hits         5256     5411     +155     
- Misses       1103     1107       +4     
Files Changed Coverage Δ
flax/cursor.py 97.35% <97.57%> (+1.52%) ⬆️
flax/errors.py 87.60% <100.00%> (+1.11%) ⬆️

... and 3 files with indirect coverage changes

flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
@chiamp chiamp force-pushed the cursor_find branch 9 times, most recently from 38a080d to 0f2e326 Compare September 8, 2023 23:17
@chiamp chiamp changed the title Add find feature for Cursor API Add find methods and magic methods for Cursor API Sep 8, 2023
@chiamp chiamp marked this pull request as ready for review September 8, 2023 23:21
@chiamp chiamp force-pushed the cursor_find branch 2 times, most recently from 336bcfa to 7a23f8d Compare September 11, 2023 20:47
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Outdated Show resolved Hide resolved
flax/cursor.py Show resolved Hide resolved
@chiamp chiamp removed the request for review from levskaya September 14, 2023 16:22
@copybara-service copybara-service bot merged commit 7f27932 into google:main Sep 19, 2023
16 checks passed
@chiamp chiamp deleted the cursor_find branch September 25, 2023 23:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Should we override magic/dunder methods for the Cursor API? Add .find and .find_all to Cursor API
4 participants