Skip to content

Commit

Permalink
[ActivityAnalysis] Remove isConstantValue call in activity analysis (#…
Browse files Browse the repository at this point in the history
…1608)

* Remove cop2

* Add integration test

* Add test

* Update enzyme/test/ActivityAnalysis/integration.ll

* Update test

* Update test

* Add missing return

* Format and test

* Back gt

* Test fix trial

* Lower test llvm to 15

* Update enzyme/test/ActivityAnalysis/integration.ll

* Update activity printer for opaque type

* update ->

* Update activity analysis

* Update if/elif llvm

* Update llvm versioning
  • Loading branch information
rmoyard authored Feb 21, 2024
1 parent f600e2e commit 1beb98b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 18 deletions.
8 changes: 4 additions & 4 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2073,14 +2073,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
<< "\n";
if (auto SI = dyn_cast<StoreInst>(I)) {
bool cop = !Hypothesis->isConstantValue(TR, SI->getValueOperand());
bool cop2 = !Hypothesis->isConstantValue(TR, SI->getPointerOperand());
// bool cop2 = !Hypothesis->isConstantValue(TR,
// SI->getPointerOperand());
if (EnzymePrintActivity)
llvm::errs() << " -- store potential activity: " << (int)cop << ","
<< (int)cop2 << ","
llvm::errs() << " -- store potential activity: " << (int)cop
<< " - " << *SI << " of "
<< " Val=" << *Val << "\n";
potentialStore = I;
if (cop && cop2)
if (cop) // && cop2)
potentiallyActiveStore = SI;
} else if (auto MTI = dyn_cast<MemTransferInst>(I)) {
bool cop = !Hypothesis->isConstantValue(TR, MTI->getArgOperand(1));
Expand Down
38 changes: 24 additions & 14 deletions enzyme/Enzyme/ActivityAnalysisPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) {
if (a.getType()->isFPOrFPVectorTy()) {
dt = ConcreteType(a.getType()->getScalarType());
} else if (a.getType()->isPointerTy()) {
#if LLVM_VERSION_MAJOR >= 17
#else
auto et = a.getType()->getPointerElementType();
if (et->isFPOrFPVectorTy()) {
dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
} else if (et->isPointerTy()) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
#if LLVM_VERSION_MAJOR < 17
#if LLVM_VERSION_MAJOR >= 13
if (a.getContext().supportsTypedPointers()) {
#endif
auto et = a.getType()->getPointerElementType();
if (et->isFPOrFPVectorTy()) {
dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
} else if (et->isPointerTy()) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
}
#if LLVM_VERSION_MAJOR >= 13
}
#endif
#endif
} else if (a.getType()->isIntOrIntVectorTy()) {
dt = ConcreteType(BaseType::Integer);
Expand All @@ -113,14 +118,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) {
if (F.getReturnType()->isFPOrFPVectorTy()) {
dt = ConcreteType(F.getReturnType()->getScalarType());
} else if (F.getReturnType()->isPointerTy()) {
#if LLVM_VERSION_MAJOR >= 17
#else
auto et = F.getReturnType()->getPointerElementType();
if (et->isFPOrFPVectorTy()) {
dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
} else if (et->isPointerTy()) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
#if LLVM_VERSION_MAJOR < 17
#if LLVM_VERSION_MAJOR >= 13
if (F.getContext().supportsTypedPointers()) {
#endif
auto et = F.getReturnType()->getPointerElementType();
if (et->isFPOrFPVectorTy()) {
dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
} else if (et->isPointerTy()) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
}
#if LLVM_VERSION_MAJOR >= 13
}
#endif
#endif
} else if (F.getReturnType()->isIntOrIntVectorTy()) {
dt = ConcreteType(BaseType::Integer);
Expand Down
93 changes: 93 additions & 0 deletions enzyme/test/ActivityAnalysis/integration.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %OPnewLoadEnzyme -passes="print-activity-analysis" -activity-analysis-func=f.preprocess -S | FileCheck %s; fi

declare void @free(ptr)

declare ptr @malloc(i64)

; This function just returns 2*input, its derivate should be 2.0.
define void @f.preprocess(ptr %param, i64 %mallocsize, ptr %res) {

; arithmetic block, changing anything here makes the bug go away
%buffer1 = call ptr @malloc(i64 %mallocsize)
%tmp = call ptr @malloc(i64 72)
%ptrtoint = ptrtoint ptr %tmp to i64
%and = and i64 %ptrtoint, -64
%inttoptr = inttoptr i64 %and to ptr
%loadarg = load double, ptr %param
%storedargmul = fmul double %loadarg, 4.000000e+00
store double %storedargmul, ptr %inttoptr
call void @free(ptr %tmp)
store double %storedargmul, ptr %buffer1

; prep arg 0 by setting the aligned pointer to the input
%arg0 = alloca { ptr, ptr, i64 }
%arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1
store ptr %param, ptr %arg0_aligned

; prep arg 1 by setting the aligned pointer to buffer1
%arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }
%arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1
store ptr %buffer1, ptr %arg1_aligned

; prep arg 2 by setting the aligned pointer to buffer2
%arg2 = alloca { ptr, ptr, i64 }
%arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1
%buffer2 = call ptr @malloc(i64 8)
store ptr %buffer2, ptr %arg2_aligned

; nested call, required for bug
call void @nested(ptr %arg0, ptr %arg1, ptr %arg2)

; return a result from this function, needs to be positioned after arithmetic block for bug
%x = load double, ptr %param
%y = fmul double %x, 2.0
store double %y, ptr %res

ret void
}

