Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Matrix] Add joint matrix query for CUDA and HIP backends #12075

Merged
merged 22 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 333 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,339 @@ struct matrix_params<
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

//////////////////////////////////////////////
/// AMD Matrix Cores - GFX90A architecture ///
//////////////////////////////////////////////

template <typename Ta, typename Tc>
constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN,
size_t sK) {
Copy link
Contributor

@mmoadeli mmoadeli Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why using sM, sN and sK to represent dimensions. I appreciate you followed them for consistency, though.

return ((std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
(sM == 16 && sN == 16 && sK == 4)));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use only one instance of ((sM == 32 && sN == 32 && sK == 8) || (sM == 16 && sN == 16 && sK == 16))) to be &&ed with ORed std::is_same_vs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just tried it this way but the code now looks quite unreadable due to the one more OR in that case:
This would take shape of: the above conditions ORed with the extra case for double:

  if ((((sM == 32 && sN == 32 && sK == 8) ||
        (sM == 16 && sN == 16 && sK == 16)) &&
           (std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
       (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
       (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)) ||
      ((sM == 16 && sN == 16 && sK == 4) &&
       (std::is_same_v<Ta, double> && std::is_same_v<Tc, double>)))

btw, this is already after applying clang-format. I think for the sake of readability this should be left as is.

template <typename Ta, typename Tc>
constexpr bool are_types_valid_amd_gfx90a() {
return ((std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double>));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may return the statement without if / else.

// Default-values query:
// Specialization for when only types are given, need to query only sizes
template <typename Ta, typename Tb, typename Tc, typename Td>
struct matrix_params<
architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, 0, 0, 0,
typename std::enable_if_t<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td>)>> {
static_assert(
(are_types_valid_amd_gfx90a<Ta, Tc>()),
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
"Invalid types for AMD gfx90a, supported types are half, float, "
"int8_t, int32_t, double and bfloat16 ");

Copy link
Contributor

@mmoadeli mmoadeli Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat16 is used in DPC++ code for instance in joint_matrix_hip_gfx90a.cpp test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to bfloat16

// Default sizes for AMD gfx90a were chosen to represent a square matrix
static constexpr std::size_t M = 16;
static constexpr std::size_t N = 16;
static constexpr std::size_t K = ((sizeof(Ta) == 8) ? 16 : 4);

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

// Validation query
// Specialization when both types and sizes are given
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
size_t sN, size_t sK>
struct matrix_params<
architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, sM, sN, sK,
typename std::enable_if_t<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td> && sM != 0 &&
sN != 0 && sK != 0)>> {
static_assert(
is_combination_valid_amd_gfx90a<Ta, Tc>(sM, sN, sK),
"Invalid parameters for AMD gfx90a, query valid combinations "
"using: "
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");

static constexpr std::size_t M = sM;
static constexpr std::size_t N = sN;
static constexpr std::size_t K = sK;

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

/////////////////////////////////////////////////
/// CUDA Tensor Cores - sm70, sm72 and sm80 ///
/////////////////////////////////////////////////

template <typename Ta, typename Tc, typename Td>
constexpr bool are_types_valid_cuda_sm70() {
return (((std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Think it would be better to just call are_types_valid_cuda_sm70 here instead of repeating the logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that could definitely make use of are_types_valid. Changed it now

std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
std::is_same_v<Td, half>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, half>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
std::is_same_v<Td, float>)));
}
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved

template <typename Ta, typename Tc, typename Td>
constexpr bool are_types_valid_cuda_sm72() {
return (((std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
std::is_same_v<Td, int32_t>) ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tc, int32_t> &&
std::is_same_v<Td, int32_t>)));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename Ta, typename Tc, typename Td>
constexpr bool are_types_valid_cuda_sm80() {
return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
std::is_same_v<Td, double>));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename Ta, typename Tc, typename Td>
constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK) {
return (are_types_valid_cuda_sm70<Ta, Tc, Td>() &&
((sM == 8 && sN == 32 && sK == 16) ||
(sM == 16 && sN == 16 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16)));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename Ta, typename Tc, typename Td>
constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK) {
return (are_types_valid_cuda_sm72<Ta, Tc, Td>() &&
((sM == 8 && sN == 32 && sK == 16) ||
(sM == 16 && sN == 16 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16)));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename Ta, typename Tc, typename Td>
constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK) {
return (((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>)&&(sM == 16 && sN == 16 && sK == 8)) ||
((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>)&&((sM == 16 && sN == 16 && sK == 16) ||
(sM == 8 && sN == 32 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16))) ||
((std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
std::is_same_v<Td, double>)&&(sM == 8 && sN == 8 && sK == 4)));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
}

// Default-values query (nvidia sm70):
// Specialization for when only types are given, need to query only sizes
template <typename Ta, typename Tb, typename Tc, typename Td>
struct matrix_params<
architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, 0, 0, 0,
typename std::enable_if_t<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb>)>> {
static_assert(
(are_types_valid_cuda_sm70<Ta, Tc, Td>()),
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
"Invalid types for nvidia sm70, supported types are half and float ");

// Default sizes for nvidia sm70 were chosen to represent a square matrix
static constexpr std::size_t M = 16;
static constexpr std::size_t N = 16;
static constexpr std::size_t K = 16;

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

// Default-values query (nvidia sm72):
// Specialization for when only types are given, need to query only sizes
template <typename Ta, typename Tb, typename Tc, typename Td>
struct matrix_params<
architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, 0, 0, 0,
typename std::enable_if<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb>)>::type> {
static_assert(
(are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
are_types_valid_cuda_sm72<Ta, Tc, Td>()),
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
"Invalid types for nvidia sm72, supported types are half, float "
"int8_t, uint8_t and int32_t ");

static constexpr std::size_t M = 16;
static constexpr std::size_t N = 16;
static constexpr std::size_t K = 16;

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

// Default-values query (nvidia sm80):
// Specialization for when only types are given, need to query only sizes
template <typename Ta, typename Tb, typename Tc, typename Td>
struct matrix_params<
architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, 0, 0, 0,
typename std::enable_if_t<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb>)>> {
static_assert(
(are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
are_types_valid_cuda_sm72<Ta, Tc, Td>() ||
are_types_valid_cuda_sm80<Ta, Tc, Td>()),
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
"Invalid types for nvidia sm80, supported types are half, float "
"int8_t, uint8_t, int32_t, double, tf32 and bfloat16 ");

static constexpr std::size_t M = ((sizeof(Ta) == 8) ? 8 : 16);
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
static constexpr std::size_t N = ((sizeof(Ta) == 8) ? 8 : 16);
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
static constexpr std::size_t K =
((std::is_same_v<Ta, precision::tf32>) ? 8
: ((sizeof(Ta) == 8) ? 4 : 16));
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

// Validation query (nvidia sm70)
// Specialization when both types and sizes are given
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
size_t sN, size_t sK>
struct matrix_params<
architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, sM, sN, sK,
typename std::enable_if_t<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
static_assert(
is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK),
"Invalid parameters for nvidia sm70, query valid combinations "
"using: "
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");

static constexpr std::size_t M = sM;
static constexpr std::size_t N = sN;
static constexpr std::size_t K = sK;

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

// Validation query (nvidia sm72)
// Specialization when both types and sizes are given
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
size_t sN, size_t sK>
struct matrix_params<
architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, sM, sN, sK,
typename std::enable_if_t<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
static_assert(
(is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK)),
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
"Invalid parameters for nvidia sm72, query valid combinations "
"using: "
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");

static constexpr std::size_t M = sM;
static constexpr std::size_t N = sN;
static constexpr std::size_t K = sK;

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

// Validation query (nvidia sm80)
// Specialization when both types and sizes are given
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
size_t sN, size_t sK>
struct matrix_params<
architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, sM, sN, sK,
typename std::enable_if_t<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
static_assert(
(is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM, sN, sK)),
konradkusiak97 marked this conversation as resolved.
Show resolved Hide resolved
"Invalid parameters for nvidia sm80, query valid combinations "
"using: "
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");

static constexpr std::size_t M = sM;
static constexpr std::size_t N = sN;
static constexpr std::size_t K = sK;

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
template <typename Group, layout Layout>
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;
template <typename Group>
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;
template <typename Group>
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

} // namespace experimental::matrix
} // namespace oneapi
} // namespace ext
Expand Down
Loading