From 0ce3c7dbc7d3980aad2b789bf8519ba8611771a9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 8 Jan 2024 22:30:44 -0500 Subject: [PATCH] logic fix --- enzyme/Enzyme/FunctionUtils.cpp | 90 ++++++++++++++------------------- 1 file changed, 39 insertions(+), 51 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 7a8c37a104ee..e3f19000b776 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2612,9 +2612,7 @@ struct compare_insts { // return true if A appears later than B. bool operator()(Instruction * A, Instruction *B) const { - llvm::errs() << " comparing " << *A << " and " << *B << "\n"; if (A == B) { - llvm::errs() << " - false\n"; return false; } if (A->getParent() == B->getParent()) { @@ -2624,11 +2622,6 @@ struct compare_insts { auto BB = B->getParent(); assert(AB->getParent() == BB->getParent()); - if (DT.dominates(AB, BB)) - return false; - if (!DT.dominates(BB, AB)) - return true; - for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) { if (prev == AB) return false; @@ -2641,18 +2634,13 @@ class DominatorOrderSet : public std::set { public: DominatorOrderSet(DominatorTree &DT, LoopInfo &LI) : std::set(compare_insts(DT, LI)) {} bool contains(Instruction* I) const { - llvm::errs() << " contains(" << *I << ")\n"; - auto v = count(I); - llvm::errs() << " count -> " << v << "\n"; - return v != 0; + auto __i = find(I); + return __i != end(); } void remove(Instruction* I) { - llvm::errs() << "pre remove(" << *I << ")\n"; auto __i = find(I); assert (__i != end()); erase(__i); - llvm::errs() << "post remove(" << *I << ")\n"; - assert(count(I) == 0); } Instruction* pop_back_val() { auto back = end(); @@ -2731,16 +2719,12 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } if (Q.contains(I)) { Q.remove(I); - llvm::errs() << " removed from queue\n"; } - for (auto q : Q) - llvm::errs() << " -- q: " << *q << "\n"; assert(!Q.contains(I)); - llvm::errs() << "erasing I: " << *I << "\n"; I->eraseFromParent(); for (auto op : operands) if (op->getNumUses() == 0) { - if (Q.contains(I)) + if (Q.contains(op)) Q.remove(op); op->eraseFromParent(); } @@ -4116,27 +4100,17 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } // select(cond, const1, b) ?= const2 -> select(cond, const1 ?= const2, b ?= const2) - llvm::errs() <<" pre nottriggered:" << *cur << "\n"; - if (auto fcmp = dyn_cast(cur)) { - llvm::errs() << " outermost nottriggered: " << *fcmp << "\n"; + if (auto fcmp = dyn_cast(cur)) for (int i=0; i<2; i++) - if (auto const2 = dyn_cast(fcmp->getOperand(i))) { - if (auto sel = dyn_cast(fcmp->getOperand(1-i))) { + if (auto const2 = dyn_cast(fcmp->getOperand(i))) + if (auto sel = dyn_cast(fcmp->getOperand(1-i))) if (isa(sel->getTrueValue()) || isa(sel->getFalseValue())) { auto tval = pushcse(B.CreateFCmp(fcmp->getPredicate(), sel->getTrueValue(), const2)); auto fval = pushcse(B.CreateFCmp(fcmp->getPredicate(), sel->getFalseValue(), const2)); auto res = pushcse(B.CreateSelect(sel->getCondition(), tval, fval)); replaceAndErase(cur, res); return "FCmpSelectConst"; - } else - llvm::errs() << " nottriggered(2): " << *fcmp << " " << *sel << "\n"; - } else { - llvm::errs() << " nottriggered(1): " << *fcmp << " " << *fcmp->getOperand(1-i) << "\n"; - } - } else { - llvm::errs() << " nottriggered(0): " << *fcmp << " " << *fcmp->getOperand(1-i) << "\n"; } - } // mul (mul a, const), b -> mul (mul a, b), const // note we avoid the case where b = (mul a, const) since otherwise @@ -5176,6 +5150,13 @@ return true; set.push_back(ty); */ } + static SetTy intersect(const SetTy &lhs, const SetTy &rhs) { + SetTy res; + for (auto &v : lhs) + if (rhs.count(v)) + res.insert(v); + return res; + } static void set_subtract(SetTy &set, const SetTy &rhs) { for (auto &v : rhs) if (set.count(v)) @@ -5199,7 +5180,7 @@ return true; case Type::Compare: return std::make_shared(node, !isEqual); case Type::Union: { - // not of or's is and or not's + // not of or's is and of not's SetTy next; for (const auto &v : values) insert(next, v->notB()); @@ -5208,7 +5189,7 @@ return true; return std::make_shared(Type::Intersect, next); } case Type::Intersect: { - // not of and's is or or not's + // not of and's is or of not's SetTy next; for (const auto &v : values) insert(next, v->notB()); @@ -5434,8 +5415,7 @@ return true; for (auto uv : unionVals) cur = cur->orB(uv, SE); - if (cur->ty != Type::Union) - return andB(cur, SE); + return andB(cur, SE); } SetTy vals = values; @@ -5449,8 +5429,7 @@ return true; // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and // (c or e)) if (ty == Type::Union && rhs->ty == Type::Union) { - SetTy intersection = values; - set_subtract(intersection, rhs->values); + SetTy intersection = intersect(values, rhs->values); if (intersection.size() != 0) { InnerTy other_lhs = remove(intersection); InnerTy other_rhs = rhs->remove(intersection); @@ -5605,21 +5584,9 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, llvm::errs() << " pre fix inner: " << F << "\n"; - llvm::errs() << "\n"; - for (auto a : Q) { - llvm::errs() << " - " << *a << "\n"; - } - llvm::errs() << "\n"; - // Full simplification while (!Q.empty()) { auto cur = Q.pop_back_val(); - - llvm::errs() << "\n"; - for (auto a : Q) { - llvm::errs() << " - " << *a << "\n"; - } - llvm::errs() << "\n"; std::set prev; for (auto v : Q) prev.insert(v); llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n"; @@ -5750,6 +5717,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, ->andB(getSparseConditions(I->getOperand(1), Constraints::all(), I), SE); + llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n"; return res; } @@ -5759,10 +5727,23 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, ->orB(getSparseConditions(I->getOperand(1), Constraints::none(), I), SE); + llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n"; return res; } - // cmp x, 1.0 -> false/true + if (I->getOpcode() == Instruction::Xor) { + for (int i=0; i<2; i++) { + if (auto C=dyn_cast(I->getOperand(i))) + if (C->isOne()) { + auto pres = getSparseConditions(I->getOperand(1-i), defaultFloat->notB(), scope); + auto res = pres->notB(); + llvm::errs() << " negate: " << *I << " pre: " << *pres << " negated: " << *res << "\n"; + llvm::errs() << " getSparse(not, " << *I << ") = " << *res << "\n"; + return res; + } + } + } + if (auto icmp = dyn_cast(I)) { auto lhs = SE.getSCEVAtScope(icmp->getOperand(0), L); auto rhs = SE.getSCEVAtScope(icmp->getOperand(1), L); @@ -5787,6 +5768,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, if (div == div_e) { auto res = std::make_shared( div, icmp->getPredicate() == ICmpInst::ICMP_EQ); + llvm::errs() << " getSparse(icmp, " << *I << ") = " << *res << "\n"; return res; } } @@ -5797,8 +5779,10 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, } } + // cmp x, 1.0 -> false/true if (auto fcmp = dyn_cast(I)) { auto res = defaultFloat; + llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n"; return res; if (fcmp->getPredicate() == CmpInst::FCMP_OEQ || @@ -5816,11 +5800,15 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, return Constraints::all(); }; + // default is condition avoids sparse, negated is condition goes + // to sparse Instruction * context = isa(cond) ? cast(cond) : idx; auto solutions = getSparseConditions(cond, negated ? Constraints::all() : Constraints::none(), context); + llvm::errs() << " solutions pre negate: " << *solutions << "\n"; if (!negated) solutions = solutions->notB(); + llvm::errs() << " solutions post negate: " << *solutions << "\n"; if (!legal) { sawError = true; continue;