Skip to content

Commit

Permalink
Merge pull request #2995 from xhzheng1895/mnn_kai
Browse files Browse the repository at this point in the history
Integrate kleidiAI release v0.3.0 into MNN 2.9.6
  • Loading branch information
wangzhaode authored Oct 28, 2024
2 parents 1ee0313 + 8f6a123 commit 630d593
Show file tree
Hide file tree
Showing 19 changed files with 2,461 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ option(MNN_OPENCL "Enable OpenCL" OFF)
option(MNN_OPENGL "Enable OpenGL" OFF)
option(MNN_VULKAN "Enable Vulkan" OFF)
option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON)
option(MNN_KLEIDIAI "Enable KLEIDIAI" OFF)
option(MNN_ONEDNN "Enable oneDNN" OFF)
option(MNN_AVX512 "Enable AVX512" OFF)
option(MNN_CUDA "Enable CUDA" OFF)
Expand Down Expand Up @@ -253,6 +254,7 @@ message(STATUS "\tOpenCL: ${MNN_OPENCL}")
message(STATUS "\tOpenGL: ${MNN_OPENGL}")
message(STATUS "\tVulkan: ${MNN_VULKAN}")
message(STATUS "\tARM82: ${MNN_ARM82}")
message(STATUS "\tKleidiAI: ${MNN_KLEIDIAI}")
message(STATUS "\toneDNN: ${MNN_ONEDNN}")
message(STATUS "\tTensorRT: ${MNN_TENSORRT}")
message(STATUS "\tCoreML: ${MNN_COREML}")
Expand Down
7 changes: 7 additions & 0 deletions source/backend/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,10 @@ IF(MNN_ARM82)
ENDIF()
ENDIF()

# Kleidi AI
IF(MNN_KLEIDIAI)
add_definitions(-DMNN_KLEIDIAI_ENABLED=1)
include(${CMAKE_CURRENT_LIST_DIR}/arm/kleidiAI/CMakeLists.txt)
list(APPEND MNN_TARGETS MNN_KleidiAI)
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNN_KleidiAI>)
ENDIF()
4 changes: 4 additions & 0 deletions source/backend/cpu/CPUBackend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include "core/BufferAllocator.hpp"
#include "MNN_generated.h"

#ifdef MNN_KLEIDIAI_ENABLED
#include "arm/kleidiAI/mnn_kleidiai.h"
#endif

namespace MNN {
class CPURuntime : public Runtime {
public:
Expand Down
63 changes: 63 additions & 0 deletions source/backend/cpu/arm/kleidiAI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#
# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#

project(MNN_KleidiAI
LANGUAGES C CXX ASM
)

set(KLEIDIAI_MIN_CLANG_VERSION 11)
set(KLEIDIAI_MIN_GNU_VERSION 11)

if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_CLANG_VERSION})
message(WARNING "KleidiAI: Using non-supported Clang version. Expected ${KLEIDIAI_MIN_CLANG_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()

if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_GNU_VERSION})
message(WARNING "KleidiAI: Using non-supported GCC version. Expected ${KLEIDIAI_MIN_GNU_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()

list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp)
list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h)

add_library(
MNN_KleidiAI
SHARED
${MNN_KleidiAI_SOURCES} ${MNN_KleidiAI_HEADERS}
)

set(KLEIDIAI_SRC ${CMAKE_CURRENT_LIST_DIR})

include_directories(
${KLEIDIAI_SRC}/
${KLEIDIAI_SRC}/kai/
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)

set(KLEIDIAI_FILES_SCALAR
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c
)

set(KLEIDIAI_FILES_NEON_DOTPROD
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c
)

set(KLEIDIAI_FILES_NEON_I8MM
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c
)

# Selectively enable architecture features.
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SCALAR})
if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC)
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD})
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_I8MM})

