Skip to content

UCC Allreduce example

valentin petrov edited this page Mar 11, 2022 · 2 revisions

The code snippet below demonstrates how the UCC API can be used to execute an Allreduce collective operation over a group of process. The code below is an MPI based application. MPI is only used to bootstrap the job: spawn processes and implement OOB (out-of-band) allgather exchange among the processes used for UCC wire-up.

Main steps to execute UCC allreduce:

  1. Read UCC lib configuration
  2. Initialize UCC library
  3. Read UCC context configuration
  4. Initialize UCC context
  5. Initialize UCC team
  6. Fill collective descriptor and initialize coll request
  7. Post collective
  8. Test for completion and progress UCC
  9. Clean up coll request
  10. Cleanup UCC

If UCC is compiled and installed into ${UCC_PATH} and MPI (mpicc/mpirun) is available in PATH then the cmd line below can be used:

mpicc ucc_allreduce.c -g -o ucc_allreduce -I${UCC_PATH}/include -L${UCC_PATH}/lib -lucc -Wl,-rpath="${UCC_PATH}/lib"
mpirun -x UCC_TLS=ucp -np 4 ./ucc_allreduce
#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <ucc/api/ucc.h>

#define STR(x) #x
#define UCC_CHECK(_call)                                            \
    if (UCC_OK != (_call)) {                                        \
        fprintf(stderr, "*** UCC TEST FAIL: %s\n", STR(_call));     \
        MPI_Abort(MPI_COMM_WORLD, -1);                              \
    }

static ucc_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen,
                                  void *coll_info, void **req)
{
    MPI_Comm    comm = (MPI_Comm)coll_info;
    MPI_Request request;

    MPI_Iallgather(sbuf, msglen, MPI_BYTE, rbuf, msglen, MPI_BYTE, comm,
                   &request);
    *req = (void *)request;
    return UCC_OK;
}

static ucc_status_t oob_allgather_test(void *req)
{
    MPI_Request request = (MPI_Request)req;
    int         completed;

    MPI_Test(&request, &completed, MPI_STATUS_IGNORE);
    return completed ? UCC_OK : UCC_INPROGRESS;
}

static ucc_status_t oob_allgather_free(void *req)
{
    return UCC_OK;
}

/* Creates UCC team for a group of processes represented by MPI
   communicator. UCC API provides different ways to create a team,
   one of them is to use out-of-band (OOB) allgather provided by
   the calling runtime. */
static ucc_team_h create_ucc_team(MPI_Comm comm, ucc_context_h ctx)
{
    int               rank, size;
    ucc_team_h        team;
    ucc_team_params_t team_params;
    ucc_status_t      status;

    MPI_Comm_rank(comm, &rank);
    MPI_Comm_size(comm, &size);

    team_params.mask          = UCC_TEAM_PARAM_FIELD_OOB;
    team_params.oob.allgather = oob_allgather;
    team_params.oob.req_test  = oob_allgather_test;
    team_params.oob.req_free  = oob_allgather_free;
    team_params.oob.coll_info = (void*)comm;
    team_params.oob.n_oob_eps = size;
    team_params.oob.oob_ep    = rank;

    UCC_CHECK(ucc_team_create_post(&ctx, 1, &team_params, &team));
    while (UCC_INPROGRESS == (status = ucc_team_create_test(team))) {
        UCC_CHECK(ucc_context_progress(ctx));
    };
    if (UCC_OK != status) {
        fprintf(stderr, "failed to create ucc team\n");
        MPI_Abort(MPI_COMM_WORLD, status);
    }
    return team;
}

int main (int argc, char **argv) {
    ucc_lib_config_h     lib_config;
    ucc_context_config_h ctx_config;
    int                  rank, size, i;
    ucc_team_h           team;
    ucc_context_h        ctx;
    ucc_lib_h            lib;
    size_t               msglen;
    size_t               count;
    int                 *sbuf, *rbuf;
    ucc_coll_req_h       req;
    ucc_coll_args_t      args;

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    /* Init ucc library */
    ucc_lib_params_t lib_params = {
        .mask        = UCC_LIB_PARAM_FIELD_THREAD_MODE,
        .thread_mode = UCC_THREAD_SINGLE
    };
    UCC_CHECK(ucc_lib_config_read(NULL, NULL, &lib_config));
    UCC_CHECK(ucc_init(&lib_params, lib_config, &lib));
    ucc_lib_config_release(lib_config);

    /* Init ucc context for a specified UCC_TEST_TLS */
    ucc_context_params_t ctx_params = {
        .mask             = UCC_CONTEXT_PARAM_FIELD_OOB,
        .oob.allgather    = oob_allgather,
        .oob.req_test     = oob_allgather_test,
        .oob.req_free     = oob_allgather_free,
        .oob.coll_info    = (void*)MPI_COMM_WORLD,
        .oob.n_oob_eps    = size,
        .oob.oob_ep       = rank
    };

    UCC_CHECK(ucc_context_config_read(lib, NULL, &ctx_config));
    UCC_CHECK(ucc_context_create(lib, &ctx_params, ctx_config, &ctx));
    ucc_context_config_release(ctx_config);

    team = create_ucc_team(MPI_COMM_WORLD, ctx);

    count = argc > 1 ? atoi(argv[1]) : 1;
    msglen = count * sizeof(int);

    sbuf = malloc(msglen);
    rbuf = malloc(msglen);    
    for (i = 0; i < count; i++) {
        sbuf[i] = rank + 1;
        rbuf[i] = 0;
    }
    
    args.mask              = 0;
    args.coll_type         = UCC_COLL_TYPE_ALLREDUCE;
    args.src.info.buffer   = sbuf;
    args.src.info.count    = count;
    args.src.info.datatype = UCC_DT_INT32;
    args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
    args.dst.info.buffer   = rbuf;
    args.dst.info.count    = count;
    args.dst.info.datatype = UCC_DT_INT32;
    args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST;
    args.op                = UCC_OP_SUM;

    UCC_CHECK(ucc_collective_init(&args, &req, team));
    UCC_CHECK(ucc_collective_post(req));    
    while (UCC_INPROGRESS == ucc_collective_test(req)) {
        UCC_CHECK(ucc_context_progress(ctx));
    }
    ucc_collective_finalize(req);

    /* Check result */
    int sum = ((size + 1) * size) / 2;
    for (i = 0; i < count; i++) {
        if (rbuf[i] != sum) {
            printf("ERROR at rank %d, pos %d, value %d, expected %d\n", rank, i, rbuf[i], sum);
            break;
        }
    }

    /* Cleanup UCC */
    UCC_CHECK(ucc_team_destroy(team));
    UCC_CHECK(ucc_context_destroy(ctx));
    UCC_CHECK(ucc_finalize(lib));

    MPI_Finalize();

    free(sbuf);
    free(rbuf);
    return 0;
}

Clone this wiki locally