diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index af4af7851d..c6dffe6bea 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -74,7 +74,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) INV_VRANK(peer,broot,size)), team, task, mh_list[task->allgather_kn.count_mh++]), task, out); - ucc_assert(task->allgather_kn.count_mh >= max_mh); + ucc_assert(task->allgather_kn.count_mh-1 <= max_mh); } UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(rbuf, data_size, mem_type, @@ -82,7 +82,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) INV_VRANK(peer,broot,size)), team, task, mh_list[task->allgather_kn.count_mh++]), task, out); - ucc_assert(task->allgather_kn.count_mh >= max_mh); + ucc_assert(task->allgather_kn.count_mh-1 <= max_mh); } if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) { peer = ucc_knomial_pattern_get_extra(p, rank); @@ -92,7 +92,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) local * dt_size), extra_count * dt_size, mem_type, peer, team, task, mh_list[task->allgather_kn.count_mh++]), task, out); - ucc_assert(task->allgather_kn.count_mh >= max_mh); + ucc_assert(task->allgather_kn.count_mh-1 <= max_mh); } UCC_KN_PHASE_EXTRA: @@ -121,14 +121,13 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) continue; } } - printf("progress : count_mh: %d, mh: %lx\n", task->allgather_kn.count_mh, (unsigned long)mh_list[task->allgather_kn.count_mh]); UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(sbuf, local_seg_count * dt_size, mem_type, ucc_ep_map_eval(task->subset.map, INV_VRANK(peer, broot, size)), team, task, mh_list[task->allgather_kn.count_mh++]), task, out); - ucc_assert(task->allgather_kn.count_mh >= max_mh); + ucc_assert(task->allgather_kn.count_mh-1 <= max_mh); } for (loop_step = 1; loop_step < radix; loop_step++) { @@ -152,7 +151,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) INV_VRANK(peer, broot, size)), team, task, mh_list[task->allgather_kn.count_mh++]), task, out); - ucc_assert(task->allgather_kn.count_mh >= max_mh); + ucc_assert(task->allgather_kn.count_mh-1 <= max_mh); } UCC_KN_PHASE_LOOP: if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks(task)) { @@ -170,7 +169,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) INV_VRANK(peer, broot, size)), team, task, mh_list[task->allgather_kn.count_mh++]), task, out); - ucc_assert(task->allgather_kn.count_mh >= max_mh); + ucc_assert(task->allgather_kn.count_mh-1 <= max_mh); } UCC_KN_PHASE_PROXY: if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) { @@ -252,6 +251,7 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){ ucc_tl_ucp_task_t); ucc_coll_args_t *args = &TASK_ARGS(task); ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_type_t ct = args->coll_type; ucc_kn_radix_t radix = task->allgather_kn.p.radix; uint8_t node_type = task->allgather_kn.p.node_type; ucc_knomial_pattern_t *p = &task->allgather_kn.p; @@ -273,18 +273,28 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){ ucc_status_t status; size_t extra_count; - ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); - ucp_mem_map_params_t mmap_params; - ucp_mem_h mh; - int size_of_list = 1; - int count_mh = 0; - ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h)); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + ucp_mem_map_params_t mmap_params; + // ucp_mem_h mh; + int size_of_list = 1; + int count_mh = 0; + ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h)); + + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0); + task->allgather_kn.etask = NULL; + task->allgather_kn.phase = UCC_KN_PHASE_INIT; + if (ct == UCC_COLL_TYPE_ALLGATHER) { + ucc_kn_ag_pattern_init(size, rank, radix, args->dst.info.count, + &task->allgather_kn.p); + } else { + ucc_kn_agx_pattern_init(size, rank, radix, args->dst.info.count, + &task->allgather_kn.p); + } mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH | UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; mmap_params.memory_type = ucc_memtype_to_ucs[mem_type]; - printf("I'm in register memory"); if (KN_NODE_EXTRA == node_type) { if (p->type != KN_PATTERN_ALLGATHERX) { mmap_params.address = task->allgather_kn.sbuf; @@ -310,13 +320,10 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){ goto out; } while (!ucc_knomial_pattern_loop_done(p)) { - printf("in the while loop"); ucc_kn_ag_pattern_peer_seg(rank, p, &local_seg_count, &local_seg_offset); sbuf = PTR_OFFSET(rbuf, local_seg_offset * dt_size); - for (loop_step = radix - 1; loop_step > 0; loop_step--) { - printf("in the for loop"); peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step); if (peer == UCC_KN_PEER_NULL) continue; @@ -329,7 +336,6 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){ } mmap_params.address = sbuf; mmap_params.length = local_seg_count * dt_size; - printf("register memory : count_mh: %d, mh: %lx\n", count_mh, (unsigned long)mh_list[count_mh]); MEM_MAP(); } @@ -370,12 +376,23 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){ ucc_status_t ucc_tl_ucp_allgather_knomial_finalize(ucc_coll_task_t *coll_task){ ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_status_t status; + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); ucc_mpool_cleanup(&task->allgather_kn.etask_node_mpool, 1); + for (int i=0; iallgather_kn.max_mh+1; i++){ + ucp_mem_unmap(ctx->worker.ucp_context, task->allgather_kn.mh_list[i]); + } free(task->allgather_kn.mh_list); + status = ucc_tl_ucp_coll_finalize(&task->super); + if (status < 0){ + tl_error(UCC_TASK_LIB(task), + "failed to initialize ucc_mpool"); + } return UCC_OK; -}; +} ucc_status_t ucc_tl_ucp_allgather_knomial_init_r( ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, @@ -401,17 +418,17 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r( task->subset.myrank = sbgp->group_rank; task->subset.map = sbgp->map; } - status = register_memory(&task->super); - if (status < 0){ - tl_error(UCC_TASK_LIB(task), - "failed to register memory"); - } task->allgather_kn.etask_linked_list_head = NULL; task->allgather_kn.p.radix = radix; task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; task->super.post = ucc_tl_ucp_allgather_knomial_start; task->super.progress = ucc_tl_ucp_allgather_knomial_progress; task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize; + status = register_memory(&task->super); + if (status < 0){ + tl_error(UCC_TASK_LIB(task), + "failed to register memory"); + } *task_h = &task->super; return UCC_OK; } diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 8bda0eb94a..4347ab2874 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -58,7 +58,7 @@ void ucc_tl_ucp_team_default_score_str_free( } while(0) #define MEM_MAP() do { \ - status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); \ + status = ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh_list[count_mh++])); \ if (UCC_OK != status) { \ return status; \ } \ @@ -66,7 +66,6 @@ void ucc_tl_ucp_team_default_score_str_free( size_of_list *= 2; \ mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); \ } \ - mh_list[count_mh++] = mh; \ } while(0) #define EXEC_TASK_WAIT(_etask, ...) \ @@ -503,7 +502,7 @@ static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *t while(current_node != NULL) { status = ucc_ee_executor_task_test(current_node->etask); if (status > 0) { - ucp_memcpy_device_complete(current_node->etask->completion, status); + ucp_memcpy_device_complete(current_node->etask->completion, ucc_status_to_ucs_status(status)); status_2 = ucc_ee_executor_task_finalize(current_node->etask); ucc_mpool_put(current_node); if (ucc_unlikely(status_2 < 0)){ @@ -517,9 +516,7 @@ static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *t task->allgather_kn.etask_linked_list_head = current_node->next; } } - else { - prev_node = current_node; - } + prev_node = current_node; current_node = current_node->next; //to iterate to next node } if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) {