diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 9efec6f2e5..db3aa31951 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -18,5 +18,7 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py