Skip to content

Commit

Permalink
logic fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 9, 2024
1 parent 46e713f commit 0ce3c7d
Showing 1 changed file with 39 additions and 51 deletions.
90 changes: 39 additions & 51 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2612,9 +2612,7 @@ struct compare_insts {

// return true if A appears later than B.
bool operator()(Instruction * A, Instruction *B) const {
llvm::errs() << " comparing " << *A << " and " << *B << "\n";
if (A == B) {
llvm::errs() << " - false\n";
return false;
}
if (A->getParent() == B->getParent()) {
Expand All @@ -2624,11 +2622,6 @@ struct compare_insts {
auto BB = B->getParent();
assert(AB->getParent() == BB->getParent());

if (DT.dominates(AB, BB))
return false;
if (!DT.dominates(BB, AB))
return true;

for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) {
if (prev == AB)
return false;
Expand All @@ -2641,18 +2634,13 @@ class DominatorOrderSet : public std::set<Instruction*, compare_insts> {
public:
DominatorOrderSet(DominatorTree &DT, LoopInfo &LI) : std::set<Instruction*, compare_insts>(compare_insts(DT, LI)) {}
bool contains(Instruction* I) const {
llvm::errs() << " contains(" << *I << ")\n";
auto v = count(I);
llvm::errs() << " count -> " << v << "\n";
return v != 0;
auto __i = find(I);
return __i != end();
}
void remove(Instruction* I) {
llvm::errs() << "pre remove(" << *I << ")\n";
auto __i = find(I);
assert (__i != end());
erase(__i);
llvm::errs() << "post remove(" << *I << ")\n";
assert(count(I) == 0);
}
Instruction* pop_back_val() {
auto back = end();
Expand Down Expand Up @@ -2731,16 +2719,12 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}
if (Q.contains(I)) {
Q.remove(I);
llvm::errs() << " removed from queue\n";
}
for (auto q : Q)
llvm::errs() << " -- q: " << *q << "\n";
assert(!Q.contains(I));
llvm::errs() << "erasing I: " << *I << "\n";
I->eraseFromParent();
for (auto op : operands)
if (op->getNumUses() == 0) {
if (Q.contains(I))
if (Q.contains(op))
Q.remove(op);
op->eraseFromParent();
}
Expand Down Expand Up @@ -4116,27 +4100,17 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}

// select(cond, const1, b) ?= const2 -> select(cond, const1 ?= const2, b ?= const2)
llvm::errs() <<" pre nottriggered:" << *cur << "\n";
if (auto fcmp = dyn_cast<FCmpInst>(cur)) {
llvm::errs() << " outermost nottriggered: " << *fcmp << "\n";
if (auto fcmp = dyn_cast<FCmpInst>(cur))
for (int i=0; i<2; i++)
if (auto const2 = dyn_cast<Constant>(fcmp->getOperand(i))) {
if (auto sel = dyn_cast<SelectInst>(fcmp->getOperand(1-i))) {
if (auto const2 = dyn_cast<Constant>(fcmp->getOperand(i)))
if (auto sel = dyn_cast<SelectInst>(fcmp->getOperand(1-i)))
if (isa<Constant>(sel->getTrueValue()) || isa<Constant>(sel->getFalseValue())) {
auto tval = pushcse(B.CreateFCmp(fcmp->getPredicate(), sel->getTrueValue(), const2));
auto fval = pushcse(B.CreateFCmp(fcmp->getPredicate(), sel->getFalseValue(), const2));
auto res = pushcse(B.CreateSelect(sel->getCondition(), tval, fval));
replaceAndErase(cur, res);
return "FCmpSelectConst";
} else
llvm::errs() << " nottriggered(2): " << *fcmp << " " << *sel << "\n";
} else {
llvm::errs() << " nottriggered(1): " << *fcmp << " " << *fcmp->getOperand(1-i) << "\n";
}
} else {
llvm::errs() << " nottriggered(0): " << *fcmp << " " << *fcmp->getOperand(1-i) << "\n";
}
}

// mul (mul a, const), b -> mul (mul a, b), const
// note we avoid the case where b = (mul a, const) since otherwise
Expand Down Expand Up @@ -5176,6 +5150,13 @@ return true;
set.push_back(ty);
*/
}
static SetTy intersect(const SetTy &lhs, const SetTy &rhs) {
SetTy res;
for (auto &v : lhs)
if (rhs.count(v))
res.insert(v);
return res;
}
static void set_subtract(SetTy &set, const SetTy &rhs) {
for (auto &v : rhs)
if (set.count(v))
Expand All @@ -5199,7 +5180,7 @@ return true;
case Type::Compare:
return std::make_shared<Constraints>(node, !isEqual);
case Type::Union: {
// not of or's is and or not's
// not of or's is and of not's
SetTy next;
for (const auto &v : values)
insert(next, v->notB());
Expand All @@ -5208,7 +5189,7 @@ return true;
return std::make_shared<Constraints>(Type::Intersect, next);
}
case Type::Intersect: {
// not of and's is or or not's
// not of and's is or of not's
SetTy next;
for (const auto &v : values)
insert(next, v->notB());
Expand Down Expand Up @@ -5434,8 +5415,7 @@ return true;
for (auto uv : unionVals)
cur = cur->orB(uv, SE);

if (cur->ty != Type::Union)
return andB(cur, SE);
return andB(cur, SE);
}

SetTy vals = values;
Expand All @@ -5449,8 +5429,7 @@ return true;
// (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and
// (c or e))
if (ty == Type::Union && rhs->ty == Type::Union) {
SetTy intersection = values;
set_subtract(intersection, rhs->values);
SetTy intersection = intersect(values, rhs->values);
if (intersection.size() != 0) {
InnerTy other_lhs = remove(intersection);
InnerTy other_rhs = rhs->remove(intersection);
Expand Down Expand Up @@ -5605,21 +5584,9 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,

llvm::errs() << " pre fix inner: " << F << "\n";

llvm::errs() << "<QUEUE>\n";
for (auto a : Q) {
llvm::errs() << " - " << *a << "\n";
}
llvm::errs() << "</QUEUE>\n";

// Full simplification
while (!Q.empty()) {
auto cur = Q.pop_back_val();

llvm::errs() << "<QUEUE POST POP>\n";
for (auto a : Q) {
llvm::errs() << " - " << *a << "\n";
}
llvm::errs() << "</QUEUE POST POP>\n";
std::set<Instruction*> prev;
for (auto v : Q) prev.insert(v);
llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n";
Expand Down Expand Up @@ -5750,6 +5717,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
->andB(getSparseConditions(I->getOperand(1),
Constraints::all(), I),
SE);
llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n";
return res;
}

Expand All @@ -5759,10 +5727,23 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
->orB(getSparseConditions(I->getOperand(1),
Constraints::none(), I),
SE);
llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n";
return res;
}

// cmp x, 1.0 -> false/true
if (I->getOpcode() == Instruction::Xor) {
for (int i=0; i<2; i++) {
if (auto C=dyn_cast<ConstantInt>(I->getOperand(i)))
if (C->isOne()) {
auto pres = getSparseConditions(I->getOperand(1-i), defaultFloat->notB(), scope);
auto res = pres->notB();
llvm::errs() << " negate: " << *I << " pre: " << *pres << " negated: " << *res << "\n";
llvm::errs() << " getSparse(not, " << *I << ") = " << *res << "\n";
return res;
}
}
}

if (auto icmp = dyn_cast<ICmpInst>(I)) {
auto lhs = SE.getSCEVAtScope(icmp->getOperand(0), L);
auto rhs = SE.getSCEVAtScope(icmp->getOperand(1), L);
Expand All @@ -5787,6 +5768,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
if (div == div_e) {
auto res = std::make_shared<Constraints>(
div, icmp->getPredicate() == ICmpInst::ICMP_EQ);
llvm::errs() << " getSparse(icmp, " << *I << ") = " << *res << "\n";
return res;
}
}
Expand All @@ -5797,8 +5779,10 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
}
}

// cmp x, 1.0 -> false/true
if (auto fcmp = dyn_cast<FCmpInst>(I)) {
auto res = defaultFloat;
llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n";
return res;

if (fcmp->getPredicate() == CmpInst::FCMP_OEQ ||
Expand All @@ -5816,11 +5800,15 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
return Constraints::all();
};

// default is condition avoids sparse, negated is condition goes
// to sparse
Instruction * context = isa<Instruction>(cond) ? cast<Instruction>(cond) : idx;
auto solutions = getSparseConditions(cond, negated ? Constraints::all()
: Constraints::none(), context);
llvm::errs() << " solutions pre negate: " << *solutions << "\n";
if (!negated)
solutions = solutions->notB();
llvm::errs() << " solutions post negate: " << *solutions << "\n";
if (!legal) {
sawError = true;
continue;
Expand Down

0 comments on commit 0ce3c7d

Please sign in to comment.