Skip to content

Commit

Permalink
Timeout options and more
Browse files Browse the repository at this point in the history
  • Loading branch information
adam committed Mar 13, 2024
1 parent 5f4c383 commit dcee97b
Show file tree
Hide file tree
Showing 4 changed files with 488 additions and 44 deletions.
2 changes: 1 addition & 1 deletion CSiMBA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ cl::opt<bool>
cl::opt<bool> RunOptimizer(
"optimize", cl::Optional,
cl::desc("Optimize LLVM IR before simplification (Default true)"),
cl::value_desc("optimize"), cl::init(false));
cl::value_desc("optimize"), cl::init(true));

cl::opt<bool> Debug("simba-debug", cl::Optional,
cl::desc("Print debug information (Default false)"),
Expand Down
108 changes: 65 additions & 43 deletions LLVMParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ llvm::cl::opt<std::string>
"simplification (Supports: SiMBA/GAMBA"),
cl::value_desc("external-simplifier"), cl::init(""));

llvm::cl::opt<int> MaxVarCount("max-var-count", cl::Optional,
cl::desc("Max variable count for simplification"),
cl::value_desc("max-var-count"), cl::init(20));
llvm::cl::opt<int>
MaxVarCount("max-var-count", cl::Optional,
cl::desc("Max variable count for simplification"),
cl::value_desc("max-var-count"), cl::init(5));

llvm::cl::opt<int> MinASTSize("min-ast-size", cl::Optional,
cl::desc("Minimum AST size for simplification"),
cl::value_desc("min-ast-size"), cl::init(3));
cl::desc("Minimum AST size for simplification"),
cl::value_desc("min-ast-size"), cl::init(4));

