Skip to content

Commit

Permalink
simd: support vector_aligned_tag (kokkos#6243)
Browse files Browse the repository at this point in the history
* support vector_aligned_tag

this SIMD load/store tag is defined in
the ISO C++ TS and specifies that the
load or store operation is being done
with a pointer aligned to vector width.
The Intel backends (AVX2, AVX512) do
have different intrinsics in this case,
and this enhancement allows users of
Kokkos SIMD to make use of those intrinsics.

* Made aligned_tags into aliases of corresponding loadstore_flags

* clang-formatted

* Replaced loadstore_flags with simd_flags

* clang-formatted

---------

Co-authored-by: Dong Hun Lee <donlee@sandia.gov>
  • Loading branch information
ibaned and ldh4 authored Feb 6, 2024
1 parent 5f128d2 commit 31fb476
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 38 deletions.
124 changes: 124 additions & 0 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,18 @@ class simd<double, simd_abi::avx2_fixed_size<4>> {
element_aligned_tag) {
m_value = _mm256_loadu_pd(ptr);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
vector_aligned_tag) {
m_value = _mm256_load_pd(ptr);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(
value_type* ptr, element_aligned_tag) const {
_mm256_storeu_pd(ptr, m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(value_type* ptr,
vector_aligned_tag) const {
_mm256_store_pd(ptr, m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m256d()
const {
return m_value;
Expand Down Expand Up @@ -820,10 +828,18 @@ class simd<float, simd_abi::avx2_fixed_size<4>> {
element_aligned_tag) {
m_value = _mm_loadu_ps(ptr);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
vector_aligned_tag) {
m_value = _mm_load_ps(ptr);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(
value_type* ptr, element_aligned_tag) const {
_mm_storeu_ps(ptr, m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(value_type* ptr,
vector_aligned_tag) const {
_mm_store_ps(ptr, m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m128()
const {
return m_value;
Expand Down Expand Up @@ -1067,12 +1083,25 @@ class simd<std::int32_t, simd_abi::avx2_fixed_size<4>> {
m_value = _mm_loadu_si128(reinterpret_cast<__m128i const*>(ptr));
#else
m_value = _mm_maskload_epi32(ptr, static_cast<__m128i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
vector_aligned_tag) {
// FIXME_HIP ROCm 5.6 can't compile with the intrinsic used here.
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
m_value = _mm_load_si128(reinterpret_cast<__m128i const*>(ptr));
#else
m_value = _mm_maskload_epi32(ptr, static_cast<__m128i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(
value_type* ptr, element_aligned_tag) const {
_mm_maskstore_epi32(ptr, static_cast<__m128i>(mask_type(true)), m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(value_type* ptr,
vector_aligned_tag) const {
_mm_maskstore_epi32(ptr, static_cast<__m128i>(mask_type(true)), m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m128i()
const {
return m_value;
Expand Down Expand Up @@ -1257,13 +1286,27 @@ class simd<std::int64_t, simd_abi::avx2_fixed_size<4>> {
#else
m_value = _mm256_maskload_epi64(reinterpret_cast<long long const*>(ptr),
static_cast<__m256i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
vector_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
m_value = _mm256_load_si256(reinterpret_cast<__m256i const*>(ptr));
#else
m_value = _mm256_maskload_epi64(reinterpret_cast<long long const*>(ptr),
static_cast<__m256i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(
value_type* ptr, element_aligned_tag) const {
_mm256_maskstore_epi64(reinterpret_cast<long long*>(ptr),
static_cast<__m256i>(mask_type(true)), m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(value_type* ptr,
vector_aligned_tag) const {
_mm256_maskstore_epi64(reinterpret_cast<long long*>(ptr),
static_cast<__m256i>(mask_type(true)), m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m256i()
const {
return m_value;
Expand Down Expand Up @@ -1461,6 +1504,15 @@ class simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> {
#else
m_value = _mm256_maskload_epi64(reinterpret_cast<long long const*>(ptr),
static_cast<__m256i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
vector_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
m_value = _mm256_load_si256(reinterpret_cast<__m256i const*>(ptr));
#else
m_value = _mm256_maskload_epi64(reinterpret_cast<long long const*>(ptr),
static_cast<__m256i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m256i()
Expand Down Expand Up @@ -1613,6 +1665,11 @@ class const_where_expression<simd_mask<double, simd_abi::avx2_fixed_size<4>>,
static_cast<__m256d>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_to(double* mem, vector_aligned_tag) const {
_mm256_maskstore_pd(mem, _mm256_castpd_si256(static_cast<__m256d>(m_mask)),
static_cast<__m256d>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
double* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) const {
Expand Down Expand Up @@ -1649,6 +1706,11 @@ class where_expression<simd_mask<double, simd_abi::avx2_fixed_size<4>>,
mem, _mm256_castpd_si256(static_cast<__m256d>(m_mask))));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_from(double const* mem, vector_aligned_tag) {
m_value = value_type(_mm256_maskload_pd(
mem, _mm256_castpd_si256(static_cast<__m256d>(m_mask))));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
double const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
Expand Down Expand Up @@ -1692,6 +1754,11 @@ class const_where_expression<simd_mask<float, simd_abi::avx2_fixed_size<4>>,
static_cast<__m128>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_to(float* mem, vector_aligned_tag) const {
_mm_maskstore_ps(mem, _mm_castps_si128(static_cast<__m128>(m_mask)),
static_cast<__m128>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
float* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) const {
Expand Down Expand Up @@ -1728,6 +1795,11 @@ class where_expression<simd_mask<float, simd_abi::avx2_fixed_size<4>>,
_mm_maskload_ps(mem, _mm_castps_si128(static_cast<__m128>(m_mask))));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_from(float const* mem, vector_aligned_tag) {
m_value = value_type(
_mm_maskload_ps(mem, _mm_castps_si128(static_cast<__m128>(m_mask))));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
float const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
Expand Down Expand Up @@ -1771,6 +1843,12 @@ class const_where_expression<
_mm_maskstore_epi32(mem, static_cast<__m128i>(m_mask),
static_cast<__m128i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_to(std::int32_t* mem, vector_aligned_tag) const {
_mm_maskstore_epi32(mem, static_cast<__m128i>(m_mask),
static_cast<__m128i>(m_value));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int32_t* mem,
Expand Down Expand Up @@ -1811,6 +1889,16 @@ class where_expression<simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>>,
m_value = value_type(_mm_maskload_epi32(mem, static_cast<__m128i>(m_mask)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_from(std::int32_t const* mem, vector_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
__m128i tmp = _mm_load_si128(reinterpret_cast<__m128i const*>(mem));
m_value = value_type(_mm_and_si128(tmp, static_cast<__m128i>(m_mask)));
#else
m_value = value_type(_mm_maskload_epi32(mem, static_cast<__m128i>(m_mask)));
#endif
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int32_t const* mem,
Expand Down Expand Up @@ -1858,6 +1946,13 @@ class const_where_expression<
static_cast<__m256i>(m_mask),
static_cast<__m256i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(std::int64_t* mem,
vector_aligned_tag) const {
_mm256_maskstore_epi64(reinterpret_cast<long long*>(mem),
static_cast<__m256i>(m_mask),
static_cast<__m256i>(m_value));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int64_t* mem,
Expand Down Expand Up @@ -1899,6 +1994,17 @@ class where_expression<simd_mask<std::int64_t, simd_abi::avx2_fixed_size<4>>,
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(std::int64_t const* mem,
vector_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
__m256i tmp = _mm256_load_si256(reinterpret_cast<__m256i const*>(mem));
m_value = value_type(_mm256_and_si256(tmp, static_cast<__m256i>(m_mask)));
#else
m_value = value_type(_mm256_maskload_epi64(
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
#endif
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int64_t const* mem,
Expand Down Expand Up @@ -1947,6 +2053,13 @@ class const_where_expression<
static_cast<__m256i>(m_mask),
static_cast<__m256i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(std::uint64_t* mem,
vector_aligned_tag) const {
_mm256_maskstore_epi64(reinterpret_cast<long long*>(mem),
static_cast<__m256i>(m_mask),
static_cast<__m256i>(m_value));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::uint64_t* mem,
Expand Down Expand Up @@ -1988,6 +2101,17 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>>,
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(std::uint64_t const* mem,
vector_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
__m256i tmp = _mm256_load_si256(reinterpret_cast<__m256i const*>(mem));
m_value = value_type(_mm256_and_si256(tmp, static_cast<__m256i>(m_mask)));
#else
m_value = value_type(_mm256_maskload_epi64(
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
#endif
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::uint64_t const* mem,
Expand Down
Loading

0 comments on commit 31fb476

Please sign in to comment.