Skip to content

Commit

Permalink
Add test of results round trip when using MPI.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmjarvis committed Feb 10, 2023
1 parent d0f4ca0 commit 80e9666
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/mpi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,18 @@ def do_mpi_cov(comm, method, output=True):
print(comm.rank, "Running RR process")
rr.process(ran_cat, comm=comm)

gg_save_file = os.path.join('data','mpi_gg_save.fits')
ng_save_file = os.path.join('data','mpi_ng_save.fits')
nn_save_file = os.path.join('data','mpi_nn_save.fits')
rr_save_file = os.path.join('data','mpi_rr_save.fits')
if comm.rank == 0:
if output:
print(comm.rank, "Writing to save files")
gg.write(gg_save_file, write_patch_results=True)
ng.write(ng_save_file, write_patch_results=True)
nn.write(nn_save_file, write_patch_results=True)
rr.write(rr_save_file, write_patch_results=True)

# Only the root process gets the complete version
# when you call the above with comm
if output:
Expand Down Expand Up @@ -324,6 +336,26 @@ def do_mpi_cov(comm, method, output=True):
np.testing.assert_allclose(A1b, A2b, atol=tol)
np.testing.assert_allclose(w1b, w2b, atol=tol)

# Finally, read back in from the save file and redo the covariance.
rng = np.random.RandomState(31415)
gg = treecorr.GGCorrelation(bin_size=0.3, min_sep=10., max_sep=50., rng=rng)
ng = treecorr.NGCorrelation(bin_size=0.3, min_sep=10., max_sep=50., rng=rng)
nn = treecorr.NNCorrelation(bin_size=0.3, min_sep=10., max_sep=50., rng=rng)
rr = treecorr.NNCorrelation(bin_size=0.3, min_sep=10., max_sep=50., rng=rng)
gg.read(gg_save_file)
ng.read(ng_save_file)
nn.read(nn_save_file)
rr.read(rr_save_file)

ng.calculateXi()
nn.calculateXi(rr=rr)
corrs = [gg, ng, nn]
cov3 = treecorr.estimate_multi_cov(corrs, method, comm=comm)
if output:
print("\nCOV 3\n", cov3[0:3,0:3], " for ", comm.rank, "\n")

np.testing.assert_allclose(cov1, cov3, atol=tol)


if __name__ == '__main__':
from mpi4py import MPI
Expand Down

0 comments on commit 80e9666

Please sign in to comment.