Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT : Register memory #1010

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions src/components/ec/base/ucc_ec_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,15 @@ typedef struct ucc_ee_executor_task {
ucc_ee_executor_t *eee;
ucc_ee_executor_task_args_t args;
ucc_status_t status;
void *completion;
} ucc_ee_executor_task_t;

typedef struct node_ucc_ee_executor_task node_ucc_ee_executor_task_t;
typedef struct node_ucc_ee_executor_task {
ucc_ee_executor_task_t *etask;
node_ucc_ee_executor_task_t *next;
} node_ucc_ee_executor_task_t;

typedef struct ucc_ee_executor_ops {
ucc_status_t (*init)(const ucc_ee_executor_params_t *params,
ucc_ee_executor_t **executor);
Expand Down
209 changes: 194 additions & 15 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "coll_patterns/sra_knomial.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"
#include <stdio.h>

#define SAVE_STATE(_phase) \
do { \
Expand Down Expand Up @@ -50,6 +51,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
args->root : 0;
ucc_rank_t rank = VRANK(task->subset.myrank, broot, size);
size_t local = GET_LOCAL_COUNT(args, size, rank);
ucp_mem_h *mh_list = task->allgather_kn.mh_list;
void *sbuf;
ptrdiff_t peer_seg_offset, local_seg_offset;
ucc_rank_t peer, peer_dist;
Expand All @@ -65,32 +67,36 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
if (KN_NODE_EXTRA == node_type) {
peer = ucc_knomial_pattern_get_proxy(p, rank);
if (p->type != KN_PATTERN_ALLGATHERX) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->allgather_kn.sbuf,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(task->allgather_kn.sbuf,
local * dt_size, mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer,broot,size)),
team, task),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);

}
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(rbuf, data_size, mem_type,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(rbuf, data_size, mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer,broot,size)),
team, task),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}
if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) {
peer = ucc_knomial_pattern_get_extra(p, rank);
extra_count = GET_LOCAL_COUNT(args, size, peer);
peer = ucc_ep_map_eval(task->subset.map, peer);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(task->allgather_kn.sbuf,
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb_with_mem(PTR_OFFSET(task->allgather_kn.sbuf,
local * dt_size), extra_count * dt_size,
mem_type, peer, team, task),
mem_type, peer, team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}

UCC_KN_PHASE_EXTRA:
if ((KN_NODE_EXTRA == node_type) || (KN_NODE_PROXY == node_type)) {
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) {
SAVE_STATE(UCC_KN_PHASE_EXTRA);
return;
}
Expand All @@ -114,12 +120,13 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
continue;
}
}
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(sbuf, local_seg_count * dt_size,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(sbuf, local_seg_count * dt_size,
mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
team, task),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}

for (loop_step = 1; loop_step < radix; loop_step++) {
Expand All @@ -137,15 +144,16 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
}
}
UCPCHECK_GOTO(
ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, peer_seg_offset * dt_size),
ucc_tl_ucp_recv_nb_with_mem(PTR_OFFSET(rbuf, peer_seg_offset * dt_size),
peer_seg_count * dt_size, mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
team, task),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}
UCC_KN_PHASE_LOOP:
if (UCC_INPROGRESS == ucc_tl_ucp_test_recv(task)) {
if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks(task)) {
SAVE_STATE(UCC_KN_PHASE_LOOP);
return;
}
Expand All @@ -154,20 +162,22 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)

if (KN_NODE_PROXY == node_type) {
peer = ucc_knomial_pattern_get_extra(p, rank);
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(args->dst.info.buffer, data_size,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(args->dst.info.buffer, data_size,
mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
team, task),
team, task, mh_list[task->allgather_kn.count_mh++]),
task, out);
ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh);
}
UCC_KN_PHASE_PROXY:
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) {
SAVE_STATE(UCC_KN_PHASE_PROXY);
return;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add in out
|ucc_assert(count_mh == max_count) or smth similar


out:
ucc_assert(task->allgather_kn.count_mh-1 == task->allgather_kn.max_mh);
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_done", 0);
Expand Down Expand Up @@ -234,25 +244,194 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t register_memory(ucc_coll_task_t *coll_task){

ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_type_t ct = args->coll_type;
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
void *rbuf = args->dst.info.buffer;
ucc_memory_type_t mem_type = args->dst.info.mem_type;
size_t count = args->dst.info.count;
size_t dt_size = ucc_dt_size(args->dst.info.datatype);
size_t data_size = count * dt_size;
ucc_rank_t size = task->subset.map.ep_num;
ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ?
args->root : 0;
ucc_rank_t rank = VRANK(task->subset.myrank, broot, size);
size_t local = GET_LOCAL_COUNT(args, size, rank);
void *sbuf;
ptrdiff_t peer_seg_offset, local_seg_offset;
ucc_rank_t peer, peer_dist;
ucc_kn_radix_t loop_step;
size_t peer_seg_count, local_seg_count;
ucc_status_t status;
size_t extra_count;

ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
ucp_mem_map_params_t mmap_params;
int size_of_list = 1;
int count_mh = 0;
ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h));

ptrdiff_t offset;

if (ct == UCC_COLL_TYPE_ALLGATHER) {
ucc_kn_ag_pattern_init(size, rank, radix, args->dst.info.count,
&task->allgather_kn.p);
} else {
ucc_kn_agx_pattern_init(size, rank, radix, args->dst.info.count,
&task->allgather_kn.p);
}

