Skip to content

Commit

Permalink
Revert "Regroup the code by ISA"
Browse files Browse the repository at this point in the history
This reverts commit 04bf6d6.
  • Loading branch information
iver56 committed Jul 29, 2024
1 parent 04bf6d6 commit 75db064
Showing 1 changed file with 83 additions and 73 deletions.
156 changes: 83 additions & 73 deletions numpy_minmax/_minmax.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@

#define IS_X86_64 (defined(__x86_64__) || defined(_M_X64))

typedef struct {
float min_val;
float max_val;
} MinMaxResult;

typedef unsigned char Byte;

#if IS_X86_64
#include <immintrin.h>

Expand All @@ -22,7 +15,17 @@ typedef unsigned char Byte;
#ifndef bit_AVX512F
#define bit_AVX512F (1 << 16)
#endif
#endif


typedef struct {
float min_val;
float max_val;
} MinMaxResult;

typedef unsigned char Byte;

#if IS_X86_64
bool system_supports_avx512() {
unsigned int eax, ebx, ecx, edx;

Expand All @@ -40,7 +43,29 @@ bool system_supports_avx512() {
// Check the AVX512F bit in EBX
return (ebx & bit_AVX512F) != 0;
}
#endif

static inline MinMaxResult minmax_pairwise(const float *a, size_t length) {
// Initialize min and max with the last element of the array.
// This ensures that it works correctly for odd length arrays as well as even.
MinMaxResult result = {.min_val = a[length -1], .max_val = a[length-1]};

// Process elements in pairs
for (size_t i = 0; i < length - 1; i += 2) {
float smaller = a[i] < a[i + 1] ? a[i] : a[i + 1];
float larger = a[i] < a[i + 1] ? a[i + 1] : a[i];

if (smaller < result.min_val) {
result.min_val = smaller;
}
if (larger > result.max_val) {
result.max_val = larger;
}
}
return result;
}

#if IS_X86_64
static inline MinMaxResult reduce_result_from_mm256(__m256 min_vals, __m256 max_vals, MinMaxResult result) {
float temp_min[8], temp_max[8];
_mm256_storeu_ps(temp_min, min_vals);
Expand All @@ -51,7 +76,9 @@ static inline MinMaxResult reduce_result_from_mm256(__m256 min_vals, __m256 max_
}
return result;
}
#endif

#if IS_X86_64
MinMaxResult minmax_avx(const float *a, size_t length) {
MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX };

Expand All @@ -72,7 +99,9 @@ MinMaxResult minmax_avx(const float *a, size_t length) {

return reduce_result_from_mm256(min_vals, max_vals, result);
}
#endif

#if IS_X86_64
static inline MinMaxResult reduce_result_from_mm512(__m512 min_vals, __m512 max_vals, MinMaxResult result) {
float temp_min[16], temp_max[16];
_mm512_storeu_ps(temp_min, min_vals);
Expand All @@ -83,7 +112,9 @@ static inline MinMaxResult reduce_result_from_mm512(__m512 min_vals, __m512 max_
}
return result;
}
#endif

#if IS_X86_64
MinMaxResult minmax_avx512(const float *a, size_t length) {
MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX };

Expand All @@ -105,74 +136,8 @@ MinMaxResult minmax_avx512(const float *a, size_t length) {

return reduce_result_from_mm512(min_vals, max_vals, result);
}

// Takes the avx min/max on strided input. Strides are in number of bytes,
// which is why the data pointer is Byte (i.e. unsigned char)
MinMaxResult minmax_avx_strided(const Byte *a, size_t length, long stride) {
MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX };

// This is faster than intrinsic gather on tested platforms
__m256 min_vals = _mm256_set_ps(
*(float*)(a),
*(float*)(a + stride),
*(float*)(a + 2*stride),
*(float*)(a + 3*stride),
*(float*)(a + 4*stride),
*(float*)(a + 5*stride),
*(float*)(a + 6*stride),
*(float*)(a + 7*stride)
);
__m256 max_vals = min_vals;

// Process elements in chunks of eight
size_t i = 8*stride;
for (; i <= (length - 8)*stride; i += 8*stride) {
__m256 vals = _mm256_set_ps(
*(float*)(a + i),
*(float*)(a + i + stride),
*(float*)(a + i + 2*stride),
*(float*)(a + i + 3*stride),
*(float*)(a + i + 4*stride),
*(float*)(a + i + 5*stride),
*(float*)(a + i + 6*stride),
*(float*)(a + i + 7*stride)
);
min_vals = _mm256_min_ps(min_vals, vals);
max_vals = _mm256_max_ps(max_vals, vals);
}

// Process remainder elements
if (i < length*stride){
result = minmax_pairwise_strided(a + i, length - i / stride, stride);
}

return reduce_result_from_mm256(min_vals, max_vals, result);
}
#endif



static inline MinMaxResult minmax_pairwise(const float *a, size_t length) {
// Initialize min and max with the last element of the array.
// This ensures that it works correctly for odd length arrays as well as even.
MinMaxResult result = {.min_val = a[length -1], .max_val = a[length-1]};

// Process elements in pairs
for (size_t i = 0; i < length - 1; i += 2) {
float smaller = a[i] < a[i + 1] ? a[i] : a[i + 1];
float larger = a[i] < a[i + 1] ? a[i + 1] : a[i];

if (smaller < result.min_val) {
result.min_val = smaller;
}
if (larger > result.max_val) {
result.max_val = larger;
}
}
return result;
}


MinMaxResult minmax_contiguous(const float *a, size_t length) {
// Return early for empty arrays
if (length == 0) {
Expand Down Expand Up @@ -226,6 +191,51 @@ MinMaxResult minmax_pairwise_strided(const Byte *a, size_t length, long stride)
return result;
}

#if IS_X86_64
// Takes the avx min/max on strided input. Strides are in number of bytes,
// which is why the data pointer is Byte (i.e. unsigned char)
MinMaxResult minmax_avx_strided(const Byte *a, size_t length, long stride) {
MinMaxResult result = { .min_val = FLT_MAX, .max_val = -FLT_MAX };

// This is faster than intrinsic gather on tested platforms
__m256 min_vals = _mm256_set_ps(
*(float*)(a),
*(float*)(a + stride),
*(float*)(a + 2*stride),
*(float*)(a + 3*stride),
*(float*)(a + 4*stride),
*(float*)(a + 5*stride),
*(float*)(a + 6*stride),
*(float*)(a + 7*stride)
);
__m256 max_vals = min_vals;

// Process elements in chunks of eight
size_t i = 8*stride;
for (; i <= (length - 8)*stride; i += 8*stride) {
__m256 vals = _mm256_set_ps(
*(float*)(a + i),
*(float*)(a + i + stride),
*(float*)(a + i + 2*stride),
*(float*)(a + i + 3*stride),
*(float*)(a + i + 4*stride),
*(float*)(a + i + 5*stride),
*(float*)(a + i + 6*stride),
*(float*)(a + i + 7*stride)
);
min_vals = _mm256_min_ps(min_vals, vals);
max_vals = _mm256_max_ps(max_vals, vals);
}

// Process remainder elements
if (i < length*stride){
result = minmax_pairwise_strided(a + i, length - i / stride, stride);
}

return reduce_result_from_mm256(min_vals, max_vals, result);
}
#endif


MinMaxResult minmax_1d_strided(const float *a, size_t length, long stride) {
// Return early for empty arrays
Expand Down

0 comments on commit 75db064

Please sign in to comment.