Skip to content

Commit

Permalink
[Autobackout][FunctionalRegression]Revert of change: 8723ba3: Move Ht…
Browse files Browse the repository at this point in the history
…oFp optimization to unsafe.

Truncate float to half -> add/multiply add -> extend half to float,
    skips the truncation and extension instructions, performing calculations
    on floats directly.
    This optimization is now unsafe and should only be used with the
    fast, reassoc, or afn attributes.
  • Loading branch information
sys-igc authored and igcbot committed Sep 20, 2024
1 parent 5f9ed8a commit bf295fd
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 294 deletions.
158 changes: 158 additions & 0 deletions IGC/Compiler/CustomSafeOptPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,162 @@ bool CustomSafeOptPass::isEmulatedAdd(BinaryOperator& I)
return false;
}

// Attempt to create new float instruction if both operands are from FPTruncInst instructions.
// Example with fadd:
// %Temp-31.prec.i = fptrunc float %34 to half
// %Temp-30.prec.i = fptrunc float %33 to half
// %41 = fadd fast half %Temp-31.prec.i, %Temp-30.prec.i
// %Temp-32.i = fpext half %41 to float
//
// This fadd is used as a float, and doesn't need the operands to be cased to half.
// We can remove the extra casts in this case.
// This becomes:
// %41 = fadd fast float %34, %33
// Can also do matches with fadd/fmul that will later become an mad instruction.
// mad example:
// %.prec70.i = fptrunc float %273 to half
// %.prec78.i = fptrunc float %276 to half
// %279 = fmul fast half %233, %.prec70.i
// %282 = fadd fast half %279, %.prec78.i
// %.prec84.i = fpext half %282 to float
// This becomes:
// %279 = fpext half %233 to float
// %280 = fmul fast float %273, %279
// %281 = fadd fast float %280, %276
void CustomSafeOptPass::removeHftoFCast(Instruction& I)
{
if (!I.getType()->isFloatingPointTy())
return;

// Check if the only user is a FPExtInst
if (!I.hasOneUse())
return;

// Check if this instruction is used in a single FPExtInst
FPExtInst* castInst = NULL;
User* U = *I.user_begin();
if (FPExtInst* inst = dyn_cast<FPExtInst>(U))
{
if (inst->getType()->isFloatTy())
{
castInst = inst;
}
}
if (!castInst)
return;

// Check for fmad pattern
if (I.getOpcode() == Instruction::FAdd)
{
Value* src0 = nullptr, * src1 = nullptr, * src2 = nullptr;

// CodeGenPatternMatch::MatchMad matches the first fmul.
Instruction* fmulInst = nullptr;
for (uint i = 0; i < 2; i++)
{
fmulInst = dyn_cast<Instruction>(I.getOperand(i));
if (fmulInst && fmulInst->getOpcode() == Instruction::FMul)
{
src0 = fmulInst->getOperand(0);
src1 = fmulInst->getOperand(1);
src2 = I.getOperand(1 - i);
break;
}
else
{
// Prevent other non-fmul instructions from getting used
fmulInst = nullptr;
}
}
if (fmulInst)
{
// Used to get the new float operands for the new instructions
auto getFloatValue = [](Value* operand, Instruction* I, Type* type)
{
if (FPTruncInst* inst = dyn_cast<FPTruncInst>(operand))
{
// Use the float input of the FPTrunc
if (inst->getOperand(0)->getType()->isFloatTy())
{
return inst->getOperand(0);
}
else
{
return (Value*)NULL;
}
}
else if (Instruction* inst = dyn_cast<Instruction>(operand))
{
// Cast the result of this operand to a float
return dyn_cast<Value>(new FPExtInst(inst, type, "", I));
}
return (Value*)NULL;
};

int convertCount = 0;
if (dyn_cast<FPTruncInst>(src0))
convertCount++;
if (dyn_cast<FPTruncInst>(src1))
convertCount++;
if (dyn_cast<FPTruncInst>(src2))
convertCount++;
if (convertCount >= 2)
{
// Conversion for the hf values
auto floatTy = castInst->getType();
src0 = getFloatValue(src0, fmulInst, floatTy);
src1 = getFloatValue(src1, fmulInst, floatTy);
src2 = getFloatValue(src2, &I, floatTy);

if (!src0 || !src1 || !src2)
return;

// Create new float fmul and fadd instructions
Value* newFmul = BinaryOperator::Create(Instruction::FMul, src0, src1, "", &I);
Value* newFadd = BinaryOperator::Create(Instruction::FAdd, newFmul, src2, "", &I);

// Copy fast math flags
Instruction* fmulInst = dyn_cast<Instruction>(newFmul);
Instruction* faddInst = dyn_cast<Instruction>(newFadd);
fmulInst->copyFastMathFlags(fmulInst);
faddInst->copyFastMathFlags(&I);
faddInst->setDebugLoc(castInst->getDebugLoc());

castInst->replaceAllUsesWith(faddInst);
return;
}
}
}

// Check if operands come from a Float to HF Cast
Value *S1 = NULL, *S2 = NULL;
if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand(0)))
{
if (!inst->getType()->isHalfTy())
return;
S1 = inst->getOperand(0);
}
if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand(1)))
{
if (!inst->getType()->isHalfTy())
return;
S2 = inst->getOperand(0);
}
if (!S1 || !S2)
{
return;
}

