Skip to content

Commit

Permalink
Fixing selective scalarization
Browse files Browse the repository at this point in the history
ScalarizeFunction pass can keep some instructions vectorized, if the vector is used as the whole entity. The pass builds a web of instructions protected from scalarization. The ending legs of the web consist of vectorial instructions such as insert and extract elements, vector shuffles, GenISA intrinsics and function calls. The vectorial instructions inside the web consist of bitcasts and PHI nodes.
  • Loading branch information
adam-bzowski authored and igcbot committed Aug 6, 2024
1 parent 9a6da08 commit 2adb59c
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 90 deletions.
160 changes: 82 additions & 78 deletions IGC/Compiler/Optimizer/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ SPDX-License-Identifier: MIT
#include "common/LLVMWarningsPop.hpp"
#include "common/igc_regkeys.hpp"
#include "common/Types.hpp"
#include <iostream>
#include "Probe/Assertion.h"
#include <vector>

using namespace llvm;
using namespace IGC;
Expand Down Expand Up @@ -62,6 +62,8 @@ ScalarizeFunction::ScalarizeFunction(bool selectiveScalarization) : FunctionPass
initializeScalarizeFunctionPass(*PassRegistry::getPassRegistry());

for (int i = 0; i < Instruction::OtherOpsEnd; i++) m_transposeCtr[i] = 0;

// Needs IGC_EnableSelectiveScalarizer = 1
m_SelectiveScalarization = selectiveScalarization;

// Initialize SCM buffers and allocation
Expand All @@ -70,14 +72,17 @@ ScalarizeFunction::ScalarizeFunction(bool selectiveScalarization) : FunctionPass
m_SCMArrayLocation = 0;

V_PRINT(scalarizer, "ScalarizeFunction constructor\n");
V_PRINT(scalarizer, "IGC_EnableSelectiveScalarizer = ");
V_PRINT(scalarizer, IGC_IS_FLAG_ENABLED(EnableSelectiveScalarizer));
V_PRINT(scalarizer, "\n");
}

