Skip to content

Commit

Permalink
Support copysign intrinsic in VC
Browse files Browse the repository at this point in the history
.
  • Loading branch information
vmustya authored and igcbot committed Aug 26, 2024
1 parent 0a90ec2 commit cc0aa07
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 1 deletion.
64 changes: 64 additions & 0 deletions IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ class GenXLowering : public FunctionPass {
bool lowerReduction(CallInst *CI, Instruction::BinaryOps Opcode);
bool lowerReduction(CallInst *CI, Intrinsic::ID);

bool lowerCopySign(CallInst *CI);

bool generatePredicatedWrrForNewLoad(CallInst *CI);
};

Expand Down Expand Up @@ -2157,6 +2159,8 @@ bool GenXLowering::processInst(Instruction *Inst) {
return lowerReduction(CI, Intrinsic::maxnum);
case Intrinsic::vector_reduce_fmin:
return lowerReduction(CI, Intrinsic::minnum);
case Intrinsic::copysign:
return lowerCopySign(CI);
case GenXIntrinsic::genx_get_hwid:
return lowerHardwareThreadID(CI);
case vc::InternalIntrinsic::logical_thread_id:
Expand Down Expand Up @@ -5026,6 +5030,66 @@ bool GenXLowering::lowerReduction(CallInst *CI, Intrinsic::ID IID) {
});
}

bool GenXLowering::lowerCopySign(CallInst *CI) {
IRBuilder<> Builder(CI);

auto *Ty = CI->getType()->getScalarType();
auto ElementSize = Ty->getPrimitiveSizeInBits();
auto Stride = ElementSize / genx::WordBits;
IGC_ASSERT(ElementSize % genx::WordBits == 0);
IGC_ASSERT(Stride == 1 || Stride == 2 || Stride == 4);

auto NumElements = 1;
if (auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(CI->getType()))
NumElements = VTy->getNumElements();
auto CastNumElements = NumElements * Stride;

auto *Int16Ty = Builder.getInt16Ty();
auto *CastTy = IGCLLVM::FixedVectorType::get(Int16Ty, CastNumElements);
auto *LowerTy = IGCLLVM::FixedVectorType::get(Int16Ty, NumElements);

auto *Mag = CI->getOperand(0);
auto *Sign = CI->getOperand(1);

auto *MagCast = Builder.CreateBitCast(Mag, CastTy);
auto *MagInt = MagCast;
auto *SignInt = Builder.CreateBitCast(Sign, CastTy);

vc::CMRegion R(LowerTy, DL);
auto &DebugLoc = CI->getDebugLoc();

if (Stride > 1) {
R.VStride = Stride;
R.Width = 1;
R.Stride = 0;
R.Offset = (Stride - 1) * genx::WordBytes;

MagInt = R.createRdRegion(MagInt, "", CI, DebugLoc);
SignInt = R.createRdRegion(SignInt, "", CI, DebugLoc);
}

auto *MagMask = ConstantInt::get(Int16Ty, 0x7FFF);
auto *SignMask = ConstantInt::get(Int16Ty, 0x8000);

auto *MagAbs = Builder.CreateAnd(
MagInt, Builder.CreateVectorSplat(NumElements, MagMask));
auto *SignBit = Builder.CreateAnd(
SignInt, Builder.CreateVectorSplat(NumElements, SignMask));

auto *Res = Builder.CreateOr(MagAbs, SignBit);

if (Stride > 1)
Res = R.createWrRegion(MagCast, Res, "", CI, DebugLoc);

Res = Builder.CreateBitCast(Res, CI->getType());

Res->takeName(CI);
CI->replaceAllUsesWith(Res);
ToErase.push_back(CI);

return true;
}

