Skip to content

Commit

Permalink
[Mosaic GPU] Expose wait_parity on collective barrier
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666761011
  • Loading branch information
apaszke authored and jax authors committed Aug 23, 2024
1 parent c430b0c commit c767875
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,11 @@ def arrive(self):
has_side_effects=True,
)

def wait(self):
self.barrier.wait()
def wait(self, *args, **kwargs):
self.barrier.wait(*args, **kwargs)

def wait_parity(self, *args, **kwargs):
self.barrier.wait_parity(*args, **kwargs)


class Partition:
Expand Down

0 comments on commit c767875

Please sign in to comment.