diff --git a/test/gtest/asym_mem/test_asymmetric_memory.cc b/test/gtest/asym_mem/test_asymmetric_memory.cc index 46b3ed031d..9fd997efb9 100644 --- a/test/gtest/asym_mem/test_asymmetric_memory.cc +++ b/test/gtest/asym_mem/test_asymmetric_memory.cc @@ -16,7 +16,7 @@ class test_asymmetric_memory : public ucc::test, public: UccCollCtxVec ctxs; void data_init(ucc_coll_type_t coll_type, ucc_memory_type_t src_mem_type, - ucc_memory_type_t dst_mem_type, UccTeam_h team) { + ucc_memory_type_t dst_mem_type, UccTeam_h team, bool persistent = false) { ucc_rank_t tsize = team->procs.size(); int root = 0; size_t msglen = 2048; @@ -42,25 +42,31 @@ class test_asymmetric_memory : public ucc::test, coll->src.info.count = (ucc_count_t)msglen * src_modifier; coll->src.info.datatype = UCC_DT_INT8; coll->root = root; + if (persistent) { + coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll->flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; + } - UCC_CHECK(ucc_mc_alloc(&ctxs[i]->src_mc_header, - msglen * src_modifier, src_mem_type)); - coll->src.info.buffer = ctxs[i]->src_mc_header->addr; - - ctxs[i]->init_buf = ucc_malloc(msglen * src_modifier, - "init buf"); - EXPECT_NE(ctxs[i]->init_buf, nullptr); - uint8_t *sbuf = (uint8_t*)ctxs[i]->init_buf; - for (int j = 0; j < msglen * src_modifier; j++) { - sbuf[j] = (uint8_t) 1; + if (i == root || coll_type != UCC_COLL_TYPE_SCATTER) { + UCC_CHECK(ucc_mc_alloc(&ctxs[i]->src_mc_header, + msglen * src_modifier, src_mem_type)); + coll->src.info.buffer = ctxs[i]->src_mc_header->addr; + + ctxs[i]->init_buf = ucc_malloc(msglen * src_modifier, + "init buf"); + EXPECT_NE(ctxs[i]->init_buf, nullptr); + uint8_t *sbuf = (uint8_t*)ctxs[i]->init_buf; + for (int j = 0; j < msglen * src_modifier; j++) { + sbuf[j] = (uint8_t) 1; + } + UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, + ctxs[i]->init_buf, + msglen * src_modifier, src_mem_type, + UCC_MEMORY_TYPE_HOST)); } - UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, - ctxs[i]->init_buf, - msglen * src_modifier, src_mem_type, - UCC_MEMORY_TYPE_HOST)); ctxs[i]->rbuf_size = msglen * dst_modifier; - if (i == root) { + if (i == root || coll_type == UCC_COLL_TYPE_SCATTER) { UCC_CHECK(ucc_mc_alloc(&ctxs[i]->dst_mc_header, ctxs[i]->rbuf_size, dst_mem_type)); coll->dst.info.buffer = ctxs[i]->dst_mc_header->addr; @@ -79,27 +85,29 @@ class test_asymmetric_memory : public ucc::test, continue; } ucc_coll_args_t* coll = ctx->args; - UCC_CHECK(ucc_mc_free(ctx->src_mc_header)); - if (i == coll->root) { + if (i == coll->root || coll->coll_type != UCC_COLL_TYPE_SCATTER) { + ucc_free(ctx->init_buf); + UCC_CHECK(ucc_mc_free(ctx->src_mc_header)); + } + if (i == coll->root || coll->coll_type == UCC_COLL_TYPE_SCATTER) { UCC_CHECK(ucc_mc_free(ctx->dst_mc_header)); } - ucc_free(ctx->init_buf); free(coll); free(ctx); } ctxs.clear(); } - bool data_validate() + bool data_validate(uint8_t data = 1) { bool ret = true; int root = 0; - uint8_t result = 1; + uint8_t result = data; ucc_memory_type_t dst_mem_type; uint8_t *rst; if (ctxs[0]->args->coll_type == UCC_COLL_TYPE_REDUCE) { - result = (uint8_t) ctxs.size(); + result *= (uint8_t) ctxs.size(); } for (int i = 0; i < ctxs.size(); i++) { @@ -132,6 +140,30 @@ class test_asymmetric_memory : public ucc::test, return ret; } + + void data_update(uint8_t data) { + ucc_rank_t tsize = ctxs.size(); + size_t msglen = 2048; + size_t src_modifier = 1; + ucc_coll_type_t coll_type = ctxs[0]->args->coll_type; + ucc_memory_type_t src_mem_type = ctxs[0]->args->src.info.mem_type; + + if (coll_type == UCC_COLL_TYPE_SCATTER) { + src_modifier = tsize; + } + + for (int i = 0; i < tsize; i++) { + ucc_coll_args_t *coll = ctxs[i]->args; + uint8_t *sbuf = (uint8_t*)ctxs[i]->init_buf; + for (int j = 0; j < msglen * src_modifier; j++) { + sbuf[j] = (uint8_t) data; + } + UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, + ctxs[i]->init_buf, + msglen * src_modifier, src_mem_type, + UCC_MEMORY_TYPE_HOST)); + } + } }; @@ -384,6 +416,33 @@ UCC_TEST_P(test_asymmetric_memory, single) TEST_ASYM_DECLARE } +UCC_TEST_P(test_asymmetric_memory, persistent) +{ + const ucc_coll_type_t coll_type = std::get<0>(GetParam()); + const ucc_memory_type_t src_mem_type = std::get<1>(GetParam()); + const ucc_memory_type_t dst_mem_type = std::get<2>(GetParam()); + const int n_procs = std::get<3>(GetParam()); + int times = 3; + + UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL); + UccTeam_h team = job.create_team(n_procs); + + data_init(coll_type, src_mem_type, dst_mem_type, team, /*persistent*/true); + UccReq req(team, ctxs); + if (req.status != UCC_OK) { + data_fini(); + GTEST_SKIP() << "ucc_collective_init returned " + << ucc_status_string(req.status); + } + for (; times > 0; times--) { + data_update(times); // Set each element in src to times + req.start(); + req.wait(); + EXPECT_EQ(true, data_validate(times)); // Check that the dst was correct based on times + } + data_fini(); +} + INSTANTIATE_TEST_CASE_P ( , test_asymmetric_memory, @@ -413,4 +472,6 @@ INSTANTIATE_TEST_CASE_P ) ); + + #endif