ScalarizeFunction::~ScalarizeFunction()
{
bool ScalarizeFunction::doFinalization(llvm::Module& M) {
releaseAllSCMEntries();
delete[] m_SCMAllocationArray;
destroyDummyFunc();
V_PRINT(scalarizer, "ScalarizeFunction destructor\n");
V_PRINT(scalarizer, "ScalarizeFunction doFinalization\n");
return true;
}

bool ScalarizeFunction::runOnFunction(Function& F)
Expand Down Expand Up @@ -157,7 +162,7 @@ bool ScalarizeFunction::runOnFunction(Function& F)
for (; index != re; ++index)
{
// get rid of old users
if (Value * val = dyn_cast<Value>(*index))
if (Value* val = dyn_cast<Value>(*index))
{
UndefValue* undefVal = UndefValue::get((*index)->getType());
(val)->replaceAllUsesWith(undefVal);
Expand All @@ -171,13 +176,18 @@ bool ScalarizeFunction::runOnFunction(Function& F)
}

/// <summary>
/// @brief We want to avoid scalarize vector-phi node if the vector is used
/// @brief We want to avoid scalarization of vector instructions if the vector is used
/// as a whole entity somewhere in the program. This function tries to find
/// this kind of definition web that involves phi-node, insert-element etc,
/// then add them into the exclusion-set (excluded from scalarization).
/// </summary>
void ScalarizeFunction::buildExclusiveSet()
{

auto isAddToWeb = [](Value* V) -> bool {
return isa<PHINode>(V) || isa<BitCastInst>(V);
};

auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
for (auto dfi = df_begin(DT->getRootNode()), dfe = df_end(DT->getRootNode());
dfi != dfe; ++dfi)
Expand All @@ -190,7 +200,10 @@ void ScalarizeFunction::buildExclusiveSet()
Instruction* currInst = &*sI;
++sI;
// find the seed for the workset
std::vector<llvm::Value*> workset;
std::vector<Value*> workset;

// Instructions that accept vectorial arguments can end legs of the web
// i.e. the instructions that produce the vectorial arguments may be protected from scalarization
if (GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(currInst))
{
unsigned numOperands = IGCLLVM::getNumArgOperands(GII);
Expand All @@ -203,6 +216,16 @@ void ScalarizeFunction::buildExclusiveSet()
}
}
}
else if (CallInst * CI = dyn_cast<CallInst>(currInst))
{
for (auto arg = CI->arg_begin(); arg != CI->arg_end(); ++arg)
{
if (isa<VectorType>(arg->get()->getType()))
{
workset.push_back(arg->get());
}
}
}
else if (auto IEI = dyn_cast<InsertElementInst>(currInst))
{
Value* scalarIndexVal = IEI->getOperand(2);
Expand All @@ -219,9 +242,12 @@ void ScalarizeFunction::buildExclusiveSet()
workset.push_back(EEI->getOperand(0));
}
}
// try to find a phi-web from the seed
bool HasPHI = false;
std::set<llvm::Value*> defweb;
else if (BitCastInst* BCI = dyn_cast<BitCastInst>(currInst))
{
workset.push_back(BCI->getOperand(0));
}
// try to find a web from the seed
std::set<Value*> defweb;
while (!workset.empty())
{
auto Def = workset.back();
Expand All @@ -230,70 +256,45 @@ void ScalarizeFunction::buildExclusiveSet()
{
continue;
}
if (auto IEI = dyn_cast<InsertElementInst>(Def))
{
defweb.insert(IEI);
if (!defweb.count(IEI->getOperand(0)) &&
(isa<PHINode>(IEI->getOperand(0)) ||
isa<ShuffleVectorInst>(IEI->getOperand(0)) ||
isa<InsertElementInst>(IEI->getOperand(0))))
{
workset.push_back(IEI->getOperand(0));
}
}
else if (auto SVI = dyn_cast<ShuffleVectorInst>(Def))

// The web grows "up" through BitCasts and PHI nodes
// but insert/extract elements and vector shuffles should be scalarized
if (!isAddToWeb(Def)) continue;

if (BitCastInst* BCI = dyn_cast<BitCastInst>(Def))
{
defweb.insert(SVI);
if (!defweb.count(SVI->getOperand(0)) &&
(isa<PHINode>(SVI->getOperand(0)) ||
isa<ShuffleVectorInst>(SVI->getOperand(0)) ||
isa<InsertElementInst>(SVI->getOperand(0))))
{
workset.push_back(SVI->getOperand(0));
}
if (!defweb.count(SVI->getOperand(1)) &&
(isa<PHINode>(SVI->getOperand(1)) ||
isa<ShuffleVectorInst>(SVI->getOperand(1)) ||
isa<InsertElementInst>(SVI->getOperand(1))))
defweb.insert(BCI);
if (!defweb.count(BCI->getOperand(0)) && isAddToWeb(BCI->getOperand(0)))
{
workset.push_back(SVI->getOperand(1));
workset.push_back(BCI->getOperand(0));
}
}
else if (auto PHI = dyn_cast<PHINode>(Def))
{
defweb.insert(PHI);
HasPHI = true; // !this def-web is qualified!
for (int i = 0, n = PHI->getNumOperands(); i < n; ++i)
if (!defweb.count(PHI->getOperand(i)) &&
(isa<PHINode>(PHI->getOperand(i)) ||
isa<ShuffleVectorInst>(PHI->getOperand(i)) ||
isa<InsertElementInst>(PHI->getOperand(i))))
{
if (!defweb.count(PHI->getOperand(i)) && isAddToWeb(PHI->getOperand(i)))
{
workset.push_back(PHI->getOperand(i));
}
}
}
else
{
continue;
}
// check use

// The web grows "down" through BitCasts and PHI nodes as well
for (auto U : Def->users())
{
if (!defweb.count(U) &&
(isa<PHINode>(U) ||
isa<ShuffleVectorInst>(U) ||
isa<InsertElementInst>(U)))
if (!defweb.count(U) && isAddToWeb(U))
{
workset.push_back(U);
}
}
}
// if we find a qualified web with PHINode, add those instructions
// into the exclusion set
if (HasPHI)
{
m_Excludes.merge(defweb);
}
m_Excludes.merge(defweb);
}
}
}
Expand Down Expand Up @@ -390,7 +391,7 @@ void ScalarizeFunction::recoverNonScalarizableInst(Instruction* Inst)
if (isa<VectorType>(Inst->getType())) getSCMEntry(Inst);

