Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Dynamic Symmetric Memory #909

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ typedef struct ucc_tl_ucp_context {
ucc_tl_ucp_remote_info_t *dynamic_remote_info;
void *dyn_seg_buf;
ucp_rkey_h *dyn_rkeys;
size_t dyn_seg_size;
size_t n_dynrinfo_segs;
} ucc_tl_ucp_context_t;
UCC_CLASS_DECLARE(ucc_tl_ucp_context_t, const ucc_base_context_params_t *,
Expand Down
136 changes: 100 additions & 36 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,45 +234,101 @@ ucc_status_t ucc_tl_ucp_memmap_segment(ucc_tl_ucp_task_t *task,
ucc_status_t ucc_tl_ucp_coll_dynamic_segment_init(ucc_coll_args_t *coll_args,
ucc_tl_ucp_task_t *task)
{
ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task);
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
int i = 0;
ucc_status_t status;
ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task);
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
int i = 0;
uint64_t need_map = 0x7;
wfaderhold21 marked this conversation as resolved.
Show resolved Hide resolved
ucc_mem_map_t *maps = coll_args->mem_map.segments;
size_t n_segments = coll_args->mem_map.n_segments;
ucc_mem_map_t *seg_maps = NULL;
size_t n_segments = 3;
ucc_status_t status;

if (n_segments == 0) {
maps = ucc_calloc(2, sizeof(ucc_mem_map_t));
if (!maps) {
return UCC_ERR_NO_MEMORY;

/* check if src, dst, global work in ctx mapped segments */
for (i = 0; i < ctx->n_rinfo_segs && n_segments > 0; i++) {
uint64_t base = (uint64_t)ctx->remote_info[i].va_base;
uint64_t end = (uint64_t)(base + ctx->remote_info[i].len);
if ((uint64_t)coll_args->src.info.buffer >= base &&
(uint64_t)coll_args->src.info.buffer < end) {
// found it
need_map ^= 1;
--n_segments;
}
if ((uint64_t)coll_args->dst.info.buffer >= base &&
(uint64_t)coll_args->dst.info.buffer < end) {
// found it
need_map ^= 2;
--n_segments;
}

maps[0].address = coll_args->src.info.buffer;
maps[0].len = (coll_args->src.info.count / UCC_TL_TEAM_SIZE(tl_team)) *
ucc_dt_size(coll_args->src.info.datatype);
maps[0].resource = NULL;
if ((uint64_t)coll_args->global_work_buffer >= base &&
(uint64_t)coll_args->global_work_buffer < end) {
// found it
need_map ^= 4;
--n_segments;
}

maps[1].address = coll_args->dst.info.buffer;
maps[1].len = (coll_args->dst.info.count / UCC_TL_TEAM_SIZE(tl_team)) *
ucc_dt_size(coll_args->dst.info.datatype);
maps[1].resource = NULL;
if (n_segments == 0) {
break;
}
}

/* add any valid segments */
if (n_segments > 0) {
int index = 0;
seg_maps = ucc_calloc(n_segments, sizeof(ucc_mem_map_t));
if (!seg_maps) {
tl_error(UCC_TL_UCP_TEAM_LIB(tl_team), "Out of Memory");
return UCC_ERR_NO_MEMORY;
}

n_segments = 2;
if (need_map & 0x1) {
seg_maps[index].address = coll_args->src.info.buffer;
seg_maps[index].len = (coll_args->src.info.count) *
ucc_dt_size(coll_args->src.info.datatype);
seg_maps[index++].resource = NULL;
}
if (need_map & 0x2) {
seg_maps[index].address = coll_args->dst.info.buffer;
seg_maps[index].len = (coll_args->dst.info.count) *
ucc_dt_size(coll_args->dst.info.datatype);
seg_maps[index++].resource = NULL;
}
if (need_map & 0x4) {
seg_maps[index].address = coll_args->global_work_buffer;
seg_maps[index].len = (ONESIDED_SYNC_SIZE + ONESIDED_REDUCE_SIZE) * sizeof(long);
seg_maps[index++].resource = NULL;
}
}

ctx->dynamic_remote_info =
ucc_calloc(n_segments, sizeof(ucc_tl_ucp_remote_info_t), "dynamic remote info");
/* map memory and fill in local segment information */
for (i = 0; i < n_segments; i++) {
status = ucc_tl_ucp_memmap_segment(task, &maps[i], i);
if (status != UCC_OK) {
tl_error(UCC_TASK_LIB(task), "failed to memory map a segment");
if (n_segments > 0) {
ctx->dynamic_remote_info =
ucc_calloc(n_segments, sizeof(ucc_tl_ucp_remote_info_t), "dynamic remote info");
if (!ctx->dynamic_remote_info) {
tl_error(UCC_TL_UCP_TEAM_LIB(tl_team), "Out of Memory");
status = UCC_ERR_NO_MEMORY;
goto failed_memory_map;
}
++ctx->n_dynrinfo_segs;
}
if (coll_args->mem_map.n_segments == 0) {
free(maps);
/* map memory and fill in local segment information */
for (i = 0; i < n_segments; i++) {
status = ucc_tl_ucp_memmap_segment(task, &seg_maps[i], i);
if (status != UCC_OK) {
tl_error(UCC_TASK_LIB(task), "failed to memory map a segment");
goto failed_memory_map;
}
++ctx->n_dynrinfo_segs;
}
for (i = 0; i < coll_args->mem_map.n_segments; i++) {
status = ucc_tl_ucp_memmap_segment(task, &maps[i], i + n_segments);
if (status != UCC_OK) {
tl_error(UCC_TASK_LIB(task), "failed to memory map a segment");
goto failed_memory_map;
}
++ctx->n_dynrinfo_segs;
}
if (n_segments) {
free(seg_maps);
}
}
return UCC_OK;
failed_memory_map:
Expand All @@ -289,8 +345,8 @@ ucc_status_t ucc_tl_ucp_coll_dynamic_segment_init(ucc_coll_args_t *coll_args,
}
}
ctx->n_dynrinfo_segs = 0;
if (coll_args->mem_map.n_segments == 0) {
free(maps);
if (n_segments) {
ucc_free(seg_maps);
}
return status;
}
Expand Down Expand Up @@ -387,6 +443,7 @@ ucc_status_t ucc_tl_ucp_coll_dynamic_segment_exchange(ucc_tl_ucp_task_t *task)
status = UCC_ERR_NO_MEMORY;
goto failed_data_exch;
}
ctx->dyn_seg_size = seg_pack_size;
ucc_free(ex_buffer);
}
return UCC_OK;
Expand All @@ -405,19 +462,24 @@ ucc_status_t ucc_tl_ucp_coll_dynamic_segment_exchange(ucc_tl_ucp_task_t *task)
return status;
}

void ucc_tl_ucp_coll_dynamic_segment_finalize(ucc_tl_ucp_task_t *task)
ucc_status_t ucc_tl_ucp_coll_dynamic_segment_finalize(ucc_tl_ucp_task_t *task)
{
ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task);
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
int i = 0;
int j = 0;
ucs_status_t status;
/* free library resources, unmap user resources */
if (ctx->dyn_seg_buf) {
/* unmap and release packed buffers */
for (i = 0; i < ctx->n_dynrinfo_segs; i++) {
if (ctx->dynamic_remote_info[i].mem_h) {
ucp_mem_unmap(ctx->worker.ucp_context,
status = ucp_mem_unmap(ctx->worker.ucp_context,
ctx->dynamic_remote_info[i].mem_h);
if (UCS_OK != status) {
tl_error(UCC_TL_UCP_TEAM_LIB(tl_team), "Failed to unmap memory");
return ucs_status_to_ucc_status(status);
wfaderhold21 marked this conversation as resolved.
Show resolved Hide resolved
}
}
if (ctx->dynamic_remote_info[i].packed_key) {
ucp_rkey_buffer_release(ctx->dynamic_remote_info[i].packed_key);
Expand All @@ -435,15 +497,17 @@ void ucc_tl_ucp_coll_dynamic_segment_finalize(ucc_tl_ucp_task_t *task)
}
}
}
free(ctx->dynamic_remote_info);
free(ctx->dyn_rkeys);
free(ctx->dyn_seg_buf);
ucc_free(ctx->dynamic_remote_info);
ucc_free(ctx->dyn_rkeys);
ucc_free(ctx->dyn_seg_buf);

ctx->dynamic_remote_info = NULL;
ctx->dyn_rkeys = NULL;
ctx->dyn_seg_buf = NULL;
ctx->dyn_seg_size = 0;
ctx->n_dynrinfo_segs = 0;
}
return UCC_OK;
}

ucc_status_t ucc_tl_ucp_coll_init(ucc_base_coll_args_t *coll_args,
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,6 @@ ucc_status_t ucc_tl_ucp_coll_dynamic_segment_init(ucc_coll_args_t *coll_args,

ucc_status_t ucc_tl_ucp_coll_dynamic_segment_exchange(ucc_tl_ucp_task_t *task);

void ucc_tl_ucp_coll_dynamic_segment_finalize(ucc_tl_ucp_task_t *task);
ucc_status_t ucc_tl_ucp_coll_dynamic_segment_finalize(ucc_tl_ucp_task_t *task);

#endif
2 changes: 1 addition & 1 deletion src/components/tl/ucp/tl_ucp_sendrecv.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, size_t msglen,
}

section_offset = sizeof(uint64_t) * ctx->n_dynrinfo_segs;
base_offset = (ptrdiff_t)(ctx->dyn_seg_buf);
base_offset = (ptrdiff_t)PTR_OFFSET(ctx->dyn_seg_buf, peer * ctx->dyn_seg_size);
rvas = (uint64_t *)base_offset;
key_sizes = PTR_OFFSET(base_offset, (section_offset * 2));
keys = PTR_OFFSET(base_offset, (section_offset * 3));
Expand Down
Loading