Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SVE optimised float WSSJ kernel #2917

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cpp/daal/src/algorithms/svm/svm_train_common_impl.i
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* file: svm_train_common_impl.i */
/*******************************************************************************
* Copyright 2020 Intel Corporation
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,7 +42,11 @@

#endif // __CPUID__(DAAL_CPU) == __avx512__
#endif // defined (_M_AMD64) || defined (__amd64) || defined (__x86_64) || defined (__x86_64__)
#endif // DAAL_INTEL_CPP_COMPILER
#elif defined(TARGET_ARM)
#if (__CPUID__(DAAL_CPU) == __sve__)
#include "src/algorithms/svm/svm_train_common_sve_impl.i"
#endif // __CPUID__(DAAL_CPU) == __sve__
#endif // DAAL_INTEL_CPP_COMPILER

namespace daal
{
Expand Down
150 changes: 150 additions & 0 deletions cpp/daal/src/algorithms/svm/svm_train_common_sve_impl.i
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/*
* Contains optimizations for SVE.
*/
Comment on lines +16 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't adding any information. If you want to add more information about what the algorithm is doing into the comment, that would be ideal. Otherwise remove it


#include <arm_sve.h>
#include "src/services/service_data_utils.h"

namespace daal
{
namespace algorithms
{
namespace svm
{
namespace training
{
namespace internal
{

template <>
inline void HelperTrainSVM<float, sve>::WSSjLocal(const size_t jStart, const size_t jEnd, const float * KiBlock, const float * kernelDiag,
const float * grad, const char * I, const float GMin, const float Kii, const float tau, int & Bj,
float & GMax, float & GMax2, float & delta, SignNuType signNuType)
{
const int w = (int)svcntw(); //vector length
float fpMax = MaxVal<float>::get();
float GMax2Local = -fpMax; // store min(grad[i]) or max(y[i]*grad[i]), y[i]*grad[i] = -GMin2
float GMaxLocal = -fpMax; // store min(-b^2/a) or max(b^2/a), b^2/a = -GMin
float GMinLocal = GMin;

float zero(0.0);
float two(2.0);

const char sign = getSign(signNuType);

svbool_t pgf = svptrue_b32(); //predicate for float

svfloat32_t valGMax2 = svdup_f32(GMax2Local);
svfloat32_t valGMax = svdup_f32(GMaxLocal);
svfloat32_t valGMin = svdup_f32(GMinLocal);
svint32_t Bj_vec = svdup_s32(-1);

// some constants used during optimization
// enum SVMVectorStatus low = 0x2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where low would be defined before this, but maybe this isn't supposed to be a comment? The code below uses it, so I'm assuming this should be uncommented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low is defined elsewhere, this comment basically reminds what the value of low is outside.

svint32_t vecSignLow;
if (signNuType == SignNuType::none)
{
vecSignLow = svdup_n_s32(low);
}
else
{
DAAL_ASSERT((sign & (sign - 1)) == 0) // used to make sure sign is always having 1 bit set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what get sign returns, but this assert is also true when sign = 0, so the comment isn't correct. I suspect this might not be what you want to be checking on the result of getSign

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to keep this optimization done under check, since low = 0x2, if it were to ever change this debug assert would help. Where getSign is defined here -

DAAL_FORCEINLINE static char getSign(SignNuType signNuType)

svint32_t t1 = svdup_n_s32(low);
svint32_t t2 = svdup_n_s32(sign);
vecSignLow = svorr_s32_z(pgf, t1, t2);
}

svfloat32_t two_vec = svdup_f32(two);
svfloat32_t Kii_vec = svdup_f32(Kii);
svfloat32_t tau_vec = svdup_f32(tau);

size_t j_cur = jStart;

for (j_cur; j_cur < jEnd; j_cur += w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (j_cur; j_cur < jEnd; j_cur += w)
for (; j_cur < jEnd; j_cur += w)

{
svint32_t Bj_vec_cur = svindex_s32(j_cur, 1); // Bj value starts with j_cur
svbool_t pg2 = svwhilelt_b32(j_cur, jEnd); // adapts to vector length

svint32_t vec_I = svld1sb_s32(pg2, reinterpret_cast<const int8_t *>(&I[j_cur])); // load chars

// Combine 2 if conditions
// cond1: !(I[j]&sign) {continue}
// cond2: (I[j]&low)!=low {continue}
// combined: (I[j] & (sign | low)) == (sign | low)
// assertion @L63 is a prerequisite for the combined condition to satisfy
svint32_t result_of_and32 = svand_s32_m(pg2, vec_I, vecSignLow);
pg2 = svcmpeq_s32(pg2, result_of_and32, vecSignLow); // if pg2 bit is 0 then continue;

svfloat32_t valGrad = svld1_f32(pg2, &grad[j_cur]); // load grads
// if (gradj > GMax2) { GMax2 = gradj; }
valGMax2 = svmax_f32_m(pg2, valGMax2, valGrad);
// cond3: if (gradj < GMin) { continue; }
svbool_t cond3 = svcmpge_f32(pg2, valGrad, valGMin);
pg2 = svand_b_z(pg2, pg2, cond3); // combine all 3 conditions

svfloat32_t b_vec = svsub_f32_x(pg2, valGMin, valGrad); // b = Gmin - grad

svfloat32_t KiBlock_vec = svld1_f32(pg2, KiBlock + j_cur - jStart); // load kiBlocs
svfloat32_t kernelDiag_vec = svld1_f32(pg2, &kernelDiag[j_cur]); // load kernelDiags
svfloat32_t a_vec = svnmls_f32_x(pg2, kernelDiag_vec, two_vec, KiBlock_vec); // a_tmp = two * KiBlock[j - jStart] - kernelDiag[j]

// originally, if a < 0, a = tau
// mask3_ : 1 if Kii > a_tmp
// if mask3_ : 1, a = Kii - a_tmp, else a = tau.
svbool_t mask3_ = svcmpgt_f32(pg2, Kii_vec, a_vec);
a_vec = svsel_f32(mask3_, svsub_f32_x(mask3_, Kii_vec, a_vec), tau_vec);

svfloat32_t dt_vec = svdiv_f32_x(pg2, b_vec, a_vec); // b/a = delta.
svfloat32_t objFunc_vec = svmul_f32_x(pg2, dt_vec, b_vec); // objFunc = b * delta

svbool_t mask4_ = svcmpgt_f32(pg2, objFunc_vec, valGMax); // if (objFunc > GMax)
valGMax = svsel_f32(mask4_, objFunc_vec, valGMax); // if mask is 1, valGMax = objFunc_vec, else valGMax original value
Bj_vec = svsel_s32(mask4_, Bj_vec_cur, Bj_vec); // if mask is 1, Bj_vec = Bj_vec_cur, else Bj_vec original value
}

// reductions
GMax = svmaxv_f32(pgf, valGMax);
GMax2 = svmaxv_f32(pgf, valGMax2);
svbool_t tmp_mask = svcmpeq(pgf, svdup_f32(GMax), valGMax);
Bj = svmaxv_s32(tmp_mask, Bj_vec);

if (Bj != -1)
{
const double gradBj = grad[Bj];
const double b = GMin - gradBj;
double a = Kii + kernelDiag[Bj] - two * KiBlock[Bj - jStart];
if (a <= zero)
{
a = tau;
}
delta = b / a;
GMax = b * delta;
}
else
{
GMax = -fpMax;
GMax2 = -fpMax;
}
delta = -delta;
}

} // namespace internal
} // namespace training
} // namespace svm
} // namespace algorithms
} // namespace daal
Loading