Skip to content

Commit

Permalink
Enable alpha vector scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
vin-huang committed Sep 3, 2024
1 parent a1a547f commit 1cddae8
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 20 deletions.
11 changes: 7 additions & 4 deletions library/src/hcc_detail/rocsparselt/src/include/handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ struct _rocsparselt_matmul_descr
, bias_pointer(rhs.bias_pointer)
, bias_stride(rhs.bias_stride)
, bias_type(rhs.bias_type)
, alpha_vector_scaling(rhs.alpha_vector_scaling)
, m(rhs.m)
, n(rhs.n)
, k(rhs.k)
Expand Down Expand Up @@ -284,10 +285,11 @@ struct _rocsparselt_matmul_descr
float* bias_pointer = nullptr;
int64_t bias_stride = 0;
hipDataType bias_type;
int64_t m = 0;
int64_t n = 0;
int64_t k = 0;
bool is_sparse_a = true;
int alpha_vector_scaling = 0;
int64_t m = 0;
int64_t n = 0;
int64_t k = 0;
bool is_sparse_a = true;

rocsparselt_operation _op_A;
rocsparselt_operation _op_B;
Expand Down Expand Up @@ -320,6 +322,7 @@ struct __attribute__((packed, aligned(8))) _rocsparselt_matmul_config

int index;
int use_bias = 0;
int use_scale_alpha_vec = 0;
size_t max_workspace_bytes = 0;
};

Expand Down
11 changes: 10 additions & 1 deletion library/src/hcc_detail/rocsparselt/src/include/tensile_host.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ struct RocsparseltContractionProblem
const void* bias_vector;
int64_t bias_stride;
hipDataType bias_type;
bool alpha_vector_scaling;

void* workspace;
size_t workspaceSize;
Expand Down Expand Up @@ -150,6 +151,7 @@ struct RocsparseltContractionProblem
const void* bias_vector,
int64_t bias_stride,
hipDataType bias_type,
bool alpha_vector_scaling,
void* workspace,
size_t workspaceSize,
hipStream_t* streams,
Expand Down Expand Up @@ -197,6 +199,7 @@ struct RocsparseltContractionProblem
, bias_vector(bias_vector)
, bias_stride(bias_stride)
, bias_type(bias_type)
, alpha_vector_scaling(alpha_vector_scaling)
, workspace(workspace)
, workspaceSize(workspaceSize)
, streams(streams)
Expand Down Expand Up @@ -245,6 +248,7 @@ struct RocsparseltContractionProblem
const void* bias_vector,
int64_t bias_stride,
hipDataType bias_type,
bool alpha_vector_scaling,
void* workspace,
size_t workspaceSize,
hipStream_t* streams,
Expand Down Expand Up @@ -292,6 +296,7 @@ struct RocsparseltContractionProblem
, bias_vector(bias_vector)
, bias_stride(bias_stride)
, bias_type(bias_type)
, alpha_vector_scaling(alpha_vector_scaling)
, workspace(workspace)
, workspaceSize(workspaceSize)
, streams(streams)
Expand Down Expand Up @@ -341,6 +346,7 @@ struct RocsparseltContractionProblem
const void* bias_vector,
int64_t bias_stride,
hipDataType bias_type,
bool alpha_vector_scaling,
void* workspace,
size_t workspaceSize,
hipStream_t* streams,
Expand Down Expand Up @@ -388,6 +394,7 @@ struct RocsparseltContractionProblem
, bias_vector(bias_vector)
, bias_stride(bias_stride)
, bias_type(bias_type)
, alpha_vector_scaling(alpha_vector_scaling)
, workspace(workspace)
, workspaceSize(workspaceSize)
, streams(streams)
Expand Down Expand Up @@ -466,7 +473,9 @@ struct RocsparseltContractionProblem
"bias_stride",
prob.bias_stride,
"bias_type",
hipDataType_to_string(prob.bias_type)));
hipDataType_to_string(prob.bias_type),
"alpha_vector_scaling",
prob.alpha_vector_scaling));
};
};