; Identity function, 2nd argument required for bug (but not used)
define void @nested(ptr %arg0, ptr %arg1, ptr %arg2) {

; load aligned pointer from %arg0 & load argument value
%loadarg = load { ptr, ptr, i64 }, ptr %arg0
%extractarg = extractvalue { ptr, ptr, i64 } %loadarg, 1
%loadextractarg = load double, ptr %extractarg

; load aligned pointer from %arg2 & store result value
%loadarg2 = load { ptr, ptr, i64 }, ptr %arg2
%extractarg2 = extractvalue { ptr, ptr, i64 } %loadarg2, 1
store double %loadextractarg, ptr %extractarg2

ret void
}

; CHECK: ptr %param: icv:0
; CHECK-NEXT: i64 %mallocsize: icv:1
; CHECK-NEXT: ptr %res: icv:0

; CHECK: %buffer1 = call ptr @malloc(i64 %mallocsize): icv:0 ici:1
; CHECK-NEXT: %tmp = call ptr @malloc(i64 72): icv:1 ici:1
; CHECK-NEXT: %ptrtoint = ptrtoint ptr %tmp to i64: icv:1 ici:1
; CHECK-NEXT: %and = and i64 %ptrtoint, -64: icv:1 ici:1
; CHECK-NEXT: %inttoptr = inttoptr i64 %and to ptr: icv:1 ici:1
; CHECK-NEXT: %loadarg = load double, ptr %param, align 8: icv:0 ici:0
; CHECK-NEXT: %storedargmul = fmul double %loadarg, 4.000000e+00: icv:0 ici:0
; CHECK-NEXT: store double %storedargmul, ptr %inttoptr, align 8: icv:1 ici:1
; CHECK-NEXT: call void @free(ptr %tmp): icv:1 ici:1
; CHECK-NEXT: store double %storedargmul, ptr %buffer1, align 8: icv:1 ici:0
; CHECK-NEXT: %arg0 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1
; CHECK-NEXT: %arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1: icv:0 ici:1
; CHECK-NEXT: store ptr %param, ptr %arg0_aligned, align 8: icv:1 ici:0
; CHECK-NEXT: %arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8: icv:0 ici:1
; CHECK-NEXT: %arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1: icv:0 ici:1
; CHECK-NEXT: store ptr %buffer1, ptr %arg1_aligned, align 8: icv:1 ici:0
; CHECK-NEXT: %arg2 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1
; CHECK-NEXT: %arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1: icv:0 ici:1
; CHECK-NEXT: %buffer2 = call ptr @malloc(i64 8): icv:0 ici:1
; CHECK-NEXT: store ptr %buffer2, ptr %arg2_aligned, align 8: icv:1 ici:0
; CHECK-NEXT: call void @nested(ptr %arg0, ptr %arg1, ptr %arg2): icv:1 ici:0
; CHECK-NEXT: %x = load double, ptr %param, align 8: icv:0 ici:0
; CHECK-NEXT: %y = fmul double %x, 2.000000e+00: icv:0 ici:0
; CHECK-NEXT: store double %y, ptr %res, align 8: icv:1 ici:0
; CHECK-NEXT: ret void: icv:1 ici:1

0 comments on commit 1beb98b

Please sign in to comment.