-
Notifications
You must be signed in to change notification settings - Fork 0
/
gemm-multi-stage.cu
500 lines (403 loc) · 16.5 KB
/
gemm-multi-stage.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
#include <cublas_v2.h>
#include <cuda.h>
#include <stdarg.h>
#include <stdio.h>
#include <cute/tensor.hpp>
#include "detail/cublaslt-gemm.h"
#include "detail/data.h"
template <typename Config>
__global__ void /* __launch_bounds__(128, 1) */
gemm_multi_stage(void *Dptr, const void *Aptr, const void *Bptr, int m, int n,
int k)
{
using namespace cute;
using X = Underscore;
using T = typename Config::T;
using SmemLayoutA = typename Config::SmemLayoutA;
using SmemLayoutB = typename Config::SmemLayoutB;
using SmemLayoutC = typename Config::SmemLayoutC;
using TiledMMA = typename Config::MMA;
using S2RCopyAtomA = typename Config::S2RCopyAtomA;
using S2RCopyAtomB = typename Config::S2RCopyAtomB;
using G2SCopyA = typename Config::G2SCopyA;
using G2SCopyB = typename Config::G2SCopyB;
using R2SCopyAtomC = typename Config::R2SCopyAtomC;
using S2GCopyAtomC = typename Config::S2GCopyAtomC;
using S2GCopyC = typename Config::S2GCopyC;
constexpr int kTileM = Config::kTileM;
constexpr int kTileN = Config::kTileN;
constexpr int kTileK = Config::kTileK;
constexpr int kStage = Config::kStage;
extern __shared__ T shm_data[];
T *Ashm = shm_data;
T *Bshm = shm_data + cute::cosize(SmemLayoutA{});
int idx = threadIdx.x;
int ix = blockIdx.x;
int iy = blockIdx.y;
// use Tensor notation to represent device pointer + dimension
Tensor A = make_tensor(make_gmem_ptr((T *)Aptr), make_shape(m, k),
make_stride(k, Int<1>{})); // (M, K)
Tensor B = make_tensor(make_gmem_ptr((T *)Bptr), make_shape(n, k),
make_stride(k, Int<1>{})); // (N, K)
Tensor D = make_tensor(make_gmem_ptr((T *)Dptr), make_shape(m, n),
make_stride(n, Int<1>{})); // (M, N)
// slice the tensor to small one which is used for current thread block.
Tensor gA = local_tile(A, make_tile(Int<kTileM>{}, Int<kTileK>{}),
make_coord(iy, _)); // (kTileM, kTileK, k)
Tensor gB = local_tile(B, make_tile(Int<kTileN>{}, Int<kTileK>{}),
make_coord(ix, _)); // (kTileN, kTileK, k)
Tensor gD = local_tile(D, make_tile(Int<kTileM>{}, Int<kTileN>{}),
make_coord(iy, ix)); // (kTileM, kTileN)
// shared memory
auto sA = make_tensor(make_smem_ptr(Ashm),
SmemLayoutA{}); // (kTileM, kTileK, kStage)
auto sB = make_tensor(make_smem_ptr(Bshm),
SmemLayoutB{}); // (kTileN, kTileK, kStage)
// dispatch TileA/TileB/TileC mma tensor into thread fragment via partition
// method
TiledMMA tiled_mma;
auto thr_mma = tiled_mma.get_slice(idx);
auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // (MMA, MMA_M, MMA_K)
auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // (MMA, MMA_N, MMA_K)
auto tCrD = thr_mma.partition_fragment_C(gD); // (MMA, MMA_M, MMA_N)
// fill zero for accumulator
clear(tCrD);
// gmem -cp.async-> shm -ldmatrix-> reg
auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma);
auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(idx);
auto tAsA = s2r_thr_copy_a.partition_S(sA); // ? (CPY, CPY_M, CPY_K, kStage)
auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA); // ? (CPY, CPY_M, CPY_K)
auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma);
auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(idx);
auto tBsB = s2r_thr_copy_b.partition_S(sB); // ? (CPY, CPY_M, CPY_K, kStage)
auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB); // ? (CPY, CPY_M, CPY_K)
G2SCopyA g2s_tiled_copy_a;
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(idx);
auto tAgA_copy = g2s_thr_copy_a.partition_S(gA); // (CPY, CPY_M, CPY_K, k)
auto tAsA_copy =
g2s_thr_copy_a.partition_D(sA); // (CPY, CPY_M, CPY_K, kStage)
G2SCopyB g2s_tiled_copy_b;
auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(idx);
auto tBgB_copy = g2s_thr_copy_b.partition_S(gB); // (CPY, CPY_N, CPY_K, k)
auto tBsB_copy =
g2s_thr_copy_b.partition_D(sB); // (CPY, CPY_N, CPY_K, kStage)
int itile_to_read = 0;
int ismem_read = 0;
int ismem_write = 0;
// submit kStage - 1 tile
// gmem -> shm
#pragma unroll
for (int istage = 0; istage < kStage - 1; ++istage)
{
cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, istage),
tAsA_copy(_, _, _, istage));
cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, istage),
tBsB_copy(_, _, _, istage));
cp_async_fence();
++itile_to_read;
++ismem_write;
}
// wait one submitted gmem->smem done
cp_async_wait<kStage - 2>();
__syncthreads();
int ik = 0;
// smem -> reg
cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik, ismem_read), tCrA_view(_, _, ik));
cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik, ismem_read), tCrB_view(_, _, ik));
// loop over k: i. load tile, ii. mma
int ntile = k / kTileK;
#pragma unroll 1
for (int itile = 0; itile < ntile; ++itile)
{
int nk = size<2>(tCrA);
#pragma unroll
for (int ik = 0; ik < nk; ++ik)
{
int ik_next = (ik + 1) % nk;
if (ik == nk - 1)
{
cp_async_wait<kStage - 2>();
__syncthreads();
ismem_read = (ismem_read + 1) % kStage;
}
// shm -> reg s[itile][ik + 1] -> r[ik + 1]
cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik_next, ismem_read),
tCrA_view(_, _, ik_next));
cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik_next, ismem_read),
tCrB_view(_, _, ik_next));
if (ik == 0)
{
if (itile_to_read < ntile)
{
cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, itile_to_read),
tAsA_copy(_, _, _, ismem_write));
cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, itile_to_read),
tBsB_copy(_, _, _, ismem_write));
++itile_to_read;
ismem_write = (ismem_write + 1) % kStage;
}
cp_async_fence();
}
cute::gemm(tiled_mma, tCrD, tCrA(_, _, ik), tCrB(_, _, ik), tCrD);
} // for ik
} // itile
// use less shared memory as a scratchpad tile to use large wide instuction
// Dreg -> shm -> reg -> global
auto sC = make_tensor(sA(_, _, ismem_read).data(), SmemLayoutC{});
auto r2s_tiled_copy_c = make_tiled_copy_C(R2SCopyAtomC{}, tiled_mma);
auto r2s_thr_copy_c = r2s_tiled_copy_c.get_slice(idx);
auto tCrC_r2s = r2s_thr_copy_c.retile_S(tCrD); // (CPY, CPY_M, CPY_N)
auto tCsC_r2s = r2s_thr_copy_c.partition_D(sC); // (CPY, _1, _1, pipe)
S2GCopyC s2g_tiled_copy_c;
auto s2g_thr_copy_c = s2g_tiled_copy_c.get_thread_slice(idx);
auto tCsC_s2g = s2g_thr_copy_c.partition_S(sC); // (CPY, _1, _1, pipe)
auto tCgC_s2g = s2g_thr_copy_c.partition_D(gD); // (CPY, CPY_M, CPY_N)
auto tCgC_s2gx = group_modes<1, 3>(tCgC_s2g); // (CPY_, CPY_MN)
auto tCrC_r2sx = group_modes<1, 3>(tCrC_r2s); // (CPY_, CPY_MN)
int step = size<3>(tCsC_r2s); // pipe
#pragma unroll
for (int i = 0; i < size<1>(tCrC_r2sx); i += step)
{
// reg -> shm
#pragma unroll
for (int j = 0; j < step; ++j)
{
// we add a temp tensor to cope with accumulator and output data type
// difference
auto t = make_tensor_like<T>(tCrC_r2sx(_, i + j));
cute::copy(tCrC_r2sx(_, i + j), t);
cute::copy(r2s_tiled_copy_c, t, tCsC_r2s(_, 0, 0, j));
}
__syncthreads();
#pragma unroll
// shm -> global
for (int j = 0; j < step; ++j)
{
cute::copy(s2g_tiled_copy_c, tCsC_s2g(_, 0, 0, j), tCgC_s2gx(_, i + j));
}
__syncthreads();
}
}
namespace config
{
using namespace cute;
template <typename T_, int kTileM_ = 128, int kTileN_ = 128, int kTileK_ = 32,
int kStage_ = 5, int kSmemLayoutCBatch_ = 2,
typename ComputeType = T_>
struct GemmConfig
{
using T = T_;
// tile configuration
static constexpr int kTileM = kTileM_;
static constexpr int kTileN = kTileN_;
static constexpr int kTileK = kTileK_;
static constexpr int kStage = kStage_;
static constexpr int kSmemLayoutCBatch = kSmemLayoutCBatch_;
static constexpr int kShmLoadSwizzleM = 3;
static constexpr int kShmLoadSwizzleS = 3;
static constexpr int kShmLoadSwizzleB = 3;
using SmemLayoutAtom = decltype(composition(
Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{},
make_layout(make_shape(Int<8>{}, Int<kTileK>{}),
make_stride(Int<kTileK>{}, Int<1>{}))));
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtom{},
make_shape(Int<kTileM>{}, Int<kTileK>{}, Int<kStage>{})));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtom{},
make_shape(Int<kTileN>{}, Int<kTileK>{}, Int<kStage>{})));
using mma_op = SM80_16x8x16_F16F16F16F16_TN;
using mma_traits = MMA_Traits<mma_op>;
using mma_atom = MMA_Atom<mma_traits>;
static constexpr int kMmaEURepeatM = 2;
static constexpr int kMmaEURepeatN = 2;
static constexpr int kMmaEURepeatK = 1;
using mma_atom_shape = mma_traits::Shape_MNK;
static constexpr int kMmaPM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{});
static constexpr int kMmaPN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{});
static constexpr int kMmaPK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{});
using MMA_EU_RepeatT = decltype(make_layout(make_shape(
Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{})));
using MMA_P_T = Tile<Int<kMmaPM>, Int<kMmaPN>, Int<kMmaPK>>;
using MMA = decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_P_T{}));
using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
using g2s_copy_atom = Copy_Atom<g2s_copy_traits, T>;
using G2SCopyA =
decltype(make_tiled_copy(g2s_copy_atom{},
make_layout(make_shape(Int<32>{}, Int<4>{}),
make_stride(Int<4>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));
using G2SCopyB = G2SCopyA;
// shared memory to register copy
using s2r_copy_op = SM75_U32x4_LDSM_N;
using s2r_copy_traits = Copy_Traits<s2r_copy_op>;
using s2r_copy_atom = Copy_Atom<s2r_copy_traits, T>;
using S2RCopyAtomA = s2r_copy_atom;
using S2RCopyAtomB = s2r_copy_atom;
// epilogue: register to global via shared memory
using SmemLayoutAtomC = decltype(composition(
Swizzle<2, 3, 3>{}, make_layout(make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}),
make_stride(Int<kMmaPN>{}, Int<1>{}))));
using SmemLayoutC = decltype(tile_to_shape(
SmemLayoutAtomC{},
make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}, Int<kSmemLayoutCBatch>{})));
static_assert(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) >=
size(SmemLayoutC{}),
"C shared memory request is large than A's one pipe");
using R2SCopyAtomC = Copy_Atom<UniversalCopy<int>, T>;
using S2GCopyAtomC = Copy_Atom<UniversalCopy<cute::uint128_t>, T>;
using S2GCopyC =
decltype(make_tiled_copy(S2GCopyAtomC{},
make_layout(make_shape(Int<32>{}, Int<4>{}),
make_stride(Int<4>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));
static constexpr int kThreadNum = size(MMA{});
static constexpr int shm_size_AB =
cute::cosize(SmemLayoutA{}) + cute::cosize(SmemLayoutB{});
static constexpr int shm_size_C = cute::cosize(SmemLayoutC{});
static constexpr int kShmSize =
cute::max(shm_size_AB, shm_size_C) * sizeof(T);
};
} // namespace config
int main(int argc, char *argv[])
{
using T = cute::half_t;
using namespace cute;
using X = Underscore;
srand(10086);
cublasHandle_t handle;
cublasCreate(&handle);
int cublas_version;
cublasGetVersion_v2(handle, &cublas_version);
printf("cuBLAS version: %d\n", cublas_version);
// default;
int M = 81920;
int N = 256;
int K = 256;
int enable_cpu = 0;
int enable_cublaslt = 1;
int nt = 11;
using ComputeType = T;
T *Aptr;
T *Bptr;
T *Dptr;
T *Dptr_cublas;
T *Dptr_cublaslt;
T *Aptr_host;
T *Bptr_host;
T *Dptr_host;
T *Dptr_host_cpu;
T *Dptr_host_blas;
T *Dptr_host_cublaslt;
Aptr_host = (T *)malloc(sizeof(T) * M * K);
Bptr_host = (T *)malloc(sizeof(T) * N * K);
Dptr_host = (T *)malloc(sizeof(T) * M * N);
Dptr_host_cpu = (T *)malloc(sizeof(T) * M * N);
Dptr_host_blas = (T *)malloc(sizeof(T) * M * N);
Dptr_host_cublaslt = (T *)malloc(sizeof(T) * M * N);
cudaMalloc(&Aptr, sizeof(T) * M * K);
cudaMalloc(&Bptr, sizeof(T) * N * K);
cudaMalloc(&Dptr, sizeof(T) * M * N);
cudaMalloc(&Dptr_cublas, sizeof(T) * M * N);
cudaMalloc(&Dptr_cublaslt, sizeof(T) * M * N);
auto tA = make_tensor(Aptr_host, make_shape(M, K), make_stride(K, 1));
auto tB = make_tensor(Bptr_host, make_shape(N, K), make_stride(K, 1));
auto tD = make_tensor(Dptr_host, make_shape(M, N), make_stride(N, 1));
cpu_rand_data(&tA);
cpu_rand_data(&tB);
clear(tD);
cudaMemcpy(Aptr, Aptr_host, sizeof(T) * M * K, cudaMemcpyHostToDevice);
cudaMemcpy(Bptr, Bptr_host, sizeof(T) * N * K, cudaMemcpyHostToDevice);
cudaMemcpy(Dptr, Dptr_host, sizeof(T) * M * N, cudaMemcpyHostToDevice);
cudaMemset(Dptr_cublas, 0, sizeof(T) * M * N);
cudaMemset(Dptr_cublaslt, 0, sizeof(T) * M * N);
CublasLtGemm<T, ComputeType> cublaslt_gemm;
if (enable_cublaslt)
{
cublaslt_gemm.init(Dptr_cublaslt, Bptr, Aptr, N, M, K);
}
config::GemmConfig<T, 128, 128, 32, 3> gemm_config;
print(typename decltype(gemm_config)::MMA{});
dim3 block = gemm_config.kThreadNum;
dim3 grid((N + gemm_config.kTileN - 1) / gemm_config.kTileN,
(M + gemm_config.kTileM - 1) / gemm_config.kTileM);
int shm_size = gemm_config.kShmSize;
half alpha = 1.f;
half beta = 0.f;
for (int it = 0; it < nt; ++it)
{
// blas
cudaMemset(Dptr_cublas, 0, sizeof(T) * M * N);
cublasStatus_t ret = cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, K,
&alpha, (half *)Bptr, K, (half *)Aptr, K,
&beta, (half *)Dptr_cublas, N);
if (ret != CUBLAS_STATUS_SUCCESS)
{
printf("cublas err = %d, str = %s\n", ret, cublasGetStatusString(ret));
}
if (enable_cublaslt)
{
cudaMemset(Dptr_cublaslt, 0, sizeof(T) * M * N);
cublaslt_gemm.run();
}
// multi-stage
cudaMemset(Dptr, 0, sizeof(T) * M * N);
cudaFuncSetAttribute(gemm_multi_stage<decltype(gemm_config)>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
gemm_multi_stage<decltype(gemm_config)>
<<<grid, block, shm_size>>>(Dptr, Aptr, Bptr, M, N, K);
}
cudaMemcpy(Dptr_host, Dptr, sizeof(T) * M * N, cudaMemcpyDeviceToHost);
cudaMemcpy(Dptr_host_blas, Dptr_cublas, sizeof(T) * M * N,
cudaMemcpyDeviceToHost);
cudaMemcpy(Dptr_host_cublaslt, Dptr_cublaslt, sizeof(T) * M * N,
cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
printf("block = (%d, %d), gird = (%d, %d), shm = %d\n", block.x, block.y,
grid.x, grid.y, shm_size);
if (err == cudaSuccess)
{
printf("err = %d, str = %s\n", err, cudaGetErrorString(err));
}
else
{
printf_fail("err = %d, str = %s\n", err, cudaGetErrorString(err));
}
gpu_compare(Dptr, Dptr_cublas, M * N);
if (enable_cublaslt)
{
gpu_compare(Dptr, Dptr_cublaslt, M * N);
}
auto tD_host = make_tensor(Dptr_host, make_shape(M, N), make_stride(N, 1));
auto tD_host_cpu =
make_tensor(Dptr_host_cpu, make_shape(M, N), make_stride(N, 1));
auto tD_host_blas =
make_tensor(Dptr_host_blas, make_shape(M, N), make_stride(N, 1));
auto tD_host_cublaslt =
make_tensor(Dptr_host_cublaslt, make_shape(M, N), make_stride(N, 1));
if (enable_cpu)
{
cpu_gemm(&tD_host_cpu, tA, tB);
cpu_compare(tD_host_cpu, tD_host, 0.1f);
}
auto tile = make_tile(min(8, M), min(8, N));
auto t32x32 = local_tile(tD_host, tile, make_coord(0, 0));
auto t32x32_cpu = local_tile(tD_host_cpu, tile, make_coord(0, 0));
auto t32x32_blas = local_tile(tD_host_blas, tile, make_coord(0, 0));
auto t32x32_cublaslt = local_tile(tD_host_cublaslt, tile, make_coord(0, 0));
printf("M = %d, N = %d, K = %d\n", M, N, K);
printf("our-impl:\n");
print_tensor(t32x32);
if (enable_cpu)
{
printf("cpu:\n");
print_tensor(t32x32_cpu);
}
printf("cublas:\n");
print_tensor(t32x32_blas);
if (enable_cublaslt)
{
printf("cublaslt:\n");
print_tensor(t32x32_cublaslt);
}
}