diff --git a/ne16/hal/ne16_task.c b/ne16/hal/ne16_task.c index 94909dc..a519ce8 100644 --- a/ne16/hal/ne16_task.c +++ b/ne16/hal/ne16_task.c @@ -113,15 +113,18 @@ uint32_t ne16_pad_ptr(uint32_t ptr, const uint32_t width, uint32_t width_stride, return ptr - (padding_top * width + padding_left) * width_stride; } -void ne16_task_set_ptrs(ne16_task_t *task, uint32_t input_ptr, uint32_t w_in, - uint32_t w_in_stride, uint8_t padding_top, - uint8_t padding_left, uint32_t output_ptr, - uint32_t weights_ptr, uint32_t scale_ptr, - uint32_t shift_ptr, uint32_t bias_ptr) { +void ne16_task_set_ptrs_conv(ne16_task_t *task, uint32_t input_ptr, + uint32_t w_in, uint32_t w_in_stride, + uint8_t padding_top, uint8_t padding_left, + uint32_t output_ptr, uint32_t weights_ptr) { task->data.infeat_ptr = ne16_pad_ptr(input_ptr, w_in, w_in_stride, padding_top, padding_left); task->data.outfeat_ptr = output_ptr; task->data.weights_ptr = weights_ptr; +} + +void ne16_task_set_ptrs_norm_quant(ne16_task_t *task, uint32_t scale_ptr, + uint32_t shift_ptr, uint32_t bias_ptr) { task->data.scale_ptr = scale_ptr; task->data.scale_shift_ptr = shift_ptr; task->data.scale_bias_ptr = bias_ptr; diff --git a/ne16/hal/ne16_task.h b/ne16/hal/ne16_task.h index 5fd167f..e18c256 100644 --- a/ne16/hal/ne16_task.h +++ b/ne16/hal/ne16_task.h @@ -133,11 +133,12 @@ uint32_t ne16_get_tile_padding(uint32_t padding, uint32_t i_height, uint32_t ne16_pad_ptr(uint32_t ptr, const uint32_t width, const uint32_t width_stride, const uint8_t padding_top, const uint8_t padding_left); -void ne16_task_set_ptrs(ne16_task_t *task, uint32_t input_ptr, uint32_t w_in, - uint32_t w_in_stride, uint8_t padding_top, - uint8_t padding_left, uint32_t output_ptr, - uint32_t weights_ptr, uint32_t scale_ptr, - uint32_t shift_ptr, uint32_t bias_ptr); +void ne16_task_set_ptrs_conv(ne16_task_t *task, uint32_t input_ptr, + uint32_t w_in, uint32_t w_in_stride, + uint8_t padding_top, uint8_t padding_left, + uint32_t output_ptr, uint32_t weights_ptr); +void ne16_task_set_ptrs_norm_quant(ne16_task_t *task, uint32_t scale_ptr, + uint32_t shift_ptr, uint32_t bias_ptr); /** ne16_task_set_strides * * All the strides variables are strides between elements alongside that diff --git a/neureka/hal/neureka_task.c b/neureka/hal/neureka_task.c index e451af9..a0ea712 100644 --- a/neureka/hal/neureka_task.c +++ b/neureka/hal/neureka_task.c @@ -126,16 +126,18 @@ uint32_t neureka_pad_ptr(uint32_t ptr, const uint32_t width, return ptr - (padding_top * width + padding_left) * width_stride; } -void neureka_task_set_ptrs(neureka_task_t *task, uint32_t input_ptr, - uint32_t w_in, uint32_t w_in_stride, - uint8_t padding_top, uint8_t padding_left, - uint32_t output_ptr, uint32_t weights_ptr, - uint32_t scale_ptr, uint32_t shift_ptr, - uint32_t bias_ptr) { +void neureka_task_set_ptrs_conv(neureka_task_t *task, uint32_t input_ptr, + uint32_t w_in, uint32_t w_in_stride, + uint8_t padding_top, uint8_t padding_left, + uint32_t output_ptr, uint32_t weights_ptr) { task->data.infeat_ptr = neureka_pad_ptr(input_ptr, w_in, w_in_stride, padding_top, padding_left); task->data.outfeat_ptr = output_ptr; task->data.weights_ptr = weights_ptr; +} + +void neureka_task_set_ptrs_norm_quant(neureka_task_t *task, uint32_t scale_ptr, + uint32_t shift_ptr, uint32_t bias_ptr) { task->data.scale_ptr = scale_ptr; task->data.scale_shift_ptr = shift_ptr; task->data.scale_bias_ptr = bias_ptr; diff --git a/neureka/hal/neureka_task.h b/neureka/hal/neureka_task.h index 6d001fe..7af5cb5 100644 --- a/neureka/hal/neureka_task.h +++ b/neureka/hal/neureka_task.h @@ -142,12 +142,12 @@ uint32_t neureka_get_tile_padding(uint32_t padding, uint32_t i_height, uint32_t neureka_pad_ptr(uint32_t ptr, const uint32_t width, const uint32_t width_stride, const uint8_t padding_top, const uint8_t padding_left); -void neureka_task_set_ptrs(neureka_task_t *task, uint32_t input_ptr, - uint32_t w_in, uint32_t w_in_stride, - uint8_t padding_top, uint8_t padding_left, - uint32_t output_ptr, uint32_t weights_ptr, - uint32_t scale_ptr, uint32_t shift_ptr, - uint32_t bias_ptr); +void neureka_task_set_ptrs_conv(neureka_task_t *task, uint32_t input_ptr, + uint32_t w_in, uint32_t w_in_stride, + uint8_t padding_top, uint8_t padding_left, + uint32_t output_ptr, uint32_t weights_ptr); +void neureka_task_set_ptrs_norm_quant(neureka_task_t *task, uint32_t scale_ptr, + uint32_t shift_ptr, uint32_t bias_ptr); /** neureka_task_set_strides * * All the strides variables are strides between elements alongside that diff --git a/test/NeurekaMemoryLayout.py b/test/NeurekaMemoryLayout.py index 80a2786..028c7a3 100644 --- a/test/NeurekaMemoryLayout.py +++ b/test/NeurekaMemoryLayout.py @@ -20,8 +20,6 @@ import numpy as np import numpy.typing as npt -from TestClasses import IntegerType - class NeurekaMemoryLayout: _WEIGHT_BANDWIDTH = 256 diff --git a/test/app/src/nnx_layer.c b/test/app/src/nnx_layer.c index d398399..001029f 100644 --- a/test/app/src/nnx_layer.c +++ b/test/app/src/nnx_layer.c @@ -31,10 +31,12 @@ typedef ne16_norm_mode_e nnx_norm_mode_e; typedef ne16_quant_t nnx_quant_t; +typedef ne16_quant_function_e nnx_quant_function_e; typedef ne16_norm_t nnx_norm_t; typedef ne16_task_t nnx_task_t; typedef ne16_dev_t nnx_dev_t; typedef ne16_pulp_conf_t nnx_bsp_conf_t; +typedef ne16_task_flag_e nnx_task_flag_e; #define nnxTaskFlagTrue ne16TaskFlagTrue #define nnxTaskFlagFalse ne16TaskFlagFalse @@ -46,7 +48,8 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t; #define nnx_task_set_weight_offset ne16_task_set_weight_offset #define nnx_task_set_dims ne16_task_set_dims #define nnx_task_set_dims_stride2x2 ne16_task_set_dims_stride2x2 -#define nnx_task_set_ptrs ne16_task_set_ptrs +#define nnx_task_set_ptrs_conv ne16_task_set_ptrs_conv +#define nnx_task_set_ptrs_norm_quant ne16_task_set_ptrs_norm_quant #define NNX_GVSOC_LOG_LEVEL NE16_GVSOC_LOG_LEVEL_ALL #define NNX_GVSOC_LOG_FORMAT NE16_GVSOC_LOG_FORMAT_HEXADECIMAL @@ -72,10 +75,12 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t; typedef neureka_norm_mode_e nnx_norm_mode_e; typedef neureka_quant_t nnx_quant_t; +typedef neureka_quant_function_e nnx_quant_function_e; typedef neureka_norm_t nnx_norm_t; typedef neureka_task_t nnx_task_t; typedef neureka_dev_t nnx_dev_t; typedef neureka_siracusa_conf_t nnx_bsp_conf_t; +typedef neureka_task_flag_e nnx_task_flag_e; #define nnxTaskFlagTrue neurekaTaskFlagTrue #define nnxTaskFlagFalse neurekaTaskFlagFalse @@ -86,7 +91,8 @@ typedef neureka_siracusa_conf_t nnx_bsp_conf_t; #define nnx_task_set_norm_quant neureka_task_set_norm_quant #define nnx_task_set_weight_offset neureka_task_set_weight_offset #define nnx_task_set_dims neureka_task_set_dims -#define nnx_task_set_ptrs neureka_task_set_ptrs +#define nnx_task_set_ptrs_conv neureka_task_set_ptrs_conv +#define nnx_task_set_ptrs_norm_quant neureka_task_set_ptrs_norm_quant #define NNX_GVSOC_LOG_LEVEL NEUREKA_GVSOC_LOG_LEVEL_ALL #define NNX_GVSOC_LOG_FORMAT NEUREKA_GVSOC_LOG_FORMAT_HEXADECIMAL @@ -120,24 +126,6 @@ static void task_prepare(nnx_task_t *task) { #endif nnx_task_set_bits(task, INPUT_BITS, OUTPUT_BITS, WEIGHT_BITS); -#if HAS_NORM_QUANT == 1 -#if SCALE_BITS == 8 - const nnx_norm_mode_e normMode = normMode8Bit; -#elif SCALE_BITS == 32 - const nnx_norm_mode_e normMode = normMode32Bit; -#endif - - nnx_task_set_norm_quant( - task, - (nnx_quant_t){.shift_amount = OUTSHIFT, - .function = - HAS_RELU ? quantFunctionRelu : quantFunctionIdentity, - .flag_rounding = nnxTaskFlagFalse}, - (nnx_norm_t){.mode = normMode, - .flag_bias = HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse, - .flag_shift = nnxTaskFlagFalse}); -#endif // HAS_NORM_QUANT - nnx_task_set_weight_offset(task, weightOffsetModeLayerWise, WEIGHT_OFFSET); #ifdef NNX_NEUREKA @@ -171,20 +159,34 @@ static void task_prepare(nnx_task_t *task) { PADDING_RIGHT); #endif - nnx_task_set_ptrs(task, (uint32_t)input, INPUT_WIDTH, w_in_stride, - PADDING_TOP, PADDING_LEFT, (uint32_t)output, - (uint32_t)weight, + nnx_task_set_ptrs_conv(task, (uint32_t)input, INPUT_WIDTH, w_in_stride, + PADDING_TOP, PADDING_LEFT, (uint32_t)output, + (uint32_t)weight); + #if HAS_NORM_QUANT == 1 - (uint32_t)scale, NULL, -#if HAS_BIAS == 1 - (uint32_t)bias -#else - NULL -#endif -#else - NULL, NULL, NULL +#if SCALE_BITS == 8 + const nnx_norm_mode_e normMode = normMode8Bit; +#elif SCALE_BITS == 32 + const nnx_norm_mode_e normMode = normMode32Bit; #endif - ); + + const nnx_task_flag_e flag_bias = + HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse; + const uint32_t bias_ptr = (uint32_t)(HAS_BIAS ? bias : NULL); + + nnx_quant_function_e quant_function = + HAS_RELU ? quantFunctionRelu : quantFunctionIdentity; + + nnx_task_set_norm_quant(task, + (nnx_quant_t){.shift_amount = OUTSHIFT, + .function = quant_function, + .flag_rounding = nnxTaskFlagFalse}, + (nnx_norm_t){.mode = normMode, + .flag_bias = flag_bias, + .flag_shift = nnxTaskFlagFalse}); + + nnx_task_set_ptrs_norm_quant(task, (uint32_t)scale, NULL, bias_ptr); +#endif // HAS_NORM_QUANT } static void task_execute(nnx_task_t *task) {