Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLAS compatibility library #7

Open
wants to merge 74 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
7cbf4de
BLAS library WIP
ChrisPattison Nov 14, 2021
2d72035
Empty matrix is a noop
ChrisPattison Nov 26, 2021
08461cc
put mpf_t in one place
ChrisPattison Nov 26, 2021
b531adc
PackFloat ToMpfr and ToGmp now const
ChrisPattison Nov 26, 2021
a59dd65
mpf_t/mpfr_t wrapper type
ChrisPattison Nov 26, 2021
d11fc10
unsigned long for precision
ChrisPattison Nov 26, 2021
0745071
overload of ApfpInterfaceType constructor with precision specified
ChrisPattison Nov 26, 2021
1da49df
ApfpInterfaceWrapper move semantics
ChrisPattison Nov 27, 2021
dd0db45
Fix memory leaks in BLAS library
ChrisPattison Nov 27, 2021
6ecea7a
Merge commit 'b3c3232369122bda9a551eb6777cc90d7721124f' into blas
ChrisPattison Dec 14, 2021
2dd8b21
Matrix Addition dummy
ChrisPattison Dec 14, 2021
e4165cb
mpf_t |-> mpf_ptr in PackedFloat
ChrisPattison Dec 14, 2021
7da55be
const ToGmp
ChrisPattison Dec 14, 2021
aeb9ce1
Hostlib takes mpf_ptr. Host transpose/add syrk
ChrisPattison Dec 14, 2021
4c499ae
Merge commit 'cd2be5046e33205e11bf54814d96263f8e66efed' into blas
ChrisPattison Dec 18, 2021
bfcacd9
MPFR BLAS interface
ChrisPattison Dec 19, 2021
b50c80e
Add unsigned long init and mul to wrapper header
ChrisPattison Dec 20, 2021
4947c85
Generate takes mpfr_ptr
ChrisPattison Dec 20, 2021
d96de48
BLAS syrk unit test
ChrisPattison Dec 20, 2021
4eb4881
Blas unit tests in separate executable
ChrisPattison Dec 21, 2021
c49e272
Search for kernel in current working directory
ChrisPattison Dec 22, 2021
bdd9f35
Throw an exception if we can't find the kernel
ChrisPattison Dec 26, 2021
adaba04
Guard against calling unitialized library
ChrisPattison Dec 26, 2021
13aa0e9
Add mechanism to get ApfpBlas error strings
ChrisPattison Dec 26, 2021
e2c32d8
Guard error code for ApfpInit in UnitTests
ChrisPattison Dec 26, 2021
41f75a8
More sophisticated kernel search routine
ChrisPattison Dec 26, 2021
4ab5d11
Setup/teardown test case
ChrisPattison Dec 26, 2021
3c37a15
Fix buffer size check on TransferToHost
ChrisPattison Dec 27, 2021
98f4721
CopyTransposeFromMatrix destination LDA
ChrisPattison Dec 27, 2021
ee113ba
Blas unit tests pass
ChrisPattison Dec 27, 2021
1d53578
Move interface type <gmp/mpfr> to Config.h
ChrisPattison Dec 28, 2021
c6a86a7
install kernels to lib
ChrisPattison Dec 28, 2021
4c22885
Compile under GMP interface type
ChrisPattison Dec 28, 2021
2aa28f5
Fix closeness check in BlasUnitTest for a=b=0
ChrisPattison Dec 28, 2021
47022fd
Use generators for SYRK test case
ChrisPattison Dec 28, 2021
8c8e2c2
Add config.h to install dirs
ChrisPattison Dec 28, 2021
b1449e1
Support 'T' argument in syrk
ChrisPattison Dec 28, 2021
572a4ca
Check upper/lower Syrk mode
ChrisPattison Dec 28, 2021
102c818
Fix MPFR wrapper argument order
ChrisPattison Dec 29, 2021
d22a4ee
Merge branch 'main' into blas
ChrisPattison Dec 29, 2021
b223108
Remove mystery character in CMakeLists.txt
ChrisPattison Dec 29, 2021
71a4cf8
Fix LD_LIBRARY_PATH search for FPGA kernel
ChrisPattison Dec 30, 2021
14274fc
Marginally more helpful error handling
ChrisPattison Dec 30, 2021
0b435be
GMP allows aliasing inputs
ChrisPattison Dec 30, 2021
1f43348
Install hw emu kernel
ChrisPattison Dec 30, 2021
9856042
Do SYRK addition on the FPGA
ChrisPattison Dec 30, 2021
1e8cffe
Move Apfp lib into namespace
ChrisPattison Dec 30, 2021
c1fdfa9
Make the interface type wrapping nicer
ChrisPattison Dec 30, 2021
0677b71
Rename ErrorDescription
ChrisPattison Dec 30, 2021
03859ee
Enum class Uplo/Trans
ChrisPattison Dec 30, 2021
b1a768d
Formatting because I keep forgetting
ChrisPattison Dec 30, 2021
78a4594
apfpHostlib naming convention
ChrisPattison Dec 31, 2021
119185c
Switch kernel to column major ordering
ChrisPattison Jan 3, 2022
8338d02
Remove extremely large volume simulation test cases
ChrisPattison Jan 3, 2022
1f8cd57
Merge branch 'main' into blas
ChrisPattison Jan 3, 2022
137d11d
ApfpIsInitialized |-> IsInitialized
ChrisPattison Jan 4, 2022
1cb5c90
Merge branch 'main' into col_major
ChrisPattison Jan 6, 2022
98434fc
Scale back directory search for kernel
ChrisPattison Jan 6, 2022
4e23918
Missing function renames
ChrisPattison Jan 6, 2022
7f1fa90
Add cwd to kernel search path
ChrisPattison Jan 6, 2022
6fd095e
Set INTERFACE_TYPE to SEMANTICS
ChrisPattison Jan 6, 2022
cacf283
Merge branch 'main' into blas
ChrisPattison Jan 6, 2022
9d1173c
BlasError is scoped enum
ChrisPattison Jan 6, 2022
9dbcb97
Merge branch 'col_major' into blas
ChrisPattison Jan 6, 2022
d0ab102
class Apfp -> Context
ChrisPattison Jan 6, 2022
155d0b8
Use RNDZ MPFR rounding mode everywhere
ChrisPattison Jan 6, 2022
55c9fe8
Throw KernelNotFoundException if APFP_KERNEL misset
ChrisPattison Jan 6, 2022
4e5c34d
Add comment about memory layout
ChrisPattison Jan 7, 2022
8b22b71
More descriptive SYRK unit tests
ChrisPattison Jan 8, 2022
ad5d63f
Add GEMM
ChrisPattison Jan 8, 2022
40a61a6
Missing syrk test case
ChrisPattison Jan 8, 2022
66ecc7e
Fix M and N in GEMM
ChrisPattison Jan 8, 2022
f9c7c3c
GEMM unit tests
ChrisPattison Jan 8, 2022
1685fd2
Go fast and break things - just not the unit tests!
ChrisPattison Jan 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ set(APFP_PROFILING OFF CACHE BOOL "Enable profiling in generated kernels.")
set(APFP_SAVE_TEMPS OFF CACHE BOOL "Save temporary files from kernel builds.")
set_property(CACHE APFP_SEMANTICS PROPERTY STRINGS GMP MPFR)

# One day we might accept both
set(APFP_INTERFACE_TYPE ${APFP_SEMANTICS})
# but not today
ChrisPattison marked this conversation as resolved.
Show resolved Hide resolved

# Validation and derived numbers
math(EXPR APFP_ALIGNED "${APFP_BITS} % 512")
if(NOT APFP_ALIGNED EQUAL 0)
Expand All @@ -30,7 +34,7 @@ find_package(GMP REQUIRED)
find_package(Threads REQUIRED)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wpedantic -Wno-unused-label -Wno-unknown-pragmas -Wno-class-memaccess -DAPFP_${APFP_SEMANTICS}_SEMANTICS")
include_directories(${CMAKE_BINARY_DIR} include SYSTEM hlslib/include ${Vitis_INCLUDE_DIRS} )
include_directories(${CMAKE_BINARY_DIR} include SYSTEM hlslib/include ${Vitis_INCLUDE_DIRS} interface)

configure_file(include/Config.h.in Config.h)

Expand All @@ -40,7 +44,7 @@ set(APFP_KERNEL_FILES device/MatrixMultiplication.cpp

# Setup FPGA kernel targets
add_vitis_kernel(MatrixMultiplication FILES ${APFP_KERNEL_FILES}
INCLUDE_DIRS include hlslib/include ${CMAKE_BINARY_DIR}
INCLUDE_DIRS include hlslib/include ${CMAKE_BINARY_DIR} ${GMP_INCLUDES}
HLS_FLAGS "-DAP_INT_MAX_W=${APFP_MAX_BITS} -DAPFP_${APFP_SEMANTICS}_SEMANTICS"
HLS_CONFIG "config_compile -pipeline_style frp\nconfig_dataflow -fifo_depth 16"
DEPENDS ${CMAKE_BINARY_DIR}/Config.h
Expand All @@ -66,9 +70,9 @@ add_library(simulation ${APFP_KERNEL_FILES})
target_compile_options(simulation PRIVATE -Wno-unknown-pragmas -DAP_INT_MAX_W=${APFP_MAX_BITS})
target_link_libraries(simulation ${CMAKE_THREAD_LIBS_INIT})

add_library(ApfpHostlib SHARED interface/Apfp.cpp)
target_link_libraries(ApfpHostlib ${Vitis_LIBRARIES} ${GMP_LIBRARIES})
target_compile_definitions(ApfpHostlib PRIVATE HLSLIB_SIMULATE_OPENCL)
add_library(apfpHostlib SHARED interface/Apfp.cpp interface/ApfpBlas.cpp interface/ApfpInterfaceType.cpp)
target_link_libraries(apfpHostlib ${Vitis_LIBRARIES} ${GMP_LIBRARIES})
target_compile_definitions(apfpHostlib PRIVATE HLSLIB_SIMULATE_OPENCL)

# Executable used to run in simulation mode, calling the kernel as a C++ function directly
add_executable(TestSimulation host/TestProgram.cpp)
Expand All @@ -84,7 +88,20 @@ enable_testing()
add_test(TestSimulation TestSimulation 4 4 4)
add_library(Catch host/Catch.cpp)
add_executable(UnitTests host/UnitTests.cpp)
target_link_libraries(UnitTests Catch ${GMP_LIBRARIES} ${MPFR_LIBRARIES} apfp simulation)
target_link_libraries(UnitTests Catch ${GMP_LIBRARIES} ${MPFR_LIBRARIES} apfp apfpHostlib simulation)
add_test(UnitTests UnitTests)

install(TARGETS ApfpHostlib)
add_executable(BlasUnitTests host/BlasUnitTests.cpp)
target_link_libraries(BlasUnitTests Catch ${GMP_LIBRARIES} ${MPFR_LIBRARIES} apfp apfpHostlib simulation)

install(TARGETS apfpHostlib)
install(FILES
interface/Apfp.h
interface/ApfpBlas.h
interface/ApfpInterfaceType.h
${CMAKE_BINARY_DIR}/Config.h
DESTINATION include/apfp)
install(FILES
${CMAKE_BINARY_DIR}/MatrixMultiplication_hw.xclbin
${CMAKE_BINARY_DIR}/MatrixMultiplication_hw_emu.xclbin
DESTINATION lib)
43 changes: 24 additions & 19 deletions device/MatrixMultiplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@

#include "ArithmeticOperations.h"

// All memory accesses are column-major!
// I.e. a(i,j) = a[i + LDA * j]
// AB = sum_k a(i,k) b(k, j) = sum_k a[i + LDA * k] * b[k + LDA * j]
// LDA (leading dimension of A) = stride

// Annoyingly we have to specialize the innermost loop on whether multiple DRAM flits per number are required or not,
// because HLS otherwise gets confused by pragmas applied to a loop of size 1 in the latter case.
template <int lines_per_number>
void ReadAInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &a_to_feeder, const int size_k, const int n0,
void ReadAInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &a_to_feeder, const int size_n, const int n0,
const int k) {
#pragma HLS INLINE
DramLine num[kLinesPerNumber];
Expand All @@ -19,7 +24,7 @@ void ReadAInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &a_to_fee
for (int i = 0; i < kLinesPerNumber; ++i) {
#pragma HLS PIPELINE II = 1
#pragma HLS LOOP_FLATTEN
num[i] = mem[((n0 * kTileSizeN + n1) * size_k + k) * kLinesPerNumber + i];
num[i] = mem[((n0 * kTileSizeN + n1) + k * size_n) * kLinesPerNumber + i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num[i] = mem[((n0 * kTileSizeN + n1) + k * size_n) * kLinesPerNumber + i];
num[i] = mem[(k * size_n + n0 * kTileSizeN + n1) * kLinesPerNumber + i];

if (i == kLinesPerNumber - 1) {
a_to_feeder.Push(PackedFloat(num));
}
Expand All @@ -28,15 +33,15 @@ void ReadAInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &a_to_fee
}

template <>
void ReadAInner<1>(DramLine const *const mem, hlslib::Stream<PackedFloat> &a_to_feeder, const int size_k, const int n0,
void ReadAInner<1>(DramLine const *const mem, hlslib::Stream<PackedFloat> &a_to_feeder, const int size_n, const int n0,
const int k) {
#pragma HLS INLINE
ReadA_N:
for (int n1 = 0; n1 < kTileSizeN; ++n1) {
#pragma HLS PIPELINE II = 1
#pragma HLS LOOP_FLATTEN
DramLine num[1];
num[0] = mem[(n0 * kTileSizeN + n1) * size_k + k];
num[0] = mem[((n0 * kTileSizeN + n1) + k * size_n) * kLinesPerNumber];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num[0] = mem[((n0 * kTileSizeN + n1) + k * size_n) * kLinesPerNumber];
num[0] = mem[k * size_n + n0 * kTileSizeN + n1];

a_to_feeder.Push(PackedFloat(num));
}
}
Expand All @@ -51,7 +56,7 @@ void ReadA(DramLine const *const mem, hlslib::Stream<PackedFloat> &a_to_feeder,
for (int m0 = 0; m0 < tiles_m; ++m0) {
ReadA_K:
for (int k = 0; k < size_k; ++k) {
ReadAInner<kLinesPerNumber>(mem, a_to_feeder, size_k, n0, k);
ReadAInner<kLinesPerNumber>(mem, a_to_feeder, size_n, n0, k);
}
}
}
Expand Down Expand Up @@ -90,7 +95,7 @@ void FeedA(hlslib::Stream<PackedFloat> &a_to_feeder, hlslib::Stream<PackedFloat>
////////////////////////////////////////////////////////////////////////////////

template <int lines_per_number>
void ReadBInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &b_to_feeder, const int size_m, const int m0,
void ReadBInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &b_to_feeder, const int size_k, const int m0,
const int k) {
#pragma HLS INLINE
DramLine num[kLinesPerNumber];
Expand All @@ -100,7 +105,7 @@ void ReadBInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &b_to_fee
for (int i = 0; i < kLinesPerNumber; ++i) {
#pragma HLS PIPELINE II = 1
#pragma HLS LOOP_FLATTEN
num[i] = mem[(k * size_m + m0 * kTileSizeM + m1) * kLinesPerNumber + i];
num[i] = mem[(k + (m0 * kTileSizeM + m1) * size_k) * kLinesPerNumber + i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num[i] = mem[(k + (m0 * kTileSizeM + m1) * size_k) * kLinesPerNumber + i];
num[i] = mem[((m0 * kTileSizeM + m1) * size_k + k) * kLinesPerNumber + i];

if (i == kLinesPerNumber - 1) {
b_to_feeder.Push(PackedFloat(num));
}
Expand All @@ -109,15 +114,15 @@ void ReadBInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &b_to_fee
}

template <>
void ReadBInner<1>(DramLine const *const mem, hlslib::Stream<PackedFloat> &b_to_feeder, const int size_m, const int m0,
void ReadBInner<1>(DramLine const *const mem, hlslib::Stream<PackedFloat> &b_to_feeder, const int size_k, const int m0,
const int k) {
#pragma HLS INLINE
ReadB_M:
for (int m1 = 0; m1 < kTileSizeM; ++m1) {
#pragma HLS PIPELINE II = 1
#pragma HLS LOOP_FLATTEN
DramLine num[1];
num[0] = mem[k * size_m + m0 * kTileSizeM + m1];
num[0] = mem[(k + (m0 * kTileSizeM + m1) * size_k) * kLinesPerNumber];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num[0] = mem[(k + (m0 * kTileSizeM + m1) * size_k) * kLinesPerNumber];
num[0] = mem[(m0 * kTileSizeM + m1) * size_k + k];

b_to_feeder.Push(PackedFloat(num));
}
}
Expand All @@ -132,7 +137,7 @@ void ReadB(DramLine const *const mem, hlslib::Stream<PackedFloat> &b_to_feeder,
for (int m0 = 0; m0 < tiles_m; ++m0) {
ReadB_K:
for (int k = 0; k < size_k; ++k) {
ReadBInner<kLinesPerNumber>(mem, b_to_feeder, size_m, m0, k);
ReadBInner<kLinesPerNumber>(mem, b_to_feeder, size_k, m0, k);
}
}
}
Expand Down Expand Up @@ -169,7 +174,7 @@ void FeedB(hlslib::Stream<PackedFloat> &b_to_feeder, hlslib::Stream<PackedFloat>
////////////////////////////////////////////////////////////////////////////////

template <int lines_per_number>
void ReadCInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &c_to_feeder, const int size_m, const int n0,
void ReadCInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &c_to_feeder, const int size_n, const int n0,
const int m0, const int n1) {
#pragma HLS INLINE
ReadC_M:
Expand All @@ -179,7 +184,7 @@ void ReadCInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &c_to_fee
for (int i = 0; i < kLinesPerNumber; ++i) {
#pragma HLS PIPELINE II = 1
#pragma HLS LOOP_FLATTEN
num[i] = mem[((n0 * kTileSizeN + n1) * size_m + m0 * kTileSizeM + m1) * kLinesPerNumber + i];
num[i] = mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber + i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num[i] = mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber + i];
num[i] = mem[((m0 * kTileSizeM + m1) * size_n + n0 * kTileSizeN + n1) * kLinesPerNumber + i];

if (i == kLinesPerNumber - 1) {
c_to_feeder.Push(PackedFloat(num));
}
Expand All @@ -188,15 +193,15 @@ void ReadCInner(DramLine const *const mem, hlslib::Stream<PackedFloat> &c_to_fee
}

template <>
void ReadCInner<1>(DramLine const *const mem, hlslib::Stream<PackedFloat> &c_to_feeder, const int size_m, const int n0,
void ReadCInner<1>(DramLine const *const mem, hlslib::Stream<PackedFloat> &c_to_feeder, const int size_n, const int n0,
const int m0, const int n1) {
#pragma HLS INLINE
ReadC_M:
for (int m1 = 0; m1 < kTileSizeM; ++m1) {
#pragma HLS PIPELINE II = 1
#pragma HLS LOOP_FLATTEN
DramLine num[1];
num[0] = mem[(n0 * kTileSizeN + n1) * size_m + m0 * kTileSizeM + m1];
num[0] = mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num[0] = mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber];
num[0] = mem[(m0 * kTileSizeM + m1) * size_n + n0 * kTileSizeN + n1];

c_to_feeder.Push(PackedFloat(num));
}
}
Expand All @@ -210,7 +215,7 @@ void ReadC(DramLine const *const mem, hlslib::Stream<PackedFloat> &c_to_feeder,
for (int m0 = 0; m0 < tiles_m; ++m0) {
ReadC_N:
for (int n1 = 0; n1 < kTileSizeN; ++n1) {
ReadCInner<kLinesPerNumber>(mem, c_to_feeder, size_m, n0, m0, n1);
ReadCInner<kLinesPerNumber>(mem, c_to_feeder, size_n, n0, m0, n1);
}
}
}
Expand Down Expand Up @@ -290,7 +295,7 @@ void WriteCInner(hlslib::Stream<PackedFloat> &from_kernel, DramLine *const mem,
}
const bool in_bounds = (n0 * kTileSizeN + n1 < size_n) && (m0 * kTileSizeM + m1 < size_m);
if (in_bounds) {
mem[((n0 * kTileSizeN + n1) * size_m + m0 * kTileSizeM + m1) * kLinesPerNumber + i] = num[i];
mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber + i] = num[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber + i] = num[i];
mem[((m0 * kTileSizeM + m1) * size_n + n0 * kTileSizeN + n1) * kLinesPerNumber + i] = num[i];

}
}
}
Expand All @@ -308,7 +313,7 @@ void WriteCInner<1>(hlslib::Stream<PackedFloat> &from_kernel, DramLine *const me
from_kernel.Pop().UnpackFlits(num);
const bool in_bounds = (n0 * kTileSizeN + n1 < size_n) && (m0 * kTileSizeM + m1 < size_m);
if (in_bounds) {
mem[(n0 * kTileSizeN + n1) * size_m + m0 * kTileSizeM + m1] = num[0];
mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber] = num[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mem[((n0 * kTileSizeN + n1) + (m0 * kTileSizeM + m1) * size_n) * kLinesPerNumber] = num[0];
mem[(m0 * kTileSizeM + m1) * size_n + n0 * kTileSizeN + n1] = num[0];

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this is not good -- C has now also become strided access, so we access two arrays strided, and only one sequentially. I'm afraid we have to complete change the iteration space of the computation for this, not just the memory indices

}
}
}
Expand Down Expand Up @@ -354,7 +359,7 @@ void Compute(hlslib::Stream<PackedFloat> &a_in, hlslib::Stream<PackedFloat> &b_i
const PackedFloat c_read = c_in.Pop();
const PackedFloat a = (m1 == 0) ? a_read : a_buffer;
const PackedFloat b = (n1 == 0) ? b_read : b_buffer[m1];
const PackedFloat c = (k == 0) ? c_read : c_buffer[n1 * kTileSizeM + m1];
const PackedFloat c = (k == 0) ? c_read : c_buffer[n1 + m1 * kTileSizeN];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no point in making this strided, it's a memory of size TN*TM anyway, doesn't matter how we order it

a_buffer = a;
b_buffer[m1] = b;
// Ignore contributions from out-of-bound indices
Expand All @@ -363,7 +368,7 @@ void Compute(hlslib::Stream<PackedFloat> &a_in, hlslib::Stream<PackedFloat> &b_i
const auto res = MultiplyAccumulate(in_bounds ? a : PackedFloat::Zero(),
in_bounds ? b : PackedFloat::Zero(), c);
// Write back to buffer
c_buffer[n1 * kTileSizeM + m1] = res;
c_buffer[n1 + m1 * kTileSizeN] = res;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, we shouldn't make these strided

c_out.Push(res);
}
}
Expand Down
Loading