From c76787571b81c5538226f8856dc2c3f87e6d4a2a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 23 Aug 2024 05:48:29 -0700 Subject: [PATCH] [Mosaic GPU] Expose wait_parity on collective barrier PiperOrigin-RevId: 666761011 --- jax/experimental/mosaic/gpu/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 64c9f409ef7f..3c0cedfc2807 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -712,8 +712,11 @@ def arrive(self): has_side_effects=True, ) - def wait(self): - self.barrier.wait() + def wait(self, *args, **kwargs): + self.barrier.wait(*args, **kwargs) + + def wait_parity(self, *args, **kwargs): + self.barrier.wait_parity(*args, **kwargs) class Partition: