Skip to content

Commit

Permalink
Merge pull request #25 from nomonosound/ij/arm-linux
Browse files Browse the repository at this point in the history
Add ARM support (without NEON for now)
  • Loading branch information
iver56 authored Jul 29, 2024
2 parents 1a6398d + 1d6aba9 commit b96008f
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
os: [ubuntu-latest, windows-latest, ubuntu-22.04-arm64, macos-latest]
python-version: [3.9]

if: startsWith(github.event.head_commit.message, 'Release') || github.event_name == 'workflow_dispatch'
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include numpy_minmax *.c
45 changes: 35 additions & 10 deletions numpy_minmax/_minmax.c
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
#include <float.h>
#include <immintrin.h>
#include <stdbool.h>

#ifdef _MSC_VER
#include <intrin.h> // MSVC
#else
#include <cpuid.h> // GCC and Clang
#endif
#define IS_X86_64 (defined(__x86_64__) || defined(_M_X64))

#if IS_X86_64
#include <immintrin.h>

#ifdef _MSC_VER
#include <intrin.h> // MSVC
#else
#include <cpuid.h> // GCC and Clang
#endif

#ifndef bit_AVX512F
#define bit_AVX512F (1 << 16)
#ifndef bit_AVX512F
#define bit_AVX512F (1 << 16)
#endif
#endif


typedef struct {
float min_val;
float max_val;
} MinMaxResult;

typedef unsigned char Byte;

#if IS_X86_64
bool system_supports_avx512() {
unsigned int eax, ebx, ecx, edx;

Expand All @@ -36,6 +43,7 @@ bool system_supports_avx512() {
// Check the AVX512F bit in EBX
return (ebx & bit_AVX512F) != 0;
}
#endif

static inline MinMaxResult minmax_pairwise(const float *a, size_t length) {
// Initialize min and max with the last element of the array.
Expand All @@ -57,6 +65,7 @@ static inline MinMaxResult minmax_pairwise(const float *a, size_t length) {
return result;
}

#if IS_X86_64
static inline MinMaxResult reduce_result_from_mm256(__m256 min_vals, __m256 max_vals, MinMaxResult result) {
float temp_min[8], temp_max[8];
_mm256_storeu_ps(temp_min, min_vals);
Expand Down Expand Up @@ -121,12 +130,15 @@ MinMaxResult minmax_avx512(const float *a, size_t length) {

return reduce_result_from_mm512(min_vals, max_vals, result);
}
#endif

MinMaxResult minmax_contiguous(const float *a, size_t length) {
// Return early for empty arrays
if (length == 0) {
return (MinMaxResult){0.0, 0.0};
}

#if IS_X86_64
if (length >= 16) {
if (system_supports_avx512()) {
return minmax_avx512(a, length);
Expand All @@ -136,6 +148,9 @@ MinMaxResult minmax_contiguous(const float *a, size_t length) {
} else {
return minmax_pairwise(a, length);
}
#else
return minmax_pairwise(a, length);
#endif
}

// Takes the pairwise min/max on strided input. Strides are in number of bytes,
Expand Down Expand Up @@ -170,6 +185,7 @@ MinMaxResult minmax_pairwise_strided(const Byte *a, size_t length, long stride)
return result;
}

#if IS_X86_64
// Takes the avx min/max on strided input. Strides are in number of bytes,
// which is why the data pointer is Byte (i.e. unsigned char)
MinMaxResult minmax_avx_strided(const Byte *a, size_t length, long stride) {
Expand Down Expand Up @@ -205,33 +221,42 @@ MinMaxResult minmax_avx_strided(const Byte *a, size_t length, long stride) {
max_vals = _mm256_max_ps(max_vals, vals);
}


// Process remainder elements
if (i < length*stride){
result = minmax_pairwise_strided(a + i, length - i / stride, stride);
}

return reduce_result_from_mm256(min_vals, max_vals, result);
}
#endif


MinMaxResult minmax_1d_strided(const float *a, size_t length, long stride) {
// Return early for empty arrays
if (length == 0) {
return (MinMaxResult){0.0, 0.0};
}

if (stride < 0){
if (-stride == sizeof(float)){
return minmax_contiguous(a - length + 1, length);
}
if (length < 16){
return minmax_pairwise_strided((Byte*)(a) + (length - 1)*stride, length, -stride);
}
#if IS_X86_64
return minmax_avx_strided((Byte*)(a) + (length - 1)*stride, length, -stride);

#else
return minmax_pairwise_strided((Byte*)(a) + (length - 1)*stride, length, -stride);
#endif
}

#if IS_X86_64
if (length < 16){
return minmax_pairwise_strided((Byte*)a, length, stride);
}
return minmax_avx_strided((Byte*)a, length, stride);
#else
return minmax_pairwise_strided((Byte*)a, length, stride);
#endif
}
9 changes: 7 additions & 2 deletions numpy_minmax/_minmax_cffi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import platform
from cffi import FFI


ffibuilder = FFI()
ffibuilder.cdef("""
typedef struct {
Expand All @@ -18,10 +18,15 @@
with open(c_file_path, "r") as file:
c_code = file.read()

extra_compile_args = ["-mavx", "-mavx512f", "-O3", "-Wall"]
extra_compile_args = ["-O3", "-Wall"]
if os.name == "posix":
extra_compile_args.append("-Wextra")

# Detect architecture and set appropriate SIMD-related compile args
if platform.machine().lower() in ["x86_64", "amd64", "i386", "i686"]:
extra_compile_args.append("-mavx")
extra_compile_args.append("-mavx512f")

ffibuilder.set_source("_numpy_minmax", c_code, extra_compile_args=extra_compile_args)


Expand Down
1 change: 1 addition & 0 deletions packaging.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Wait for build workflow in GitHub Actions to complete
* Download wheels artifact from the build workflow
* Place all the fresh whl files in dist/
* `python setup.py sdist`
* `python -m twine upload dist/*`
* Add a tag with name "x.y.z" to the commit
* Go to https://github.com/nomonosound/numpy-minmax/releases and create a release where you choose the new tag
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ name = "numpy-minmax"
version = "0.2.1"
description = "A fast python library for finding both min and max value in a NumPy array"
dependencies = [
"numpy>=1.21"
"cffi>=1.0.0",
"numpy>=1.21,<2"
]
readme = "README.md"
license = { file = "LICENSE" }
Expand Down

0 comments on commit b96008f

Please sign in to comment.