diff --git a/numpy_minmax/_minmax.c b/numpy_minmax/_minmax.c index 449c55b..4fac089 100644 --- a/numpy_minmax/_minmax.c +++ b/numpy_minmax/_minmax.c @@ -3,13 +3,6 @@ #define IS_X86_64 (defined(__x86_64__) || defined(_M_X64)) -typedef struct { - float min_val; - float max_val; -} MinMaxResult; - -typedef unsigned char Byte; - #if IS_X86_64 #include @@ -22,7 +15,17 @@ typedef unsigned char Byte; #ifndef bit_AVX512F #define bit_AVX512F (1 << 16) #endif +#endif + + +typedef struct { + float min_val; + float max_val; +} MinMaxResult; + +typedef unsigned char Byte; +#if IS_X86_64 bool system_supports_avx512() { unsigned int eax, ebx, ecx, edx; @@ -40,7 +43,29 @@ bool system_supports_avx512() { // Check the AVX512F bit in EBX return (ebx & bit_AVX512F) != 0; } +#endif +static inline MinMaxResult minmax_pairwise(const float *a, size_t length) { + // Initialize min and max with the last element of the array. + // This ensures that it works correctly for odd length arrays as well as even. + MinMaxResult result = {.min_val = a[length -1], .max_val = a[length-1]}; + + // Process elements in pairs + for (size_t i = 0; i < length - 1; i += 2) { + float smaller = a[i] < a[i + 1] ? a[i] : a[i + 1]; + float larger = a[i] < a[i + 1] ? a[i + 1] : a[i]; + + if (smaller < result.min_val) { + result.min_val = smaller; + } + if (larger > result.max_val) { + result.max_val = larger; + } + } + return result; +} + +#if IS_X86_64 static inline MinMaxResult reduce_result_from_mm256(__m256 min_vals, __m256 max_vals, MinMaxResult result) { float temp_min[8], temp_max[8]; _mm256_storeu_ps(temp_min, min_vals); @@ -51,7 +76,9 @@ static inline MinMaxResult reduce_result_from_mm256(__m256 min_vals, __m256 max_ } return result; } +#endif +#if IS_X86_64 MinMaxResult minmax_avx(const float *a, size_t length) { MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX }; @@ -72,7 +99,9 @@ MinMaxResult minmax_avx(const float *a, size_t length) { return reduce_result_from_mm256(min_vals, max_vals, result); } +#endif +#if IS_X86_64 static inline MinMaxResult reduce_result_from_mm512(__m512 min_vals, __m512 max_vals, MinMaxResult result) { float temp_min[16], temp_max[16]; _mm512_storeu_ps(temp_min, min_vals); @@ -83,7 +112,9 @@ static inline MinMaxResult reduce_result_from_mm512(__m512 min_vals, __m512 max_ } return result; } +#endif +#if IS_X86_64 MinMaxResult minmax_avx512(const float *a, size_t length) { MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX }; @@ -105,74 +136,8 @@ MinMaxResult minmax_avx512(const float *a, size_t length) { return reduce_result_from_mm512(min_vals, max_vals, result); } - -// Takes the avx min/max on strided input. Strides are in number of bytes, -// which is why the data pointer is Byte (i.e. unsigned char) -MinMaxResult minmax_avx_strided(const Byte *a, size_t length, long stride) { - MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX }; - - // This is faster than intrinsic gather on tested platforms - __m256 min_vals = _mm256_set_ps( - *(float*)(a), - *(float*)(a + stride), - *(float*)(a + 2*stride), - *(float*)(a + 3*stride), - *(float*)(a + 4*stride), - *(float*)(a + 5*stride), - *(float*)(a + 6*stride), - *(float*)(a + 7*stride) - ); - __m256 max_vals = min_vals; - - // Process elements in chunks of eight - size_t i = 8*stride; - for (; i <= (length - 8)*stride; i += 8*stride) { - __m256 vals = _mm256_set_ps( - *(float*)(a + i), - *(float*)(a + i + stride), - *(float*)(a + i + 2*stride), - *(float*)(a + i + 3*stride), - *(float*)(a + i + 4*stride), - *(float*)(a + i + 5*stride), - *(float*)(a + i + 6*stride), - *(float*)(a + i + 7*stride) - ); - min_vals = _mm256_min_ps(min_vals, vals); - max_vals = _mm256_max_ps(max_vals, vals); - } - - // Process remainder elements - if (i < length*stride){ - result = minmax_pairwise_strided(a + i, length - i / stride, stride); - } - - return reduce_result_from_mm256(min_vals, max_vals, result); -} #endif - - -static inline MinMaxResult minmax_pairwise(const float *a, size_t length) { - // Initialize min and max with the last element of the array. - // This ensures that it works correctly for odd length arrays as well as even. - MinMaxResult result = {.min_val = a[length -1], .max_val = a[length-1]}; - - // Process elements in pairs - for (size_t i = 0; i < length - 1; i += 2) { - float smaller = a[i] < a[i + 1] ? a[i] : a[i + 1]; - float larger = a[i] < a[i + 1] ? a[i + 1] : a[i]; - - if (smaller < result.min_val) { - result.min_val = smaller; - } - if (larger > result.max_val) { - result.max_val = larger; - } - } - return result; -} - - MinMaxResult minmax_contiguous(const float *a, size_t length) { // Return early for empty arrays if (length == 0) { @@ -226,6 +191,51 @@ MinMaxResult minmax_pairwise_strided(const Byte *a, size_t length, long stride) return result; } +#if IS_X86_64 +// Takes the avx min/max on strided input. Strides are in number of bytes, +// which is why the data pointer is Byte (i.e. unsigned char) +MinMaxResult minmax_avx_strided(const Byte *a, size_t length, long stride) { + MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX }; + + // This is faster than intrinsic gather on tested platforms + __m256 min_vals = _mm256_set_ps( + *(float*)(a), + *(float*)(a + stride), + *(float*)(a + 2*stride), + *(float*)(a + 3*stride), + *(float*)(a + 4*stride), + *(float*)(a + 5*stride), + *(float*)(a + 6*stride), + *(float*)(a + 7*stride) + ); + __m256 max_vals = min_vals; + + // Process elements in chunks of eight + size_t i = 8*stride; + for (; i <= (length - 8)*stride; i += 8*stride) { + __m256 vals = _mm256_set_ps( + *(float*)(a + i), + *(float*)(a + i + stride), + *(float*)(a + i + 2*stride), + *(float*)(a + i + 3*stride), + *(float*)(a + i + 4*stride), + *(float*)(a + i + 5*stride), + *(float*)(a + i + 6*stride), + *(float*)(a + i + 7*stride) + ); + min_vals = _mm256_min_ps(min_vals, vals); + max_vals = _mm256_max_ps(max_vals, vals); + } + + // Process remainder elements + if (i < length*stride){ + result = minmax_pairwise_strided(a + i, length - i / stride, stride); + } + + return reduce_result_from_mm256(min_vals, max_vals, result); +} +#endif + MinMaxResult minmax_1d_strided(const float *a, size_t length, long stride) { // Return early for empty arrays