From cc0aa07751e84f4989baf7928cf9c3f2f163ab5d Mon Sep 17 00:00:00 2001 From: Victor Mustya Date: Sat, 24 Aug 2024 02:24:50 +0000 Subject: [PATCH] Support copysign intrinsic in VC . --- .../lib/GenXCodeGen/GenXLowering.cpp | 64 +++++++++++++++++ .../lib/GenXCodeGen/GenXPatternMatch.cpp | 2 +- .../CMTrans/GenXTranslateSPIRVBuiltins.cpp | 1 + IGC/VectorCompiler/test/Lowering/copysign.ll | 72 +++++++++++++++++++ .../test/PatternMatch/bfn_match.ll | 9 +++ .../SPIRVBuiltins/math_native_builtins.ll | 8 +++ 6 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 IGC/VectorCompiler/test/Lowering/copysign.ll diff --git a/IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp b/IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp index 76b40e12864f..74f3f8bbc772 100644 --- a/IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp +++ b/IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp @@ -276,6 +276,8 @@ class GenXLowering : public FunctionPass { bool lowerReduction(CallInst *CI, Instruction::BinaryOps Opcode); bool lowerReduction(CallInst *CI, Intrinsic::ID); + bool lowerCopySign(CallInst *CI); + bool generatePredicatedWrrForNewLoad(CallInst *CI); }; @@ -2157,6 +2159,8 @@ bool GenXLowering::processInst(Instruction *Inst) { return lowerReduction(CI, Intrinsic::maxnum); case Intrinsic::vector_reduce_fmin: return lowerReduction(CI, Intrinsic::minnum); + case Intrinsic::copysign: + return lowerCopySign(CI); case GenXIntrinsic::genx_get_hwid: return lowerHardwareThreadID(CI); case vc::InternalIntrinsic::logical_thread_id: @@ -5026,6 +5030,66 @@ bool GenXLowering::lowerReduction(CallInst *CI, Intrinsic::ID IID) { }); } +bool GenXLowering::lowerCopySign(CallInst *CI) { + IRBuilder<> Builder(CI); + + auto *Ty = CI->getType()->getScalarType(); + auto ElementSize = Ty->getPrimitiveSizeInBits(); + auto Stride = ElementSize / genx::WordBits; + IGC_ASSERT(ElementSize % genx::WordBits == 0); + IGC_ASSERT(Stride == 1 || Stride == 2 || Stride == 4); + + auto NumElements = 1; + if (auto *VTy = dyn_cast(CI->getType())) + NumElements = VTy->getNumElements(); + auto CastNumElements = NumElements * Stride; + + auto *Int16Ty = Builder.getInt16Ty(); + auto *CastTy = IGCLLVM::FixedVectorType::get(Int16Ty, CastNumElements); + auto *LowerTy = IGCLLVM::FixedVectorType::get(Int16Ty, NumElements); + + auto *Mag = CI->getOperand(0); + auto *Sign = CI->getOperand(1); + + auto *MagCast = Builder.CreateBitCast(Mag, CastTy); + auto *MagInt = MagCast; + auto *SignInt = Builder.CreateBitCast(Sign, CastTy); + + vc::CMRegion R(LowerTy, DL); + auto &DebugLoc = CI->getDebugLoc(); + + if (Stride > 1) { + R.VStride = Stride; + R.Width = 1; + R.Stride = 0; + R.Offset = (Stride - 1) * genx::WordBytes; + + MagInt = R.createRdRegion(MagInt, "", CI, DebugLoc); + SignInt = R.createRdRegion(SignInt, "", CI, DebugLoc); + } + + auto *MagMask = ConstantInt::get(Int16Ty, 0x7FFF); + auto *SignMask = ConstantInt::get(Int16Ty, 0x8000); + + auto *MagAbs = Builder.CreateAnd( + MagInt, Builder.CreateVectorSplat(NumElements, MagMask)); + auto *SignBit = Builder.CreateAnd( + SignInt, Builder.CreateVectorSplat(NumElements, SignMask)); + + auto *Res = Builder.CreateOr(MagAbs, SignBit); + + if (Stride > 1) + Res = R.createWrRegion(MagCast, Res, "", CI, DebugLoc); + + Res = Builder.CreateBitCast(Res, CI->getType()); + + Res->takeName(CI); + CI->replaceAllUsesWith(Res); + ToErase.push_back(CI); + + return true; +} + /*********************************************************************** * widenByteOp : widen a vector byte operation to short if that might * improve code diff --git a/IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp b/IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp index 9f24d5ac3214..e4e1b50a0b03 100644 --- a/IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp +++ b/IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp @@ -505,7 +505,7 @@ class BfnMatcher { return std::nullopt; } - auto *CV = dyn_cast(C); + auto *CV = dyn_cast(C); if (!CV) return std::nullopt; diff --git a/IGC/VectorCompiler/lib/GenXOpts/CMTrans/GenXTranslateSPIRVBuiltins.cpp b/IGC/VectorCompiler/lib/GenXOpts/CMTrans/GenXTranslateSPIRVBuiltins.cpp index 5f823f2380ba..edc6bde202d8 100644 --- a/IGC/VectorCompiler/lib/GenXOpts/CMTrans/GenXTranslateSPIRVBuiltins.cpp +++ b/IGC/VectorCompiler/lib/GenXOpts/CMTrans/GenXTranslateSPIRVBuiltins.cpp @@ -261,6 +261,7 @@ Value *SPIRVExpander::visitCallInst(CallInst &CI) { .StartsWith("popcount", Intrinsic::ctpop) .StartsWith("s_abs", GenXIntrinsic::genx_absi) // Floating-point intrinsics + .StartsWith("copysign", Intrinsic::copysign) .StartsWith("fabs", Intrinsic::fabs) .StartsWith("fmax", Intrinsic::maxnum) .StartsWith("fma", Intrinsic::fma) diff --git a/IGC/VectorCompiler/test/Lowering/copysign.ll b/IGC/VectorCompiler/test/Lowering/copysign.ll new file mode 100644 index 000000000000..4596d6381f78 --- /dev/null +++ b/IGC/VectorCompiler/test/Lowering/copysign.ll @@ -0,0 +1,72 @@ +;=========================== begin_copyright_notice ============================ +; +; Copyright (C) 2024 Intel Corporation +; +; SPDX-License-Identifier: MIT +; +;============================ end_copyright_notice ============================= + +; RUN: %opt %use_old_pass_manager% -GenXLowering -march=genx64 -mcpu=XeHPC -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s + +declare <4 x half> @llvm.copysign.v4f16(<4 x half>, <4 x half>) +declare <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat>, <4 x bfloat>) +declare <4 x float> @llvm.copysign.v4f32(<4 x float>, <4 x float>) +declare <4 x double> @llvm.copysign.v4f64(<4 x double>, <4 x double>) + +; CHECK-LABEL: @test_v4f16 +define <4 x half> @test_v4f16(<4 x half> %src, <4 x half> %sign) { +; CHECK: [[MAG:%.*]] = bitcast <4 x half> %src to <4 x i16> +; CHECK: [[SGN:%.*]] = bitcast <4 x half> %sign to <4 x i16> +; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG]], +; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN]], +; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]] +; CHECK: [[RES_HALF:%.*]] = bitcast <4 x i16> [[RES]] to <4 x half> +; CHECK: ret <4 x half> [[RES_HALF]] + %res = call <4 x half> @llvm.copysign.v4f16(<4 x half> %src, <4 x half> %sign) + ret <4 x half> %res +} + +; CHECK-LABEL: @test_v4bf16 +define <4 x bfloat> @test_v4bf16(<4 x bfloat> %src, <4 x bfloat> %sign) { +; CHECK: [[MAG:%.*]] = bitcast <4 x bfloat> %src to <4 x i16> +; CHECK: [[SGN:%.*]] = bitcast <4 x bfloat> %sign to <4 x i16> +; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG]], +; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN]], +; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]] +; CHECK: [[RES_BF:%.*]] = bitcast <4 x i16> [[RES]] to <4 x bfloat> +; CHECK: ret <4 x bfloat> [[RES_BF]] + %res = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %src, <4 x bfloat> %sign) + ret <4 x bfloat> %res +} + +; CHECK-LABEL: @test_v4f32 +define <4 x float> @test_v4f32(<4 x float> %src, <4 x float> %sign) { +; CHECK: [[MAG:%.*]] = bitcast <4 x float> %src to <8 x i16> +; CHECK: [[SGN:%.*]] = bitcast <4 x float> %sign to <8 x i16> +; CHECK: [[MAG_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v8i16.i16(<8 x i16> [[MAG]], i32 2, i32 1, i32 0, i16 2, i32 undef) +; CHECK: [[SGN_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v8i16.i16(<8 x i16> [[SGN]], i32 2, i32 1, i32 0, i16 2, i32 undef) +; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG_EXTRACT]], +; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN_EXTRACT]], +; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]] +; CHECK: [[RES_INSERT:%.*]] = call <8 x i16> @llvm.genx.wrregioni.v8i16.v4i16.i16.i1(<8 x i16> [[MAG]], <4 x i16> [[RES]], i32 2, i32 1, i32 0, i16 2, i32 undef, i1 true) +; CHECK: [[RES_FLOAT:%.*]] = bitcast <8 x i16> [[RES_INSERT]] to <4 x float> +; CHECK: ret <4 x float> [[RES_FLOAT]] + %res = call <4 x float> @llvm.copysign.v4f32(<4 x float> %src, <4 x float> %sign) + ret <4 x float> %res +} + +; CHECK-LABEL: @test_v4f64 +define <4 x double> @test_v4f64(<4 x double> %src, <4 x double> %sign) { +; CHECK: [[MAG:%.*]] = bitcast <4 x double> %src to <16 x i16> +; CHECK: [[SGN:%.*]] = bitcast <4 x double> %sign to <16 x i16> +; CHECK: [[MAG_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v16i16.i16(<16 x i16> [[MAG]], i32 4, i32 1, i32 0, i16 6, i32 undef) +; CHECK: [[SGN_EXTRACT:%.*]] = call <4 x i16> @llvm.genx.rdregioni.v4i16.v16i16.i16(<16 x i16> [[SGN]], i32 4, i32 1, i32 0, i16 6, i32 undef) +; CHECK: [[ABS:%.*]] = and <4 x i16> [[MAG_EXTRACT]], +; CHECK: [[SIGN:%.*]] = and <4 x i16> [[SGN_EXTRACT]], +; CHECK: [[RES:%.*]] = or <4 x i16> [[ABS]], [[SIGN]] +; CHECK: [[RES_INSERT:%.*]] = call <16 x i16> @llvm.genx.wrregioni.v16i16.v4i16.i16.i1(<16 x i16> [[MAG]], <4 x i16> [[RES]], i32 4, i32 1, i32 0, i16 6, i32 undef, i1 true) +; CHECK: [[RES_DOUBLE:%.*]] = bitcast <16 x i16> [[RES_INSERT]] to <4 x double> +; CHECK: ret <4 x double> [[RES_DOUBLE]] + %res = call <4 x double> @llvm.copysign.v4f64(<4 x double> %src, <4 x double> %sign) + ret <4 x double> %res +} diff --git a/IGC/VectorCompiler/test/PatternMatch/bfn_match.ll b/IGC/VectorCompiler/test/PatternMatch/bfn_match.ll index d6c6134d4d17..2a0901c23416 100644 --- a/IGC/VectorCompiler/test/PatternMatch/bfn_match.ll +++ b/IGC/VectorCompiler/test/PatternMatch/bfn_match.ll @@ -132,3 +132,12 @@ define i32 @test_unmatch_flag(<32 x i1> %a, <32 x i1> %b, <32 x i1> %c) { ret i32 %2 } + +; CHECK-LABEL: @test_match_combine_by_mask_inv_vector( +define <4 x i16> @test_match_combine_by_mask_inv_vector(<4 x i16> %mag, <4 x i16> %sgn) { +; CHECK: %res = call <4 x i16> @llvm.genx.bfn.v4i16.v4i16(<4 x i16> %mag, <4 x i16> %sgn, <4 x i16> , i8 -84) + %abs = and <4 x i16> %mag, + %sign = and <4 x i16> %sgn, + %res = or <4 x i16> %abs, %sign + ret <4 x i16> %res +} diff --git a/IGC/VectorCompiler/test/SPIRVBuiltins/math_native_builtins.ll b/IGC/VectorCompiler/test/SPIRVBuiltins/math_native_builtins.ll index 8c38e465e775..3abfadfa2ad9 100644 --- a/IGC/VectorCompiler/test/SPIRVBuiltins/math_native_builtins.ll +++ b/IGC/VectorCompiler/test/SPIRVBuiltins/math_native_builtins.ll @@ -38,6 +38,7 @@ declare spir_func <16 x float> @_Z15__spirv_ocl_madDv16_fS_S_(<16 x float>, <16 declare spir_func <7 x double> @_Z15__spirv_ocl_fmaDv7_dS_S_(<7 x double>, <7 x double>, <7 x double>) declare spir_func <7 x double> @_Z15__spirv_ocl_fmaxDv7_dS_S_(<7 x double>, <7 x double>) declare spir_func <16 x double> @_Z16__spirv_ocl_fabsDv16_d(<16 x double>) +declare spir_func <7 x double> @_Z15__spirv_ocl_copysignDv7_dS_S_(<7 x double>, <7 x double>) define spir_func i32 @popcount(i32 %arg) { ; CHECK-LABEL: @popcount @@ -229,3 +230,10 @@ define spir_func <16 x double> @abs_vector(<16 x double> %arg) { %res = call spir_func <16 x double> @_Z16__spirv_ocl_fabsDv16_d(<16 x double> %arg) ret <16 x double> %res } + +define spir_func <7 x double> @copysign_vector(<7 x double> %arg1, <7 x double> %arg2) { +; CHECK-LABEL: @copysign_vector +; CHECK: %res = call <7 x double> @llvm.copysign.v7f64(<7 x double> %arg1, <7 x double> %arg2) + %res = call spir_func <7 x double> @_Z15__spirv_ocl_copysignDv7_dS_S_(<7 x double> %arg1, <7 x double> %arg2) + ret <7 x double> %res +}