set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a)
set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod)
set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm)
endif()
194 changes: 194 additions & 0 deletions source/backend/cpu/arm/kleidiAI/kai/kai_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef __cplusplus
extern "C" {
#endif

// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)
//
// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros.
// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected.
// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting.
#define KAI_ERROR(msg) \
do { \
fflush(stdout); \
fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \
exit(EXIT_FAILURE); \
} while (0)

#define KAI_ASSERT_MSG(cond, msg) \
do { \
if (!(cond)) { \
KAI_ERROR(msg); \
} \
} while (0)

// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)

#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond)

#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg)
#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond)

#define KAI_ASSUME_MSG KAI_ASSERT_MSG
#define KAI_ASSUME KAI_ASSERT
#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG
#define KAI_ASSUME_IF KAI_ASSERT_IF

#define KAI_UNUSED(x) (void)(x)
#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b))
#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b))

/// KleidiAI data types
/// Format: <byte 3>(reserved)|<byte 2>(num-bytes)|<byte 1>(type)|<byte 0>(variant-type)
enum kai_datatype {
kai_dt_unknown = 0x0000,
kai_dt_f32 = 0x0411,
kai_dt_f16 = 0x0212,
kai_dt_bf16 = 0x0213,
kai_dt_int32 = 0x0421,
kai_dt_int16 = 0x0222,
kai_dt_int8 = 0x0124,
kai_dt_uint32 = 0x0431,
kai_dt_uint16 = 0x0232,
kai_dt_uint8 = 0x0134,
kai_dt_bool = 0x0441
};

/// Gets number of bytes for a given data type
/// @param[in] dt KleidiAI data type
///
/// @return the numbers of bytes for the data type
inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) {
return (size_t)(dt >> 8);
}

/// Converts a scalar f16 value to f32
/// @param[in] f16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_f16(uint16_t f16) {
#if defined(__ARM_NEON)
__fp16 f32 = 0;
memcpy(&f32, &f16, sizeof(uint16_t));
return (float)f32;
#endif
}

/// Converts a scalar bf16 value to f32
/// @param[in] bf16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_bf16(uint16_t bf16) {
const uint32_t i32 = (bf16 << 16);
float f32;
memcpy(&f32, &i32, sizeof(i32));
return f32;
}

/// Converts a f32 value to bf16
/// @param[in] f32 The f32 value
///
/// @return the bf16 value
inline static uint16_t kai_cast_bf16_f32(float f32) {
uint16_t bf16 = 0;
#ifdef __ARM_FEATURE_BF16
__asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32));
#else
const uint32_t* i32 = (uint32_t*)(&f32);
bf16 = (*i32 >> 16);
#endif
return bf16;
}

/// Converts a scalar f32 value to f16
/// @param[in] f32 The f32 value
///
/// @return the f16 value
inline static uint16_t kai_cast_f16_f32(float f32) {
#if defined(__ARM_NEON)
uint16_t f16 = 0;
__fp16 tmp = f32;
memcpy(&f16, &tmp, sizeof(uint16_t));
return f16;
#endif
}

inline static size_t kai_roundup(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}

#ifdef __ARM_FEATURE_SVE

/// Gets the SME vector length for 8-bit elements.
inline static uint64_t kai_get_sme_vector_length_u8(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntb %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
}

/// Gets the SME vector length for 16-bit elements.
inline static uint64_t kai_get_sme_vector_length_u16(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cnth %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
}

/// Gets the SME vector length for 32-bit elements.
inline static uint64_t kai_get_sme_vector_length_u32(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntw %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
}

#endif // __ARM_FEATURE_SVE

/// Extends the sign bit of int 4-bit value (stored in int8_t variable)
/// @param[in] value The 4-bit int value
///
/// @return the int8_t value with sign extended
inline static int8_t kai_ext_sign_i8_i4(int8_t value) {
return (value ^ 0x8) - 8;
}

#ifdef __cplusplus
}
#endif
Loading

0 comments on commit 630d593

Please sign in to comment.