Skip to content

Commit

Permalink
Remove stride2x2 for neureka
Browse files Browse the repository at this point in the history
  • Loading branch information
lukamac committed Jan 18, 2024
1 parent 741b5d7 commit 27664cd
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 97 deletions.
16 changes: 0 additions & 16 deletions inc/pulp_nnx_neureka.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,3 @@ int neureka_nnx_resolve_check(neureka_dev_t *dev, neureka_task_t *task);
* Block until you can resolve the task.
*/
void neureka_nnx_resolve_wait(neureka_dev_t *dev, neureka_task_t *task);

/* Additional helper functions */

/** neureka_nnx_dispatch_stride2x2
*
* It uses Neureka's 2x2 strided mode which reduces the number of writes Neureka
* does. This mode doesn't stride the Neureka's subtile input pointer, so we
* have to tile the tile to the subtile's spatial dimensions (in this case 3x3
* output). Works only if the k_out is divisible by 2.
*/
void neureka_nnx_dispatch_stride2x2(
neureka_dev_t *dev, neureka_task_t *task, const uint32_t w_in,
const uint32_t k_in, const uint32_t w_in_stride, const uint32_t k_in_stride,
const uint32_t h_out, const uint32_t w_out, const uint32_t k_out,
const uint32_t w_out_stride, const uint32_t k_out_stride,
const uint8_t h_ker, const uint8_t w_ker);
8 changes: 0 additions & 8 deletions neureka/hal/neureka_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,5 @@ void neureka_task_set_dims(
const uint32_t w_out_stride, const uint32_t k_out_stride,
const uint8_t padding_top, const uint8_t padding_bottom,
const uint8_t padding_right, const uint8_t padding_left);
void neureka_task_set_dims_stride2x2(
neureka_task_t *task, const uint32_t h_in, const uint32_t w_in,
const uint32_t k_in, const uint32_t w_in_stride, const uint32_t k_in_stride,
const uint32_t h_out, const uint32_t w_out, const uint32_t k_out,
const uint32_t w_out_stride, const uint32_t k_out_stride,
const uint8_t h_ker, const uint8_t w_ker, const uint8_t padding_top,
const uint8_t padding_bottom, const uint8_t padding_right,
const uint8_t padding_left);

#endif // !__NEUREKA_TASK_H__
55 changes: 0 additions & 55 deletions src/pulp_nnx_neureka.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,58 +74,3 @@ void neureka_nnx_resolve_wait(neureka_dev_t *dev, neureka_task_t *task) {
neureka_siracusa_event_wait_and_clear();
}
}

static inline uint32_t _get_tile_ptr(uint32_t ptr, int i, int j, int size_i,
uint32_t size_j, uint32_t size_k,
uint32_t stride_j, uint32_t stride_k,
uint32_t overlap_i, uint32_t overlap_j,
uint32_t offset_i, uint32_t offset_j,
uint8_t data_size) {
return ptr +
(i * (size_i - overlap_i) - offset_i) * stride_j * stride_k *
data_size / 8 +
(j * (size_j - overlap_j) - offset_j) * stride_k * data_size / 8;
}

void neureka_nnx_dispatch_stride2x2(
neureka_dev_t *dev, neureka_task_t *task, const uint32_t w_in,
const uint32_t k_in, const uint32_t w_in_stride, const uint32_t k_in_stride,
const uint32_t h_out, const uint32_t w_out, const uint32_t k_out,
const uint32_t w_out_stride, const uint32_t k_out_stride,
const uint8_t h_ker, const uint8_t w_ker) {
const uint8_t stride = 2;
const uint8_t bits = 8;

const uint32_t n_h = divnceil(h_out, stride);
const uint32_t n_w = divnceil(w_out, stride);
const uint32_t input_height_offset = h_out % stride == 1 ? stride : 0;
const uint32_t input_width_offset = w_out % stride == 1 ? stride : 0;
const uint32_t output_height_offset = h_out % stride == 1 ? 1 : 0;
const uint32_t output_width_offset = w_out % stride == 1 ? 1 : 0;

const uint32_t input_base = task->data.infeat_ptr;
const uint32_t output_base = task->data.outfeat_ptr;
const uint32_t tile_padding = task->data.cfg.padding;

for (int i = 0; i < n_h; i++) {
for (int j = 0; j < n_w; j++) {
task->data.infeat_ptr =
_get_tile_ptr(input_base, i, j, 3 + h_ker - 1, 3 + w_ker - 1, k_in,
w_in_stride, k_in_stride, h_ker - stride,
w_ker - stride, i == 0 ? 0 : input_height_offset,
j == 0 ? 0 : input_width_offset, bits);
task->data.outfeat_ptr =
_get_tile_ptr(output_base, i, j, 2, 2, k_out, w_out_stride,
k_out_stride, 0, 0, i == 0 ? 0 : output_height_offset,
j == 0 ? 0 : output_width_offset, bits);

task->data.cfg.padding =
neureka_get_tile_padding(tile_padding, i, j, n_h, n_w);

// Altered dispatch to wait if cannot acquire
while (neureka_nnx_dispatch(dev, task)) {
neureka_siracusa_event_wait_and_clear();
}
}
}
}
11 changes: 1 addition & 10 deletions test/NeurekaTestConf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def check_valid_kernel_shape(cls, v: KernelShape) -> KernelShape:
@field_validator("stride")
@classmethod
def check_valid_stride(cls, v: Stride) -> Stride:
assert v == Stride(height=1, width=1) or v == Stride(
height=2, width=2
), f"Unsupported stride {v}. Supported 1x1 and 2x2."
assert v == Stride(height=1, width=1), f"Unsupported stride {v}. Supported 1x1."
return v

