Skip to content

Commit

Permalink
Initialized information of compressed matrix's row, col and ld when u…
Browse files Browse the repository at this point in the history
…sing prune2, compress2 related functions directly.
  • Loading branch information
vin-huang committed Sep 6, 2024
1 parent 7182c64 commit de7bb0d
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 5 deletions.
45 changes: 42 additions & 3 deletions clients/include/spmm/testing_compress.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,23 @@ void testing_compress(const Arguments& arg)
arg.b_type,
orderB);

hipsparselt_local_mat_descr matAv2(arg.sparse_b ? hipsparselt_matrix_type_dense
: hipsparselt_matrix_type_structured,
handle,
A_row,
A_col,
lda,
arg.a_type,
orderA);
hipsparselt_local_mat_descr matBv2(arg.sparse_b ? hipsparselt_matrix_type_structured
: hipsparselt_matrix_type_dense,
handle,
B_row,
B_col,
ldb,
arg.b_type,
orderB);

hipsparselt_local_mat_descr matC(
hipsparselt_matrix_type_dense, handle, M, N, ldc, arg.c_type, orderC);
hipsparselt_local_mat_descr matD(
Expand Down Expand Up @@ -484,6 +501,14 @@ void testing_compress(const Arguments& arg)
hipsparseLtMatDescSetAttribute(
handle, matB, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matAv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matBv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matC, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
Expand All @@ -509,6 +534,20 @@ void testing_compress(const Arguments& arg)
eStatus);
if(eStatus != HIPSPARSE_STATUS_SUCCESS)
return;
eStatus = expected_hipsparse_status_of_matrix_stride(stride_a, A_row, A_col, lda, orderA);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matAv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_a, sizeof(int64_t)),
eStatus);
if(eStatus != HIPSPARSE_STATUS_SUCCESS)
return;
eStatus = expected_hipsparse_status_of_matrix_stride(stride_b, B_row, B_col, ldb, orderB);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matBv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_b, sizeof(int64_t)),
eStatus);
if(eStatus != HIPSPARSE_STATUS_SUCCESS)
return;
eStatus = expected_hipsparse_status_of_matrix_stride(stride_c, M, N, ldc, orderC);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
Expand Down Expand Up @@ -545,7 +584,7 @@ void testing_compress(const Arguments& arg)
{
EXPECT_HIPSPARSE_STATUS(
hipsparseLtSpMMACompressedSize2(
handle, arg.sparse_b ? matB : matA, &compressed_size, &compress_buffer_size),
handle, arg.sparse_b ? matBv2 : matAv2, &compressed_size, &compress_buffer_size),
HIPSPARSE_STATUS_SUCCESS);
}
const size_t size_A = stride_a == 0
Expand Down Expand Up @@ -623,7 +662,7 @@ void testing_compress(const Arguments& arg)
else if(arg.func_version == 2)
{
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPrune2(handle,
arg.sparse_b ? matB : matA,
arg.sparse_b ? matBv2 : matAv2,
!arg.sparse_b,
arg.sparse_b ? transB : transA,
dT,
Expand Down Expand Up @@ -717,7 +756,7 @@ void testing_compress(const Arguments& arg)
HIPSPARSE_STATUS_SUCCESS);
else if(arg.func_version == 2)
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMACompress2(handle,
arg.sparse_b ? matB : matA,
arg.sparse_b ? matBv2 : matAv2,
!arg.sparse_b,
arg.sparse_b ? transB : transA,
dT,
Expand Down
42 changes: 40 additions & 2 deletions clients/include/spmm/testing_prune.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,22 @@ void testing_prune(const Arguments& arg)
ldb,
arg.b_type,
orderB);
hipsparselt_local_mat_descr matAv2(arg.sparse_b ? hipsparselt_matrix_type_dense
: hipsparselt_matrix_type_structured,
handle,
A_row,
A_col,
lda,
arg.a_type,
orderA);
hipsparselt_local_mat_descr matBv2(arg.sparse_b ? hipsparselt_matrix_type_structured
: hipsparselt_matrix_type_dense,
handle,
B_row,
B_col,
ldb,
arg.b_type,
orderB);
hipsparselt_local_mat_descr matC(
hipsparselt_matrix_type_dense, handle, M, N, ldc, arg.c_type, orderC);
hipsparselt_local_mat_descr matD(
Expand Down Expand Up @@ -596,6 +612,14 @@ void testing_prune(const Arguments& arg)
hipsparseLtMatDescSetAttribute(
handle, matB, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matAv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matBv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matC, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
Expand All @@ -622,6 +646,20 @@ void testing_prune(const Arguments& arg)
eStatusB);
if(eStatusB != HIPSPARSE_STATUS_SUCCESS)
return;
eStatusA = expected_hipsparse_status_of_matrix_stride(stride_a, A_row, A_col, lda, orderA);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matAv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_a, sizeof(int64_t)),
eStatusA);
if(eStatusA != HIPSPARSE_STATUS_SUCCESS)
return;
eStatusB = expected_hipsparse_status_of_matrix_stride(stride_b, B_row, B_col, ldb, orderB);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matBv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_b, sizeof(int64_t)),
eStatusB);
if(eStatusB != HIPSPARSE_STATUS_SUCCESS)
return;
eStatusC = expected_hipsparse_status_of_matrix_stride(stride_c, M, N, ldc, orderC);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
Expand Down Expand Up @@ -714,7 +752,7 @@ void testing_prune(const Arguments& arg)
HIPSPARSE_STATUS_SUCCESS);
else if(arg.func_version == 2)
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPrune2(handle,
arg.sparse_b ? matB : matA,
arg.sparse_b ? matBv2 : matAv2,
!arg.sparse_b,
arg.sparse_b ? transB : transA,
dT,
Expand All @@ -737,7 +775,7 @@ void testing_prune(const Arguments& arg)
HIPSPARSE_STATUS_SUCCESS);
else if(arg.func_version == 2)
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPruneCheck2(handle,
arg.sparse_b ? matB : matA,
arg.sparse_b ? matBv2 : matAv2,
!arg.sparse_b,
arg.sparse_b ? transB : transA,
dT_pruned,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,40 @@ inline rocsparselt_status getOriginalSizes(rocsparselt_operation opA,
return rocsparselt_status_success;
}

inline void initSparseMatrixLayout(rocsparselt_operation op,
const rocsparselt_mat_descr* sparseMatDescr,
bool isSparseA)

{
auto _sparseMatDescr = reinterpret_cast<_rocsparselt_mat_descr*>(
const_cast<rocsparselt_mat_descr*>(sparseMatDescr));
if(isSparseA)
{
auto m = _sparseMatDescr->m;
auto k = _sparseMatDescr->n;
if (op == rocsparselt_operation_transpose)
std::swap(m, k);
_sparseMatDescr->c_k = k / 2;
_sparseMatDescr->c_ld = m;
_sparseMatDescr->c_n = _sparseMatDescr->c_k;
if((op == rocsparselt_operation_transpose)
!= (_sparseMatDescr->order == rocsparselt_order_row))
std::swap(_sparseMatDescr->c_ld, _sparseMatDescr->c_n);
}
else
{
auto k = _sparseMatDescr->m;
auto n = _sparseMatDescr->n;
if (op == rocsparselt_operation_transpose)
std::swap(n, k);
_sparseMatDescr->c_k = k / 2;
_sparseMatDescr->c_ld = _sparseMatDescr->c_k;
_sparseMatDescr->c_n = n;
if((op == rocsparselt_operation_transpose)
!= (_sparseMatDescr->order == rocsparselt_order_row))
std::swap(_sparseMatDescr->c_ld, _sparseMatDescr->c_n);
}
}
/*******************************************************************************
* Get the offset of the metatdata (in bytes)
******************************************************************************/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,8 @@ rocsparselt_status rocsparselt_smfmac_compress2(const rocsparselt_handle* han
return rocsparselt_status_not_implemented;
}

initSparseMatrixLayout(op, sparseMatDescr, isSparseA);

log_api(_handle,
__func__,
"sparseMatDescr[in]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "rocsparselt.h"
#include "status.h"
#include "utility.hpp"
#include "rocsparselt_spmm_utils.hpp"

#include "hipsparselt_ostream.hpp"
#include <hip/hip_runtime_api.h>
Expand Down Expand Up @@ -805,6 +806,8 @@ rocsparselt_status rocsparselt_smfmac_prune2(const rocsparselt_handle* handle
return rocsparselt_status_not_implemented;
}

initSparseMatrixLayout(op, sparseMatDescr, isSparseA);

log_api(_handle,
__func__,
"sparseMatDescr[in]",
Expand Down Expand Up @@ -961,6 +964,8 @@ rocsparselt_status rocsparselt_smfmac_prune_check2(const rocsparselt_handle*
return rocsparselt_status_not_implemented;
}

initSparseMatrixLayout(op, sparseMatDescr, isSparseA);

log_api(_handle,
__func__,
"sparseMatDescr[in]",
Expand Down

0 comments on commit de7bb0d

Please sign in to comment.