Value* newInst = NULL;
if (BinaryOperator* bo = dyn_cast<BinaryOperator>(&I))
{
newInst = BinaryOperator::Create(bo->getOpcode(), S1, S2, "", &I);
Instruction* inst = dyn_cast<Instruction>(newInst);
inst->copyFastMathFlags(&I);
inst->setDebugLoc(castInst->getDebugLoc());
castInst->replaceAllUsesWith(inst);
}
}

void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
{
Expand Down Expand Up @@ -1763,6 +1919,8 @@ void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
}
}
}
} else if (I.getType()->isFloatingPointTy()) {
removeHftoFCast(I);
}

if (IGC_IS_FLAG_ENABLED(ForceHoistDp3) || (!pContext->m_retryManager.IsFirstTry() && IGC_IS_FLAG_ENABLED(EnableHoistDp3)))
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CustomSafeOptPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace IGC
void visitUDiv(llvm::BinaryOperator& I);
void visitAllocaInst(llvm::AllocaInst& I);
void visitCallInst(llvm::CallInst& C);
void removeHftoFCast(llvm::Instruction& I);
void visitBinaryOperator(llvm::BinaryOperator& I);
bool isEmulatedAdd(llvm::BinaryOperator& I);
void visitBfi(llvm::CallInst* inst);
Expand Down
167 changes: 0 additions & 167 deletions IGC/Compiler/CustomUnsafeOptPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1822,12 +1822,6 @@ void CustomUnsafeOptPass::visitBinaryOperator(BinaryOperator& I)
patternFound = visitBinaryOperatorFmulFaddPropagation(I);
}

// remove casting to half when assigning to float
if (!patternFound)
{
patternFound = visitBinaryOperatorRemoveHftoFCast(I);
}

// A/B +C/D can be changed to (A * D +C * B)/(B * D).
if (!patternFound && IGC_IS_FLAG_ENABLED(EnableSumFractions))
{
Expand Down Expand Up @@ -2035,167 +2029,6 @@ void CustomUnsafeOptPass::visitBinaryOperator(BinaryOperator& I)
}
}

