-
Notifications
You must be signed in to change notification settings - Fork 738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCL][Matrix] Add joint matrix query for CUDA and HIP backends #12075
Changes from 17 commits
9416dc4
7ab986d
9cfcb8b
6202345
2a539f8
ca74cb4
17ac818
6160151
ab1264d
05e55a5
8fd994c
26c211e
b3225b4
d27b42d
56bd759
3f99e38
e716c9d
0bf4b6f
f2e1966
3abcc41
685163a
c06ae6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -482,6 +482,339 @@ struct matrix_params< | |
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
////////////////////////////////////////////// | ||
/// AMD Matrix Cores - GFX90A architecture /// | ||
////////////////////////////////////////////// | ||
|
||
template <typename Ta, typename Tc> | ||
constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN, | ||
size_t sK) { | ||
return ((std::is_same_v<Ta, half> && std::is_same_v<Tc, float> && | ||
((sM == 32 && sN == 32 && sK == 8) || | ||
(sM == 16 && sN == 16 && sK == 16))) || | ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> && | ||
((sM == 32 && sN == 32 && sK == 8) || | ||
(sM == 16 && sN == 16 && sK == 16))) || | ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> && | ||
((sM == 32 && sN == 32 && sK == 8) || | ||
(sM == 16 && sN == 16 && sK == 16))) || | ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> && | ||
(sM == 16 && sN == 16 && sK == 4))); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd use only one instance of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've just tried it this way but the code now looks quite unreadable due to the one more OR in that case: if ((((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16)) &&
(std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)) ||
((sM == 16 && sN == 16 && sK == 4) &&
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double>))) btw, this is already after applying clang-format. I think for the sake of readability this should be left as is. |
||
template <typename Ta, typename Tc> | ||
constexpr bool are_types_valid_amd_gfx90a() { | ||
return ((std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) || | ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) || | ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) || | ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double>)); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may return the statement without if / else. |
||
// Default-values query: | ||
// Specialization for when only types are given, need to query only sizes | ||
template <typename Ta, typename Tb, typename Tc, typename Td> | ||
struct matrix_params< | ||
architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, 0, 0, 0, | ||
typename std::enable_if_t<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td>)>> { | ||
static_assert( | ||
(are_types_valid_amd_gfx90a<Ta, Tc>()), | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Invalid types for AMD gfx90a, supported types are half, float, " | ||
"int8_t, int32_t, double and bfloat16 "); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to |
||
// Default sizes for AMD gfx90a were chosen to represent a square matrix | ||
static constexpr std::size_t M = 16; | ||
static constexpr std::size_t N = 16; | ||
static constexpr std::size_t K = ((sizeof(Ta) == 8) ? 16 : 4); | ||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
// Validation query | ||
// Specialization when both types and sizes are given | ||
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM, | ||
size_t sN, size_t sK> | ||
struct matrix_params< | ||
architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, sM, sN, sK, | ||
typename std::enable_if_t<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td> && sM != 0 && | ||
sN != 0 && sK != 0)>> { | ||
static_assert( | ||
is_combination_valid_amd_gfx90a<Ta, Tc>(sM, sN, sK), | ||
"Invalid parameters for AMD gfx90a, query valid combinations " | ||
"using: " | ||
"q.get_device().get_info<sycl::info::device::matrix::combinations>()"); | ||
|
||
static constexpr std::size_t M = sM; | ||
static constexpr std::size_t N = sN; | ||
static constexpr std::size_t K = sK; | ||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
///////////////////////////////////////////////// | ||
/// CUDA Tensor Cores - sm70, sm72 and sm80 /// | ||
///////////////////////////////////////////////// | ||
|
||
template <typename Ta, typename Tc, typename Td> | ||
constexpr bool are_types_valid_cuda_sm70() { | ||
return (((std::is_same_v<Ta, half> && std::is_same_v<Tc, float> && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Think it would be better to just call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that could definitely make use of |
||
std::is_same_v<Td, float>) || | ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> && | ||
std::is_same_v<Td, half>) || | ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, float> && | ||
std::is_same_v<Td, half>) || | ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> && | ||
std::is_same_v<Td, float>))); | ||
} | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
template <typename Ta, typename Tc, typename Td> | ||
constexpr bool are_types_valid_cuda_sm72() { | ||
return (((std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> && | ||
std::is_same_v<Td, int32_t>) || | ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tc, int32_t> && | ||
std::is_same_v<Td, int32_t>))); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <typename Ta, typename Tc, typename Td> | ||
constexpr bool are_types_valid_cuda_sm80() { | ||
return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> && | ||
std::is_same_v<Td, float>) || | ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> && | ||
std::is_same_v<Td, float>) || | ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> && | ||
std::is_same_v<Td, double>)); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <typename Ta, typename Tc, typename Td> | ||
constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK) { | ||
return (are_types_valid_cuda_sm70<Ta, Tc, Td>() && | ||
((sM == 8 && sN == 32 && sK == 16) || | ||
(sM == 16 && sN == 16 && sK == 16) || | ||
(sM == 32 && sN == 8 && sK == 16))); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <typename Ta, typename Tc, typename Td> | ||
constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK) { | ||
return (are_types_valid_cuda_sm72<Ta, Tc, Td>() && | ||
((sM == 8 && sN == 32 && sK == 16) || | ||
(sM == 16 && sN == 16 && sK == 16) || | ||
(sM == 32 && sN == 8 && sK == 16))); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <typename Ta, typename Tc, typename Td> | ||
constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK) { | ||
return (((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> && | ||
std::is_same_v<Td, float>)&&(sM == 16 && sN == 16 && sK == 8)) || | ||
((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> && | ||
std::is_same_v<Td, float>)&&((sM == 16 && sN == 16 && sK == 16) || | ||
(sM == 8 && sN == 32 && sK == 16) || | ||
(sM == 32 && sN == 8 && sK == 16))) || | ||
((std::is_same_v<Ta, double> && std::is_same_v<Tc, double> && | ||
std::is_same_v<Td, double>)&&(sM == 8 && sN == 8 && sK == 4))); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
// Default-values query (nvidia sm70): | ||
// Specialization for when only types are given, need to query only sizes | ||
template <typename Ta, typename Tb, typename Tc, typename Td> | ||
struct matrix_params< | ||
architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, 0, 0, 0, | ||
typename std::enable_if_t<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb>)>> { | ||
static_assert( | ||
(are_types_valid_cuda_sm70<Ta, Tc, Td>()), | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Invalid types for nvidia sm70, supported types are half and float "); | ||
|
||
// Default sizes for nvidia sm70 were chosen to represent a square matrix | ||
static constexpr std::size_t M = 16; | ||
static constexpr std::size_t N = 16; | ||
static constexpr std::size_t K = 16; | ||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
// Default-values query (nvidia sm72): | ||
// Specialization for when only types are given, need to query only sizes | ||
template <typename Ta, typename Tb, typename Tc, typename Td> | ||
struct matrix_params< | ||
architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, 0, 0, 0, | ||
typename std::enable_if<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb>)>::type> { | ||
static_assert( | ||
(are_types_valid_cuda_sm70<Ta, Tc, Td>() || | ||
are_types_valid_cuda_sm72<Ta, Tc, Td>()), | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Invalid types for nvidia sm72, supported types are half, float " | ||
"int8_t, uint8_t and int32_t "); | ||
|
||
static constexpr std::size_t M = 16; | ||
static constexpr std::size_t N = 16; | ||
static constexpr std::size_t K = 16; | ||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
// Default-values query (nvidia sm80): | ||
// Specialization for when only types are given, need to query only sizes | ||
template <typename Ta, typename Tb, typename Tc, typename Td> | ||
struct matrix_params< | ||
architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, 0, 0, 0, | ||
typename std::enable_if_t<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb>)>> { | ||
static_assert( | ||
(are_types_valid_cuda_sm70<Ta, Tc, Td>() || | ||
are_types_valid_cuda_sm72<Ta, Tc, Td>() || | ||
are_types_valid_cuda_sm80<Ta, Tc, Td>()), | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Invalid types for nvidia sm80, supported types are half, float " | ||
"int8_t, uint8_t, int32_t, double, tf32 and bfloat16 "); | ||
|
||
static constexpr std::size_t M = ((sizeof(Ta) == 8) ? 8 : 16); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
static constexpr std::size_t N = ((sizeof(Ta) == 8) ? 8 : 16); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
static constexpr std::size_t K = | ||
((std::is_same_v<Ta, precision::tf32>) ? 8 | ||
: ((sizeof(Ta) == 8) ? 4 : 16)); | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
// Validation query (nvidia sm70) | ||
// Specialization when both types and sizes are given | ||
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM, | ||
size_t sN, size_t sK> | ||
struct matrix_params< | ||
architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, sM, sN, sK, | ||
typename std::enable_if_t<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> { | ||
static_assert( | ||
is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK), | ||
"Invalid parameters for nvidia sm70, query valid combinations " | ||
"using: " | ||
"q.get_device().get_info<sycl::info::device::matrix::combinations>()"); | ||
|
||
static constexpr std::size_t M = sM; | ||
static constexpr std::size_t N = sN; | ||
static constexpr std::size_t K = sK; | ||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
// Validation query (nvidia sm72) | ||
// Specialization when both types and sizes are given | ||
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM, | ||
size_t sN, size_t sK> | ||
struct matrix_params< | ||
architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, sM, sN, sK, | ||
typename std::enable_if_t<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> { | ||
static_assert( | ||
(is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) || | ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK)), | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Invalid parameters for nvidia sm72, query valid combinations " | ||
"using: " | ||
"q.get_device().get_info<sycl::info::device::matrix::combinations>()"); | ||
|
||
static constexpr std::size_t M = sM; | ||
static constexpr std::size_t N = sN; | ||
static constexpr std::size_t K = sK; | ||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
// Validation query (nvidia sm80) | ||
// Specialization when both types and sizes are given | ||
template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM, | ||
size_t sN, size_t sK> | ||
struct matrix_params< | ||
architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, sM, sN, sK, | ||
typename std::enable_if_t<( | ||
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> && | ||
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> && | ||
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> { | ||
static_assert( | ||
(is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) || | ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK) || | ||
is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM, sN, sK)), | ||
konradkusiak97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Invalid parameters for nvidia sm80, query valid combinations " | ||
"using: " | ||
"q.get_device().get_info<sycl::info::device::matrix::combinations>()"); | ||
|
||
static constexpr std::size_t M = sM; | ||
static constexpr std::size_t N = sN; | ||
static constexpr std::size_t K = sK; | ||
|
||
template <typename Group, layout Layout> | ||
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>; | ||
template <typename Group, layout Layout> | ||
using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>; | ||
template <typename Group> | ||
using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>; | ||
template <typename Group> | ||
using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>; | ||
}; | ||
|
||
} // namespace experimental::matrix | ||
} // namespace oneapi | ||
} // namespace ext | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why using
sM
,sN
andsK
to represent dimensions. I appreciate you followed them for consistency, though.