Expand Down
10 changes: 10 additions & 0 deletions library/src/hcc_detail/rocsparselt/src/rocsparselt_auxiliary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,11 @@ rocsparselt_status
}
break;
}
case rocsparselt_matmul_alpha_vector_scaling:
{
assign_data(&_matmulDescr->alpha_vector_scaling);
break;
}
default:
log_error(
_handle, __func__, "matmulAttribute", matmulAttribute, "is not implemented");
Expand Down Expand Up @@ -1195,6 +1200,11 @@ rocsparselt_status
retrive_data(_matmulDescr->bias_type);
break;
}
case rocsparselt_matmul_alpha_vector_scaling:
{
retrive_data(_matmulDescr->alpha_vector_scaling);
break;
}
default:
log_error(
_handle, __func__, "matmulAttribute", matmulAttribute, "is not implemented");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ rocsparselt_status ConstructRocSparseLtProblem(const char*
matmul_descr->bias_pointer,
matmul_descr->bias_stride,
matmul_descr->bias_type,
matmul_descr->alpha_vector_scaling,
workspace,
workspaceSize,
streams,
Expand Down
64 changes: 49 additions & 15 deletions library/src/hcc_detail/rocsparselt/src/tensile_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ namespace
****************************************************************/
template <typename Ti, typename To, typename Tc>
auto ConstructTensileProblem(const RocsparseltContractionProblem<Ti, To, Tc>& prob,
int useBias = 0)
int useBias = 0,
int useScaleAlphaVec = 0)
{
// Tensile DataTypes corresponding to rocsparselt data types
static constexpr Tensile::DataType Tensile_Ti = tensile_datatype<Ti>;
Expand Down Expand Up @@ -302,8 +303,14 @@ namespace
// If k==0, we do not need to dereference prob.alpha and can set tensileAlpha=0
// Not positive if this is necessary here as well
typename AlphaBeta<Ti, To, Tc>::tensile_type tensileAlpha;
const Tc ALPHAONE = static_cast<Tc>(1);
if(prob.k)
AlphaBeta<Ti, To, Tc>::copy(&tensileAlpha, prob.alpha);
{
if(prob.alpha_vector_scaling)
AlphaBeta<Ti, To, Tc>::copy(&tensileAlpha, &ALPHAONE);
else
AlphaBeta<Ti, To, Tc>::copy(&tensileAlpha, prob.alpha);
}
else
memset(&tensileAlpha, 0, sizeof(tensileAlpha));
tensileProblem.setAlphaRestriction(Tensile::toScalarValueEnum(tensileAlpha));
Expand Down Expand Up @@ -349,11 +356,13 @@ namespace
}
tensileProblem.setParams().setActivationEnum(tensileAct);

assert((useBias == useScaleAlphaVec) || (useBias == 0 || useScaleAlphaVec == 0));

int guessedFactorDim = (prob.order == rocsparselt_order_row) ? 2 : 1;
// set bias mode
if(prob.bias_vector != nullptr || useBias > 0)
{
int guessedUseBias = (prob.order == rocsparselt_order_row) ? 2 : 1;
tensileProblem.setUseBias(useBias > 0 ? useBias : guessedUseBias);
tensileProblem.setUseBias(useBias > 0 ? useBias : guessedFactorDim);
tensileProblem.setBias(hipDataType_to_tensile_type(prob.bias_type),
prob.order == rocsparselt_order_row ? d.sizes()[1]
: d.sizes()[0],
Expand All @@ -363,6 +372,16 @@ namespace
prob.order == rocsparselt_order_row);
}

if(prob.alpha_vector_scaling || useScaleAlphaVec > 0)
{

tensileProblem.setUseScaleAlphaVec(useScaleAlphaVec > 0 ? useScaleAlphaVec
: guessedFactorDim);
tensileProblem.setScaleAlphaVec(Tensile_Tc,
prob.order == rocsparselt_order_row ? d.sizes()[1]
: d.sizes()[0],
prob.order == rocsparselt_order_row);
}
return tensileProblem;
}

