From 6caea67298f1b4d54f180cc0cde6cebf30f4bd49 Mon Sep 17 00:00:00 2001 From: Ilya Kryukov Date: Thu, 4 Jul 2024 17:37:30 +0200 Subject: [PATCH] TL/CUDA: minor cleanup --- src/components/tl/cuda/bcast/bcast_linear.c | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/components/tl/cuda/bcast/bcast_linear.c b/src/components/tl/cuda/bcast/bcast_linear.c index ad55cfc1a3..67dd083043 100644 --- a/src/components/tl/cuda/bcast/bcast_linear.c +++ b/src/components/tl/cuda/bcast/bcast_linear.c @@ -62,14 +62,8 @@ ucc_status_t ucc_tl_cuda_bcast_linear_setup_test(ucc_tl_cuda_task_t *task) return ucc_tl_cuda_shm_barrier_test(UCC_TL_TEAM_RANK(team), task->bar); } -static inline size_t get_scratch_size(ucc_tl_cuda_team_t *team, - ucc_datatype_t dt) +static inline size_t get_raw_scratch_size(ucc_tl_cuda_team_t *team) { - size_t dt_size = ucc_dt_size(dt); - ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); - - ucc_assert((dt_size > 0) && (tsize > 0)); - return UCC_TL_CUDA_TEAM_LIB(team)->cfg.scratch_size; } @@ -101,8 +95,8 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) ucc_tl_cuda_team_t *team = TASK_TEAM(task); ucc_rank_t trank = UCC_TL_TEAM_RANK(team); ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); - ucc_datatype_t dt = task->bcast_linear.dt; - ucc_status_t st; + // ucc_datatype_t dt = task->bcast_linear.dt; + ucc_status_t st; (void)team; (void)st; ucc_ee_executor_t *exec; @@ -148,7 +142,7 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) break; } - size_t scratch_size = get_scratch_size(team, dt); + size_t scratch_size = get_raw_scratch_size(team); size_t chunk_size = task->bcast_linear.step < task->bcast_linear.num_steps ? ucc_min(scratch_size, task->bcast_linear.size) : task->bcast_linear.size - @@ -239,8 +233,6 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) task->bcast_linear.exec_task = NULL; ++task->bcast_linear.step; set_rank_step(task, trank, task->bcast_linear.step, 0); - // task->bcast_linear.stage = - // STAGE_DONE; // TODO: just for debug if (task->bcast_linear.step < task->bcast_linear.num_steps) { task->bcast_linear.stage = STAGE_WAIT_ROOT; @@ -276,7 +268,7 @@ ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task) args->src.info.count); task->bcast_linear.size = ucc_dt_size(dt) * args->src.info.count; - size_t scratch_size = get_scratch_size(team, dt); + size_t scratch_size = get_raw_scratch_size(team); task->bcast_linear.num_steps = ucc_div_round_up(task->bcast_linear.size, scratch_size);