Skip to content

Commit

Permalink
[software] Fix f16 cfft
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Jul 5, 2024
1 parent 4d60166 commit d3f5650
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 91 deletions.
89 changes: 31 additions & 58 deletions software/apps/baremetal/cfft_radix4_f16/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,18 @@
#include "data/data_cfft_radix4_f16.h"

/*
CHOOSE ONE
- SINGLE: Single core FFT
- PARALLEL: Parallel FFT not "memory-aware"
- FOLDED: Parallel FFT with "memory-aware" load/store scheme
- SCHEDULED: Scheduling of multiple parallel FFTs with "memory-aware"
load/store scheme
load/store scheme
- N_FFTs_COL: Independent FFTs scheduled on one row (default 1)
- N_FFTs_ROW: Independent FFTs scheduled on columns (default 1)
(OPTIONALLY ENABLE)
- FOLDED_TWIDDLES: Also the twiddles have "memory-aware" load/stores
- BITREVERSETABLE: The bitreversal indeces are loaded from a table
- ASM: Use asm_volatile statements
*/

#define FOLDED
#define FOLDED_TWIDDLES
#define BITREVERSETABLE
#define ASM

#if !(defined(N_FFT_ROW) && defined(N_FFTs_COL))
#define N_FFTs_ROW 1
#define N_FFTs_COL 1
#endif

#include "kernel/mempool_checks.h"
#include "kernel/mempool_radix4_cfft_butterfly_f16.h"
Expand All @@ -61,98 +50,82 @@ __fp16 l1_twiddleCoef_f16_dst[8 * N_BANKS]
uint16_t l1_BitRevIndexTable[BITREVINDEXTABLE_LENGTH]
__attribute__((aligned(4 * N_BANKS), section(".l1_prio")));

///////////////////////////////////////////////////////////////////////////////////////////////////
/* MAIN */
int main() {

uint32_t core_id = mempool_get_core_id();
uint32_t num_cores = mempool_get_core_count();
mempool_barrier_init(core_id);
__fp16 *pRes = (__fp16 *)0;

///////////////////////////////////////////////////////////////////////////////////////////////////
/* INITIALIZATION */
if (core_id == 0) {
// Each FFT is folded over 4 memory rows
// Each memory row is 2 * N_BANKS samples
// Each memory row is 2 * N_BANKS (real-imag) samples
for (uint32_t j = 0; j < N_FFTs_ROW; j++) {
dma_memcpy_blocking(l1_pSrc + j * (8 * N_BANKS), l2_pSrc,
(N_RSAMPLES * N_FFTs_COL) * sizeof(int32_t));
(N_CSAMPLES * N_FFTs_COL) * sizeof(int32_t));
}
dma_memcpy_blocking(l1_pSrc, l2_pSrc, N_CSAMPLES * sizeof(int32_t));
dma_memcpy_blocking(l1_BitRevIndexTable, l2_BitRevIndexTable,
BITREVINDEXTABLE_LENGTH * sizeof(int16_t));
dma_memcpy_blocking(l1_twiddleCoef_f16_src, l2_twiddleCoef_f16,
3 * (N_CSAMPLES / 4) * sizeof(int32_t));
}
// Initialize the Twiddles folded
// Initialize the Twiddles folded
#ifdef FOLDED_TWIDDLES
for (uint32_t j = 0; j < N_FFTs_COL; j++) {
uint32_t N_WORDS_COL = (N_CSAMPLES / 4);
for (uint32_t i = core_id; i < N_WORDS_COL; i += num_cores) {
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4))] =
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * N_WORDS_COL)] =
*(v2h *)&l2_twiddleCoef_f16[2U * i];
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4) +
1 * N_BANKS)] =
*(v2h *)&l1_twiddleCoef_f16_src[2U *
(i + j * N_WORDS_COL + 1 * N_BANKS)] =
*(v2h *)&l2_twiddleCoef_f16[2U * (i * 2U)];
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4) +
2 * N_BANKS)] =
*(v2h *)&l1_twiddleCoef_f16_src[2U *
(i + j * N_WORDS_COL + 2 * N_BANKS)] =
*(v2h *)&l2_twiddleCoef_f16[2U * (i * 3U)];
}
}
#endif
mempool_barrier(num_cores);

