From d4aa2996d1d47a1a63dcd48b4f27da78778b8db6 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 7 Nov 2024 06:04:02 +0800 Subject: [PATCH] [JAX] Add back the xla deterministic flag (#1301) Add back the xla deterministic flag Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- qa/L0_jax_unittest/test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 9efec6f2e5..db3aa31951 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -18,5 +18,7 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py