Skip to content

Commit

Permalink
some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetyusufoglu committed Jul 19, 2024
1 parent d8a8ca1 commit b49a304
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
5 changes: 1 addition & 4 deletions example/matrixAddWithMdspan/src/matrixAddMdSpan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ auto example(TAccTag const&) -> int
auto mdDevB = alpaka::experimental::getMdSpan(bufDevB);
auto mdDevC = alpaka::experimental::getMdSpan(bufDevC);


// Let alpaka calculate good block and grid sizes given our full problem extent.
auto const workDiv = alpaka::getValidWorkDiv<Acc>(
devAcc,
Expand All @@ -142,11 +141,9 @@ auto example(TAccTag const&) -> int
// Execute the kernel
alpaka::exec<Acc>(queue, workDiv, MatrixAddKernel{}, mdDevA, mdDevB, mdDevC);

// Wait for the kernel to finish
alpaka::wait(queue);

// Copy result back to host
alpaka::memcpy(queue, bufHostC, bufDevC);
// This wait is not necessary if the queue is a blocking queue
alpaka::wait(queue);

// Verify the result
Expand Down
13 changes: 7 additions & 6 deletions example/matrixMulWithMdspan/src/matrixMulMdSpan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ struct MatrixMulKernel
//! \param C Output matrix where the result of A * B will be stored
//! \param K The shared dimension between A and B
template<typename TAcc, typename MdSpan>
ALPAKA_FN_ACC void operator()(TAcc const& acc, MdSpan A, MdSpan B, MdSpan C, Idx K) const
ALPAKA_FN_ACC void operator()(TAcc const& acc, MdSpan A, MdSpan B, MdSpan C) const
{
// compile time check
// compile time checks
static_assert(isMdspan<MdSpan>::value, "The type MdSpan should be an std mdspan");

// A is MxK and B is KxN
auto const K = static_cast<Idx>(A.extent(1));

auto const i = alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0];
auto const j = alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[1];

Expand Down Expand Up @@ -146,13 +149,11 @@ auto example(TAccTag const&) -> int
alpaka::GridBlockExtentSubDivRestrictions::Unrestricted);

// Execute the kernel
alpaka::exec<Acc>(queue, workDiv, MatrixMulKernel{}, mdDevA, mdDevB, mdDevC, K);

// Wait for the kernel to finish
alpaka::wait(queue);
alpaka::exec<Acc>(queue, workDiv, MatrixMulKernel{}, mdDevA, mdDevB, mdDevC);

// Copy result back to host
alpaka::memcpy(queue, bufHostC, bufDevC);
// This wait is not necessary if the queue is a blocking queue
alpaka::wait(queue);

// Verify the result
Expand Down

0 comments on commit b49a304

Please sign in to comment.