if (core_id == 0) {
printf("On the run...\n");
printf("01: END INITIALIZATION\n");
}
mempool_barrier(num_cores);

///////////////////////////////////////////////////////////////////////////////////////////////////
/* MULTI-CORE FOLDED */
#ifdef FOLDED
__fp16 *pRes = NULL;
#if (defined(FOLDED) && defined(FOLDED_TWIDDLES))
if (core_id < (N_CSAMPLES / 16)) {
mempool_start_benchmark();
#ifdef FOLDED_TWIDDLES
mempool_radix4_cfft_f16p_folded(l1_pSrc, l1_pDst, (uint16_t)N_CSAMPLES,
mempool_radix4_cfft_f16p_folded(l1_pSrc, l1_pDst, N_CSAMPLES,
l1_twiddleCoef_f16_src,
l1_twiddleCoef_f16_dst, (N_CSAMPLES / 16));
#else
mempool_radix4_cfft_f16p_folded(l1_pSrc, l1_pDst, (uint16_t)N_CSAMPLES,
l1_twiddleCoef_f16_src, (N_CSAMPLES / 16));
#endif
pRes = ((LOG2 / 2) % 2) == 0 ? l1_pSrc : l1_pDst;
mempool_bitrevtable_q16p_xpulpimg((uint16_t *)pRes, BITREVINDEXTABLE_LENGTH,
l1_BitRevIndexTable, (N_CSAMPLES / 16));
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
#endif

///////////////////////////////////////////////////////////////////////////////////////////////////
/* MULTI-CORE SCHEDULED */
#ifdef SCHEDULED
__fp16 *pRes = NULL;
if (core_id < N_FFTs_COL * (N_CSAMPLES / 16)) {
uint32_t nPE = (N_CSAMPLES / 16);
if (core_id < N_FFTs_COL * nPE) {
mempool_start_benchmark();
uint32_t col_fftLen = N_CSAMPLES / 4;
uint32_t col_id = core_id / (N_CSAMPLES / 16);
// Distribute FFTs over columns
if (col_id < N_FFTs_COL) {
mempool_radix4_cfft_f16p_scheduler(
l1_pSrc, l1_pDst, N_CSAMPLES, l1_pCoef_src + 2 * col_id * col_fftLen,
l1_pCoef_dst + 2 * col_id * col_fftLen, l1_pRevT16,
BITREVINDEXTABLE_FIXED_TABLE_LENGTH, 1, (N_CSAMPLES / 16));
}
pRes = ((LOG2 / 2) % 2) == 0 ? l1_pSrc : l1_pDst;
mempool_log_partial_barrier(2, core_id, N_FFTs_COL * (N_CSAMPLES / 16));
uint32_t N_WORDS_COL = N_CSAMPLES / 4;
uint32_t col_id = core_id / nPE;
mempool_radix4_cfft_f16p_scheduler(
l1_pSrc, l1_pDst, N_CSAMPLES, N_FFTs_ROW, N_FFTs_COL,
l1_twiddleCoef_f16_src + 2 * col_id * N_WORDS_COL,
l1_twiddleCoef_f16_dst + 2 * col_id * N_WORDS_COL, l1_BitRevIndexTable,
BITREVINDEXTABLE_LENGTH, 1, nPE);
pRes = l1_pDst;
mempool_log_partial_barrier(2, core_id, N_FFTs_COL * nPE);
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
#endif

///////////////////////////////////////////////////////////////////////////////////////////////////
/* CHECK */
mempool_check_f16(pRes, l2_pRes, 2 * N_CSAMPLES, 0.5f, 0);
mempool_barrier(num_cores);

if (core_id == 0) {
printf("02: END COMPUTATION\n");
}
mempool_check_f16(pRes, l2_pRes, 2 * N_CSAMPLES, 0.5, 0);
mempool_barrier(num_cores);
return 0;
}
Empty file.
2 changes: 0 additions & 2 deletions software/runtime/data/data_cfft_f16.h.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

#define LOG2 (${Log2Len})
#define N_CSAMPLES (${Len})
#define N_RSAMPLES (2 * N_CSAMPLES)
#define N_TWIDDLES (3 * N_CSAMPLES / 4)
#define N_BANKS (NUM_CORES * BANKING_FACTOR)
#define BITREVINDEXTABLE_LENGTH (${BitrevLen})

Expand Down
2 changes: 0 additions & 2 deletions software/runtime/data/data_cfft_q16.h.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
%> \
#define LOG2 (${Log2Len})
#define N_CSAMPLES (${Len})
#define N_RSAMPLES (2 * N_CSAMPLES)
#define N_TWIDDLES (3 * N_CSAMPLES / 4)
#define N_BANKS (NUM_CORES * BANKING_FACTOR)
#define BITREVINDEXTABLE_LENGTH (${BitrevLen})

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// Author: Marco Bertuletti, ETH Zurich

#pragma once
#include "xpulp/builtins_v2.h"

/**
Expand Down
43 changes: 14 additions & 29 deletions software/runtime/kernel/mempool_radix4_cfft_f16p.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

// Author: Marco Bertuletti, ETH Zurich

#pragma once
#define BITREVERSETABLE
#include "xpulp/builtins_v2.h"
#define MIN(x, y) (((x) < (y)) ? (x) : (y))

Expand Down Expand Up @@ -58,7 +60,6 @@
CoSi3 = *(v2h *)&pCoef_src[2U * (ic * 3U)];
#endif

#ifdef FOLDED_TWIDDLES
/**
@brief Full FFT butterfly
@param[in] pSrc16 points to input buffer of 16b data, Re and Im parts are
Expand All @@ -73,25 +74,8 @@
*/
void mempool_radix4_cfft_f16p_folded(__fp16 *pSrc16, __fp16 *pDst16,
uint32_t fftLen, __fp16 *pCoef_src,
__fp16 *pCoef_dst, uint32_t nPE)
#else
/**
Twiddles are not folded in memory
@brief Full FFT butterfly
@param[in] pSrc16 points to input buffer of 16b data, Re and Im parts are
interleaved
@param[out] pDst16 points to output buffer of 16b data, Re and Im parts
are interleaved
@param[in] fftLen Length of the complex input vector
@param[in] pCoef_src Twiddle coefficients vector
@param[in] nPE Number of PE
@return pointer to output vector
*/
void mempool_radix4_cfft_f16p_folded(__fp16 *pSrc16, __fp16 *pDst16,
uint32_t fftLen, __fp16 *pCoef_src,
uint32_t nPE)
#endif
{
__fp16 __attribute__((unused)) * pCoef_dst,
uint32_t nPE) {

uint32_t absolute_core_id = mempool_get_core_id();
uint32_t core_id = absolute_core_id;
Expand Down Expand Up @@ -218,8 +202,9 @@ void mempool_radix4_cfft_f16p_folded(__fp16 *pSrc16, __fp16 *pDst16,
*/

void mempool_radix4_cfft_f16p_scheduler(
__fp16 *pSrc16, __fp16 *pDst16, uint32_t fftLen, __fp16 *pCoef_src,
__fp16 *pCoef_dst, __attribute__((unused)) uint16_t *pBitRevTable,
__fp16 *pSrc16, __fp16 *pDst16, uint32_t fftLen, uint32_t n_FFTs_ROW,
uint32_t n_FFTs_COL, __fp16 *pCoef_src, __fp16 *pCoef_dst,
__attribute__((unused)) uint16_t *pBitRevTable,
__attribute__((unused)) uint16_t bitReverseLen, uint8_t bitReverseFlag,
uint32_t nPE) {

Expand Down Expand Up @@ -251,7 +236,7 @@ void mempool_radix4_cfft_f16p_scheduler(
#endif
LOAD_STORE_TWIDDLEFACT;
SHUFFLE_TWIDDLEFACT;
for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
for (uint32_t idx_row = 0; idx_row < n_FFTs_ROW; idx_row++) {
__fp16 *pIn = pSrc16 + idx_row * (N_BANKS * 8) + 2 * col_id * fftLen;
__fp16 *pOut =
pDst16 + idx_row * (N_BANKS * 8) + 2 * col_id * (fftLen / 4);
Expand Down Expand Up @@ -282,7 +267,7 @@ void mempool_radix4_cfft_f16p_scheduler(
LOAD_STORE_TWIDDLEFACT;
SHUFFLE_TWIDDLEFACT;

for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
for (uint32_t idx_row = 0; idx_row < n_FFTs_ROW; idx_row++) {
__fp16 *pIn =
pSrc16 + idx_row * (N_BANKS * 8) + 2 * col_id * (fftLen / 4);
__fp16 *pOut =
Expand All @@ -297,12 +282,12 @@ void mempool_radix4_cfft_f16p_scheduler(
pTmp = pCoef_src;
pCoef_src = pCoef_dst;
pCoef_dst = pTmp;
mempool_log_partial_barrier(2, absolute_core_id, N_FFTs_COL * nPE);
mempool_log_partial_barrier(2, absolute_core_id, n_FFTs_COL * nPE);
}

/* LAST STAGE */
for (i0 = core_id * 4; i0 < MIN(core_id * 4 + 4, fftLen >> 2U); i0++) {
for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
for (uint32_t idx_row = 0; idx_row < n_FFTs_ROW; idx_row++) {
__fp16 *pIn =
pSrc16 + idx_row * (N_BANKS * 8) + 2 * col_id * (fftLen / 4);
__fp16 *pOut =
Expand All @@ -313,7 +298,7 @@ void mempool_radix4_cfft_f16p_scheduler(
pTmp = pSrc16;
pSrc16 = pDst16;
pDst16 = pTmp;
mempool_log_partial_barrier(2, absolute_core_id, N_FFTs_COL * nPE);
mempool_log_partial_barrier(2, absolute_core_id, n_FFTs_COL * nPE);
mempool_stop_benchmark();
mempool_start_benchmark();
/* BITREVERSAL */
Expand Down Expand Up @@ -362,7 +347,7 @@ void mempool_radix4_cfft_f16p_scheduler(
b2_load = (b2 % 4) * 2 * N_BANKS + 2 * (b2 / 4);
b3_load = (b3 % 4) * 2 * N_BANKS + 2 * (b3 / 4);
b4_load = (b4 % 4) * 2 * N_BANKS + 2 * (b4 / 4);
for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
for (uint32_t idx_row = 0; idx_row < n_FFTs_ROW; idx_row++) {
uint16_t *ptr1 = (uint16_t *)(pSrc16 + idx_row * (N_BANKS * 8));
uint16_t *ptr2 = (uint16_t *)(pDst16 + idx_row * (N_BANKS * 8));
// Load at address a
Expand Down Expand Up @@ -410,7 +395,7 @@ void mempool_radix4_cfft_f16p_scheduler(
idx2 = idx2 >> 1U;
idx3 = idx3 >> 1U;
}
for (uint32_t idx_row = 0; idx_row < N_FFTs_ROW; idx_row++) {
for (uint32_t idx_row = 0; idx_row < n_FFTs_ROW; idx_row++) {
uint32_t addr_src0 = (idx0 / 4) + (idx0 % 4) * N_BANKS;
uint32_t addr_src1 = (idx1 / 4) + (idx1 % 4) * N_BANKS;
uint32_t addr_src2 = (idx2 / 4) + (idx2 % 4) * N_BANKS;
Expand Down

0 comments on commit d3f5650

Please sign in to comment.