Skip to content

Commit

Permalink
[software] Fix f16 cfft
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Jan 8, 2024
1 parent aa0f848 commit f0871b7
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 189 deletions.
63 changes: 31 additions & 32 deletions software/apps/cfft_radix4_f16/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@
#define ASM

#if !(defined(N_FFT_ROW) && defined(N_FFTs_COL))
#define N_FFTs_ROW 2
#define N_FFTs_COL 2
#define N_FFTs_ROW 1
#define N_FFTs_COL 1
#endif

#include "kernel/mempool_checks.h"
#include "kernel/mempool_radix4_cfft_butterfly_f16.h"
#include "kernel/mempool_radix4_cfft_f16p.h"
#include "kernel/mempool_radix4_cfft_q16_bitreversal.h"
Expand Down Expand Up @@ -73,19 +74,27 @@ int main() {
// Each FFT is folded over 4 memory rows
// Each memory row is 2 * N_BANKS samples
for (uint32_t j = 0; j < N_FFTs_ROW; j++) {
dma_memcpy_blocking(l1_pSrc + j * (8 * N_BANKS), l2_pSrc, (N_RSAMPLES * N_FFTs_COL) * sizeof(int32_t));
dma_memcpy_blocking(l1_pSrc + j * (8 * N_BANKS), l2_pSrc,
(N_RSAMPLES * N_FFTs_COL) * sizeof(int32_t));
}
dma_memcpy_blocking(l1_BitRevIndexTable, l2_BitRevIndexTable, BITREVINDEXTABLE_LENGTH * sizeof(int16_t));
dma_memcpy_blocking(l1_twiddleCoef_f16_src, l2_twiddleCoef_f16, 3 * (N_CSAMPLES / 4) * sizeof(int32_t));
dma_memcpy_blocking(l1_BitRevIndexTable, l2_BitRevIndexTable,
BITREVINDEXTABLE_LENGTH * sizeof(int16_t));
dma_memcpy_blocking(l1_twiddleCoef_f16_src, l2_twiddleCoef_f16,
3 * (N_CSAMPLES / 4) * sizeof(int32_t));
}
// Initialize the Twiddles folded
#ifdef FOLDED_TWIDDLES
for (uint32_t j = 0; j < N_FFTs_COL; j++) {
uint32_t N_WORDS_COL = (N_CSAMPLES / 4);
for (uint32_t i = core_id; i < N_WORDS_COL; i += num_cores) {
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4))] = *(v2h *)&l2_twiddleCoef_f16[2U * i];
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4) + 1 * N_BANKS)] = *(v2h *)&l2_twiddleCoef_f16[2U * (i * 2U)];
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4) + 2 * N_BANKS)] = *(v2h *)&l2_twiddleCoef_f16[2U * (i * 3U)];
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4))] =
*(v2h *)&l2_twiddleCoef_f16[2U * i];
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4) +
1 * N_BANKS)] =
*(v2h *)&l2_twiddleCoef_f16[2U * (i * 2U)];
*(v2h *)&l1_twiddleCoef_f16_src[2U * (i + j * (N_CSAMPLES / 4) +
2 * N_BANKS)] =
*(v2h *)&l2_twiddleCoef_f16[2U * (i * 3U)];
}
}
#endif
Expand All @@ -99,16 +108,20 @@ int main() {
///////////////////////////////////////////////////////////////////////////////////////////////////
/* MULTI-CORE FOLDED */
#ifdef FOLDED
__fp16 *pRes;
__fp16 *pRes = NULL;
if (core_id < (N_CSAMPLES / 16)) {
mempool_start_benchmark();
#ifdef FOLDED_TWIDDLES
mempool_radix4_cfft_f16p_folded(l1_pSrc, l1_pDst, (uint16_t)N_CSAMPLES, l1_twiddleCoef_f16_src, l1_twiddleCoef_f16_dst, (N_CSAMPLES / 16));
mempool_radix4_cfft_f16p_folded(l1_pSrc, l1_pDst, (uint16_t)N_CSAMPLES,
l1_twiddleCoef_f16_src,
l1_twiddleCoef_f16_dst, (N_CSAMPLES / 16));
#else
mempool_radix4_cfft_f16p_folded(l1_pSrc, l1_pDst, (uint16_t)N_CSAMPLES, l1_twiddleCoef_f16_src, (N_CSAMPLES / 16));
mempool_radix4_cfft_f16p_folded(l1_pSrc, l1_pDst, (uint16_t)N_CSAMPLES,
l1_twiddleCoef_f16_src, (N_CSAMPLES / 16));
#endif
pRes = ((LOG2 / 2) % 2) == 0 ? l1_pSrc : l1_pDst;
mempool_bitrevtable_q16p_xpulpimg((uint16_t *)pRes, BITREVINDEXTABLE_LENGTH, l1_BitRevIndexTable, (N_CSAMPLES / 16));
mempool_bitrevtable_q16p_xpulpimg((uint16_t *)pRes, BITREVINDEXTABLE_LENGTH,
l1_BitRevIndexTable, (N_CSAMPLES / 16));
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
Expand All @@ -117,14 +130,17 @@ int main() {
///////////////////////////////////////////////////////////////////////////////////////////////////
/* MULTI-CORE SCHEDULED */
#ifdef SCHEDULED
__fp16 *pRes;
__fp16 *pRes = NULL;
if (core_id < N_FFTs_COL * (N_CSAMPLES / 16)) {
mempool_start_benchmark();
uint32_t col_fftLen = N_CSAMPLES / 4;
uint32_t col_id = core_id / (N_CSAMPLES / 16);
// Distribute FFTs over columns
if (col_id < N_FFTs_COL) {
mempool_radix4_cfft_f16p_scheduler(l1_pSrc, l1_pDst, N_CSAMPLES, l1_pCoef_src + 2 * col_id * col_fftLen, l1_pCoef_dst + 2 * col_id * col_fftLen, l1_pRevT16, BITREVINDEXTABLE_FIXED_TABLE_LENGTH, 1, (N_CSAMPLES / 16));
mempool_radix4_cfft_f16p_scheduler(
l1_pSrc, l1_pDst, N_CSAMPLES, l1_pCoef_src + 2 * col_id * col_fftLen,
l1_pCoef_dst + 2 * col_id * col_fftLen, l1_pRevT16,
BITREVINDEXTABLE_FIXED_TABLE_LENGTH, 1, (N_CSAMPLES / 16));
}
pRes = ((LOG2 / 2) % 2) == 0 ? l1_pSrc : l1_pDst;
mempool_log_partial_barrier(2, core_id, N_FFTs_COL * (N_CSAMPLES / 16));
Expand All @@ -135,24 +151,7 @@ int main() {

///////////////////////////////////////////////////////////////////////////////////////////////////
/* CHECK */
if (core_id == 0) {
printf("Done!\n");
for (uint32_t i = 0; i < 2 * N_CSAMPLES; i++) {
__fp16 exp = l2_pRes[i];
__fp16 res = pRes[i];
__fp16 dif;
float tol = (__fp16)0.05f;
float dif_f32;
asm volatile("fsub.h %[dif], %[res], %[exp];"
"fcvt.h.s %[dif_f32], %[dif];"
: [dif] "+&r"(dif), [dif_f32] "+&r"(dif_f32)
: [res] "r"(res), [exp] "r"(exp)
:);
if ((dif_f32 > tol) || (dif_f32 < (-tol))) {
printf("%d %x %x\n", i, *(int32_t *)&exp, *(int32_t *)&res);
}
}
}
mempool_check_f16(pRes, l2_pRes, 2 * N_CSAMPLES, 0.5f, 0);
mempool_barrier(num_cores);

return 0;
Expand Down
4 changes: 2 additions & 2 deletions software/apps/cfft_radix4_q16/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
- ASM: Use asm_volatile statements
*/

#define SCHEDULED
#define FOLDED
#define FOLDED_TWIDDLES
#define BITREVERSETABLE
#define ASM // Use asm_volatile statements
Expand Down Expand Up @@ -153,7 +153,7 @@ int main() {
#endif
pRes = ((LOG2 / 2) % 2) == 0 ? l1_pSrc : l1_pDst;
mempool_bitrevtable_q16p_xpulpimg((uint16_t *)pRes, BITREVINDEXTABLE_LENGTH,
pRevT16, (N_CSAMPLES / 16));
l1_BitRevIndexTable, (N_CSAMPLES / 16));
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
Expand Down
2 changes: 1 addition & 1 deletion software/runtime/data/data_cfft_radix4_f16.h.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
i = 0
out += '\n'
for a in array:
out += '(__fp16){:0.5}f, '.format(a)
out += '(__fp16){:0.4}f, '.format(a)
i += 1
if i % 8 == 0:
out += '\n'
Expand Down
11 changes: 6 additions & 5 deletions software/runtime/data/data_cfft_radix4_f16.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,17 @@ def main():
args = parser.parse_args()
Len = args.dimension

src = np.random.rand(Len) + 1.j * np.random.rand(Len)
src = np.random.rand(Len).astype(np.float16)
src = src + 1.j * np.random.rand(Len).astype(np.float16)
dst = np.fft.fft(src)
src = np.column_stack((src.real, src.imag)).astype(np.float16).flatten()
dst = np.column_stack((dst.real, dst.imag)).astype(np.float16).flatten()
src = np.column_stack((src.imag, src.real)).astype(np.float16).flatten()
dst = np.column_stack((dst.imag, dst.real)).astype(np.float16).flatten()
Bitreversal = np.ndarray.flatten(np.array(compute_bitreversal(Len, 2)))

twi = np.zeros(int(2 * 3 * Len / 4), np.float16)
for i in range(0, int(3 * Len / 4)):
twi[2 * i] = np.cos(i * 2 * np.pi / Len).astype(np.float16)
twi[2 * i + 1] = np.sin(i * 2 * np.pi / Len).astype(np.float16)
twi[2 * i] = np.sin(i * 2 * np.pi / Len).astype(np.float16)
twi[2 * i + 1] = np.cos(i * 2 * np.pi / Len).astype(np.float16)

kwargs = {'name': 'data_cfft_radix4_f16',
'src': src,
Expand Down
Loading

0 comments on commit f0871b7

Please sign in to comment.