offset = ucc_sra_kn_get_offset(count,
dt_size, rank,
size, radix);
task->allgather_kn.sbuf = PTR_OFFSET(args->dst.info.buffer, offset);

ucc_knomial_pattern_t *p = &task->allgather_kn.p;
uint8_t node_type = task->allgather_kn.p.node_type;

mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
mmap_params.memory_type = ucc_memtype_to_ucs[mem_type];
if (KN_NODE_EXTRA == node_type) {
if (p->type != KN_PATTERN_ALLGATHERX) {
mmap_params.address = task->allgather_kn.sbuf;
mmap_params.length = local * dt_size;
MEM_MAP();
}

mmap_params.address = rbuf;
mmap_params.length = data_size;
MEM_MAP();
}
if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) {
peer = ucc_knomial_pattern_get_extra(p, rank);
extra_count = GET_LOCAL_COUNT(args, size, peer);
peer = ucc_ep_map_eval(task->subset.map, peer);
mmap_params.address = PTR_OFFSET(task->allgather_kn.sbuf,
local * dt_size);
mmap_params.length = extra_count * dt_size;
MEM_MAP();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (KN_NODE_EXTRA == node_type) {
goto out;
}
while (!ucc_knomial_pattern_loop_done(p)) {
ucc_kn_ag_pattern_peer_seg(rank, p, &local_seg_count,
&local_seg_offset);
sbuf = PTR_OFFSET(rbuf, local_seg_offset * dt_size);
for (loop_step = radix - 1; loop_step > 0; loop_step--) {
peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step);
if (peer == UCC_KN_PEER_NULL)
continue;
if (coll_task->bargs.args.coll_type == UCC_COLL_TYPE_BCAST) {
peer_dist = ucc_knomial_calc_recv_dist(size - p->n_extra,
ucc_knomial_pattern_loop_rank(p, peer), p->radix, 0);
if (peer_dist < task->allgather_kn.recv_dist) {
continue;
}
}
mmap_params.address = sbuf;
mmap_params.length = local_seg_count * dt_size;
MEM_MAP();
}

for (loop_step = 1; loop_step < radix; loop_step++) {
peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step);
if (peer == UCC_KN_PEER_NULL)
continue;
ucc_kn_ag_pattern_peer_seg(peer, p, &peer_seg_count,
&peer_seg_offset);

if (coll_task->bargs.args.coll_type == UCC_COLL_TYPE_BCAST) {
peer_dist = ucc_knomial_calc_recv_dist(size - p->n_extra,
ucc_knomial_pattern_loop_rank(p, peer), p->radix, 0);
if (peer_dist > task->allgather_kn.recv_dist) {
continue;
}
}
mmap_params.address = PTR_OFFSET(rbuf, peer_seg_offset * dt_size);
mmap_params.length = peer_seg_count * dt_size;
MEM_MAP();
}
ucc_kn_ag_pattern_next_iter(p);
}

if (KN_NODE_PROXY == node_type) {
mmap_params.address = args->dst.info.buffer;
mmap_params.length = data_size;
MEM_MAP();
}

out:
task->allgather_kn.mh_list = mh_list;
task->allgather_kn.max_mh = count_mh-1;
task->allgather_kn.count_mh = 0;
return UCC_OK;
}

ucc_status_t ucc_tl_ucp_allgather_knomial_finalize(ucc_coll_task_t *coll_task){
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_status_t status;
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);

ucc_mpool_cleanup(&task->allgather_kn.etask_node_mpool, 1);
for (int i=0; i<task->allgather_kn.max_mh+1; i++){
ucp_mem_unmap(ctx->worker.ucp_context, task->allgather_kn.mh_list[i]);
}
free(task->allgather_kn.mh_list);
status = ucc_tl_ucp_coll_finalize(&task->super);
if (status < 0){
tl_error(UCC_TASK_LIB(task),
"failed to initialize ucc_mpool");
}

return UCC_OK;
}

ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task_h, ucc_kn_radix_t radix)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_task_t *task;
ucc_sbgp_t *sbgp;
ucc_status_t status;

task = ucc_tl_ucp_init_task(coll_args, team);
status = ucc_mpool_init(&task->allgather_kn.etask_node_mpool, 0, sizeof(node_ucc_ee_executor_task_t),
0, UCC_CACHE_LINE_SIZE, 16, UINT_MAX, NULL,
tl_team->super.super.context->ucc_context->thread_mode, "etasks_linked_list_nodes");
if (status < 0){
tl_error(UCC_TASK_LIB(task),
"failed to initialize ucc_mpool");
}

if (tl_team->cfg.use_reordering &&
coll_args->args.coll_type == UCC_COLL_TYPE_ALLREDUCE) {
sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED);
task->subset.myrank = sbgp->group_rank;
task->subset.map = sbgp->map;
}
task->allgather_kn.etask_linked_list_head = NULL;
task->allgather_kn.p.radix = radix;
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
task->super.post = ucc_tl_ucp_allgather_knomial_start;
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize;
status = register_memory(&task->super);
if (status < 0){
tl_error(UCC_TASK_LIB(task),
"failed to register memory");
}
*task_h = &task->super;
return UCC_OK;
}
Expand Down
Loading