Skip to content

Commit

Permalink
Apply rsqrt pattern for double
Browse files Browse the repository at this point in the history
.
  • Loading branch information
igorban-intel authored and igcbot committed Aug 24, 2024
1 parent 945236a commit fc58546
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 47 deletions.
39 changes: 22 additions & 17 deletions IGC/VectorCompiler/lib/BiF/Library/Math/F64/rsqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,34 @@ namespace {

template <int N>
CM_NODEBUG CM_INLINE mask<N> check_is_nan_or_inf(vector<double, N> q) {
vector<uint32_t, 2 * N> q_split = q.template format<uint32_t>();
vector<uint32_t, 2 *N> q_split = q.template format<uint32_t>();
vector<uint32_t, N> q_hi = q_split.template select<N, 2>(1);
return (q_hi >= exp_32bitmask);
}

template <int N>
CM_NODEBUG CM_INLINE vector<uint32_t, N> get_exp(vector<double, N> x) {
vector<uint32_t, 2 * N> x_split = x.template format<uint32_t>();
vector<uint32_t, 2 *N> x_split = x.template format<uint32_t>();
vector<uint32_t, N> x_hi = x_split.template select<N, 2>(1);
return (x_hi >> exp_shift) & exp_mask;
}

template <int N>
CM_NODEBUG CM_INLINE vector<uint32_t, N> get_sign(vector<double, N> x) {
vector<uint32_t, 2 * N> x_split = x.template format<uint32_t>();
vector<uint32_t, 2 *N> x_split = x.template format<uint32_t>();
vector<uint32_t, N> x_hi = x_split.template select<N, 2>(1);
return x_hi & sign_32bit;
}

template <int N> CM_NODEBUG CM_INLINE mask<N> is_denormal(vector<double, N> x) {
vector<uint32_t, 2 * N> x_int = x.template format<uint32_t>();
vector<uint32_t, 2 *N> x_int = x.template format<uint32_t>();
vector<uint32_t, N> x_hi = x_int.template select<N, 2>(1);
return x_hi < min_sign_exp;
}

template <int N>
CM_NODEBUG CM_INLINE vector<uint32_t, N> sep_exp(vector<double, N> x) {
vector<uint32_t, 2 * N> x_int = x.template format<uint32_t>();
vector<uint32_t, 2 *N> x_int = x.template format<uint32_t>();
vector<uint32_t, N> x_hi = x_int.template select<N, 2>(1);
vector<uint32_t, N> res = (x_hi >> exp_shift) - exp_bias;
return res >> 1;
Expand Down Expand Up @@ -84,8 +84,9 @@ CM_NODEBUG CM_INLINE vector<double, N> rsqrt_float(vector<double, N> x) {
}

template <int N>
CM_NODEBUG CM_INLINE vector<double, N> uint64_sub_hi(vector<double, N> x, vector<uint32_t, N> hi) {
vector<uint32_t, 2 * N> ex_mx_int = 0;
CM_NODEBUG CM_INLINE vector<double, N> uint64_sub_hi(vector<double, N> x,
vector<uint32_t, N> hi) {
vector<uint32_t, 2 *N> ex_mx_int = 0;
ex_mx_int.template select<N, 2>(1) = hi;
vector<uint64_t, N> ex_u64 = ex_mx_int.template format<uint64_t>();
vector<uint64_t, N> mx_u64 = x.template format<uint64_t>();
Expand Down Expand Up @@ -163,9 +164,10 @@ CM_NODEBUG CM_INLINE vector<double, N> sqrt_special(vector<double, N> a) {
}

template <int N>
CM_NODEBUG CM_INLINE vector<double, N> calc_sqrt(vector<double, N> x, mask<N> special) {
CM_NODEBUG CM_INLINE vector<double, N> calc_sqrt(vector<double, N> x,
mask<N> special) {
// Now start the SQRT computation
// Use math.rsqtm (emulated here)
// Use math.rsqtm (emulated here)
vector<double, N> y0 = math_rsqt_dp(x);
// predicate is set for 0, neg a, Inf, NaN inputs
y0.merge(sqrt_special(x), special);
Expand All @@ -174,7 +176,8 @@ CM_NODEBUG CM_INLINE vector<double, N> calc_sqrt(vector<double, N> x, mask<N> sp
}

template <int N>
CM_NODEBUG CM_INLINE vector<double, N> invert_calc(vector<double, N> a, vector<double, N> y0) {
CM_NODEBUG CM_INLINE vector<double, N> invert_calc(vector<double, N> a,
vector<double, N> y0) {
// IEEE SQRT computes H0 = 0.5*y0 (can be skipped)
// Step 3: S0 = a*y0
vector<double, N> S0 = a * y0;
Expand Down Expand Up @@ -235,15 +238,17 @@ __vc_builtin_rsqrt_f64__rte_(double a) {
return __impl_rsqrt_f64(va)[0];
}

#define FREM(WIDTH) \
#define RSQRT(WIDTH) \
CM_NODEBUG CM_NOINLINE extern "C" cl_vector<double, WIDTH> \
__vc_builtin_rsqrt_v##WIDTH##f64__rte_(cl_vector<double, WIDTH> a) { \
__vc_builtin_rsqrt_v##WIDTH##f64__rte_(cl_vector<double, WIDTH> a) { \
vector<double, WIDTH> va{a}; \
auto r = __impl_rsqrt_f64(va); \
auto r = __impl_rsqrt_f64(va); \
return r.cl_vector(); \
}

FREM(1)
FREM(2)
FREM(4)
FREM(8)
RSQRT(1)
RSQRT(2)
RSQRT(4)
RSQRT(8)
RSQRT(16)
RSQRT(32)
53 changes: 27 additions & 26 deletions IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,8 +1255,9 @@ bool GenXPatternMatch::flipBoolNot(Instruction *Inst) {
bool GenXPatternMatch::matchInverseSqrt(CallInst *I) {
IGC_ASSERT(I && I->arg_size() == 1);

// Leave as it is for double types
if (I->getType()->getScalarType()->isDoubleTy())
// Double rsqrt may be generated only before legalization
if (I->getType()->getScalarType()->isDoubleTy() &&
(!ST->hasFP64() || Kind == PatternMatchKind::PostLegalization))
return false;

bool IsFast = true;
Expand Down Expand Up @@ -2293,9 +2294,9 @@ bool MinMaxMatcher::emit() {

// For a given instruction, find the insertion position which is the closest
// to all the similar users to the specified reference user.
static Instruction *findOptimalInsertionPos(
Instruction *I, Instruction *Ref, DominatorTree *DT,
std::function<bool(Instruction *, Instruction *)> IsDivisor) {
static Instruction *
findOptimalInsertionPos(Value *I, Instruction *Ref, DominatorTree *DT,
std::function<bool(Instruction *, Value *)> IsDivisor) {
IGC_ASSERT_MESSAGE(!isa<PHINode>(Ref), "PHINode is not expected!");

// Shortcut case. If it's single-used, insert just before that user.
Expand Down Expand Up @@ -2402,48 +2403,48 @@ void GenXPatternMatch::visitFDiv(BinaryOperator &I) {
return;
}

// Skip if FP64 emulation is required for this platform
if (ST->emulateFDivFSqrt64() && I.getType()->getScalarType()->isDoubleTy())
return;

Instruction *Divisor = dyn_cast<Instruction>(Op1);
if (!Divisor)
return;

auto IsDivisor = [](Instruction *I, Instruction *MaybeDivisor) {
auto IsDivisor = [](Instruction *I, Value *MaybeDivisor) {
return I->getOpcode() == Instruction::FDiv &&
I->getOperand(1) == MaybeDivisor;
};

Instruction *Pos = findOptimalInsertionPos(Divisor, &I, DT, IsDivisor);
Instruction *Pos = findOptimalInsertionPos(Op1, &I, DT, IsDivisor);
IRB.SetInsertPoint(Pos);

// (fdiv 1., (sqrt x)) -> (rsqrt x)
// Allow the pattern even if fdiv has no fast-math flags.
auto IID = vc::getAnyIntrinsicID(Divisor);
if ((IID == GenXIntrinsic::genx_sqrt ||
(IID == Intrinsic::sqrt && Divisor->hasApproxFunc())) &&
match(Op0, m_FPOne()) && Divisor->hasOneUse()) {
auto *Rsqrt = createInverseSqrt(Divisor->getOperand(0), Pos);
I.replaceAllUsesWith(Rsqrt);
I.eraseFromParent();
Divisor->eraseFromParent();
if (Divisor) {
auto IID = vc::getAnyIntrinsicID(Divisor);
if ((IID == GenXIntrinsic::genx_sqrt ||
(IID == Intrinsic::sqrt && Divisor->hasApproxFunc())) &&
match(Op0, m_FPOne()) && Divisor->hasOneUse()) {
auto *Rsqrt = createInverseSqrt(Divisor->getOperand(0), Pos);
I.replaceAllUsesWith(Rsqrt);
I.eraseFromParent();
Divisor->eraseFromParent();

Changed |= true;
return;
}
}

Changed |= true;
// Skip if FP64 emulation is required for this platform
if (ST->emulateFDivFSqrt64() && I.getType()->getScalarType()->isDoubleTy())
return;
}

// Skip if reciprocal optimization is not allowed.
if (!I.hasAllowReciprocal())
return;

auto Rcp = getReciprocal(IRB, Divisor);
auto *Rcp = getReciprocal(IRB, Op1);
cast<Instruction>(Rcp)->setDebugLoc(I.getDebugLoc());

for (auto UI = Divisor->user_begin(); UI != Divisor->user_end();) {
for (auto UI = Op1->user_begin(); UI != Op1->user_end();) {
auto *U = *UI++;
Instruction *UserInst = dyn_cast<Instruction>(U);
if (!UserInst || UserInst == Rcp || !IsDivisor(UserInst, Divisor))
if (!UserInst || UserInst == Rcp || !IsDivisor(UserInst, Op1))
continue;
Op0 = UserInst->getOperand(0);
Value *NewVal = Rcp;
Expand Down
26 changes: 22 additions & 4 deletions IGC/VectorCompiler/test/PatternMatch/inverse_sqrt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
;
;============================ end_copyright_notice =============================

; RUN: %opt %use_old_pass_manager% -GenXPatternMatch -march=genx64 -mcpu=Gen9 -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s
; RUN: %opt %use_old_pass_manager% -GenXPatternMatch -march=genx64 -mcpu=Gen9 \
; RUN: -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s

; CHECK-LABEL: @test_inverse
define <16 x float> @test_inverse(<16 x float> %val) {
Expand Down Expand Up @@ -56,10 +57,10 @@ define <16 x float> @test_inverse_not_fast(<16 x float> %src) {
ret <16 x float> %inv
}

; CHECK-LABEL: @test_not_inverse_double
define <16 x double> @test_not_inverse_double(<16 x double> %val_double) {
; CHECK-LABEL: @test_inverse_double
define <16 x double> @test_inverse_double(<16 x double> %val_double) {
%sqrt = call <16 x double> @llvm.sqrt.v16f64(<16 x double> %val_double)
; CHECK: @llvm.genx.inv.v16f64(<16 x double> %sqrt)
; CHECK: call <16 x double> @llvm.genx.rsqrt.v16f64(<16 x double> %val_double)
%inv = call <16 x double> @llvm.genx.inv.v16f64(<16 x double> %sqrt)
ret <16 x double> %inv
}
Expand Down Expand Up @@ -240,6 +241,22 @@ define float @test_inv_sqrt_6(float %val) {
ret float %sqrt
}

; CHECK-LABEL: @test_inverse_double_2
define <16 x double> @test_inverse_double_2(<16 x double> %val_double) {
%sqrt = call <16 x double> @llvm.genx.sqrt.v16f64(<16 x double> %val_double)
%div = fdiv <16 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>, %sqrt
; CHECK: call <16 x double> @llvm.genx.rsqrt.v16f64(<16 x double> %val_double)
ret <16 x double> %div
}

; CHECK-LABEL: @test_inverse_double_3
define <16 x double> @test_inverse_double_3(<16 x double> %val_double) {
%div = fdiv arcp <16 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>, %val_double
%sqrt = call <16 x double> @llvm.genx.sqrt.v16f64(<16 x double> %div)
; CHECK: call <16 x double> @llvm.genx.rsqrt.v16f64(<16 x double> %val_double)
ret <16 x double> %sqrt
}

declare float @llvm.sqrt.f32(float)
declare float @llvm.genx.sqrt.f32(float)
declare float @llvm.genx.inv.f32(float)
Expand All @@ -248,5 +265,6 @@ declare <2 x float> @llvm.genx.inv.v2f32(<2 x float>)
declare <16 x float> @llvm.sqrt.v16f32(<16 x float>)
declare <16 x double> @llvm.sqrt.v16f64(<16 x double>)
declare <16 x float> @llvm.genx.sqrt.v16f32(<16 x float>)
declare <16 x double> @llvm.genx.sqrt.v16f64(<16 x double>)
declare <16 x float> @llvm.genx.inv.v16f32(<16 x float>)
declare <16 x double> @llvm.genx.inv.v16f64(<16 x double>)

0 comments on commit fc58546

Please sign in to comment.