Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onesided collective perftest #1004

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/perf/ucc_perftest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ int main(int argc, char *argv[])
ucc_pt_cuda_init();
ucc_pt_rocm_init();
try {
comm = new ucc_pt_comm(pt_config.comm);
comm = new ucc_pt_comm(pt_config.comm,pt_config.bench);
} catch(std::exception &e) {
std::cerr << e.what() << std::endl;
std::exit(1);
Expand Down
2 changes: 1 addition & 1 deletion tools/perf/ucc_pt_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ucc_pt_benchmark::ucc_pt_benchmark(ucc_pt_benchmark_config cfg,
break;
case UCC_PT_OP_TYPE_ALLTOALL:
coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
cfg.persistent, cfg.onesided, comm);
break;
case UCC_PT_OP_TYPE_ALLTOALLV:
coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace,
Expand Down
5 changes: 4 additions & 1 deletion tools/perf/ucc_pt_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ extern "C" {
#include <components/ec/ucc_ec.h>
#include <components/mc/ucc_mc.h>
}
#define UCC_IS_ONESIDED(_args) \
(((_args).mask & UCC_COLL_ARGS_FIELD_FLAGS) && \
((_args).flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS))

ucc_status_t ucc_pt_alloc(ucc_mc_buffer_header_t **h_ptr, size_t len,
ucc_memory_type_t mem_type);
Expand Down Expand Up @@ -87,7 +90,7 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll {
class ucc_pt_coll_alltoall: public ucc_pt_coll {
public:
ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, bool is_persistent,
bool is_inplace, bool is_persistent, bool is_onesided,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand Down
11 changes: 10 additions & 1 deletion tools/perf/ucc_pt_coll_alltoall.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

ucc_pt_coll_alltoall::ucc_pt_coll_alltoall(ucc_datatype_t dt,
ucc_memory_type mt, bool is_inplace,
bool is_persistent,
bool is_persistent,bool is_onesided,
ucc_pt_comm *communicator) : ucc_pt_coll(communicator)
{
has_inplace_ = true;
Expand All @@ -38,6 +38,10 @@ ucc_pt_coll_alltoall::ucc_pt_coll_alltoall(ucc_datatype_t dt,
coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
}
if(is_onesided){
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER;
coll_args.flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS;
}
}

ucc_status_t ucc_pt_coll_alltoall::init_args(size_t single_rank_count,
Expand All @@ -60,6 +64,11 @@ ucc_status_t ucc_pt_coll_alltoall::init_args(size_t single_rank_count,
free_dst, st);
args.src.info.buffer = src_header->addr;
}
if(UCC_IS_ONESIDED(args)){
args.src.info.buffer = comm->get_global_buffer(0);
args.dst.info.buffer = comm->get_global_buffer(1);
args.global_work_buffer = comm->get_global_buffer(2);
}
return UCC_OK;
free_dst:
ucc_pt_free(dst_header);
Expand Down
45 changes: 44 additions & 1 deletion tools/perf/ucc_pt_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ extern "C" {
#include "utils/ucc_coll_utils.h"
#include "components/mc/ucc_mc.h"
}
#define UCC_MALLOC_CHECK(_obj) \
if (!(_obj)) { \
std::cerr << "*** UCC MALLOC FAIL \n"; \
MPI_Abort(MPI_COMM_WORLD.-1); \
}

ucc_pt_comm::ucc_pt_comm(ucc_pt_comm_config config)
ucc_pt_comm::ucc_pt_comm(ucc_pt_comm_config config,ucc_pt_benchmark_config ben_config)
{
cfg = config;
bcfg = ben_config;
bootstrap = new ucc_pt_bootstrap_mpi();
}

Expand Down Expand Up @@ -124,6 +130,7 @@ ucc_status_t ucc_pt_comm::init()
ucc_status_t st;
std::string cfg_mod;

ucc_mem_map_t segments[UCC_TEST_N_MEM_SEGMENTS];
ee = nullptr;
executor = nullptr;
stream = nullptr;
Expand Down Expand Up @@ -157,6 +164,29 @@ ucc_status_t ucc_pt_comm::init()
UCC_CONTEXT_PARAM_FIELD_OOB;
ctx_params.type = UCC_CONTEXT_SHARED;
ctx_params.oob = bootstrap->get_context_oob();
if (bcfg.onesided)
{
for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++)
{
onesided_buffers[i] = ucc_calloc(UCC_TEST_MEM_SEGMENT_SIZE,
bootstrap->get_size(),"onesided buffers");
UCC_MALLOC_CHECK(onesided_buffers[i]);
segments[i].address = onesided_buffers[i];
segments[i].len = UCC_TEST_MEM_SEGMENT_SIZE * (bootstrap->get_size());
}
ctx_params.mask |= UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS;
ctx_params.mem_params.segments = segments;
ctx_params.mem_params.n_segments = UCC_TEST_N_MEM_SEGMENTS;
}
if(!bcfg.onesided)
{
for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++)
{
onesided_buffers[i] = NULL;
}

}

UCCCHECK_GOTO(ucc_context_create(lib, &ctx_params, ctx_config, &context),
free_ctx_config, st);
team_params.mask = UCC_TEAM_PARAM_FIELD_EP |
Expand All @@ -165,6 +195,10 @@ ucc_status_t ucc_pt_comm::init()
team_params.oob = bootstrap->get_team_oob();
team_params.ep = bootstrap->get_rank();
team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG;
if(bcfg.onesided){
team_params.mask |= UCC_TEAM_PARAM_FIELD_FLAGS;
team_params.flags = UCC_TEAM_FLAG_COLL_WORK_BUFFER;
}
UCCCHECK_GOTO(ucc_team_create_post(&context, 1, &team_params, &team),
free_ctx, st);
do {
Expand Down Expand Up @@ -219,6 +253,15 @@ ucc_status_t ucc_pt_comm::finalize()
if (status != UCC_OK) {
std::cerr << "ucc team destroy error: " << ucc_status_string(status);
}
if (onesided_buffers[0])
{
for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++)
{
ucc_free(onesided_buffers[i]);
}

}

ucc_context_destroy(context);
ucc_finalize(lib);
return UCC_OK;
Expand Down
14 changes: 13 additions & 1 deletion tools/perf/ucc_pt_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
extern "C" {
#include "components/ec/ucc_ec.h"
}
#define UCC_TEST_N_MEM_SEGMENTS 3
#define UCC_TEST_MEM_SEGMENT_SIZE (1 << 21)

class ucc_pt_comm {
ucc_pt_benchmark_config bcfg;
ucc_pt_comm_config cfg;
ucc_lib_h lib;
ucc_context_h context;
Expand All @@ -24,11 +27,20 @@ class ucc_pt_comm {
ucc_ee_h ee;
ucc_ee_executor_t *executor;
ucc_pt_bootstrap *bootstrap;
void *onesided_buffers[3];
void set_gpu_device();
public:
ucc_pt_comm(ucc_pt_comm_config config);
ucc_pt_comm(ucc_pt_comm_config config,ucc_pt_benchmark_config ben_config);
int get_rank();
int get_size();
void* get_global_buffer(int index){
if (index < 0 || index >=3)
{
throw std::out_of_range("Index out range");
}
return onesided_buffers[index];

};
ucc_ee_executor_t* get_executor();
ucc_ee_h get_ee();
ucc_team_h get_team();
Expand Down
7 changes: 6 additions & 1 deletion tools/perf/ucc_pt_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ ucc_pt_config::ucc_pt_config() {
bench.root_shift = 0;
bench.mult_factor = 2;
comm.mt = bench.mt;
bench.onesided = false;
}

const std::map<std::string, ucc_reduction_op_t> ucc_pt_reduction_op_map = {
Expand Down Expand Up @@ -91,7 +92,7 @@ ucc_status_t ucc_pt_config::process_args(int argc, char *argv[])
int c;
ucc_status_t st;

while ((c = getopt(argc, argv, "c:b:e:d:m:n:w:o:N:r:S:iphFT")) != -1) {
while ((c = getopt(argc, argv, "c:b:e:d:m:n:w:o:N:r:S:iphFTJ")) != -1) {
switch (c) {
case 'c':
if (ucc_pt_op_map.count(optarg) == 0) {
Expand Down Expand Up @@ -172,6 +173,9 @@ ucc_status_t ucc_pt_config::process_args(int argc, char *argv[])
case 'F':
bench.full_print = true;
break;
case 'J':
bench.onesided = true;
break;
case 'h':
default:
print_help();
Expand Down Expand Up @@ -201,5 +205,6 @@ void ucc_pt_config::print_help()
std::cout << " -F: enable full print"<<std::endl;
std::cout << " -S: <number>: root shift for rooted collectives"<<std::endl;
std::cout << " -h: show this help message"<<std::endl;
std::cout << " -J: onesided collective"<<std::endl;
std::cout << std::endl;
}
1 change: 1 addition & 0 deletions tools/perf/ucc_pt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct ucc_pt_benchmark_config {
int root;
int root_shift;
int mult_factor;
bool onesided;
};

struct ucc_pt_config {
Expand Down
Loading