Skip to content

Commit

Permalink
no jlinstsimpl
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Aug 11, 2024
1 parent da2dadc commit 3ae6b89
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 128 deletions.
113 changes: 14 additions & 99 deletions enzyme/Enzyme/JLInstSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,58 +129,6 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst) {
return true;
}

static inline SetVector<llvm::Value *> getBaseObjects(llvm::Value *V,
bool offsetAllowed) {
SetVector<llvm::Value *> results;

SmallPtrSet<llvm::Value *, 2> seen;
SmallVector<llvm::Value *, 1> todo = {V};

while (todo.size()) {
auto cur = todo.back();
todo.pop_back();
if (seen.count(cur))
continue;
seen.insert(cur);
auto obj = getBaseObject(cur, offsetAllowed);
if (auto PN = dyn_cast<PHINode>(obj)) {
for (auto &val : PN->incoming_values()) {
todo.push_back(val);
}
continue;
}
if (auto SI = dyn_cast<SelectInst>(obj)) {
todo.push_back(SI->getTrueValue());
todo.push_back(SI->getFalseValue());
continue;
}
results.insert(obj);
}
return results;
}

bool noaliased_or_arg(SetVector<llvm::Value *> &lhs_v,
SetVector<llvm::Value *> &rhs_v) {
for (auto lhs : lhs_v) {
auto lhs_na = isNoAlias(lhs);
auto lhs_arg = isa<Argument>(lhs);

// This LHS value is neither noalias or an argument
if (!lhs_na && !lhs_arg)
return false;

for (auto rhs : rhs_v) {
if (lhs == rhs)
return false;
if (isNoAlias(lhs))
continue;
if (!lhs_na && !isa<Argument>(rhs))
return false;
}
}
return true;
}

bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
llvm::AAResults &AA, llvm::LoopInfo &LI) {
bool changed = false;
Expand Down Expand Up @@ -227,59 +175,33 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
}

if (legal) {
auto lhs_v = getBaseObjects(I.getOperand(0), /*offsetAllowed*/ false);
auto rhs_v = getBaseObjects(I.getOperand(1), /*offsetAllowed*/ false);
if (lhs_v.size() == 1 && rhs_v.size() == 1 && lhs_v[0] == rhs_v[0]) {
auto lhs = getBaseObject(I.getOperand(0), /*offsetAllowed*/ false);
auto rhs = getBaseObject(I.getOperand(1), /*offsetAllowed*/ false);
if (lhs == rhs) {
auto repval = ICmpInst::isTrueWhenEqual(pred)
? ConstantInt::get(I.getType(), 1)
: ConstantInt::get(I.getType(), 0);
I.replaceAllUsesWith(repval);
changed = true;
continue;
}
if (noaliased_or_arg(lhs_v, rhs_v)) {
if ((isNoAlias(lhs) && (isNoAlias(rhs) || isa<Argument>(rhs))) ||
(isNoAlias(rhs) && isa<Argument>(lhs))) {
auto repval = ICmpInst::isTrueWhenEqual(pred)
? ConstantInt::get(I.getType(), 0)
: ConstantInt::get(I.getType(), 1);
I.replaceAllUsesWith(repval);
changed = true;
continue;
}
bool loadlegal = true;
SmallVector<LoadInst *, 1> llhs, lrhs;
for (auto lhs : lhs_v) {
auto ld = dyn_cast<LoadInst>(lhs);
if (!ld || !isa<PointerType>(ld->getType())) {
loadlegal = false;
break;
}
llhs.push_back(ld);
}
for (auto rhs : rhs_v) {
auto ld = dyn_cast<LoadInst>(rhs);
if (!ld || !isa<PointerType>(ld->getType())) {
loadlegal = false;
break;
}
lrhs.push_back(ld);
}
SetVector<Value *> llhs_s, lrhs_s;
for (auto v : llhs) {
for (auto obj :
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
llhs_s.insert(obj);
}
}
for (auto v : lrhs) {
for (auto obj :
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
lrhs_s.insert(obj);
}
}
// TODO handle multi size
if (llhs_s.size() == 1 && lrhs_s.size() == 1 && loadlegal) {
auto lhsv = llhs_s[0];
auto rhsv = lrhs_s[0];
auto llhs = dyn_cast<LoadInst>(lhs);
auto lrhs = dyn_cast<LoadInst>(rhs);
if (llhs && lrhs && isa<PointerType>(llhs->getType()) &&
isa<PointerType>(lrhs->getType())) {
auto lhsv =
getBaseObject(llhs->getOperand(0), /*offsetAllowed*/ false);
auto rhsv =
getBaseObject(lrhs->getOperand(0), /*offsetAllowed*/ false);
if ((isNoAlias(lhsv) && (isNoAlias(rhsv) || isa<Argument>(rhsv) ||
notCapturedBefore(lhsv, &I))) ||
(isNoAlias(rhsv) &&
Expand All @@ -303,14 +225,7 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
if (!I->mayWriteToMemory())
return /*earlyBreak*/ false;

for (auto LI : llhs)
if (writesToMemoryReadBy(nullptr, AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
overwritten = true;
return /*earlyBreak*/ true;
}
for (auto LI : lrhs)
for (auto LI : {llhs, lrhs})
if (writesToMemoryReadBy(nullptr, AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
Expand Down
26 changes: 0 additions & 26 deletions enzyme/test/Enzyme/JLSimplify/yesptr2.ll

This file was deleted.

6 changes: 3 additions & 3 deletions enzyme/test/Enzyme/ReverseMode/blas_diffuse.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ target triple = "x86_64-unknown-linux-gnu"

%struct.Prod = type { ptr, double }

declare i32 @dgemm_(ptr nocapture noundef readonly %transa_t, ptr nocapture noundef readonly %transb_t, ptr nocapture noundef readonly %m, ptr nocapture noundef readonly %n, ptr nocapture noundef readonly %k, ptr nocapture noundef readonly %alpha, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %lda, ptr nocapture noundef readonly %b, ptr nocapture noundef readonly %ldb, ptr nocapture noundef readonly %beta, ptr nocapture noundef %c, ptr nocapture noundef readonly %ldc)
declare i32 @dgemm_(ptr nocapture noundef readonly %transa_t, ptr nocapture noundef readonly %transb_t, ptr nocapture noundef readonly %m, ptr nocapture noundef readonly %n, ptr nocapture noundef readonly %k, ptr nocapture noundef readonly %alpha, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %lda, ptr nocapture noundef readonly %b, ptr nocapture noundef readonly %ldb, ptr nocapture noundef readonly %beta, ptr nocapture noundef %c, ptr nocapture noundef readonly %ldc, i32, i32)

; Function Attrs: mustprogress noinline nounwind uwtable
define dso_local void @_Z3mulR4ProdPd(ptr nocapture noundef nonnull align 8 dereferenceable(16) %P, ptr noalias nocapture noundef readonly %rhs) {
Expand All @@ -22,9 +22,9 @@ entry:
store i32 2, ptr %ten, align 4
store double 1.000000e+00, ptr %one, align 8
store double 0.000000e+00, ptr %zero, align 8
%call1 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten)
%call1 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten, i32 1, i32 1)
%0 = load ptr, ptr %P, align 8
%call2 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %zero, ptr noundef %0, ptr noundef nonnull %ten)
%call2 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %zero, ptr noundef %0, ptr noundef nonnull %ten, i32 1, i32 1)
%alpha = getelementptr inbounds %struct.Prod, ptr %P, i64 0, i32 1
store double 0.000000e+00, ptr %alpha, align 8
ret void
Expand Down

0 comments on commit 3ae6b89

Please sign in to comment.