From 8645c6dbf3c8cffdaf66e092027f6c3258598ca2 Mon Sep 17 00:00:00 2001 From: Colin Cotter Date: Fri, 12 Jul 2024 15:15:09 +0100 Subject: [PATCH] add dJdm to taylor_test calls --- tests/ensemble_reduced_functional/test_reduced_functional.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/ensemble_reduced_functional/test_reduced_functional.py b/tests/ensemble_reduced_functional/test_reduced_functional.py index 6a7edb4c9f..80099110ee 100644 --- a/tests/ensemble_reduced_functional/test_reduced_functional.py +++ b/tests/ensemble_reduced_functional/test_reduced_functional.py @@ -38,7 +38,6 @@ def test_verification(): dJdm = rf.derivative() assert_allclose(ensemble_J, size, rtol=1e-12) assert_allclose(dJdm.dat.data_ro, 2.0 * size, rtol=1e-12) - dJdm = ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM) assert taylor_test(rf, x, Function(R, val=0.1)) > 1.9 @@ -63,7 +62,7 @@ def test_verification_gather_functional_adjfloat(): assert_allclose(ensemble_J, 1.0**4+2.0**4, rtol=1e-12) assert_allclose(dJdm.dat.data_ro, 4*(rank+1)**3, rtol=1e-12) dJdm = ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM) - assert taylor_test(rf, x, Function(R, val=0.1)) > 1.9 + assert taylor_test(rf, x, Function(R, val=0.1), dJdm=dJdm) > 1.9 @pytest.mark.parallel(nprocs=4) @@ -88,7 +87,7 @@ def test_verification_gather_functional_Function(): assert_allclose(ensemble_J, 1.0**4+2.0**4, rtol=1e-12) assert_allclose(dJdm.dat.data_ro, 4*(rank+1)**3, rtol=1e-12) dJdm = ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM) - assert taylor_test(rf, x, Function(R, val=0.1)) > 1.9 + assert taylor_test(rf, x, Function(R, val=0.1), dJdm=dJdm) > 1.9 @pytest.mark.parallel(nprocs=6)