namespace LSiMBA {

Expand Down Expand Up @@ -286,10 +287,6 @@ int LLVMParser::extractAndSimplify() {

DominatorTree DT(*F);

// Clone function to compare
ValueToValueMapTy VMap;
auto FClone = CloneFunction(F, VMap);

// Measure Time
auto start = high_resolution_clock::now();

Expand Down Expand Up @@ -327,9 +324,6 @@ int LLVMParser::extractAndSimplify() {
MBASimplified++;
MBACount++;
}

// When we reach this here, replacements are valid
FClone->eraseFromParent();
}

auto stop = high_resolution_clock::now();
Expand Down Expand Up @@ -501,12 +495,10 @@ bool LLVMParser::verify(llvm::SmallVectorImpl<BFSEntry> &AST,
int Operations = 0;

llvm::SmallVector<APInt, 16> par;
llvm::SmallVector<int64_t, 16> parInt;
for (int i = 0; i < NUM_TEST_CASES; i++) {
for (int j = 0; j < VNumber; j++) {
auto v = SP64.next();
par.push_back(APInt(BitWidth, v));
parInt.push_back(v);
}

// Eval AST
Expand All @@ -516,24 +508,17 @@ bool LLVMParser::verify(llvm::SmallVectorImpl<BFSEntry> &AST,
return false;
}

// Mod
/*
if (AP_R0.isSignBitSet()) {
AP_R0 = AP_R0.srem(Modulus);
} else {
AP_R0 = AP_R0.urem(Modulus);
}
*/

// Eval replacement
auto AP_R1 = eval(Expr1_replVar, par, BitWidth, &Operations);
/*
if (AP_R1.isSignBitSet()) {
AP_R1 = AP_R1.srem(Modulus);
} else {
AP_R1 = AP_R1.urem(Modulus);

// Check if replacement is cheaper than original expression
if (AST.size() <= Operations) {
#ifdef DEBUG_SIMPLIFICATION
outs() << "[!] Simplification is no improvement: AST: " << AST.size()
<< " Operations: " << Operations << "\n";
#endif
return false;
}
*/

if (AP_R0.getSExtValue() != AP_R1.getSExtValue()) {
#ifdef DEBUG_SIMPLIFICATION
Expand All @@ -545,15 +530,6 @@ bool LLVMParser::verify(llvm::SmallVectorImpl<BFSEntry> &AST,
par.clear();
}

// Check if replacement is cheaper than original expression
if (AST.size() <= Operations) {
#ifdef DEBUG_SIMPLIFICATION
outs() << "[!] Simpl is no improvement! (" << AST.size()
<< " <= " << Operations << ")\n";
#endif
return false;
}

// Prove with z3
if (this->Prove) {
z3::context Z3Ctx;
Expand Down Expand Up @@ -731,38 +707,52 @@ bool LLVMParser::isSupportedInstruction(llvm::Value *V) {

void LLVMParser::extractCandidates(llvm::Function &F,
std::vector<MBACandidate> &Candidates) {

std::set<llvm::Value *> Visited;
auto isVisited = [&](llvm::Value *I) -> bool {
return Visited.find(I) != Visited.end();
};

// Instruction to look for 'store', 'select', 'gep', 'icmp', 'ret'
for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
switch (I->getOpcode()) {
case Instruction::Store: {
// Check Candidate
auto SI = dyn_cast<StoreInst>(&*I);
auto Op = SI->getValueOperand();
if (isSupportedInstruction(Op)) {
if (!isVisited(Op) && isSupportedInstruction(Op)) {
MBACandidate Cand;
Cand.Candidate = dyn_cast<Instruction>(Op);
Candidates.push_back(Cand);
Visited.insert(Op);
}
} break;
case Instruction::Call: {
auto CI = dyn_cast<CallInst>(&*I);
for (unsigned int i = 0; i < CI->arg_size(); i++) {
if (isSupportedInstruction(CI->getArgOperand(i)->stripPointerCasts())) {
if (isVisited(CI->getArgOperand(i)->stripPointerCasts()))
continue;

MBACandidate Cand;
Cand.Candidate =
dyn_cast<Instruction>(CI->getArgOperand(i)->stripPointerCasts());
Candidates.push_back(Cand);
Visited.insert(CI->getArgOperand(i)->stripPointerCasts());
}
}
} break;
case Instruction::ICmp:
case Instruction::Select: {
for (unsigned int i = 0; i < I->getNumOperands(); i++) {
if (isSupportedInstruction(I->getOperand(i)->stripPointerCasts())) {
if (isVisited(I->getOperand(i)->stripPointerCasts()))
continue;
MBACandidate Cand;
Cand.Candidate =
dyn_cast<Instruction>(I->getOperand(i)->stripPointerCasts());
Candidates.push_back(Cand);
Visited.insert(I->getOperand(i)->stripPointerCasts());
}
}
} break;
Expand All @@ -771,9 +761,13 @@ void LLVMParser::extractCandidates(llvm::Function &F,
auto Index = GEP->getOperand(GEP->getNumOperands() - 1);

if (isSupportedInstruction(Index->stripPointerCasts())) {
if (isVisited(Index->stripPointerCasts()))
continue;

MBACandidate Cand;
Cand.Candidate = dyn_cast<Instruction>(Index->stripPointerCasts());
Candidates.push_back(Cand);
Visited.insert(Index->stripPointerCasts());
}
} break;
case Instruction::Ret: {
Expand All @@ -782,19 +776,25 @@ void LLVMParser::extractCandidates(llvm::Function &F,
continue;

if (isSupportedInstruction(RI->getReturnValue()->stripPointerCasts())) {
if (isVisited(RI->getReturnValue()->stripPointerCasts()))
continue;
MBACandidate Cand;
Cand.Candidate =
dyn_cast<Instruction>(RI->getReturnValue()->stripPointerCasts());
Candidates.push_back(Cand);
Visited.insert(RI->getReturnValue()->stripPointerCasts());
}
} break;
case Instruction::PHI: {
auto Phi = dyn_cast<PHINode>(&*I);
for (auto &Inc : Phi->incoming_values()) {
if (isSupportedInstruction(Inc->stripPointerCasts())) {
if (isVisited(Inc->stripPointerCasts()))
continue;
MBACandidate Cand;
Cand.Candidate = dyn_cast<Instruction>(Inc->stripPointerCasts());
Candidates.push_back(Cand);
Visited.insert(Inc->stripPointerCasts());
}
}
} break;
Expand All @@ -811,14 +811,22 @@ void LLVMParser::extractCandidates(llvm::Function &F,
case Instruction::SRem:
case Instruction::IntToPtr:
case Instruction::BitCast: {
if (isVisited(&*I))
continue;
MBACandidate Cand;
Cand.Candidate = dyn_cast<Instruction>(&*I);
Candidates.push_back(Cand);
Visited.insert(&*I);
}
default: {
}
}
}
#ifdef DEBUG_SIMPLIFICATION
outs() << "[*] Found " << Candidates.size()
<< " candidates Duplicates: " << (Visited.size() - Candidates.size())
<< "\n";
#endif
}

bool LLVMParser::findReplacements(llvm::DominatorTree *DT,
Expand All @@ -836,10 +844,21 @@ bool LLVMParser::findReplacements(llvm::DominatorTree *DT,

// Search for replacements
std::vector<MBACandidate> SubASTCandidates;
auto StartTime = high_resolution_clock::now();
for (int i = 0; i < Candidates.size(); i++) {
auto &Cand = Candidates[i];
getAST(DT, Cand.Candidate, Cand.AST, Cand.Variables, true);
}

auto EndTime = high_resolution_clock::now();
auto Duration = duration_cast<milliseconds>(EndTime - StartTime);
#ifdef DEBUG_SIMPLIFICATION
outs() << "[*] Extracted ASTs in " << Duration.count() << " ms\n";

#endif

for (int i = 0; i < Candidates.size(); i++) {
auto &Cand = Candidates[i];
int s = Cand.AST.size();
if (Cand.AST.size() < MinASTSize) {
continue;
Expand All @@ -857,7 +876,7 @@ bool LLVMParser::findReplacements(llvm::DominatorTree *DT,
}
#endif

// Only handle max 20 Vars
// Only handle max xx Vars
if (Cand.Variables.size() > MaxVarCount) {
Cand.isValid = false;
continue;
Expand Down Expand Up @@ -920,14 +939,14 @@ bool LLVMParser::findReplacements(llvm::DominatorTree *DT,

if (Cand.isValid == false) {
// Could not simplify the whole AST so walk through SubASTs
ReplacementFound = walkSubAST(DT, Cand.AST, SubASTCandidates);
ReplacementFound |= walkSubAST(DT, Cand.AST, SubASTCandidates);
} else {
if (this->Debug) {
outs() << "[*] Full AST Simplified Expression: " << Cand.Replacement
<< "\n";
}

ReplacementFound = true;
ReplacementFound |= true;
}
}

Expand Down Expand Up @@ -979,6 +998,8 @@ bool LLVMParser::walkSubAST(llvm::DominatorTree *DT,
continue;

int BitWidth = C.AST.front().I->getType()->getIntegerBitWidth();
if (BitWidth == 0 || BitWidth > 64)
continue;

auto Modulus = getModulus(BitWidth);

Expand Down Expand Up @@ -1016,8 +1037,9 @@ bool LLVMParser::walkSubAST(llvm::DominatorTree *DT,
}
#endif

if (!SkipVerify)
if (!SkipVerify) {
C.isValid = this->verify(C.AST, C.Replacement, C.Variables);
}

if (C.isValid) {
// Store valid replacement
Expand Down
8 changes: 8 additions & 0 deletions Z3Prover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@ llvm::cl::opt<bool>
llvm::cl::desc("Print SMT2 formula for debugging purposes"),
llvm::cl::value_desc("print-smt"), llvm::cl::init(false));

// Add timeout parameter as string
llvm::cl::opt<std::string>
timeout("timeout", llvm::cl::Optional,
llvm::cl::desc("Timeout for Z3 solver (Default 100)"),
llvm::cl::value_desc("timeout"), llvm::cl::init("100"));

bool prove(z3::expr conjecture) {
z3::context &c = conjecture.ctx();

Z3_global_param_set("timeout", timeout.c_str());

auto t = (z3::tactic(c, "simplify") & z3::tactic(c, "bit-blast") &
z3::tactic(c, "smt"));
auto s = t.mk_solver();
Expand Down
Loading

0 comments on commit dcee97b

Please sign in to comment.