diff --git a/IGC/Compiler/CustomSafeOptPass.cpp b/IGC/Compiler/CustomSafeOptPass.cpp index 0df3680551d8..f4a8a33bf810 100644 --- a/IGC/Compiler/CustomSafeOptPass.cpp +++ b/IGC/Compiler/CustomSafeOptPass.cpp @@ -1688,6 +1688,162 @@ bool CustomSafeOptPass::isEmulatedAdd(BinaryOperator& I) return false; } +// Attempt to create new float instruction if both operands are from FPTruncInst instructions. +// Example with fadd: +// %Temp-31.prec.i = fptrunc float %34 to half +// %Temp-30.prec.i = fptrunc float %33 to half +// %41 = fadd fast half %Temp-31.prec.i, %Temp-30.prec.i +// %Temp-32.i = fpext half %41 to float +// +// This fadd is used as a float, and doesn't need the operands to be cased to half. +// We can remove the extra casts in this case. +// This becomes: +// %41 = fadd fast float %34, %33 +// Can also do matches with fadd/fmul that will later become an mad instruction. +// mad example: +// %.prec70.i = fptrunc float %273 to half +// %.prec78.i = fptrunc float %276 to half +// %279 = fmul fast half %233, %.prec70.i +// %282 = fadd fast half %279, %.prec78.i +// %.prec84.i = fpext half %282 to float +// This becomes: +// %279 = fpext half %233 to float +// %280 = fmul fast float %273, %279 +// %281 = fadd fast float %280, %276 +void CustomSafeOptPass::removeHftoFCast(Instruction& I) +{ + if (!I.getType()->isFloatingPointTy()) + return; + + // Check if the only user is a FPExtInst + if (!I.hasOneUse()) + return; + + // Check if this instruction is used in a single FPExtInst + FPExtInst* castInst = NULL; + User* U = *I.user_begin(); + if (FPExtInst* inst = dyn_cast(U)) + { + if (inst->getType()->isFloatTy()) + { + castInst = inst; + } + } + if (!castInst) + return; + + // Check for fmad pattern + if (I.getOpcode() == Instruction::FAdd) + { + Value* src0 = nullptr, * src1 = nullptr, * src2 = nullptr; + + // CodeGenPatternMatch::MatchMad matches the first fmul. + Instruction* fmulInst = nullptr; + for (uint i = 0; i < 2; i++) + { + fmulInst = dyn_cast(I.getOperand(i)); + if (fmulInst && fmulInst->getOpcode() == Instruction::FMul) + { + src0 = fmulInst->getOperand(0); + src1 = fmulInst->getOperand(1); + src2 = I.getOperand(1 - i); + break; + } + else + { + // Prevent other non-fmul instructions from getting used + fmulInst = nullptr; + } + } + if (fmulInst) + { + // Used to get the new float operands for the new instructions + auto getFloatValue = [](Value* operand, Instruction* I, Type* type) + { + if (FPTruncInst* inst = dyn_cast(operand)) + { + // Use the float input of the FPTrunc + if (inst->getOperand(0)->getType()->isFloatTy()) + { + return inst->getOperand(0); + } + else + { + return (Value*)NULL; + } + } + else if (Instruction* inst = dyn_cast(operand)) + { + // Cast the result of this operand to a float + return dyn_cast(new FPExtInst(inst, type, "", I)); + } + return (Value*)NULL; + }; + + int convertCount = 0; + if (dyn_cast(src0)) + convertCount++; + if (dyn_cast(src1)) + convertCount++; + if (dyn_cast(src2)) + convertCount++; + if (convertCount >= 2) + { + // Conversion for the hf values + auto floatTy = castInst->getType(); + src0 = getFloatValue(src0, fmulInst, floatTy); + src1 = getFloatValue(src1, fmulInst, floatTy); + src2 = getFloatValue(src2, &I, floatTy); + + if (!src0 || !src1 || !src2) + return; + + // Create new float fmul and fadd instructions + Value* newFmul = BinaryOperator::Create(Instruction::FMul, src0, src1, "", &I); + Value* newFadd = BinaryOperator::Create(Instruction::FAdd, newFmul, src2, "", &I); + + // Copy fast math flags + Instruction* fmulInst = dyn_cast(newFmul); + Instruction* faddInst = dyn_cast(newFadd); + fmulInst->copyFastMathFlags(fmulInst); + faddInst->copyFastMathFlags(&I); + faddInst->setDebugLoc(castInst->getDebugLoc()); + + castInst->replaceAllUsesWith(faddInst); + return; + } + } + } + + // Check if operands come from a Float to HF Cast + Value *S1 = NULL, *S2 = NULL; + if (FPTruncInst* inst = dyn_cast(I.getOperand(0))) + { + if (!inst->getType()->isHalfTy()) + return; + S1 = inst->getOperand(0); + } + if (FPTruncInst* inst = dyn_cast(I.getOperand(1))) + { + if (!inst->getType()->isHalfTy()) + return; + S2 = inst->getOperand(0); + } + if (!S1 || !S2) + { + return; + } + + Value* newInst = NULL; + if (BinaryOperator* bo = dyn_cast(&I)) + { + newInst = BinaryOperator::Create(bo->getOpcode(), S1, S2, "", &I); + Instruction* inst = dyn_cast(newInst); + inst->copyFastMathFlags(&I); + inst->setDebugLoc(castInst->getDebugLoc()); + castInst->replaceAllUsesWith(inst); + } +} void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I) { @@ -1763,6 +1919,8 @@ void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I) } } } + } else if (I.getType()->isFloatingPointTy()) { + removeHftoFCast(I); } if (IGC_IS_FLAG_ENABLED(ForceHoistDp3) || (!pContext->m_retryManager.IsFirstTry() && IGC_IS_FLAG_ENABLED(EnableHoistDp3))) diff --git a/IGC/Compiler/CustomSafeOptPass.hpp b/IGC/Compiler/CustomSafeOptPass.hpp index 1fcd7b95287d..3c605c801dc9 100644 --- a/IGC/Compiler/CustomSafeOptPass.hpp +++ b/IGC/Compiler/CustomSafeOptPass.hpp @@ -55,6 +55,7 @@ namespace IGC void visitUDiv(llvm::BinaryOperator& I); void visitAllocaInst(llvm::AllocaInst& I); void visitCallInst(llvm::CallInst& C); + void removeHftoFCast(llvm::Instruction& I); void visitBinaryOperator(llvm::BinaryOperator& I); bool isEmulatedAdd(llvm::BinaryOperator& I); void visitBfi(llvm::CallInst* inst); diff --git a/IGC/Compiler/CustomUnsafeOptPass.cpp b/IGC/Compiler/CustomUnsafeOptPass.cpp index f599d2d5eb0d..e92fb7e23818 100644 --- a/IGC/Compiler/CustomUnsafeOptPass.cpp +++ b/IGC/Compiler/CustomUnsafeOptPass.cpp @@ -1822,12 +1822,6 @@ void CustomUnsafeOptPass::visitBinaryOperator(BinaryOperator& I) patternFound = visitBinaryOperatorFmulFaddPropagation(I); } - // remove casting to half when assigning to float - if (!patternFound) - { - patternFound = visitBinaryOperatorRemoveHftoFCast(I); - } - // A/B +C/D can be changed to (A * D +C * B)/(B * D). if (!patternFound && IGC_IS_FLAG_ENABLED(EnableSumFractions)) { @@ -2035,167 +2029,6 @@ void CustomUnsafeOptPass::visitBinaryOperator(BinaryOperator& I) } } -// Attempt to create new float instruction if both operands are from FPTruncInst instructions. -// Example with fadd: -// %Temp-31.prec.i = fptrunc float %34 to half -// %Temp-30.prec.i = fptrunc float %33 to half -// %41 = fadd fast half %Temp-31.prec.i, %Temp-30.prec.i -// %Temp-32.i = fpext half %41 to float -// -// This fadd is used as a float, and doesn't need the operands to be cased to half. -// We can remove the extra casts in this case. -// This becomes: -// %41 = fadd fast float %34, %33 -// Can also do matches with fadd/fmul that will later become an mad instruction. -// mad example: -// %.prec70.i = fptrunc float %273 to half -// %.prec78.i = fptrunc float %276 to half -// %279 = fmul fast half %233, %.prec70.i -// %282 = fadd fast half %279, %.prec78.i -// %.prec84.i = fpext half %282 to float -// This becomes: -// %279 = fpext half %233 to float -// %280 = fmul fast float %273, %279 -// %281 = fadd fast float %280, %276 -bool CustomUnsafeOptPass::visitBinaryOperatorRemoveHftoFCast(BinaryOperator& I) -{ - // Allow only if the reassoc or afn flags are used - if (!(I.hasAllowReassoc() || I.hasApproxFunc())) - return false; - - // Check if the only user is a FPExtInst - if (!I.hasOneUse()) - return false; - - // Check if this instruction is used in a single FPExtInst - FPExtInst* CastInst = NULL; - User* U = *I.user_begin(); - if (FPExtInst* inst = dyn_cast(U)) - { - if (inst->getType()->isFloatTy()) - { - CastInst = inst; - } - } - if (!CastInst || CastInst->use_empty()) - return false; - - - // Check for fmad pattern - if (I.getOpcode() == Instruction::FAdd) - { - Value* Src0 = nullptr, * Src1 = nullptr, * Src2 = nullptr; - - // CodeGenPatternMatch::MatchMad matches the first fmul. - Instruction* FmulInst = nullptr; - for (uint i = 0; i < 2; i++) - { - FmulInst = dyn_cast(I.getOperand(i)); - if (FmulInst && FmulInst->getOpcode() == Instruction::FMul) - { - Src0 = FmulInst->getOperand(0); - Src1 = FmulInst->getOperand(1); - Src2 = I.getOperand(1 - i); - break; - } - else - { - // Prevent other non-fmul instructions from getting used - FmulInst = nullptr; - } - } - if (FmulInst && (I.hasAllowReassoc() || I.hasApproxFunc())) - { - // Used to get the new float operands for the new instructions - auto getFloatValue = [](Value* operand, Instruction* I, Type* type) - { - if (FPTruncInst* Inst = dyn_cast(operand)) - { - // Use the float input of the FPTrunc - if (Inst->getOperand(0)->getType()->isFloatTy()) - { - return Inst->getOperand(0); - } - else - { - return (Value*)NULL; - } - } - else if (operand->getType()->isHalfTy()) - { - return dyn_cast(new FPExtInst(operand, type, "", I)); - } - return (Value*)NULL; - }; - - int ConvertCount = 0; - if (dyn_cast(Src0)) - ConvertCount++; - if (dyn_cast(Src1)) - ConvertCount++; - if (dyn_cast(Src2)) - ConvertCount++; - if (ConvertCount >= 2) - { - // Conversion for the hf values - auto FloatTy = CastInst->getType(); - Src0 = getFloatValue(Src0, FmulInst, FloatTy); - Src1 = getFloatValue(Src1, FmulInst, FloatTy); - Src2 = getFloatValue(Src2, &I, FloatTy); - - if (!Src0 || !Src1 || !Src2) - return false; - - // Create new float fmul and fadd instructions - Value* NewFmul = BinaryOperator::Create(Instruction::FMul, Src0, Src1, "", &I); - Value* NewFadd = BinaryOperator::Create(Instruction::FAdd, NewFmul, Src2, "", &I); - - // Copy fast math flags - Instruction* FmulInst = dyn_cast(NewFmul); - Instruction* FaddInst = dyn_cast(NewFadd); - FmulInst->copyFastMathFlags(FmulInst); - FaddInst->copyFastMathFlags(&I); - FaddInst->setDebugLoc(CastInst->getDebugLoc()); - CastInst->replaceAllUsesWith(FaddInst); - collectForErase(*CastInst, 3); - return true; - } - } - } - - // Check if operands come from a Float to HF Cast - Value* S1 = NULL, * S2 = NULL; - if (FPTruncInst* Inst = dyn_cast(I.getOperand(0))) - { - if (!Inst->getType()->isHalfTy()) - return false; - S1 = Inst->getOperand(0); - } - if (FPTruncInst* Inst = dyn_cast(I.getOperand(1))) - { - if (!Inst->getType()->isHalfTy()) - return false; - S2 = Inst->getOperand(0); - } - if (!S1 || !S2) - { - return false; - } - - Value* newInst = NULL; - if (BinaryOperator* BinOp = dyn_cast(&I)) - { - newInst = BinaryOperator::Create(BinOp->getOpcode(), S1, S2, "", &I); - Instruction* Inst = dyn_cast(newInst); - Inst->copyFastMathFlags(&I); - Inst->setDebugLoc(CastInst->getDebugLoc()); - CastInst->replaceAllUsesWith(Inst); - collectForErase(*CastInst, 2); - return true; - } - return false; -} - // Optimize mix operation if detected. // Mix is computed as x*(1 - a) + y*a // Replace it with a*(y - x) + x to save one instruction ('add' ISA, 'sub' in IR). diff --git a/IGC/Compiler/CustomUnsafeOptPass.hpp b/IGC/Compiler/CustomUnsafeOptPass.hpp index 9b36403ef850..30dada28c35b 100644 --- a/IGC/Compiler/CustomUnsafeOptPass.hpp +++ b/IGC/Compiler/CustomUnsafeOptPass.hpp @@ -70,7 +70,6 @@ namespace IGC bool visitBinaryOperatorAddSubOp(llvm::BinaryOperator& I); bool visitBinaryOperatorDivAddDiv(llvm::BinaryOperator& I); bool visitBinaryOperatorFDivFMulCancellation(llvm::BinaryOperator& I); - bool visitBinaryOperatorRemoveHftoFCast(llvm::BinaryOperator& I); bool isFDiv(llvm::Value* I, llvm::Value*& numerator, llvm::Value*& denominator); bool possibleForFmadOpt(llvm::Instruction* inst); bool visitFCmpInstFCmpFAddOp(llvm::FCmpInst& FC); diff --git a/IGC/Compiler/tests/CustomUnsafeOptPass/fp_to_half_to_fp.ll b/IGC/Compiler/tests/CustomUnsafeOptPass/fp_to_half_to_fp.ll deleted file mode 100755 index ec13dec19063..000000000000 --- a/IGC/Compiler/tests/CustomUnsafeOptPass/fp_to_half_to_fp.ll +++ /dev/null @@ -1,126 +0,0 @@ -;=========================== begin_copyright_notice ============================ -; -; Copyright (C) 2024 Intel Corporation -; -; SPDX-License-Identifier: MIT -; -;============================ end_copyright_notice ============================= - -; RUN: igc_opt -igc-custom-unsafe-opt-pass -S %s -o %t.ll -; RUN: FileCheck %s --input-file=%t.ll - -; tests removeHftoFCast - -define float @doNothing(float %a, float %b) { - %1 = fptrunc float %a to half - %2 = fptrunc float %b to half - %3 = fadd half %1, %2 - %4 = fpext half %3 to float - ret float %4 -} - -; CHECK-LABEL: define float @doNothing -; CHECK: %1 = fptrunc float %a to half -; CHECK: %2 = fptrunc float %b to half -; CHECK: %3 = fadd half %1, %2 -; CHECK: %4 = fpext half %3 to float -; CHECK: ret float %4 - -define float @fastAttr(float %a, float %b) { - %1 = fptrunc float %a to half - %2 = fptrunc float %b to half - %3 = fadd fast half %1, %2 - %4 = fpext half %3 to float - ret float %4 -} - -; CHECK-LABEL: define float @fastAttr -; CHECK-NOT: fptrunc -; CHECK-NOT: fpext -; CHECK: %1 = fadd fast float %a, %b -; CHECK: ret float %1 - -define float @afnAttr(float %a, float %b) { - %1 = fptrunc float %a to half - %2 = fptrunc float %b to half - %3 = fadd afn half %1, %2 - %4 = fpext half %3 to float - ret float %4 -} - -; CHECK-LABEL: define float @afnAttr -; CHECK-NOT: fptrunc -; CHECK-NOT: fpext -; CHECK: %1 = fadd afn float %a, %b -; CHECK: ret float %1 - -define float @reassocAttr(float %a, float %b) { - %1 = fptrunc float %a to half - %2 = fptrunc float %b to half - %3 = fadd reassoc half %1, %2 - %4 = fpext half %3 to float - ret float %4 -} - -; CHECK-LABEL: define float @reassocAttr -; CHECK-NOT: fptrunc -; CHECK-NOT: fpext -; CHECK: %1 = fadd reassoc float %a, %b -; CHECK: ret float %1 - -define float @doNothingMulAdd(float %a, float %b, half %c) { - %1 = fptrunc float %a to half - %2 = fptrunc float %b to half - %3 = fmul half %1, %c - %4 = fadd half %2, %3 - %5 = fpext half %4 to float - ret float %5 -} - -; CHECK-LABEL: define float @doNothingMulAdd -; CHECK: %1 = fptrunc float %a to half -; CHECK: %2 = fptrunc float %b to half -; CHECK: %3 = fmul half %1, %c -; CHECK: %4 = fadd half %2, %3 -; CHECK: %5 = fpext half %4 to float -; CHECK: ret float %5 - -define float @thirdOperandIsInstruction(float %a, float %b, i32 %c) { - %dummy_half = sitofp i32 %c to half - %1 = fptrunc float %a to half - %2 = fptrunc float %b to half - %3 = fmul fast half %1, %dummy_half - %4 = fadd fast half %2, %3 - %5 = fpext half %4 to float - ret float %5 -} - -; CHECK-LABEL: define float @thirdOperandIsInstruction -; CHECK-NOT: fptrunc -; CHECK: %1 = fpext half %dummy_half to float -; CHECK: %2 = fmul float %a, %1 -; CHECK: %3 = fadd fast float %2, %b -; CHECK: ret float %3 - -define float @thirdOperandIsValue(float %a, float %b, half %c) { - %1 = fptrunc float %a to half - %2 = fptrunc float %b to half - %3 = fmul fast half %1, %c - %4 = fadd fast half %2, %3 - %5 = fpext half %4 to float - ret float %5 -} - -; CHECK-LABEL: define float @thirdOperandIsValue -; CHECK-NOT: fptrunc -; CHECK: %1 = fpext half %c to float -; CHECK: %2 = fmul float %a, %1 -; CHECK: %3 = fadd fast float %2, %b -; CHECK: ret float %3 - - -!IGCMetadata = !{!0} - -!0 = !{!"ModuleMD", !1} -!1 = !{!"compOpt", !2} -!2 = !{!"FastRelaxedMath", i1 true}