// Iterate over all arguments. Check that they all exist (or rebuilt)
if (CallInst * CI = dyn_cast<CallInst>(Inst))
if (CallInst* CI = dyn_cast<CallInst>(Inst))
{
unsigned numOperands = IGCLLVM::getNumArgOperands(CI);
for (unsigned i = 0; i < numOperands; i++)
Expand Down Expand Up @@ -508,7 +509,7 @@ void ScalarizeFunction::scalarizeInstruction(BinaryOperator* BI)
BI->getName(),
BI
);
if (BinaryOperator * BO = dyn_cast<BinaryOperator>(Val)) {
if (BinaryOperator* BO = dyn_cast<BinaryOperator>(Val)) {
// Copy overflow flags if any.
if (isa<OverflowingBinaryOperator>(BO)) {
BO->setHasNoSignedWrap(BI->hasNoSignedWrap());
Expand Down Expand Up @@ -609,7 +610,7 @@ void ScalarizeFunction::scalarizeInstruction(CastInst* CI)
"unexpected type!");
IGC_ASSERT_MESSAGE(
cast<IGCLLVM::FixedVectorType>(CI->getOperand(0)->getType())
->getNumElements() == numElements,
->getNumElements() == numElements,
"unexpected vector width");

// Obtain scalarized argument
Expand Down Expand Up @@ -666,7 +667,7 @@ void ScalarizeFunction::scalarizeInstruction(PHINode* PI)
{
auto* Op = PI->getIncomingValue(i);

if (auto * GII = dyn_cast<GenIntrinsicInst>(Op))
if (auto* GII = dyn_cast<GenIntrinsicInst>(Op))
{
switch (GII->getIntrinsicID())
{
Expand Down Expand Up @@ -694,7 +695,7 @@ void ScalarizeFunction::scalarizeInstruction(PHINode* PI)
phis.pop_back();
for (auto U : PN->users())
{
if (GenIntrinsicInst * GII = dyn_cast<GenIntrinsicInst>(U))
if (GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(U))
{
switch (GII->getIntrinsicID())
{
Expand All @@ -703,11 +704,16 @@ void ScalarizeFunction::scalarizeInstruction(PHINode* PI)
case GenISAIntrinsic::GenISA_sub_group_dpas:
case GenISAIntrinsic::GenISA_dpas:
case GenISAIntrinsic::GenISA_simdBlockWrite:
case GenISAIntrinsic::GenISA_simdBlockWriteBindless:
case GenISAIntrinsic::GenISA_simdMediaBlockWrite:
case GenISAIntrinsic::GenISA_LSC2DBlockWrite:
case GenISAIntrinsic::GenISA_LSC2DBlockWriteAddrPayload:
case GenISAIntrinsic::GenISA_LSCStoreBlock:
recoverNonScalarizableInst(PI);
return;
}
}
else if (PHINode * N = dyn_cast<PHINode>(U))
else if (PHINode* N = dyn_cast<PHINode>(U))
{
if (visited.count(N) == 0) {
visited[N] = 1;
Expand All @@ -720,7 +726,6 @@ void ScalarizeFunction::scalarizeInstruction(PHINode* PI)
phis.clear();
}


// Prepare empty SCM entry for the instruction
SCMEntry* newEntry = getSCMEntry(PI);

Expand Down Expand Up @@ -1047,7 +1052,7 @@ void ScalarizeFunction::scalarizeInstruction(GetElementPtrInst* GI)
auto op1 = baseValue->getType()->isVectorTy() ? operand1[i] : baseValue;
auto op2 = indexValue->getType()->isVectorTy() ? operand2[i] : indexValue;

Type *BaseTy = IGCLLVM::getNonOpaquePtrEltTy(op1->getType());
Type* BaseTy = IGCLLVM::getNonOpaquePtrEltTy(op1->getType());
Value* newGEP = GetElementPtrInst::Create(BaseTy, op1, op2,
VALUE_NAME(GI->getName()), GI);
Value* constIndex = ConstantInt::get(Type::getInt32Ty(context()), i);
Expand Down Expand Up @@ -1123,7 +1128,7 @@ void ScalarizeFunction::obtainScalarizedValues(SmallVectorImpl<Value*>& retValue
retValues[i + destIdx] = undefElement;
}
}
else if (Constant * vectorConst = dyn_cast<Constant>(origValue))
else if (Constant* vectorConst = dyn_cast<Constant>(origValue))
{
V_PRINT(scalarizer, "\t\t\tProper constant: " << *vectorConst << "\n");
// Value is a constant. Break it down to scalars by employing a constant expression
Expand Down Expand Up @@ -1310,7 +1315,7 @@ void ScalarizeFunction::updateSCMEntryWithValues(ScalarizeFunction::SCMEntry* en

if (matchDbgLoc)
{
if (const Instruction * origInst = dyn_cast<Instruction>(origValue))
if (const Instruction* origInst = dyn_cast<Instruction>(origValue))
{
for (unsigned i = 0; i < width; ++i)
{
Expand Down Expand Up @@ -1347,17 +1352,17 @@ void ScalarizeFunction::resolveDeferredInstructions()

// lambda to check if a value is a dummy instruction
auto isDummyValue = [this](Value* val)
{
auto* call = dyn_cast<CallInst>(val);
if (!call) return false;
// If the Value is one of the dummy functions that we created.
for (const auto& function : createdDummyFunctions) {
if (call->getCalledFunction() == function.second)
return true;
}
{
auto* call = dyn_cast<CallInst>(val);
if (!call) return false;
// If the Value is one of the dummy functions that we created.
for (const auto& function : createdDummyFunctions) {
if (call->getCalledFunction() == function.second)
return true;
}

return false;
};
return false;
};

for (auto deferredEntry = m_DRL.begin(); m_DRL.size() > 0;)
{
Expand Down Expand Up @@ -1395,8 +1400,8 @@ void ScalarizeFunction::resolveDeferredInstructions()
newInsts.resize(width);
for (unsigned i = 0; i < width; i++)
{
Value *constIndex = ConstantInt::get(Type::getInt32Ty(context()), i);
Instruction *EE = ExtractElementInst::Create(vectorInst, constIndex,
Value* constIndex = ConstantInt::get(Type::getInt32Ty(context()), i);
Instruction* EE = ExtractElementInst::Create(vectorInst, constIndex,
VALUE_NAME(vectorInst->getName() + ".scalar"), &(*insertLocation));
newInsts[i] = EE;
}
Expand All @@ -1417,7 +1422,7 @@ void ScalarizeFunction::resolveDeferredInstructions()
// It's possible the scalar values are not resolved earlier and are themselves dummy instructions.
// In order to find the real value, we look in the map to see which value replaced it.
if (dummyToScalarMap.count(scalarVal))
scalarVal = dummyToScalarMap[scalarVal];
scalarVal = dummyToScalarMap[scalarVal];
else
totallyResolved = false;
}
Expand All @@ -1441,10 +1446,10 @@ void ScalarizeFunction::resolveDeferredInstructions()
}
}

for (const auto &entry : dummyToScalarMap)
for (const auto& entry : dummyToScalarMap)
{
// Replace and erase all dummy instructions (don't use eraseFromParent as the dummy is not in the function)
Instruction *dummyInst = cast<Instruction>(entry.first);
Instruction* dummyInst = cast<Instruction>(entry.first);
dummyInst->replaceAllUsesWith(entry.second);
dummyInst->deleteValue();
}
Expand All @@ -1453,9 +1458,8 @@ void ScalarizeFunction::resolveDeferredInstructions()
m_DRL.clear();
}

extern "C" FunctionPass* createScalarizerPass(bool selectiveScalarization)
extern "C" FunctionPass * createScalarizerPass(bool selectiveScalarization)
{
return new ScalarizeFunction(selectiveScalarization);
}


Loading

0 comments on commit 2adb59c

Please sign in to comment.