Expand Down Expand Up @@ -405,12 +424,19 @@ namespace

// set bias vector
inputs.bias = reinterpret_cast<const void*>(prob.bias_vector);
if(prob.alpha_vector_scaling)
inputs.scaleAlphaVec = reinterpret_cast<const void*>(prob.alpha);

// alpha and beta are stored by value in Tensile::TypedContractionInputs
// alpha and beta are copied from host to Tensile::TypedContractionInputs
// If k==0, we do not need to dereference prob.alpha and can set inputs.alpha=0
if(prob.k)
inputs.alpha = static_cast<Tensile_Talpha_beta>((*prob.alpha));
{
if(prob.alpha_vector_scaling)
inputs.alpha = static_cast<Tensile_Talpha_beta>(1);
else
inputs.alpha = static_cast<Tensile_Talpha_beta>((*prob.alpha));
}
else
inputs.alpha = static_cast<Tensile_Talpha_beta>(0);
inputs.beta = static_cast<Tensile_Talpha_beta>((*prob.beta));
Expand Down Expand Up @@ -766,7 +792,8 @@ rocsparselt_status runContractionProblem(const RocsparseltContractionProblem<Ti,
}
else
{
auto tensile_prob = ConstructTensileProblem(prob, configs[*config_id].use_bias);
auto tensile_prob = ConstructTensileProblem(
prob, configs[*config_id].use_bias, configs[*config_id].use_scale_alpha_vec);

auto tensile_inputs = GetTensileInputs(prob);

Expand Down Expand Up @@ -900,8 +927,8 @@ rocsparselt_status getBestSolutions(const RocsparseltContractionProblem<Ti, To,
*foundConfigs = std::min((int)solutions.size(), requestConfigs);

// Finding alternative solutions.
auto findAlternativeSolution = [&](int useBias) {
tensile_prob = ConstructTensileProblem(prob, useBias);
auto findAlternativeSolution = [&](int useBias, int useScaleAlphaVec) {
tensile_prob = ConstructTensileProblem(prob, useBias, useScaleAlphaVec);
solutions = library->findTopSolutions(tensile_prob, *hardware, requestConfigs);
*foundConfigs = std::min((int)solutions.size(), requestConfigs);
};
Expand All @@ -910,18 +937,24 @@ rocsparselt_status getBestSolutions(const RocsparseltContractionProblem<Ti, To,
{
log_info(prob.handle, __func__, "No solution founds, try to find alternative solutions");

if(tensile_prob.useBias() == 0)
for(int useBias = tensile_prob.useBias(); useBias < 4; useBias++)
{
for(int useBias = 1; useBias < 4; useBias++)
for(int useScaleAlphaVec = tensile_prob.useScaleAlphaVec(); useScaleAlphaVec < 4;
useScaleAlphaVec++)
{
findAlternativeSolution(useBias);
if(useBias != 0 && useScaleAlphaVec != 0 && useScaleAlphaVec != useBias)
continue; //useBias and useScaleAlphaVec must in the same dimension.

if(useBias == tensile_prob.useBias()
&& useScaleAlphaVec == tensile_prob.useScaleAlphaVec())
continue; // already try in the first time.

findAlternativeSolution(useBias, useScaleAlphaVec);
if(*foundConfigs != 0)
break;
}
}
else if(tensile_prob.useBias() == 1 || tensile_prob.useBias() == 2)
{
findAlternativeSolution(3);
if(*foundConfigs != 0)
break;
}

if(*foundConfigs != 0)
Expand All @@ -936,6 +969,7 @@ rocsparselt_status getBestSolutions(const RocsparseltContractionProblem<Ti, To,
configs[i].index = solution->index;
configs[i].max_workspace_bytes = solution->requiredWorkspaceSize(tensile_prob);
configs[i].use_bias = tensile_prob.useBias();
configs[i].use_scale_alpha_vec = tensile_prob.useScaleAlphaVec();
}
return rocsparselt_status_success;
}
Expand Down

0 comments on commit 1cddae8

Please sign in to comment.