Skip to content

Commit

Permalink
TL/CUDA: wip multistep
Browse files Browse the repository at this point in the history
  • Loading branch information
ikryukov authored and Ilya Kryukov committed Jul 4, 2024
1 parent 7fc7d7b commit 9333f41
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
66 changes: 51 additions & 15 deletions src/components/tl/cuda/bcast/bcast_linear.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ ucc_status_t ucc_tl_cuda_bcast_linear_setup_test(ucc_tl_cuda_task_t *task)
return ucc_tl_cuda_shm_barrier_test(UCC_TL_TEAM_RANK(team), task->bar);
}

static inline size_t get_scratch_size(ucc_tl_cuda_team_t *team,
ucc_datatype_t dt)
{
size_t dt_size = ucc_dt_size(dt);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);

ucc_assert((dt_size > 0) && (tsize > 0));

return UCC_TL_CUDA_TEAM_LIB(team)->cfg.scratch_size;
}

static inline ucc_status_t ecopy(void *dst, void *src, size_t size,
ucc_ee_executor_t *exec,
ucc_ee_executor_task_t **etask)
Expand Down Expand Up @@ -90,6 +101,7 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
ucc_tl_cuda_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
ucc_datatype_t dt = task->bcast_linear.dt;
ucc_status_t st;
(void)team;
(void)st;
Expand Down Expand Up @@ -136,14 +148,23 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
break;
}

size_t scratch_size = get_scratch_size(team, dt);
size_t chunk_size = task->bcast_linear.step < task->bcast_linear.num_steps
? ucc_min(scratch_size, task->bcast_linear.size)
: task->bcast_linear.size -
(task->bcast_linear.step - 1) * scratch_size;
size_t offset_buff = task->bcast_linear.step * scratch_size;

// ucc_info("chunk_size: %ld", chunk_size);

if (trank == task->bcast_linear.root) {
// fall-through between cases is intentional
switch (task->bcast_linear.stage) {
case STAGE_COPY:
// copy from src buffer to scratch
dbuf = TASK_SCRATCH(task, trank);
sbuf = task->bcast_linear.sbuf;
status = ecopy(dbuf, sbuf, task->bcast_linear.size, exec,
sbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff);
status = ecopy(dbuf, sbuf, chunk_size, exec,
&task->bcast_linear.exec_task);
task->bcast_linear.stage = STAGE_WAIT_COPY;
break;
Expand All @@ -156,7 +177,8 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
task->bcast_linear.exec_task = NULL;
// signal others
++task->bcast_linear.step;
set_rank_step(task, task->bcast_linear.root, task->bcast_linear.step, 0);
set_rank_step(task, task->bcast_linear.root,
task->bcast_linear.step, 0);
task->bcast_linear.stage = STAGE_WAIT_ALL;
}
}
Expand All @@ -173,8 +195,11 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
}
task->bcast_linear.stage = STAGE_COPY;
// ucc_info("all others ready for next step");
// TODO: remove
task->bcast_linear.stage = STAGE_DONE;
if (task->bcast_linear.stage < task->bcast_linear.num_steps) {
task->bcast_linear.stage = STAGE_COPY;
} else {
task->bcast_linear.stage = STAGE_DONE;
}
break;
case STAGE_DONE:
task->super.status = UCC_OK;
Expand All @@ -196,11 +221,13 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
}
break;
case STAGE_CLIENT_COPY:
dbuf = task->bcast_linear.sbuf;
sbuf = TASK_SCRATCH(task,
task->bcast_linear.root); // need to copy from root's scratch buffer
status = ecopy(dbuf, sbuf, task->bcast_linear.size, exec,
&task->bcast_linear.exec_task);
dbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff);
sbuf = TASK_SCRATCH(
task,
task->bcast_linear
.root); // need to copy from root's scratch buffer
status = ecopy(dbuf, sbuf, chunk_size, exec,
&task->bcast_linear.exec_task);
task->bcast_linear.stage = STAGE_CLIENT_COPY_WAIT;
break;
case STAGE_CLIENT_COPY_WAIT:
Expand All @@ -212,8 +239,14 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
task->bcast_linear.exec_task = NULL;
++task->bcast_linear.step;
set_rank_step(task, trank, task->bcast_linear.step, 0);
task->bcast_linear.stage =
STAGE_DONE; // TODO: just for debug
// task->bcast_linear.stage =
// STAGE_DONE; // TODO: just for debug
if (task->bcast_linear.stage <
task->bcast_linear.num_steps) {
task->bcast_linear.stage = STAGE_COPY;
} else {
task->bcast_linear.stage = STAGE_DONE;
}
}
}
break;
Expand All @@ -235,7 +268,6 @@ ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task)
ucc_datatype_t dt = task->bcast_linear.dt;

(void)tsize;
(void)args;
(void)dt;

task->bcast_linear.stage = STAGE_SYNC;
Expand All @@ -244,8 +276,12 @@ ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task)
args->src.info.count);

task->bcast_linear.size = ucc_dt_size(dt) * args->src.info.count;
size_t scratch_size = get_scratch_size(team, dt);
task->bcast_linear.num_steps =
ucc_div_round_up(task->bcast_linear.size, scratch_size);

ucc_info("bcast buffer size: %ld", task->bcast_linear.size);
ucc_info("bcast buffer size: %ld, num_steps: %d", task->bcast_linear.size,
task->bcast_linear.num_steps);

task->bcast_linear.sbuf = args->src.info.buffer;
task->bcast_linear.step = 0;
Expand Down Expand Up @@ -275,7 +311,7 @@ ucc_status_t ucc_tl_cuda_bcast_linear_init(ucc_base_coll_args_t *coll_args,
}

task->bcast_linear.root = coll_args->args.root;
task->bcast_linear.dt = coll_args->args.src.info.datatype;
task->bcast_linear.dt = coll_args->args.src.info.datatype;
ucc_info("bcast init with dt: %s", ucc_datatype_str(task->bcast_linear.dt));

task->bcast_linear.sbuf = coll_args->args.src.info.buffer;
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/cuda/tl_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ struct ucc_tl_cuda_task {
ucc_datatype_t dt;
ucc_rank_t root;
size_t size;
int num_steps;
ucc_ee_executor_task_t *exec_task;
} bcast_linear;
struct {
Expand Down

0 comments on commit 9333f41

Please sign in to comment.