Skip to content

Commit

Permalink
Split normquant ptrs (#4)
Browse files Browse the repository at this point in the history
* Fix remove unused import

* Split set_ptrs into set_ptrs_conv and set_ptrs_norm_quant
  • Loading branch information
lukamac authored Feb 9, 2024
1 parent a8765d1 commit 1250806
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 56 deletions.
13 changes: 8 additions & 5 deletions ne16/hal/ne16_task.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,18 @@ uint32_t ne16_pad_ptr(uint32_t ptr, const uint32_t width, uint32_t width_stride,
return ptr - (padding_top * width + padding_left) * width_stride;
}

void ne16_task_set_ptrs(ne16_task_t *task, uint32_t input_ptr, uint32_t w_in,
uint32_t w_in_stride, uint8_t padding_top,
uint8_t padding_left, uint32_t output_ptr,
uint32_t weights_ptr, uint32_t scale_ptr,
uint32_t shift_ptr, uint32_t bias_ptr) {
void ne16_task_set_ptrs_conv(ne16_task_t *task, uint32_t input_ptr,
uint32_t w_in, uint32_t w_in_stride,
uint8_t padding_top, uint8_t padding_left,
uint32_t output_ptr, uint32_t weights_ptr) {
task->data.infeat_ptr =
ne16_pad_ptr(input_ptr, w_in, w_in_stride, padding_top, padding_left);
task->data.outfeat_ptr = output_ptr;
task->data.weights_ptr = weights_ptr;
}

void ne16_task_set_ptrs_norm_quant(ne16_task_t *task, uint32_t scale_ptr,
uint32_t shift_ptr, uint32_t bias_ptr) {
task->data.scale_ptr = scale_ptr;
task->data.scale_shift_ptr = shift_ptr;
task->data.scale_bias_ptr = bias_ptr;
Expand Down
11 changes: 6 additions & 5 deletions ne16/hal/ne16_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,12 @@ uint32_t ne16_get_tile_padding(uint32_t padding, uint32_t i_height,
uint32_t ne16_pad_ptr(uint32_t ptr, const uint32_t width,
const uint32_t width_stride, const uint8_t padding_top,
const uint8_t padding_left);
void ne16_task_set_ptrs(ne16_task_t *task, uint32_t input_ptr, uint32_t w_in,
uint32_t w_in_stride, uint8_t padding_top,
uint8_t padding_left, uint32_t output_ptr,
uint32_t weights_ptr, uint32_t scale_ptr,
uint32_t shift_ptr, uint32_t bias_ptr);
void ne16_task_set_ptrs_conv(ne16_task_t *task, uint32_t input_ptr,
uint32_t w_in, uint32_t w_in_stride,
uint8_t padding_top, uint8_t padding_left,
uint32_t output_ptr, uint32_t weights_ptr);
void ne16_task_set_ptrs_norm_quant(ne16_task_t *task, uint32_t scale_ptr,
uint32_t shift_ptr, uint32_t bias_ptr);
/** ne16_task_set_strides
*
* All the strides variables are strides between elements alongside that
Expand Down
14 changes: 8 additions & 6 deletions neureka/hal/neureka_task.c
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,18 @@ uint32_t neureka_pad_ptr(uint32_t ptr, const uint32_t width,
return ptr - (padding_top * width + padding_left) * width_stride;
}

void neureka_task_set_ptrs(neureka_task_t *task, uint32_t input_ptr,
uint32_t w_in, uint32_t w_in_stride,
uint8_t padding_top, uint8_t padding_left,
uint32_t output_ptr, uint32_t weights_ptr,
uint32_t scale_ptr, uint32_t shift_ptr,
uint32_t bias_ptr) {
void neureka_task_set_ptrs_conv(neureka_task_t *task, uint32_t input_ptr,
uint32_t w_in, uint32_t w_in_stride,
uint8_t padding_top, uint8_t padding_left,
uint32_t output_ptr, uint32_t weights_ptr) {
task->data.infeat_ptr =
neureka_pad_ptr(input_ptr, w_in, w_in_stride, padding_top, padding_left);
task->data.outfeat_ptr = output_ptr;
task->data.weights_ptr = weights_ptr;
}

void neureka_task_set_ptrs_norm_quant(neureka_task_t *task, uint32_t scale_ptr,
uint32_t shift_ptr, uint32_t bias_ptr) {
task->data.scale_ptr = scale_ptr;
task->data.scale_shift_ptr = shift_ptr;
task->data.scale_bias_ptr = bias_ptr;
Expand Down
12 changes: 6 additions & 6 deletions neureka/hal/neureka_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ uint32_t neureka_get_tile_padding(uint32_t padding, uint32_t i_height,
uint32_t neureka_pad_ptr(uint32_t ptr, const uint32_t width,
const uint32_t width_stride, const uint8_t padding_top,
const uint8_t padding_left);
void neureka_task_set_ptrs(neureka_task_t *task, uint32_t input_ptr,
uint32_t w_in, uint32_t w_in_stride,
uint8_t padding_top, uint8_t padding_left,
uint32_t output_ptr, uint32_t weights_ptr,
uint32_t scale_ptr, uint32_t shift_ptr,
uint32_t bias_ptr);
void neureka_task_set_ptrs_conv(neureka_task_t *task, uint32_t input_ptr,
uint32_t w_in, uint32_t w_in_stride,
uint8_t padding_top, uint8_t padding_left,
uint32_t output_ptr, uint32_t weights_ptr);
void neureka_task_set_ptrs_norm_quant(neureka_task_t *task, uint32_t scale_ptr,
uint32_t shift_ptr, uint32_t bias_ptr);
/** neureka_task_set_strides
*
* All the strides variables are strides between elements alongside that
Expand Down
2 changes: 0 additions & 2 deletions test/NeurekaMemoryLayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import numpy as np
import numpy.typing as npt

from TestClasses import IntegerType


class NeurekaMemoryLayout:
_WEIGHT_BANDWIDTH = 256
Expand Down
66 changes: 34 additions & 32 deletions test/app/src/nnx_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@

typedef ne16_norm_mode_e nnx_norm_mode_e;
typedef ne16_quant_t nnx_quant_t;
typedef ne16_quant_function_e nnx_quant_function_e;
typedef ne16_norm_t nnx_norm_t;
typedef ne16_task_t nnx_task_t;
typedef ne16_dev_t nnx_dev_t;
typedef ne16_pulp_conf_t nnx_bsp_conf_t;
typedef ne16_task_flag_e nnx_task_flag_e;

#define nnxTaskFlagTrue ne16TaskFlagTrue
#define nnxTaskFlagFalse ne16TaskFlagFalse
Expand All @@ -46,7 +48,8 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t;
#define nnx_task_set_weight_offset ne16_task_set_weight_offset
#define nnx_task_set_dims ne16_task_set_dims
#define nnx_task_set_dims_stride2x2 ne16_task_set_dims_stride2x2
#define nnx_task_set_ptrs ne16_task_set_ptrs
#define nnx_task_set_ptrs_conv ne16_task_set_ptrs_conv
#define nnx_task_set_ptrs_norm_quant ne16_task_set_ptrs_norm_quant

#define NNX_GVSOC_LOG_LEVEL NE16_GVSOC_LOG_LEVEL_ALL
#define NNX_GVSOC_LOG_FORMAT NE16_GVSOC_LOG_FORMAT_HEXADECIMAL
Expand All @@ -72,10 +75,12 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t;

typedef neureka_norm_mode_e nnx_norm_mode_e;
typedef neureka_quant_t nnx_quant_t;
typedef neureka_quant_function_e nnx_quant_function_e;
typedef neureka_norm_t nnx_norm_t;
typedef neureka_task_t nnx_task_t;
typedef neureka_dev_t nnx_dev_t;
typedef neureka_siracusa_conf_t nnx_bsp_conf_t;
typedef neureka_task_flag_e nnx_task_flag_e;

#define nnxTaskFlagTrue neurekaTaskFlagTrue
#define nnxTaskFlagFalse neurekaTaskFlagFalse
Expand All @@ -86,7 +91,8 @@ typedef neureka_siracusa_conf_t nnx_bsp_conf_t;
#define nnx_task_set_norm_quant neureka_task_set_norm_quant
#define nnx_task_set_weight_offset neureka_task_set_weight_offset
#define nnx_task_set_dims neureka_task_set_dims
#define nnx_task_set_ptrs neureka_task_set_ptrs
#define nnx_task_set_ptrs_conv neureka_task_set_ptrs_conv
#define nnx_task_set_ptrs_norm_quant neureka_task_set_ptrs_norm_quant

#define NNX_GVSOC_LOG_LEVEL NEUREKA_GVSOC_LOG_LEVEL_ALL
#define NNX_GVSOC_LOG_FORMAT NEUREKA_GVSOC_LOG_FORMAT_HEXADECIMAL
Expand Down Expand Up @@ -120,24 +126,6 @@ static void task_prepare(nnx_task_t *task) {
#endif
nnx_task_set_bits(task, INPUT_BITS, OUTPUT_BITS, WEIGHT_BITS);

#if HAS_NORM_QUANT == 1
#if SCALE_BITS == 8
const nnx_norm_mode_e normMode = normMode8Bit;
#elif SCALE_BITS == 32
const nnx_norm_mode_e normMode = normMode32Bit;
#endif

nnx_task_set_norm_quant(
task,
(nnx_quant_t){.shift_amount = OUTSHIFT,
.function =
HAS_RELU ? quantFunctionRelu : quantFunctionIdentity,
.flag_rounding = nnxTaskFlagFalse},
(nnx_norm_t){.mode = normMode,
.flag_bias = HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse,
.flag_shift = nnxTaskFlagFalse});
#endif // HAS_NORM_QUANT

nnx_task_set_weight_offset(task, weightOffsetModeLayerWise, WEIGHT_OFFSET);

#ifdef NNX_NEUREKA
Expand Down Expand Up @@ -171,20 +159,34 @@ static void task_prepare(nnx_task_t *task) {
PADDING_RIGHT);
#endif

nnx_task_set_ptrs(task, (uint32_t)input, INPUT_WIDTH, w_in_stride,
PADDING_TOP, PADDING_LEFT, (uint32_t)output,
(uint32_t)weight,
nnx_task_set_ptrs_conv(task, (uint32_t)input, INPUT_WIDTH, w_in_stride,
PADDING_TOP, PADDING_LEFT, (uint32_t)output,
(uint32_t)weight);

#if HAS_NORM_QUANT == 1
(uint32_t)scale, NULL,
#if HAS_BIAS == 1
(uint32_t)bias
#else
NULL
#endif
#else
NULL, NULL, NULL
#if SCALE_BITS == 8
const nnx_norm_mode_e normMode = normMode8Bit;
#elif SCALE_BITS == 32
const nnx_norm_mode_e normMode = normMode32Bit;
#endif
);

const nnx_task_flag_e flag_bias =
HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse;
const uint32_t bias_ptr = (uint32_t)(HAS_BIAS ? bias : NULL);

nnx_quant_function_e quant_function =
HAS_RELU ? quantFunctionRelu : quantFunctionIdentity;

nnx_task_set_norm_quant(task,
(nnx_quant_t){.shift_amount = OUTSHIFT,
.function = quant_function,
.flag_rounding = nnxTaskFlagFalse},
(nnx_norm_t){.mode = normMode,
.flag_bias = flag_bias,
.flag_shift = nnxTaskFlagFalse});

nnx_task_set_ptrs_norm_quant(task, (uint32_t)scale, NULL, bias_ptr);
#endif // HAS_NORM_QUANT
}

static void task_execute(nnx_task_t *task) {
Expand Down

0 comments on commit 1250806

Please sign in to comment.