// Attempt to create new float instruction if both operands are from FPTruncInst instructions.
// Example with fadd:
// %Temp-31.prec.i = fptrunc float %34 to half
// %Temp-30.prec.i = fptrunc float %33 to half
// %41 = fadd fast half %Temp-31.prec.i, %Temp-30.prec.i
// %Temp-32.i = fpext half %41 to float
//
// This fadd is used as a float, and doesn't need the operands to be cased to half.
// We can remove the extra casts in this case.
// This becomes:
// %41 = fadd fast float %34, %33
// Can also do matches with fadd/fmul that will later become an mad instruction.
// mad example:
// %.prec70.i = fptrunc float %273 to half
// %.prec78.i = fptrunc float %276 to half
// %279 = fmul fast half %233, %.prec70.i
// %282 = fadd fast half %279, %.prec78.i
// %.prec84.i = fpext half %282 to float
// This becomes:
// %279 = fpext half %233 to float
// %280 = fmul fast float %273, %279
// %281 = fadd fast float %280, %276
bool CustomUnsafeOptPass::visitBinaryOperatorRemoveHftoFCast(BinaryOperator& I)
{
// Allow only if the reassoc or afn flags are used
if (!(I.hasAllowReassoc() || I.hasApproxFunc()))
return false;

// Check if the only user is a FPExtInst
if (!I.hasOneUse())
return false;

// Check if this instruction is used in a single FPExtInst
FPExtInst* CastInst = NULL;
User* U = *I.user_begin();
if (FPExtInst* inst = dyn_cast<FPExtInst>(U))
{
if (inst->getType()->isFloatTy())
{
CastInst = inst;
}
}
if (!CastInst || CastInst->use_empty())
return false;


// Check for fmad pattern
if (I.getOpcode() == Instruction::FAdd)
{
Value* Src0 = nullptr, * Src1 = nullptr, * Src2 = nullptr;

// CodeGenPatternMatch::MatchMad matches the first fmul.
Instruction* FmulInst = nullptr;
for (uint i = 0; i < 2; i++)
{
FmulInst = dyn_cast<Instruction>(I.getOperand(i));
if (FmulInst && FmulInst->getOpcode() == Instruction::FMul)
{
Src0 = FmulInst->getOperand(0);
Src1 = FmulInst->getOperand(1);
Src2 = I.getOperand(1 - i);
break;
}
else
{
// Prevent other non-fmul instructions from getting used
FmulInst = nullptr;
}
}
if (FmulInst && (I.hasAllowReassoc() || I.hasApproxFunc()))
{
// Used to get the new float operands for the new instructions
auto getFloatValue = [](Value* operand, Instruction* I, Type* type)
{
if (FPTruncInst* Inst = dyn_cast<FPTruncInst>(operand))
{
// Use the float input of the FPTrunc
if (Inst->getOperand(0)->getType()->isFloatTy())
{
return Inst->getOperand(0);
}
else
{
return (Value*)NULL;
}
}
else if (operand->getType()->isHalfTy())
{
return dyn_cast<Value>(new FPExtInst(operand, type, "", I));
}
return (Value*)NULL;
};

int ConvertCount = 0;
if (dyn_cast<FPTruncInst>(Src0))
ConvertCount++;
if (dyn_cast<FPTruncInst>(Src1))
ConvertCount++;
if (dyn_cast<FPTruncInst>(Src2))
ConvertCount++;
if (ConvertCount >= 2)
{
// Conversion for the hf values
auto FloatTy = CastInst->getType();
Src0 = getFloatValue(Src0, FmulInst, FloatTy);
Src1 = getFloatValue(Src1, FmulInst, FloatTy);
Src2 = getFloatValue(Src2, &I, FloatTy);

if (!Src0 || !Src1 || !Src2)
return false;

// Create new float fmul and fadd instructions
Value* NewFmul = BinaryOperator::Create(Instruction::FMul, Src0, Src1, "", &I);
Value* NewFadd = BinaryOperator::Create(Instruction::FAdd, NewFmul, Src2, "", &I);

// Copy fast math flags
Instruction* FmulInst = dyn_cast<Instruction>(NewFmul);
Instruction* FaddInst = dyn_cast<Instruction>(NewFadd);
FmulInst->copyFastMathFlags(FmulInst);
FaddInst->copyFastMathFlags(&I);
FaddInst->setDebugLoc(CastInst->getDebugLoc());
CastInst->replaceAllUsesWith(FaddInst);
collectForErase(*CastInst, 3);
return true;
}
}
}

// Check if operands come from a Float to HF Cast
Value* S1 = NULL, * S2 = NULL;
if (FPTruncInst* Inst = dyn_cast<FPTruncInst>(I.getOperand(0)))
{
if (!Inst->getType()->isHalfTy())
return false;
S1 = Inst->getOperand(0);
}
if (FPTruncInst* Inst = dyn_cast<FPTruncInst>(I.getOperand(1)))
{
if (!Inst->getType()->isHalfTy())
return false;
S2 = Inst->getOperand(0);
}
if (!S1 || !S2)
{
return false;
}

Value* newInst = NULL;
if (BinaryOperator* BinOp = dyn_cast<BinaryOperator>(&I))
{
newInst = BinaryOperator::Create(BinOp->getOpcode(), S1, S2, "", &I);
Instruction* Inst = dyn_cast<Instruction>(newInst);
Inst->copyFastMathFlags(&I);
Inst->setDebugLoc(CastInst->getDebugLoc());
CastInst->replaceAllUsesWith(Inst);
collectForErase(*CastInst, 2);
return true;
}
return false;
}

// Optimize mix operation if detected.
// Mix is computed as x*(1 - a) + y*a
// Replace it with a*(y - x) + x to save one instruction ('add' ISA, 'sub' in IR).
Expand Down
1 change: 0 additions & 1 deletion IGC/Compiler/CustomUnsafeOptPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ namespace IGC
bool visitBinaryOperatorAddSubOp(llvm::BinaryOperator& I);
bool visitBinaryOperatorDivAddDiv(llvm::BinaryOperator& I);
bool visitBinaryOperatorFDivFMulCancellation(llvm::BinaryOperator& I);
bool visitBinaryOperatorRemoveHftoFCast(llvm::BinaryOperator& I);
bool isFDiv(llvm::Value* I, llvm::Value*& numerator, llvm::Value*& denominator);
bool possibleForFmadOpt(llvm::Instruction* inst);
bool visitFCmpInstFCmpFAddOp(llvm::FCmpInst& FC);
Expand Down
Loading

0 comments on commit bf295fd

Please sign in to comment.