From 9702a98f548ea1c16bbbbc3e4a8febdaba232012 Mon Sep 17 00:00:00 2001 From: Vin Huang Date: Thu, 11 Jul 2024 23:16:07 +0800 Subject: [PATCH] Add test for alpha vector scaling --- clients/benchmarks/client.cpp | 4 + clients/common/cblas_interface.cpp | 503 +++++++++++++----- clients/gtest/spmm_batched_gtest.yaml | 2 + clients/gtest/spmm_batched_gtest_1b.yaml | 2 + clients/gtest/spmm_batched_gtest_1b_row.yaml | 2 + clients/gtest/spmm_batched_gtest_row.yaml | 2 + clients/gtest/spmm_gtest.cpp | 5 + clients/gtest/spmm_gtest.yaml | 2 + clients/gtest/spmm_gtest_1b.yaml | 2 + clients/gtest/spmm_gtest_row.yaml | 2 + clients/gtest/spmm_strided_batched_gtest.yaml | 3 + .../gtest/spmm_strided_batched_gtest_1b.yaml | 3 + .../spmm_strided_batched_gtest_1b_row.yaml | 3 + .../gtest/spmm_strided_batched_gtest_row.yaml | 3 + clients/include/cblas_interface.hpp | 1 + clients/include/hipsparselt_arguments.hpp | 3 + clients/include/hipsparselt_common.yaml | 2 + clients/include/spmm/testing_spmm.hpp | 94 +++- 18 files changed, 498 insertions(+), 140 deletions(-) diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 4006954c..1f63204e 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -496,6 +496,10 @@ try value(&order_d)->default_value('N'), "C = Column Major, R = Row Major") + ("alpha_vector_scaling", + bool_switch(&arg.alpha_vector_scaling)->default_value(false), + "Apply alpha vector scaling") + ("help,h", "produces this help message") ("version", "Prints the version number"); diff --git a/clients/common/cblas_interface.cpp b/clients/common/cblas_interface.cpp index 2d60e44e..dca93406 100644 --- a/clients/common/cblas_interface.cpp +++ b/clients/common/cblas_interface.cpp @@ -72,6 +72,7 @@ void cblas_gemm(hipsparseOrder_t order, hip_bfloat16* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support hip_bfloat16, so convert to higher precision float @@ -86,22 +87,52 @@ void cblas_gemm(hipsparseOrder_t order, for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); + if(alphaVec != nullptr) + { + host_vector T_float(sizeC); + memset(T_float, 0, sizeC); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_float, + lda, + B_float, + ldb, + static_cast(0), + T_float, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C_float[pos] = T_float[pos] * alphaVec[i] + C_float[pos] * beta; + } + } + } + else + { + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_float, + lda, + B_float, + ldb, + beta, + C_float, + ldc); + } for(size_t i = 0; i < sizeC; i++) C[i] = static_cast(C_float[i]); @@ -125,34 +156,65 @@ void cblas_gemm(hipsparseOrder_t order, float* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support hip_bfloat16, so convert to higher precision float // This will give more precise result which is acceptable for testing - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); + host_vector A_float(sizeA), B_float(sizeB); for(size_t i = 0; i < sizeA; i++) A_float[i] = static_cast(A[i]); for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); + if(alphaVec != nullptr) + { + host_vector T_float(sizeC); + memset(T_float, 0, sizeC); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_float, + lda, + B_float, + ldb, + static_cast(0), + T_float, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C[pos] = T_float[pos] * alphaVec[i] + C[pos] * beta; + } + } + } + else + { + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_float, + lda, + B_float, + ldb, + beta, + C, + ldc); + } } template <> @@ -173,6 +235,7 @@ void cblas_gemm<__half, __half, float>(hipsparseOrder_t order, __half* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support __half, so convert to higher precision float @@ -199,22 +262,52 @@ void cblas_gemm<__half, __half, float>(hipsparseOrder_t order, C_float[i] = C[i]; } - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); + if(alphaVec != nullptr) + { + host_vector T_float(sizeC); + memset(T_float, 0, sizeC); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_float, + lda, + B_float, + ldb, + static_cast(0), + T_float, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C_float[pos] = T_float[pos] * alphaVec[i] + C_float[pos] * beta; + } + } + } + else + { + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_float, + lda, + B_float, + ldb, + beta, + C_float, + ldc); + } for(size_t i = 0; i < sizeC; i++) C[i] = __half(C_float[i]); @@ -238,6 +331,7 @@ void cblas_gemm<__half, float, float>(hipsparseOrder_t order, float* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support __half, so convert to higher precision float @@ -260,22 +354,53 @@ void cblas_gemm<__half, float, float>(hipsparseOrder_t order, B_float[i] = B[i]; } - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); + if(alphaVec != nullptr) + { + host_vector T_float(sizeC); + memset(T_float, 0, sizeC); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_float, + lda, + B_float, + ldb, + static_cast(0), + T_float, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C[pos] = T_float[pos] * alphaVec[i] + C[pos] * beta; + } + } + } + else + { + + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipsparselt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_float, + lda, + B_float, + ldb, + beta, + C, + ldc); + } } template <> @@ -296,6 +421,7 @@ void cblas_gemm(hipsparseOrder_t order, int8_t* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support int8_t input / int8_t output, however non-overflowing @@ -315,21 +441,52 @@ void cblas_gemm(hipsparseOrder_t order, for(size_t i = 0; i < sizeC; i++) C_double[i] = static_cast(C[i]); - // just directly cast, since transA, transB are integers in the enum - cblas_dgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_double, - lda, - B_double, - ldb, - beta, - C_double, - ldc); + if(alphaVec != nullptr) + { + host_vector T_double(sizeC); + memset(T_double, 0, sizeC); + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_double, + lda, + B_double, + ldb, + static_cast(0), + T_double, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C_double[pos] = T_double[pos] * static_cast(alphaVec[i]) + + C_double[pos] * static_cast(beta); + } + } + } + else + { + // just directly cast, since transA, transB are integers in the enum + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_double, + lda, + B_double, + ldb, + beta, + C_double, + ldc); + } auto saturate = [](double val) { val = std::nearbyint(val); @@ -359,6 +516,7 @@ void cblas_gemm(hipsparseOrder_t order, float* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support int8_t input / int8_t output, however non-overflowing @@ -378,21 +536,52 @@ void cblas_gemm(hipsparseOrder_t order, for(size_t i = 0; i < sizeC; i++) C_double[i] = static_cast(C[i]); - // just directly cast, since transA, transB are integers in the enum - cblas_dgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_double, - lda, - B_double, - ldb, - beta, - C_double, - ldc); + if(alphaVec != nullptr) + { + host_vector T_double(sizeC); + memset(T_double, 0, sizeC); + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_double, + lda, + B_double, + ldb, + static_cast(0), + T_double, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C_double[pos] = T_double[pos] * static_cast(alphaVec[i]) + + C_double[pos] * static_cast(beta); + } + } + } + else + { + // just directly cast, since transA, transB are integers in the enum + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_double, + lda, + B_double, + ldb, + beta, + C_double, + ldc); + } for(size_t i = 0; i < sizeC; i++) C[i] = static_cast(C_double[i]); @@ -416,6 +605,7 @@ void cblas_gemm(hipsparseOrder_t order, __half* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support int8_t input / int8_t output, however non-overflowing @@ -435,21 +625,52 @@ void cblas_gemm(hipsparseOrder_t order, for(size_t i = 0; i < sizeC; i++) C_double[i] = static_cast(C[i]); - // just directly cast, since transA, transB are integers in the enum - cblas_dgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_double, - lda, - B_double, - ldb, - beta, - C_double, - ldc); + if(alphaVec != nullptr) + { + host_vector T_double(sizeC); + memset(T_double, 0, sizeC); + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_double, + lda, + B_double, + ldb, + static_cast(0), + T_double, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C_double[pos] = T_double[pos] * static_cast(alphaVec[i]) + + C_double[pos] * static_cast(beta); + } + } + } + else + { + // just directly cast, since transA, transB are integers in the enum + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_double, + lda, + B_double, + ldb, + beta, + C_double, + ldc); + } for(size_t i = 0; i < sizeC; i++) C[i] = __half(C_double[i]); @@ -473,6 +694,7 @@ void cblas_gemm(hipsparseOrder_t order, hip_bfloat16* C, int64_t ldc, int64_t sizeC, + float* alphaVec, bool alt) { // cblas does not support int8_t input / int8_t output, however non-overflowing @@ -492,21 +714,52 @@ void cblas_gemm(hipsparseOrder_t order, for(size_t i = 0; i < sizeC; i++) C_double[i] = static_cast(C[i]); - // just directly cast, since transA, transB are integers in the enum - cblas_dgemm(HIPOrderToCBLASOrder(order), - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_double, - lda, - B_double, - ldb, - beta, - C_double, - ldc); + if(alphaVec != nullptr) + { + host_vector T_double(sizeC); + memset(T_double, 0, sizeC); + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + static_cast(1), + A_double, + lda, + B_double, + ldb, + static_cast(0), + T_double, + ldc); + for(int i = 0; i < m; i++) + { + for(int j = 0; j < n; j++) + { + size_t pos = order == HIPSPARSE_ORDER_COL ? j * ldc + i : i * ldc + j; + C_double[pos] = T_double[pos] * static_cast(alphaVec[i]) + + C_double[pos] * static_cast(beta); + } + } + } + else + { + // just directly cast, since transA, transB are integers in the enum + cblas_dgemm(HIPOrderToCBLASOrder(order), + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_double, + lda, + B_double, + ldb, + beta, + C_double, + ldc); + } for(size_t i = 0; i < sizeC; i++) C[i] = static_cast(C_double[i]); diff --git a/clients/gtest/spmm_batched_gtest.yaml b/clients/gtest/spmm_batched_gtest.yaml index 907f063f..0a0a51ec 100644 --- a/clients/gtest/spmm_batched_gtest.yaml +++ b/clients/gtest/spmm_batched_gtest.yaml @@ -51,6 +51,7 @@ Tests: transA_transB: *transA_transB_range batch_count: [ 1, 3 ] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_batched_small category: quick @@ -64,6 +65,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_batched_medium category: pre_checkin diff --git a/clients/gtest/spmm_batched_gtest_1b.yaml b/clients/gtest/spmm_batched_gtest_1b.yaml index 5849d9b5..1ab71c94 100644 --- a/clients/gtest/spmm_batched_gtest_1b.yaml +++ b/clients/gtest/spmm_batched_gtest_1b.yaml @@ -51,6 +51,7 @@ Tests: transA_transB: *transA_transB_range batch_count: [ 1, 3 ] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_batched_small category: quick @@ -63,6 +64,7 @@ Tests: bias_vector: [true] bias_stride: [0, -1, 256] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_batched_medium category: pre_checkin diff --git a/clients/gtest/spmm_batched_gtest_1b_row.yaml b/clients/gtest/spmm_batched_gtest_1b_row.yaml index 10c62966..ad8afdc1 100644 --- a/clients/gtest/spmm_batched_gtest_1b_row.yaml +++ b/clients/gtest/spmm_batched_gtest_1b_row.yaml @@ -55,6 +55,7 @@ Tests: transA_transB: *transA_transB_range batch_count: [ 1, 3 ] bias_vector: [false] + alpha_vector_scaling: [true, false] sparse_b: [ true, false] orderA: [R] orderB: [R] @@ -71,6 +72,7 @@ Tests: batch_count: [ 1, 3 ] bias_vector: [true] bias_stride: [0, -1, 256] + alpha_vector_scaling: [true, false] sparse_b: [ true, false] orderA: [R] orderB: [R] diff --git a/clients/gtest/spmm_batched_gtest_row.yaml b/clients/gtest/spmm_batched_gtest_row.yaml index 6ed3c9e6..08b5ad38 100644 --- a/clients/gtest/spmm_batched_gtest_row.yaml +++ b/clients/gtest/spmm_batched_gtest_row.yaml @@ -56,6 +56,7 @@ Tests: batch_count: [ 1, 3 ] bias_vector: [false] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] @@ -73,6 +74,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] diff --git a/clients/gtest/spmm_gtest.cpp b/clients/gtest/spmm_gtest.cpp index f0f27d83..d9e248e5 100644 --- a/clients/gtest/spmm_gtest.cpp +++ b/clients/gtest/spmm_gtest.cpp @@ -129,6 +129,11 @@ namespace << hip_datatype_to_string(arg.bias_type); } + if(arg.alpha_vector_scaling) + { + name << "_avs"; + } + name << '_' << (char)std::toupper(arg.transA) << (char)std::toupper(arg.transB); name << '_' << arg.M << '_' << arg.N << '_' << arg.K << '_' << arg.alpha << '_' diff --git a/clients/gtest/spmm_gtest.yaml b/clients/gtest/spmm_gtest.yaml index 289951be..4f5b41e9 100644 --- a/clients/gtest/spmm_gtest.yaml +++ b/clients/gtest/spmm_gtest.yaml @@ -58,6 +58,7 @@ Tests: transA_transB: *transA_transB_range alpha_beta: *alpha_beta_range sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_small category: quick @@ -70,6 +71,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_medium category: pre_checkin diff --git a/clients/gtest/spmm_gtest_1b.yaml b/clients/gtest/spmm_gtest_1b.yaml index 15044502..041ba13a 100644 --- a/clients/gtest/spmm_gtest_1b.yaml +++ b/clients/gtest/spmm_gtest_1b.yaml @@ -59,6 +59,7 @@ Tests: transA_transB: *transA_transB_range alpha_beta: *alpha_beta_range sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_small category: quick @@ -70,6 +71,7 @@ Tests: bias_vector: [true] bias_stride: [0, -1, 256] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_medium category: pre_checkin diff --git a/clients/gtest/spmm_gtest_row.yaml b/clients/gtest/spmm_gtest_row.yaml index b44fbd11..91245923 100644 --- a/clients/gtest/spmm_gtest_row.yaml +++ b/clients/gtest/spmm_gtest_row.yaml @@ -60,6 +60,7 @@ Tests: alpha_beta: *alpha_beta_range bias_vector: [false] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] @@ -76,6 +77,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] diff --git a/clients/gtest/spmm_strided_batched_gtest.yaml b/clients/gtest/spmm_strided_batched_gtest.yaml index 0557fd50..d777e457 100644 --- a/clients/gtest/spmm_strided_batched_gtest.yaml +++ b/clients/gtest/spmm_strided_batched_gtest.yaml @@ -42,6 +42,7 @@ Tests: transB: N batch_count: [ 1, 3 ] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_strided_batched_small category: quick @@ -55,6 +56,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_strided_batched_small_stride_zero category: quick @@ -72,6 +74,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_strided_batched_medium category: pre_checkin diff --git a/clients/gtest/spmm_strided_batched_gtest_1b.yaml b/clients/gtest/spmm_strided_batched_gtest_1b.yaml index 43bec08d..7b01609c 100644 --- a/clients/gtest/spmm_strided_batched_gtest_1b.yaml +++ b/clients/gtest/spmm_strided_batched_gtest_1b.yaml @@ -42,6 +42,7 @@ Tests: transB: N batch_count: [ 1, 3 ] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_strided_batched_small category: quick @@ -54,6 +55,7 @@ Tests: bias_vector: [true] bias_stride: [0, -1, 256] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_strided_batched_small_stride_zero category: quick @@ -70,6 +72,7 @@ Tests: bias_vector: [true] bias_stride: [0, -1, 256] sparse_b: [true, false] + alpha_vector_scaling: [true, false] - name: spmm_strided_batched_medium category: pre_checkin diff --git a/clients/gtest/spmm_strided_batched_gtest_1b_row.yaml b/clients/gtest/spmm_strided_batched_gtest_1b_row.yaml index 8891c1e6..b9576d65 100644 --- a/clients/gtest/spmm_strided_batched_gtest_1b_row.yaml +++ b/clients/gtest/spmm_strided_batched_gtest_1b_row.yaml @@ -48,6 +48,7 @@ Tests: batch_count: [ 1, 3 ] bias_vector: [false] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] @@ -64,6 +65,7 @@ Tests: bias_vector: [true] bias_stride: [0, -1, 256] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] @@ -84,6 +86,7 @@ Tests: bias_vector: [true] bias_stride: [0, -1, 256] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] diff --git a/clients/gtest/spmm_strided_batched_gtest_row.yaml b/clients/gtest/spmm_strided_batched_gtest_row.yaml index fde83640..66b6fb09 100644 --- a/clients/gtest/spmm_strided_batched_gtest_row.yaml +++ b/clients/gtest/spmm_strided_batched_gtest_row.yaml @@ -48,6 +48,7 @@ Tests: batch_count: [ 1, 3 ] bias_vector: [false] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] @@ -65,6 +66,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] @@ -86,6 +88,7 @@ Tests: bias_stride: [0, -1, 256] bias_type: [f32_r, f16_r] sparse_b: [ true, false] + alpha_vector_scaling: [true, false] orderA: [R] orderB: [R] orderC: [R] diff --git a/clients/include/cblas_interface.hpp b/clients/include/cblas_interface.hpp index 4d06cddf..3f7f7997 100644 --- a/clients/include/cblas_interface.hpp +++ b/clients/include/cblas_interface.hpp @@ -54,4 +54,5 @@ void cblas_gemm(hipsparseOrder_t order, std::add_pointer_t C, int64_t ldc, int64_t sizeC, + Tc* alphaVec, bool alt = false); diff --git a/clients/include/hipsparselt_arguments.hpp b/clients/include/hipsparselt_arguments.hpp index 22825726..78435d80 100644 --- a/clients/include/hipsparselt_arguments.hpp +++ b/clients/include/hipsparselt_arguments.hpp @@ -126,6 +126,8 @@ struct Arguments bool sparse_b; int func_version; + bool alpha_vector_scaling; + char orderA; char orderB; char orderC; @@ -193,6 +195,7 @@ struct Arguments OPER(search_iters) SEP \ OPER(sparse_b) SEP \ OPER(func_version) SEP \ + OPER(alpha_vector_scaling) SEP \ OPER(orderA) SEP \ OPER(orderB) SEP \ OPER(orderC) SEP \ diff --git a/clients/include/hipsparselt_common.yaml b/clients/include/hipsparselt_common.yaml index c0c5dff7..f3efc6b1 100644 --- a/clients/include/hipsparselt_common.yaml +++ b/clients/include/hipsparselt_common.yaml @@ -128,6 +128,7 @@ Arguments: - search_iters: c_int32 - sparse_b: c_bool - func_version: c_int32 + - alpha_vector_scaling: c_bool - orderA: c_char - orderB: c_char - orderC: c_char @@ -208,6 +209,7 @@ Defaults: search_iters: 10 sparse_b: false func_version: 1 + alpha_vector_scaling: false orderA: C orderB: C orderC: C diff --git a/clients/include/spmm/testing_spmm.hpp b/clients/include/spmm/testing_spmm.hpp index acc55c9d..e89058b9 100644 --- a/clients/include/spmm/testing_spmm.hpp +++ b/clients/include/spmm/testing_spmm.hpp @@ -547,6 +547,26 @@ void testing_spmm(const Arguments& arg) #endif } + const size_t size_alpha_vec = arg.alpha_vector_scaling ? M : 0; + + device_vector dAlpahVector(size_alpha_vec, 1, HMM); + CHECK_DEVICE_ALLOCATION(dAlpahVector.memcheck()); + host_vector hAlpahVector(size_alpha_vec); + if(arg.alpha_vector_scaling) + { + hipsparselt_init(hAlpahVector, M, 1, M, size_alpha_vec, 1); + CHECK_HIP_ERROR(dAlpahVector.transfer_from(hAlpahVector)); + int alpha_vector_scaling = 1; + EXPECT_HIPSPARSE_STATUS( + hipsparseLtMatmulDescSetAttribute(handle, + matmul, + HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, + &alpha_vector_scaling, + sizeof(int)), + HIPSPARSE_STATUS_SUCCESS); + h_alpha = static_cast(1); + } + hipsparselt_local_matmul_alg_selection alg_sel(handle, matmul, HIPSPARSELT_MATMUL_ALG_DEFAULT); size_t workspace_size = 0, compressed_size = 0, compress_buffer_size = 0; @@ -733,16 +753,34 @@ void testing_spmm(const Arguments& arg) if(arg.search) EXPECT_HIPSPARSE_STATUS( - hipsparseLtMatmulSearch( - handle, plan, &h_alpha, dA_, dB_, &h_beta, dC, dD, dWorkspace, &stream, 1), + hipsparseLtMatmulSearch(handle, + plan, + arg.alpha_vector_scaling ? dAlpahVector : &h_alpha, + dA_, + dB_, + &h_beta, + dC, + dD, + dWorkspace, + &stream, + 1), HIPSPARSE_STATUS_SUCCESS); if(arg.unit_check || arg.norm_check) { CHECK_HIP_ERROR(hipStreamSynchronize(stream)); CHECK_HIP_ERROR(h_pruned.transfer_from(arg.sparse_b ? dB : dA)); EXPECT_HIPSPARSE_STATUS( - hipsparseLtMatmul( - handle, plan, &h_alpha, dA_, dB_, &h_beta, dC, dD, dWorkspace, &stream, 1), + hipsparseLtMatmul(handle, + plan, + arg.alpha_vector_scaling ? dAlpahVector : &h_alpha, + dA_, + dB_, + &h_beta, + dC, + dD, + dWorkspace, + &stream, + 1), HIPSPARSE_STATUS_SUCCESS); // now we can recycle gold matrix for reference purposes if(arg.timing) @@ -807,6 +845,7 @@ void testing_spmm(const Arguments& arg) hD_gold_act + stride_d * i, ldd, tSizeD, + arg.alpha_vector_scaling ? hAlpahVector : nullptr, false); auto pos = stride_d * i; @@ -877,6 +916,7 @@ void testing_spmm(const Arguments& arg) hD_gold + stride_d * i, ldd, tSizeD, + arg.alpha_vector_scaling ? hAlpahVector : nullptr, false); } #undef activation_param @@ -907,13 +947,17 @@ void testing_spmm(const Arguments& arg) } // Debug - //print_strided_batched("A", &hA[0], A_row_r, A_col_r, num_batches, 1, lda, stride_a); - //print_strided_batched("B", &hB[0], B_row_r, B_col_r, num_batches, 1, ldb, stride_b); - //print_strided_batched("C", &hC[0], C_row_r, C_col_r, num_batches, 1, ldc, stride_c); - //if(arg.bias_vector) - // print_strided_batched("bias", &hBias[0], M, 1, num_batches, 1, M, bias_stride); - //print_strided_batched("hD_gold", &hD_gold[0], tM, tN, num_batches, 1, ldd, stride_d); - //print_strided_batched("hD1", &hD_1[0], tM, tN, num_batches, 1, ldd, stride_d); +#if 0 + print_strided_batched("A", &hA_[0], A_row_r, A_col_r, num_batches, 1, lda, stride_a); + print_strided_batched("B", &hB_[0], B_row_r, B_col_r, num_batches, 1, ldb, stride_b); + print_strided_batched("C", &hC[0], C_row_r, C_col_r, num_batches, 1, ldc, stride_c); + if(arg.bias_vector) + print_strided_batched("bias", &hBias[0], M, 1, num_batches, 1, M, bias_stride); + if(arg.alpha_vector_scaling) + print_strided_batched("alpha_vec", &hAlpahVector[0], M, 1, 1, 1, M, M); + print_strided_batched("hD_gold", &hD_gold[0], tM, tN, num_batches, 1, ldd, stride_d); + print_strided_batched("hD1", &hD_1[0], tM, tN, num_batches, 1, ldd, stride_d); +#endif } if(arg.timing) @@ -923,8 +967,17 @@ void testing_spmm(const Arguments& arg) for(int i = 0; i < number_cold_calls; i++) { EXPECT_HIPSPARSE_STATUS( - hipsparseLtMatmul( - handle, plan, &h_alpha, dA_, dB_, &h_beta, dC, dD, dWorkspace, &stream, 1), + hipsparseLtMatmul(handle, + plan, + arg.alpha_vector_scaling ? dAlpahVector : &h_alpha, + dA_, + dB_, + &h_beta, + dC, + dD, + dWorkspace, + &stream, + 1), HIPSPARSE_STATUS_SUCCESS); } @@ -933,8 +986,17 @@ void testing_spmm(const Arguments& arg) for(int i = 0; i < number_hot_calls; i++) { EXPECT_HIPSPARSE_STATUS( - hipsparseLtMatmul( - handle, plan, &h_alpha, dA_, dB_, &h_beta, dC, dD, dWorkspace, &stream, 1), + hipsparseLtMatmul(handle, + plan, + arg.alpha_vector_scaling ? dAlpahVector : &h_alpha, + dA_, + dB_, + &h_beta, + dC, + dD, + dWorkspace, + &stream, + 1), HIPSPARSE_STATUS_SUCCESS); } CHECK_HIP_ERROR(hipStreamSynchronize(stream)); @@ -1317,6 +1379,7 @@ void testing_aux_plan_assign(const Arguments& arg) hD_gold_act + stride_d * i, ldd, ldd * N, + nullptr, false); auto pos = stride_d * i; @@ -1341,6 +1404,7 @@ void testing_aux_plan_assign(const Arguments& arg) hD_gold + stride_d * i, ldd, ldd * N, + nullptr, false); } #undef activation_param