diff --git a/src/components/tl/ucc_tl.h b/src/components/tl/ucc_tl.h index 53e62052dc..5c3171cbb8 100644 --- a/src/components/tl/ucc_tl.h +++ b/src/components/tl/ucc_tl.h @@ -64,6 +64,12 @@ typedef struct ucc_tl_service_coll { typedef struct ucc_tl_coll_plugin_iface { ucc_component_iface_t super; ucs_config_global_list_entry_t config; + ucc_status_t (*context_create)(const ucc_base_context_params_t *params, + const ucc_base_config_t *config, + ucc_base_context_t *tl_ctx, + void **plugin_ctx); + ucc_status_t (*context_destroy)(ucc_base_context_t *ctx, + void *plugin_ctx); ucc_get_coll_scores_fn_t get_scores; uint32_t id; } ucc_tl_coll_plugin_iface_t; @@ -88,8 +94,9 @@ typedef struct ucc_tl_lib { UCC_CLASS_DECLARE(ucc_tl_lib_t, ucc_tl_iface_t *, const ucc_tl_lib_config_t *); typedef struct ucc_tl_context { - ucc_base_context_t super; - int ref_count; + ucc_base_context_t super; + int ref_count; + void *coll_plugin_context; } ucc_tl_context_t; UCC_CLASS_DECLARE(ucc_tl_context_t, const ucc_tl_context_config_t *, ucc_context_t *); diff --git a/src/components/tl/ucp/coll_plugins/example/Makefile.am b/src/components/tl/ucp/coll_plugins/example/Makefile.am index 4d27082ec1..5a695ef5bb 100644 --- a/src/components/tl/ucp/coll_plugins/example/Makefile.am +++ b/src/components/tl/ucp/coll_plugins/example/Makefile.am @@ -1,9 +1,13 @@ # -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # if TLCP_UCP_EXAMPLE_ENABLED -sources = example.c + +sources = \ + example_ctx.h \ + example_ctx.c \ + example_coll.c module_LTLIBRARIES = libucc_tlcp_ucp_example.la libucc_tlcp_ucp_example_la_SOURCES = $(sources) diff --git a/src/components/tl/ucp/coll_plugins/example/example.c b/src/components/tl/ucp/coll_plugins/example/example.c deleted file mode 100644 index 60a529dc6c..0000000000 --- a/src/components/tl/ucp/coll_plugins/example/example.c +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See file LICENSE for terms. - */ - -#include "config.h" -#include "components/tl/ucp/tl_ucp.h" -#include "components/tl/ucp/tl_ucp_coll.h" -#include "core/ucc_progress_queue.h" -#include "components/tl/ucp/tl_ucp_sendrecv.h" -#include "coll_patterns/recursive_knomial.h" -#include "coll_score/ucc_coll_score.h" -#include "utils/ucc_math.h" - -ucc_tl_coll_plugin_iface_t ucc_tlcp_ucp_example; - -typedef struct ucc_tlcp_ucp_example_config { - char *score_str; -} ucc_tlcp_ucp_example_config_t; - -#define CONFIG(_lib) ((ucc_tlcp_ucp_example_config_t*)((_lib)->tlcp_configs[ucc_tlcp_ucp_example.id])) - -static ucc_config_field_t ucc_tlcp_ucp_example_table[] = { - {"TLCP_EXAMPLE_TUNE", "", "Collective score modifier", - ucc_offsetof(ucc_tlcp_ucp_example_config_t, score_str), UCC_CONFIG_TYPE_STRING}, - - {NULL}}; - -static ucs_config_global_list_entry_t ucc_tlcp_ucp_example_cfg_entry = -{ - .name = "TLCP_EXAMPLE", - .prefix = "TL_UCP_", - .table = ucc_tlcp_ucp_example_table, - .size = sizeof(ucc_tlcp_ucp_example_config_t) -}; - -UCC_CONFIG_REGISTER_TABLE_ENTRY(&ucc_tlcp_ucp_example_cfg_entry, - &ucc_config_global_list); - -#define UCC_TLCP_UCP_EXAMPLE_SCORE 100 -void ucc_tlcp_ucp_example_progress(ucc_coll_task_t *coll_task) -{ - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - - tl_info(TASK_LIB(task), "completing tl_ucp_example coll task"); - - ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); - task->super.status = UCC_OK; -} - -ucc_status_t ucc_tlcp_ucp_example_start(ucc_coll_task_t *coll_task) -{ - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - - tl_info(TASK_LIB(task), "starting tl_ucp_example coll task"); - - task->super.status = UCC_INPROGRESS; - ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); - - return UCC_OK; -} - -ucc_status_t ucc_tlcp_ucp_example_coll_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h) -{ - ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); - ucc_tl_ucp_task_t *task = ucc_tl_ucp_get_task(tl_team); - - ucc_coll_task_init(&task->super, coll_args, team); - task->tagged.tag = tl_team->seq_num; - tl_team->seq_num = (tl_team->seq_num + 1) % UCC_TL_UCP_MAX_COLL_TAG; - task->super.finalize = ucc_tl_ucp_coll_finalize; - task->super.post = ucc_tlcp_ucp_example_start; - task->super.progress = ucc_tlcp_ucp_example_progress; - *task_h = &task->super; - return UCC_OK; -} - -ucc_status_t ucc_tlcp_ucp_example_get_scores(ucc_base_team_t *tl_team, - ucc_coll_score_t **score_p) -{ - ucc_tl_ucp_team_t *team = ucc_derived_of(tl_team, ucc_tl_ucp_team_t); - ucc_tl_ucp_lib_t *lib = UCC_TL_UCP_TEAM_LIB(team); - const char *score_str; - ucc_coll_score_t *score; - ucc_status_t status; - - /* There can be a different logic for different coll_type/mem_type. - Right now just init everything the same way. */ - status = ucc_coll_score_alloc(&score); - if (UCC_OK != status) { - tl_error(lib, "failed to alloc score"); - return status; - } - status = ucc_coll_score_add_range(score, UCC_COLL_TYPE_ALLTOALL, UCC_MEMORY_TYPE_HOST, - 0, 4096, UCC_TLCP_UCP_EXAMPLE_SCORE, - ucc_tlcp_ucp_example_coll_init, tl_team); - if (UCC_OK != status) { - tl_error(lib, "failed to add range"); - return status; - } - score_str = CONFIG(lib)->score_str; - if (strlen(score_str) > 0) { - status = ucc_coll_score_update_from_str(score_str, score, UCC_TL_TEAM_SIZE(team), - ucc_tlcp_ucp_example_coll_init, - &team->super.super, UCC_TLCP_UCP_EXAMPLE_SCORE, - NULL); - if (status == UCC_ERR_INVALID_PARAM) { - /* User provided incorrect input - try to proceed */ - status = UCC_OK; - } - } - *score_p = score; - return status; -} - -ucc_tl_coll_plugin_iface_t ucc_tlcp_ucp_example = { - .super.name = "tl_ucp_example", - .super.score = UCC_TLCP_UCP_EXAMPLE_SCORE, - .config.table = ucc_tlcp_ucp_example_table, - .config.size = sizeof(ucc_tlcp_ucp_example_config_t), - .get_scores = ucc_tlcp_ucp_example_get_scores -}; diff --git a/src/components/tl/ucp/coll_plugins/example/example_coll.c b/src/components/tl/ucp/coll_plugins/example/example_coll.c new file mode 100644 index 0000000000..2154e31a18 --- /dev/null +++ b/src/components/tl/ucp/coll_plugins/example/example_coll.c @@ -0,0 +1,341 @@ +/** + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "example_ctx.h" +#include "components/tl/ucp/tl_ucp_sendrecv.h" +#include "utils/ucc_dt_reduce.h" + +#define SAVE_STATE(_phase) \ + do { \ + task->allreduce_kn.phase = _phase; \ + } while (0) + +static inline ucc_status_t +ucc_tl_ucp_send_am(void *buffer, size_t msglen, ucc_memory_type_t mtype, + ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task) +{ + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_status_t status; + ucp_ep_h ep; + ucp_request_param_t req_param; + ucs_status_ptr_t ucp_status; + ucp_tag_t ucp_tag; + + ucp_tag = UCC_TL_UCP_MAKE_SEND_TAG((args->mask & UCC_COLL_ARGS_FIELD_TAG), + task->tagged.tag, UCC_TL_TEAM_RANK(team), + team->super.super.params.id, + team->super.super.params.scope_id, + team->super.super.params.scope); + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA | + UCP_OP_ATTR_FIELD_MEMORY_TYPE | + UCP_OP_ATTR_FIELD_FLAGS; + req_param.datatype = ucp_dt_make_contig(msglen); + req_param.cb.send = ucc_tl_ucp_send_completion_cb; + req_param.memory_type = ucc_memtype_to_ucs[mtype]; + req_param.user_data = (void*)task; + req_param.flags = UCP_AM_SEND_FLAG_EAGER | + UCP_AM_SEND_FLAG_COPY_HEADER; + status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + task->tagged.send_posted++; + ucp_status = ucp_am_send_nbx(ep, 1, &ucp_tag, sizeof(ucp_tag), buffer, + 1, &req_param); + if (UCS_OK != ucp_status) { + UCC_TL_UCP_CHECK_REQ_STATUS(); + } else { + task->tagged.send_completed++; + } + return UCC_OK; +} + +static inline ucc_status_t +ucc_tl_ucp_check_am_recv(ucc_tlcp_ucp_example_am_msg_t **recv, + ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task) +{ + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_context_t *tl_ucp_ctx = TASK_CTX(task); + ucc_tlcp_ucp_example_context_t *plugin_ctx; + ucc_tlcp_ucp_example_am_msg_t *entry; + ucp_tag_t ucp_tag, ucp_tag_mask; + + plugin_ctx = (ucc_tlcp_ucp_example_context_t*) + tl_ucp_ctx->super.coll_plugin_context; + + UCC_TL_UCP_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, + (args->mask & UCC_COLL_ARGS_FIELD_TAG), + task->tagged.tag, dest_group_rank, + team->super.super.params.id, + team->super.super.params.scope_id, + team->super.super.params.scope); + + ucc_assert(ucp_tag_mask != 0); + ucc_list_for_each(entry, &plugin_ctx->am_list, list_elem) { + if (entry->tag == ucp_tag) { + *recv = entry; + return UCC_OK; + } + } + ucp_worker_progress(UCC_TL_UCP_TASK_TEAM(task)->worker->ucp_worker); + return UCC_INPROGRESS; +} + +static inline void +ucc_tl_ucp_put_am_msg(ucc_tl_ucp_task_t *task, + ucc_tlcp_ucp_example_am_msg_t *recv) +{ + ucp_am_data_release(TASK_CTX(task)->worker.ucp_worker, recv->msg); + ucc_list_del(&recv->list_elem); + ucc_free(recv); +} + +ucc_status_t ucc_tl_ucp_allreduce_knomial_am_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_status_t st; + + st = ucc_tl_ucp_coll_finalize(&task->super); + if (ucc_unlikely(st != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed finalize collective"); + } + return st; +} + +void ucc_tl_ucp_allreduce_knomial_am_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + int avg_pre_op = team->cfg.reduce_avg_pre_op; + ucc_kn_radix_t radix = task->allreduce_kn.p.radix; + uint8_t node_type = task->allreduce_kn.p.node_type; + ucc_knomial_pattern_t *p = &task->allreduce_kn.p; + void *sbuf = args->src.info.buffer; + void *rbuf = args->dst.info.buffer; + ucc_memory_type_t mem_type = args->dst.info.mem_type; + size_t count = args->dst.info.count; + ucc_datatype_t dt = args->dst.info.datatype; + size_t data_size = count * ucc_dt_size(dt); + ucc_rank_t rank = task->subset.myrank; + void *send_buf; + ptrdiff_t recv_offset; + ucc_rank_t peer; + ucc_status_t status; + ucc_kn_radix_t loop_step; + int is_avg, k; + void *srcs[8]; + ucc_tlcp_ucp_example_am_msg_t *recv; + + if (UCC_IS_INPLACE(*args)) { + sbuf = rbuf; + } + UCC_KN_REDUCE_GOTO_PHASE(task->allreduce_kn.phase); + + if (KN_NODE_EXTRA == node_type) { + peer = ucc_ep_map_eval(task->subset.map, + ucc_knomial_pattern_get_proxy(p, rank)); + UCPCHECK_GOTO( + ucc_tl_ucp_send_am(sbuf, data_size, mem_type, peer, team, task), + task, out); + UCPCHECK_GOTO( + ucc_tl_ucp_recv_nb(rbuf, data_size, mem_type, peer, team, task), + task, out); + } + +UCC_KN_PHASE_EXTRA: + if (KN_NODE_PROXY == node_type || KN_NODE_EXTRA == node_type) { + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + SAVE_STATE(UCC_KN_PHASE_EXTRA); + return; + } + if (KN_NODE_EXTRA == node_type) { + goto completion; + } else { + peer = ucc_ep_map_eval(task->subset.map, + ucc_knomial_pattern_get_extra(p, rank)); + status = ucc_tl_ucp_check_am_recv(&recv, peer, team, task); + if (status == UCC_INPROGRESS) { + SAVE_STATE(UCC_KN_PHASE_EXTRA); + return; + } + status = ucc_dt_reduce(sbuf, recv->msg, rbuf, count, dt, args, 0, 0, + task->allreduce_kn.executor, + &task->allreduce_kn.etask); + + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction"); + task->super.status = status; + return; + } +UCC_KN_PHASE_EXTRA_REDUCE: + EXEC_TASK_TEST(UCC_KN_PHASE_EXTRA_REDUCE, + "failed to perform dt reduction", + task->allreduce_kn.etask); + ucc_tl_ucp_put_am_msg(task, recv); + } + } + while(!ucc_knomial_pattern_loop_done(p)) { + for (loop_step = 1; loop_step < radix; loop_step++) { + peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step); + if (peer == UCC_KN_PEER_NULL) { + continue; + } + peer = ucc_ep_map_eval(task->subset.map, peer); + if ((ucc_knomial_pattern_loop_first_iteration(p)) && + (KN_NODE_PROXY != node_type) && !UCC_IS_INPLACE(*args)) { + send_buf = sbuf; + } else { + send_buf = rbuf; + } + UCPCHECK_GOTO( + ucc_tl_ucp_send_am(send_buf, data_size, mem_type, peer, team, + task), + task, out); + } + + UCC_KN_PHASE_LOOP: + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + SAVE_STATE(UCC_KN_PHASE_LOOP); + return; + } + recv_offset = 0; + for (loop_step = 1, k = 1; loop_step < radix; loop_step++) { + peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step); + if (peer == UCC_KN_PEER_NULL) { + continue; + } + peer = ucc_ep_map_eval(task->subset.map, peer); + status = ucc_tl_ucp_check_am_recv(&recv, peer, team, task); + if (status != UCC_OK) { + SAVE_STATE(UCC_KN_PHASE_LOOP); + return; + } + srcs[k] = recv->msg; + recv_offset += data_size; + k++; + } + + if (task->tagged.send_posted > p->iteration * (radix - 1)) { + if ((ucc_knomial_pattern_loop_first_iteration(p)) && + (KN_NODE_PROXY != node_type) && !UCC_IS_INPLACE(*args)) { + send_buf = sbuf; + } else { + send_buf = rbuf; + } + is_avg = args->op == UCC_OP_AVG && + (avg_pre_op ? ucc_knomial_pattern_loop_first_iteration(p) + : ucc_knomial_pattern_loop_last_iteration(p)); + srcs[0] = send_buf; + status = ucc_dt_reduce_vec( + srcs, rbuf, + task->tagged.send_posted - p->iteration * (radix - 1) + 1, count, + dt, args, + UCC_EEE_TASK_FLAG_REDUCE_SRCS_EXT | + (is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0), + AVG_ALPHA(task), task->allreduce_kn.executor, + &task->allreduce_kn.etask); + + if (ucc_unlikely(UCC_OK != status)) { + tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction"); + task->super.status = status; + return; + } +UCC_KN_PHASE_REDUCE: + EXEC_TASK_TEST(UCC_KN_PHASE_REDUCE, + "failed to perform dt reduction", + task->allreduce_kn.etask); + for (loop_step = 1; loop_step < radix; loop_step++) { + peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step); + if (peer == UCC_KN_PEER_NULL) + continue; + peer = ucc_ep_map_eval(task->subset.map, peer); + status = ucc_tl_ucp_check_am_recv(&recv, peer, team, task); + ucc_tl_ucp_put_am_msg(task, recv); + } + } + ucc_knomial_pattern_next_iteration(p); + } + if (KN_NODE_PROXY == node_type) { + peer = ucc_ep_map_eval(task->subset.map, + ucc_knomial_pattern_get_extra(p, rank)); + UCPCHECK_GOTO( + ucc_tl_ucp_send_nb(rbuf, data_size, mem_type, peer, team, task), + task, out); + goto UCC_KN_PHASE_PROXY; + } else { + goto completion; + } + +UCC_KN_PHASE_PROXY: + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + SAVE_STATE(UCC_KN_PHASE_PROXY); + return; + } + +completion: + ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); + task->super.status = UCC_OK; + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_kn_done", 0); +UCC_KN_PHASE_COMPLETE: /* unused label */ +out: + return; +} + +ucc_status_t ucc_tl_ucp_allreduce_knomial_am_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num; + ucc_rank_t rank = task->subset.myrank; + ucc_memory_type_t mem_type = TASK_ARGS(task).dst.info.mem_type; + size_t count = TASK_ARGS(task).dst.info.count; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t data_size = count * ucc_dt_size(dt); + ucc_mrange_uint_t *p = &team->cfg.allreduce_kn_radix; + ucc_kn_radix_t cfg_radix; + ucc_status_t status; + + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_kn_start", 0); + task->allreduce_kn.phase = UCC_KN_PHASE_INIT; + ucc_assert(UCC_IS_INPLACE(TASK_ARGS(task)) || + (TASK_ARGS(task).src.info.mem_type == mem_type)); + cfg_radix = ucc_tl_ucp_get_radix_from_range(team, data_size, + mem_type, p); + ucc_knomial_pattern_init(size, rank, ucc_min(cfg_radix, size), + &task->allreduce_kn.p); + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + status = + ucc_coll_task_get_executor(&task->super, &task->allreduce_kn.executor); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +ucc_status_t ucc_tl_ucp_allreduce_knomial_am_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *tl_team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_task_t *task; + + task = ucc_tl_ucp_init_task(coll_args, tl_team); + if (!task) { + return UCC_ERR_NO_MEMORY; + } + + task->super.post = ucc_tl_ucp_allreduce_knomial_am_start; + task->super.progress = ucc_tl_ucp_allreduce_knomial_am_progress; + task->super.finalize = ucc_tl_ucp_allreduce_knomial_am_finalize; + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + + *task_h = &task->super; + return UCC_OK; +} diff --git a/src/components/tl/ucp/coll_plugins/example/example_ctx.c b/src/components/tl/ucp/coll_plugins/example/example_ctx.c new file mode 100644 index 0000000000..ee4d9700cf --- /dev/null +++ b/src/components/tl/ucp/coll_plugins/example/example_ctx.c @@ -0,0 +1,166 @@ +/** + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "config.h" +#include "example_ctx.h" +#include "core/ucc_progress_queue.h" +#include "coll_patterns/recursive_knomial.h" +#include "coll_score/ucc_coll_score.h" +#include "utils/ucc_math.h" + +ucc_tl_coll_plugin_iface_t ucc_tlcp_ucp_example; + +#define CONFIG(_lib) ((ucc_tlcp_ucp_example_config_t*)((_lib)->tlcp_configs[ucc_tlcp_ucp_example.id])) + +static ucc_config_field_t ucc_tlcp_ucp_example_table[] = { + {"TLCP_EXAMPLE_TUNE", "", "Collective score modifier", + ucc_offsetof(ucc_tlcp_ucp_example_config_t, score_str), UCC_CONFIG_TYPE_STRING}, + + {NULL}}; + +static ucs_config_global_list_entry_t ucc_tlcp_ucp_example_cfg_entry = +{ + .name = "TLCP_EXAMPLE", + .prefix = "TL_UCP_", + .table = ucc_tlcp_ucp_example_table, + .size = sizeof(ucc_tlcp_ucp_example_config_t) +}; + +UCC_CONFIG_REGISTER_TABLE_ENTRY(&ucc_tlcp_ucp_example_cfg_entry, + &ucc_config_global_list); + +#define UCC_TLCP_UCP_EXAMPLE_SCORE 100 + +ucc_status_t ucc_tlcp_ucp_example_get_scores(ucc_base_team_t *tl_team, + ucc_coll_score_t **score_p) +{ + ucc_tl_ucp_team_t *team = ucc_derived_of(tl_team, ucc_tl_ucp_team_t); + ucc_tl_ucp_lib_t *lib = UCC_TL_UCP_TEAM_LIB(team); + const char *score_str; + ucc_coll_score_t *score; + ucc_status_t status; + ucc_memory_type_t mt = UCC_MEMORY_TYPE_HOST; + /* There can be a different logic for different coll_type/mem_type. + Right now just init everything the same way. */ + status = ucc_coll_score_alloc(&score); + if (UCC_OK != status) { + tl_error(lib, "failed to alloc score"); + return status; + } + status = ucc_coll_score_add_range(score, UCC_COLL_TYPE_ALLREDUCE, + UCC_MEMORY_TYPE_HOST, + 0, 4096, UCC_TLCP_UCP_EXAMPLE_SCORE, + ucc_tl_ucp_allreduce_knomial_am_init, + tl_team); + if (UCC_OK != status) { + tl_error(lib, "failed to add range"); + return status; + } + score_str = CONFIG(lib)->score_str; + if (strlen(score_str) > 0) { + + status = ucc_coll_score_update_from_str(score_str, score, + UCC_TL_TEAM_SIZE(team), + ucc_tl_ucp_allreduce_knomial_am_init, + &team->super.super, + UCC_TLCP_UCP_EXAMPLE_SCORE, + NULL, &mt, 1); + if (status == UCC_ERR_INVALID_PARAM) { + /* User provided incorrect input - try to proceed */ + status = UCC_OK; + } + } + *score_p = score; + return status; +} + +ucs_status_t ucc_tlcp_ucp_example_am_recv_handler(void *arg, const void *header, + size_t header_length, + void *data, size_t length, + const ucp_am_recv_param_t *param) +{ + ucc_tlcp_ucp_example_context_t *ctx = (ucc_tlcp_ucp_example_context_t *)arg; + ucc_tlcp_ucp_example_am_msg_t *entry; + + uint64_t *tag = (uint64_t*)header; + + entry = ucc_malloc(sizeof(ucc_tlcp_ucp_example_am_msg_t)); + ucc_assert(header_length == 8); + if (!entry) { + ucc_error("failed to allocate %zd bytes for am entry", + sizeof(*entry)); + return UCS_ERR_NO_MEMORY; + } + entry->tag = *tag; + entry->msg = data; + ucc_list_add_tail(&ctx->am_list, &entry->list_elem); + return UCS_INPROGRESS; +} + +ucc_status_t ucc_tlcp_ucp_example_context_create(const ucc_base_context_params_t *params, + const ucc_base_config_t *config, + ucc_base_context_t *tl_ctx, + void **plugin_ctx) +{ + ucc_tl_ucp_lib_t *lib = ucc_derived_of(tl_ctx->lib, + ucc_tl_ucp_lib_t); + ucc_tl_ucp_context_t *tl_ucp_ctx = ucc_derived_of(tl_ctx, + ucc_tl_ucp_context_t); + ucc_tlcp_ucp_example_context_t *ctx; + ucc_status_t status; + ucs_status_t ucs_status; + ucp_am_handler_param_t am_handler_param; + + ctx = ucc_malloc(sizeof(ucc_tlcp_ucp_example_context_t), + "tlcp_ucp_example_context"); + if (!ctx) { + tl_error(lib, "failed to alloc memory for plugin context"); + return UCC_ERR_NO_MEMORY; + } + + tl_ucp_ctx = ucc_derived_of(tl_ctx, ucc_tl_ucp_context_t); + am_handler_param.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | + UCP_AM_HANDLER_PARAM_FIELD_FLAGS | + UCP_AM_HANDLER_PARAM_FIELD_CB | + UCP_AM_HANDLER_PARAM_FIELD_ARG; + am_handler_param.id = 1; + am_handler_param.flags = UCP_AM_FLAG_WHOLE_MSG | + UCP_AM_FLAG_PERSISTENT_DATA; + am_handler_param.cb = ucc_tlcp_ucp_example_am_recv_handler; + am_handler_param.arg = ctx; + + ucs_status = ucp_worker_set_am_recv_handler(tl_ucp_ctx->worker.ucp_worker, + &am_handler_param); + if (ucs_status != UCS_OK) { + tl_error(lib, "failed to set am recv handler"); + status = ucs_status_to_ucc_status(ucs_status); + goto free_ctx; + } + ucc_list_head_init(&ctx->am_list); + + *plugin_ctx = ctx; + return UCC_OK; +free_ctx: + ucc_free(ctx); + return status; +} + +ucc_status_t ucc_tlcp_ucp_example_context_destroy(ucc_base_context_t *tl_ctx, + void *plugin_ctx) +{ + ucc_free(plugin_ctx); + return UCC_OK; +} + +ucc_tl_coll_plugin_iface_t ucc_tlcp_ucp_example = { + .super.name = "tl_ucp_example", + .super.score = UCC_TLCP_UCP_EXAMPLE_SCORE, + .config.table = ucc_tlcp_ucp_example_table, + .config.size = sizeof(ucc_tlcp_ucp_example_config_t), + .get_scores = ucc_tlcp_ucp_example_get_scores, + .context_create = ucc_tlcp_ucp_example_context_create, + .context_destroy = ucc_tlcp_ucp_example_context_destroy, +}; diff --git a/src/components/tl/ucp/coll_plugins/example/example_ctx.h b/src/components/tl/ucp/coll_plugins/example/example_ctx.h new file mode 100644 index 0000000000..28f9fd4ee8 --- /dev/null +++ b/src/components/tl/ucp/coll_plugins/example/example_ctx.h @@ -0,0 +1,31 @@ +/** + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_TLCP_UCP_EXAMPLE_CTX_H_ +#define UCC_TLCP_UCP_EXAMPLE_CTX_H_ + +#include "components/tl/ucp/tl_ucp.h" +#include "components/tl/ucp/tl_ucp_coll.h" + +typedef struct ucc_tlcp_ucp_example_config { + char *score_str; +} ucc_tlcp_ucp_example_config_t; + +typedef struct ucc_tlcp_ucp_example_am_msg { + ucc_list_link_t list_elem; + uint64_t tag; + void *msg; +} ucc_tlcp_ucp_example_am_msg_t; + +typedef struct ucc_tlcp_ucp_example_context { + ucc_list_link_t am_list; +} ucc_tlcp_ucp_example_context_t; + +ucc_status_t ucc_tl_ucp_allreduce_knomial_am_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *tl_team, + ucc_coll_task_t **task_h); + +#endif diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index e00109ad95..6dd841e830 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -272,7 +272,19 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, } ucc_free(prefix); prefix = NULL; + ucc_component_framework_t *plugins = &UCC_TL_CTX_IFACE(&self->super)->coll_plugins; + ucc_tl_coll_plugin_iface_t *tlcp; + if (plugins->n_components != 0) { + if (plugins->n_components > 1) { + tl_warn(self->super.super.lib, + "multiple plugins are not supported"); + } + tlcp = ucc_derived_of(plugins->components[0], + ucc_tl_coll_plugin_iface_t); + tlcp->context_create(params, config, &self->super.super, + &self->super.coll_plugin_context); + } tl_debug(self->super.super.lib, "initialized tl context: %p", self); return UCC_OK; @@ -376,7 +388,18 @@ static inline void ucc_tl_ucp_worker_cleanup(ucc_tl_ucp_worker_t worker) UCC_CLASS_CLEANUP_FUNC(ucc_tl_ucp_context_t) { + ucc_component_framework_t *plugins = &UCC_TL_CTX_IFACE(&self->super)->coll_plugins; + ucc_tl_coll_plugin_iface_t *tlcp; + tl_debug(self->super.super.lib, "finalizing tl context: %p", self); + + if (plugins->n_components != 0) { + tlcp = ucc_derived_of(plugins->components[0], + ucc_tl_coll_plugin_iface_t); + tlcp->context_destroy(&self->super.super, + self->super.coll_plugin_context); + } + if (self->remote_info) { ucc_tl_ucp_rinfo_destroy(self); } diff --git a/src/utils/ucc_dt_reduce.h b/src/utils/ucc_dt_reduce.h index c3caee4c25..6b18b25467 100644 --- a/src/utils/ucc_dt_reduce.h +++ b/src/utils/ucc_dt_reduce.h @@ -69,4 +69,33 @@ static inline ucc_status_t ucc_dt_reduce(void *src1, void *src2, void *dst, alpha, exec, task); } +static inline ucc_status_t +ucc_dt_reduce_vec(void **srcs, void *dst, size_t n_vectors, + size_t count, ucc_datatype_t dt, + ucc_coll_args_t *args, uint16_t flags, double alpha, + ucc_ee_executor_t *exec, ucc_ee_executor_task_t **task) +{ + ucc_ee_executor_task_args_t eargs; + + if (count == 0) { + *task = NULL; + return UCC_OK; + } + if (!UCC_DT_IS_PREDEFINED(dt)) { + return UCC_ERR_NOT_IMPLEMENTED; + } else { + eargs.flags = flags; + eargs.task_type = UCC_EE_EXECUTOR_TASK_REDUCE; + eargs.reduce.alpha = alpha; + eargs.reduce.count = count; + eargs.reduce.dst = dst; + eargs.reduce.dt = dt; + eargs.reduce.n_srcs = n_vectors; + eargs.reduce.op = args->op; + eargs.reduce.srcs_ext = srcs; + + return ucc_ee_executor_task_post(exec, &eargs, task); + } +} + #endif