/***********************************************************************
* widenByteOp : widen a vector byte operation to short if that might
* improve code
Expand Down
2 changes: 1 addition & 1 deletion IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ class BfnMatcher {
return std::nullopt;
}

auto *CV = dyn_cast<ConstantVector>(C);
auto *CV = dyn_cast<ConstantDataVector>(C);
if (!CV)
return std::nullopt;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ Value *SPIRVExpander::visitCallInst(CallInst &CI) {
.StartsWith("popcount", Intrinsic::ctpop)
.StartsWith("s_abs", GenXIntrinsic::genx_absi)
// Floating-point intrinsics
.StartsWith("copysign", Intrinsic::copysign)
.StartsWith("fabs", Intrinsic::fabs)
.StartsWith("fmax", Intrinsic::maxnum)
.StartsWith("fma", Intrinsic::fma)
Expand Down
72 changes: 72 additions & 0 deletions IGC/VectorCompiler/test/Lowering/copysign.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
;=========================== begin_copyright_notice ============================
;
; Copyright (C) 2024 Intel Corporation
;
; SPDX-License-Identifier: MIT
;
;============================ end_copyright_notice =============================

; RUN: %opt %use_old_pass_manager% -GenXLowering -march=genx64 -mcpu=XeHPC -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s

declare <4 x half> @llvm.copysign.v4f16(<4 x half>, <4 x half>)
declare <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat>, <4 x bfloat>)
declare <4 x float> @llvm.copysign.v4f32(<4 x float>, <4 x float>)
declare <4 x double> @llvm.copysign.v4f64(<4 x double>, <4 x double>)

; CHECK-LABEL: @test_v4f16
define <4 x half> @test_v4f16(<4 x half> %src, <4 x half> %sign) {
; CHECK: [[MAG:%.*]] = bitcast <4 x half> %src to <4 x i16>
; CHECK: [[SGN:%.*]] = bitcast <4 x half> %sign to <4 x i16>
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG]], <i16 32767, i16 32767, i16 32767, i16 32767>
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
; CHECK: [[RES_HALF:%.*]] = bitcast <4 x i16> [[RES]] to <4 x half>
; CHECK: ret <4 x half> [[RES_HALF]]
%res = call <4 x half> @llvm.copysign.v4f16(<4 x half> %src, <4 x half> %sign)
ret <4 x half> %res
}

; CHECK-LABEL: @test_v4bf16
define <4 x bfloat> @test_v4bf16(<4 x bfloat> %src, <4 x bfloat> %sign) {
; CHECK: [[MAG:%.*]] = bitcast <4 x bfloat> %src to <4 x i16>
; CHECK: [[SGN:%.*]] = bitcast <4 x bfloat> %sign to <4 x i16>
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG]], <i16 32767, i16 32767, i16 32767, i16 32767>
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
; CHECK: [[RES_BF:%.*]] = bitcast <4 x i16> [[RES]] to <4 x bfloat>
; CHECK: ret <4 x bfloat> [[RES_BF]]
%res = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %src, <4 x bfloat> %sign)
ret <4 x bfloat> %res
}

; CHECK-LABEL: @test_v4f32
define <4 x float> @test_v4f32(<4 x float> %src, <4 x float> %sign) {
; CHECK: [[MAG:%.*]] = bitcast <4 x float> %src to <8 x i16>
; CHECK: [[SGN:%.*]] = bitcast <4 x float> %sign to <8 x i16>
; CHECK: [[MAG_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v8i16.i16(<8 x i16> [[MAG]], i32 2, i32 1, i32 0, i16 2, i32 undef)
; CHECK: [[SGN_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v8i16.i16(<8 x i16> [[SGN]], i32 2, i32 1, i32 0, i16 2, i32 undef)
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG_EXTRACT]], <i16 32767, i16 32767, i16 32767, i16 32767>
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN_EXTRACT]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
; CHECK: [[RES_INSERT:%.*]] = call <8 x i16> @llvm.genx.wrregioni.v8i16.v4i16.i16.i1(<8 x i16> [[MAG]], <4 x i16> [[RES]], i32 2, i32 1, i32 0, i16 2, i32 undef, i1 true)
; CHECK: [[RES_FLOAT:%.*]] = bitcast <8 x i16> [[RES_INSERT]] to <4 x float>
; CHECK: ret <4 x float> [[RES_FLOAT]]
%res = call <4 x float> @llvm.copysign.v4f32(<4 x float> %src, <4 x float> %sign)
ret <4 x float> %res
}

; CHECK-LABEL: @test_v4f64
define <4 x double> @test_v4f64(<4 x double> %src, <4 x double> %sign) {
; CHECK: [[MAG:%.*]] = bitcast <4 x double> %src to <16 x i16>
; CHECK: [[SGN:%.*]] = bitcast <4 x double> %sign to <16 x i16>
; CHECK: [[MAG_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v16i16.i16(<16 x i16> [[MAG]], i32 4, i32 1, i32 0, i16 6, i32 undef)
; CHECK: [[SGN_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v16i16.i16(<16 x i16> [[SGN]], i32 4, i32 1, i32 0, i16 6, i32 undef)
; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG_EXTRACT]], <i16 32767, i16 32767, i16 32767, i16 32767>
; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN_EXTRACT]], <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]]
; CHECK: [[RES_INSERT:%.*]] = call <16 x i16> @llvm.genx.wrregioni.v16i16.v4i16.i16.i1(<16 x i16> [[MAG]], <4 x i16> [[RES]], i32 4, i32 1, i32 0, i16 6, i32 undef, i1 true)
; CHECK: [[RES_DOUBLE:%.*]] = bitcast <16 x i16> [[RES_INSERT]] to <4 x double>
; CHECK: ret <4 x double> [[RES_DOUBLE]]
%res = call <4 x double> @llvm.copysign.v4f64(<4 x double> %src, <4 x double> %sign)
ret <4 x double> %res
}
9 changes: 9 additions & 0 deletions IGC/VectorCompiler/test/PatternMatch/bfn_match.ll
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,12 @@ define i32 @test_unmatch_flag(<32 x i1> %a, <32 x i1> %b, <32 x i1> %c) {

ret i32 %2
}

; CHECK-LABEL: @test_match_combine_by_mask_inv_vector(
define <4 x i16> @test_match_combine_by_mask_inv_vector(<4 x i16> %mag, <4 x i16> %sgn) {
; CHECK: %res = call <4 x i16> @llvm.genx.bfn.v4i16.v4i16(<4 x i16> %mag, <4 x i16> %sgn, <4 x i16> <i16 32767, i16 32767, i16 32767, i16 32767>, i8 -84)
%abs = and <4 x i16> %mag, <i16 32767, i16 32767, i16 32767, i16 32767>
%sign = and <4 x i16> %sgn, <i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%res = or <4 x i16> %abs, %sign
ret <4 x i16> %res
}
8 changes: 8 additions & 0 deletions IGC/VectorCompiler/test/SPIRVBuiltins/math_native_builtins.ll
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ declare spir_func <16 x float> @_Z15__spirv_ocl_madDv16_fS_S_(<16 x float>, <16
declare spir_func <7 x double> @_Z15__spirv_ocl_fmaDv7_dS_S_(<7 x double>, <7 x double>, <7 x double>)
declare spir_func <7 x double> @_Z15__spirv_ocl_fmaxDv7_dS_S_(<7 x double>, <7 x double>)
declare spir_func <16 x double> @_Z16__spirv_ocl_fabsDv16_d(<16 x double>)
declare spir_func <7 x double> @_Z15__spirv_ocl_copysignDv7_dS_S_(<7 x double>, <7 x double>)

define spir_func i32 @popcount(i32 %arg) {
; CHECK-LABEL: @popcount
Expand Down Expand Up @@ -229,3 +230,10 @@ define spir_func <16 x double> @abs_vector(<16 x double> %arg) {
%res = call spir_func <16 x double> @_Z16__spirv_ocl_fabsDv16_d(<16 x double> %arg)
ret <16 x double> %res
}

define spir_func <7 x double> @copysign_vector(<7 x double> %arg1, <7 x double> %arg2) {
; CHECK-LABEL: @copysign_vector
; CHECK: %res = call <7 x double> @llvm.copysign.v7f64(<7 x double> %arg1, <7 x double> %arg2)
%res = call spir_func <7 x double> @_Z15__spirv_ocl_copysignDv7_dS_S_(<7 x double> %arg1, <7 x double> %arg2)
ret <7 x double> %res
}

0 comments on commit cc0aa07

Please sign in to comment.