@staticmethod
Expand Down Expand Up @@ -81,13 +79,6 @@ def check_valid_bias_type(cls, v: Optional[IntegerType]) -> Optional[IntegerType
NeurekaTestConf._check_type("bias_type", v, ["int32"])
return v

@model_validator(mode="after") # type: ignore
def check_valid_out_channel_with_stride_2x2(self) -> NeurekaTestConf:
assert implies(
self.stride == Stride(height=2, width=2), self.out_channel % 2 == 0
), f"With stride 2x2 supported only even output channel sizes. Given output channel {self.out_channel}"
return self

@model_validator(mode="after") # type: ignore
def check_valid_depthwise(self) -> NeurekaTestConf:
assert implies(
Expand Down
14 changes: 6 additions & 8 deletions test/app/src/nnx_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ typedef neureka_siracusa_conf_t nnx_bsp_conf_t;

#define nnx_task_init neureka_task_init
#define nnx_task_set_dims neureka_task_set_dims
#define nnx_task_set_dims_stride2x2 neureka_task_set_dims_stride2x2
#define nnx_task_set_ptrs neureka_task_set_ptrs

#define NNX_GVSOC_LOG_LEVEL NEUREKA_GVSOC_LOG_LEVEL_ALL
Expand All @@ -88,7 +87,6 @@ typedef neureka_siracusa_conf_t nnx_bsp_conf_t;

#define nnx_init neureka_nnx_init
#define nnx_dispatch_wait neureka_nnx_dispatch_wait
#define nnx_dispatch_stride2x2 neureka_nnx_dispatch_stride2x2
#define nnx_dispatch neureka_nnx_dispatch
#define nnx_resolve_wait neureka_nnx_resolve_wait
#define nnx_term neureka_nnx_term
Expand Down Expand Up @@ -117,18 +115,18 @@ static void task_prepare(nnx_task_t *task) {
.flag_shift = nnxTaskFlagFalse},
STRIDE_HEIGHT);

if (STRIDE_WIDTH == 2 && STRIDE_HEIGHT == 2) {
#if STRIDE_HEIGHT == 2 && STRIDE_WIDTH == 2
nnx_task_set_dims_stride2x2(
task, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNEL, INPUT_WIDTH,
INPUT_CHANNEL, OUTPUT_HEIGHT, OUTPUT_WIDTH, OUTPUT_CHANNEL,
OUTPUT_WIDTH, OUTPUT_CHANNEL, WEIGHT_HEIGHT, WEIGHT_WIDTH, PADDING_TOP,
PADDING_BOTTOM, PADDING_RIGHT, PADDING_LEFT);
} else {
#else
nnx_task_set_dims(task, INPUT_WIDTH, INPUT_CHANNEL, INPUT_WIDTH,
INPUT_CHANNEL, OUTPUT_HEIGHT, OUTPUT_WIDTH,
OUTPUT_CHANNEL, OUTPUT_WIDTH, OUTPUT_CHANNEL, PADDING_TOP,
PADDING_BOTTOM, PADDING_RIGHT, PADDING_LEFT);
}
#endif

nnx_task_set_ptrs(task, (uint32_t)input, INPUT_WIDTH, INPUT_CHANNEL,
INPUT_BITS, PADDING_TOP, PADDING_LEFT, (uint32_t)output,
Expand All @@ -151,14 +149,14 @@ static void task_execute(nnx_task_t *task) {

nnx_dispatch_wait(dev);

if (STRIDE_WIDTH == 2 && STRIDE_HEIGHT == 2) {
#if STRIDE_HEIGHT == 2 && STRIDE_WIDTH == 2
nnx_dispatch_stride2x2(dev, task, INPUT_WIDTH, INPUT_CHANNEL, INPUT_WIDTH,
INPUT_CHANNEL, OUTPUT_HEIGHT, OUTPUT_WIDTH,
OUTPUT_CHANNEL, OUTPUT_WIDTH, OUTPUT_CHANNEL,
WEIGHT_HEIGHT, WEIGHT_WIDTH);
} else {
#else
nnx_dispatch(dev, task);
}
#endif

nnx_resolve_wait(dev, task);

Expand Down

0 comments on commit 27664cd

Please sign in to comment.