Skip to content

Commit

Permalink
Efficient fwd mode potrf (#1953)
Browse files Browse the repository at this point in the history
* Efficient fwd mode potrf

* Now featuring tests

* correct bug

* Update BlasDerivatives.td

* wip rev

* cleanup

* fixup

* fixup

* improve overwritten in blas

* fix

* fix hide uplo

* fixup

* fixed

* Disable reverse complex

* fmt

* correct cache err

* Reverse mode potrf

* fix

* fix

* layout fix

* Fix

* fix weird cublas

* fix cublas

* remove mutable check

* fix
  • Loading branch information
wsmoses authored Jul 1, 2024
1 parent b1f627a commit d23b53a
Show file tree
Hide file tree
Showing 18 changed files with 1,072 additions and 309 deletions.
239 changes: 178 additions & 61 deletions enzyme/Enzyme/BlasDerivatives.td

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8226,7 +8226,6 @@ void GradientUtils::eraseFictiousPHIs() {
for (auto pair : phis) {
auto pp = pair.first;
if (pp->getNumUses() != 0) {
assert(0);
if (CustomErrorHandler) {
std::string str;
raw_string_ostream ss(str);
Expand Down
128 changes: 91 additions & 37 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,10 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
llvm::ArrayRef<llvm::Value *> args,
llvm::Type *copy_retty,
llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
auto copy_name =
std::string(blas.prefix) + blas.floatType + "copy" + blas.suffix;
const bool cublasv2 =
blas.prefix == "cublas" && StringRef(blas.suffix).contains("v2");
auto copy_name = std::string(blas.prefix) + blas.floatType + "copy" +
(cublasv2 ? "" : blas.suffix);

SmallVector<Type *, 1> tys;
for (auto arg : args)
Expand Down Expand Up @@ -793,9 +795,7 @@ void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas,
cast<PointerType>(blasalpha->getType())->getAddressSpace()));
alpha = B1.CreateLoad(fpTy, VP);
}
Value *is_u = is_uper(B1, blasuplo, byRef);
// Value *k = B1.CreateSelect(is_u, ConstantInt::get(IT, 0),
// ConstantInt::get(IT, 1), "k");
Value *is_l = is_lower(B1, blasuplo, byRef, /*cublas*/ false);
B1.CreateCondBr(B1.CreateICmpEQ(n, ConstantInt::get(IT, 0)), end, init);

IRBuilder<> B2(init);
Expand All @@ -811,7 +811,7 @@ void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas,
blasdAP,
PointerType::get(
fpTy, cast<PointerType>(blasdAP->getType())->getAddressSpace()));
B2.CreateCondBr(is_u, uper_code, lower_code);
B2.CreateCondBr(is_l, lower_code, uper_code);

IRBuilder<> B3(uper_code);
B3.setFastMathFlags(getFast());
Expand Down Expand Up @@ -2654,9 +2654,11 @@ std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
#endif
{
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv",
"syrk", "nrm2", "trmm", "trmv", "symm"};
const char *floatType[] = {"s", "d"}; // c, z
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm",
"spmv", "syrk", "nrm2", "trmm", "trmv",
"symm", "potrf", "copy", "spmv", "syr2k",
"potrs", "getrf", "getrs", "trtrs", "getri"};
const char *floatType[] = {"s", "d", "c", "z"};
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
const char *suffixes[] = {"", "_", "64_", "_64_"};
for (auto t : floatType) {
Expand All @@ -2674,8 +2676,8 @@ llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
}
}
// c interface to cublas
const char *cuCFloatType[] = {"S", "D"}; // c, z
const char *cuFFloatType[] = {"s", "d"}; // c, z
const char *cuCFloatType[] = {"S", "D", "C", "Z"};
const char *cuFFloatType[] = {"s", "d", "c", "z"};
const char *cuCPrefixes[] = {"cublas"};
const char *cuSuffixes[] = {"", "_v2", "_64", "_v2_64"};
for (auto t : llvm::enumerate(cuCFloatType)) {
Expand Down Expand Up @@ -2737,13 +2739,13 @@ llvm::FastMathFlags getFast() {
void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty,
llvm::SmallVectorImpl<llvm::Value *> &cacheValues,
llvm::IRBuilder<> &BuilderZ, const Twine &name) {
if (!cache_arg)
return;
if (!arg->getType()->isPointerTy()) {
assert(arg->getType() == ty);
cacheValues.push_back(arg);
return;
}
if (!cache_arg)
return;
#if LLVM_VERSION_MAJOR < 17
auto PT = cast<PointerType>(arg->getType());
#if LLVM_VERSION_MAJOR <= 14
Expand Down Expand Up @@ -2796,38 +2798,47 @@ llvm::Value *to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef,
return allocV;
}

llvm::Value *select_vec_dims(IRBuilder<> &B, llvm::Value *trans,
llvm::Value *dim1, llvm::Value *dim2, bool byRef,
bool cublas) {
auto norm = is_normal(B, trans, byRef, cublas);
Value *width = B.CreateSelect(norm, dim1, dim2);

return width;
}

Value *is_uper(IRBuilder<> &B, Value *trans, bool byRef) {
IntegerType *charTy;
Value *is_lower(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas) {
if (cublas) {
Value *isNormal = nullptr;
isNormal = B.CreateICmpEQ(
uplo, ConstantInt::get(uplo->getType(),
/*cublasFillMode_t::CUBLAS_FILL_MODE_LOWER*/ 0));
return isNormal;
}
if (auto CI = dyn_cast<ConstantInt>(uplo)) {
if (CI->getValue() == 'L' || CI->getValue() == 'l')
return ConstantInt::getTrue(B.getContext());
if (CI->getValue() == 'U' || CI->getValue() == 'u')
return ConstantInt::getFalse(B.getContext());
}
if (byRef) {
// can't inspect opaque ptr, so assume 8 (Julia)
charTy = IntegerType::get(trans->getContext(), 8);
trans = B.CreateLoad(charTy, trans, "loaded.trans");
IntegerType *charTy = IntegerType::get(uplo->getContext(), 8);
uplo = B.CreateLoad(charTy, uplo, "loaded.trans");

auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L'));
auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l'));
// fortran blas
return B.CreateOr(isl, isL);
} else {
// we can inspect scalars
unsigned int len = trans->getType()->getScalarSizeInBits();
charTy = IntegerType::get(trans->getContext(), len);
auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 122));
// TODO we really should just return capi, but for sake of consistency,
// we will accept either here.
auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L'));
auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l'));
return B.CreateOr(capi, B.CreateOr(isl, isL));
}

Value *isUper =
B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'u')),
B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'U')));
return isUper;
}

llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef,
bool cublas) {
if (cublas) {
Value *isNormal = nullptr;
isNormal = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
isNormal = B.CreateICmpEQ(
trans, ConstantInt::get(trans->getType(),
/*cublasOperation_t::CUBLAS_OP_N*/ 0));
return isNormal;
}
// Explicitly support 'N' always, since we use in the rule infra
Expand All @@ -2841,13 +2852,56 @@ llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef,
IntegerType *charTy = IntegerType::get(trans->getContext(), 8);
trans = B.CreateLoad(charTy, trans, "loaded.trans");

auto isN = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N'));
auto isn = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n'));
auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
// fortran blas
return B.CreateOr(isn, isN);
} else {
// TODO we really should just return capi, but for sake of consistency,
// we will accept either here.
// we can inspect scalars
return B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
auto capi = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
// fortran blas
return B.CreateOr(capi, B.CreateOr(isn, isN));
}
}

llvm::Value *is_left(IRBuilder<> &B, llvm::Value *side, bool byRef,
bool cublas) {
if (cublas) {
Value *isNormal = nullptr;
isNormal = B.CreateICmpEQ(
side, ConstantInt::get(side->getType(),
/*cublasSideMode_t::CUBLAS_SIDE_LEFT*/ 0));
return isNormal;
}
// Explicitly support 'L'/'R' always, since we use in the rule infra
if (auto CI = dyn_cast<ConstantInt>(side)) {
if (CI->getValue() == 'L' || CI->getValue() == 'l')
return ConstantInt::getTrue(B.getContext());
if (CI->getValue() == 'R' || CI->getValue() == 'r')
return ConstantInt::getFalse(B.getContext());
}
if (byRef) {
// can't inspect opaque ptr, so assume 8 (Julia)
IntegerType *charTy = IntegerType::get(side->getContext(), 8);
side = B.CreateLoad(charTy, side, "loaded.side");

auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L'));
auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l'));
// fortran blas
return B.CreateOr(isl, isL);
} else {
// TODO we really should just return capi, but for sake of consistency,
// we will accept either here.
// we can inspect scalars
auto capi = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 141));
auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L'));
auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l'));
// fortran blas
return B.CreateOr(capi, B.CreateOr(isl, isL));
}
}

Expand Down
9 changes: 5 additions & 4 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1848,10 +1848,11 @@ static inline llvm::SmallVector<llvm::Value *, 1> concat_values(T &&...t) {

llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef,
bool cublas);
llvm::Value *is_uper(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef);
llvm::Value *select_vec_dims(llvm::IRBuilder<> &B, llvm::Value *trans,
llvm::Value *dim1, llvm::Value *dim2, bool byRef,
bool cublas);
llvm::Value *is_left(llvm::IRBuilder<> &B, llvm::Value *side, bool byRef,
bool cublas);
llvm::Value *is_lower(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef,
bool cublas);

// first one assume V is an Integer
llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool cublas);
// secon one assume V is an Integer or a ptr to an int (depends on byRef)
Expand Down
17 changes: 8 additions & 9 deletions enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,16 @@ entry:
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double*
; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N)
; CHECK-NEXT: %1 = select i1 true, i32 %N, i32 %N
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %N, 8
; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1)
; CHECK-NEXT: %2 = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2
; CHECK-NEXT: %3 = getelementptr inbounds i8*, i8** %2, i64 %iv
; CHECK-NEXT: store i8* %malloccall2, i8** %3, align 8, !invariant.group !7
; CHECK-NEXT: %4 = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5
; CHECK-NEXT: %5 = getelementptr inbounds i8*, i8** %4, i64 %iv
; CHECK-NEXT: store i8* %malloccall, i8** %5, align 8, !invariant.group !8
; CHECK-NEXT: %[[i2:.+]] = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2
; CHECK-NEXT: %[[i3:.+]] = getelementptr inbounds i8*, i8** %[[i2]], i64 %iv
; CHECK-NEXT: store i8* %malloccall2, i8** %[[i3]], align 8, !invariant.group !7
; CHECK-NEXT: %[[i4:.+]] = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5
; CHECK-NEXT: %[[i5:.+]] = getelementptr inbounds i8*, i8** %[[i4]], i64 %iv
; CHECK-NEXT: store i8* %malloccall, i8** %[[i5]], align 8, !invariant.group !8
; CHECK-NEXT: %cache.x = bitcast i8* %malloccall2 to double*
; CHECK-NEXT: call void @cblas_dcopy(i32 %1, double* %x0, i32 1, double* %cache.x, i32 1)
; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x, i32 1)
; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1)
; CHECK-NEXT: %exitcond.not = icmp eq i64 %iv.next, 5000
; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
Expand Down
17 changes: 8 additions & 9 deletions enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,16 @@ entry:
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double*
; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N)
; CHECK-NEXT: %1 = select i1 true, i32 %N, i32 %N
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %N, 8
; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1)
; CHECK-NEXT: %2 = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2
; CHECK-NEXT: %3 = getelementptr inbounds i8*, i8** %2, i64 %iv
; CHECK-NEXT: store i8* %malloccall2, i8** %3, align 8, !invariant.group !7
; CHECK-NEXT: %4 = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5
; CHECK-NEXT: %5 = getelementptr inbounds i8*, i8** %4, i64 %iv
; CHECK-NEXT: store i8* %malloccall, i8** %5, align 8, !invariant.group !8
; CHECK-NEXT: %[[i2:.+]] = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2
; CHECK-NEXT: %[[i3:.+]] = getelementptr inbounds i8*, i8** %[[i2]], i64 %iv
; CHECK-NEXT: store i8* %malloccall2, i8** %[[i3]], align 8, !invariant.group !7
; CHECK-NEXT: %[[i4:.+]] = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5
; CHECK-NEXT: %[[i5:.+]] = getelementptr inbounds i8*, i8** %[[i4]], i64 %iv
; CHECK-NEXT: store i8* %malloccall, i8** %[[i5]], align 8, !invariant.group !8
; CHECK-NEXT: %cache.x = bitcast i8* %malloccall2 to double*
; CHECK-NEXT: call void @cblas_dcopy(i32 %1, double* %x0, i32 1, double* %cache.x, i32 1)
; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x, i32 1)
; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1)
; CHECK-NEXT: %exitcond.not = icmp eq i64 %iv.next, 5000
; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
Expand Down
20 changes: 8 additions & 12 deletions enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ entry:
; CHECK-NEXT: br i1 %7, label %__enzyme_memcpy_double_mat_32.exit, label %init.idx.i

; CHECK: __enzyme_memcpy_double_mat_32.exit: ; preds = %entry, %init.end.i
; CHECK-NEXT: %8 = select i1 true, i32 %N, i32 %N
; CHECK-NEXT: %mallocsize22 = mul nuw nsw i32 %8, 8
; CHECK-NEXT: %mallocsize22 = mul nuw nsw i32 %N, 8
; CHECK-NEXT: %malloccall23 = tail call noalias nonnull i8* @malloc(i32 %mallocsize22)
; CHECK-NEXT: %cache.x24 = bitcast i8* %malloccall23 to double*
; CHECK-NEXT: call void @cblas_dcopy(i32 %8, double* %x0, i32 1, double* %cache.x24, i32 1)
; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x24, i32 1)
; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1)
; CHECK-NEXT: %[[i11:.+]] = mul i32 %N, %N
; CHECK-NEXT: %mallocsize11 = mul nuw nsw i32 %[[i11]], 8
Expand Down Expand Up @@ -100,11 +99,10 @@ entry:
; CHECK-NEXT: br i1 %[[i18:.+]], label %__enzyme_memcpy_double_mat_32.exit38, label %init.idx.i29

; CHECK: __enzyme_memcpy_double_mat_32.exit38: ; preds = %__enzyme_memcpy_double_mat_32.exit, %init.end.i37
; CHECK-NEXT: %[[i19:.+]] = select i1 true, i32 %N, i32 %N
; CHECK-NEXT: %mallocsize14 = mul nuw nsw i32 %[[i19]], 8
; CHECK-NEXT: %mallocsize14 = mul nuw nsw i32 %N, 8
; CHECK-NEXT: %malloccall15 = tail call noalias nonnull i8* @malloc(i32 %mallocsize14)
; CHECK-NEXT: %cache.x16 = bitcast i8* %malloccall15 to double*
; CHECK-NEXT: call void @cblas_dcopy(i32 %[[i19]], double* %x0, i32 1, double* %cache.x16, i32 1)
; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x16, i32 1)
; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1)
; CHECK-NEXT: %[[i22:.+]] = mul i32 %N, %N
; CHECK-NEXT: %mallocsize3 = mul nuw nsw i32 %[[i22]], 8
Expand Down Expand Up @@ -138,11 +136,10 @@ entry:
; CHECK-NEXT: br i1 %[[i29]], label %__enzyme_memcpy_double_mat_32.exit50, label %init.idx.i41

; CHECK: __enzyme_memcpy_double_mat_32.exit50: ; preds = %__enzyme_memcpy_double_mat_32.exit38, %init.end.i49
; CHECK-NEXT: %[[i30:.+]] = select i1 true, i32 %N, i32 %N
; CHECK-NEXT: %mallocsize6 = mul nuw nsw i32 %[[i30]], 8
; CHECK-NEXT: %mallocsize6 = mul nuw nsw i32 %N, 8
; CHECK-NEXT: %malloccall7 = tail call noalias nonnull i8* @malloc(i32 %mallocsize6)
; CHECK-NEXT: %cache.x8 = bitcast i8* %malloccall7 to double*
; CHECK-NEXT: call void @cblas_dcopy(i32 %[[i30]], double* %x0, i32 1, double* %cache.x8, i32 1)
; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x8, i32 1)
; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1)
; CHECK-NEXT: %[[i33:.+]] = mul i32 %N, %N
; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %[[i33]], 8
Expand Down Expand Up @@ -176,11 +173,10 @@ entry:
; CHECK-NEXT: br i1 %[[i40]], label %__enzyme_memcpy_double_mat_32.exit62, label %init.idx.i53

; CHECK: __enzyme_memcpy_double_mat_32.exit62: ; preds = %__enzyme_memcpy_double_mat_32.exit50, %init.end.i61
; CHECK-NEXT: %[[i41:.+]] = select i1 true, i32 %N, i32 %N
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %[[i41]], 8
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %N, 8
; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1)
; CHECK-NEXT: %cache.x = bitcast i8* %malloccall2 to double*
; CHECK-NEXT: call void @cblas_dcopy(i32 %[[i41]], double* %x0, i32 1, double* %cache.x, i32 1)
; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x, i32 1)
; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1)
; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1)
; CHECK-NEXT: br label %invertentry
Expand Down
Loading

0 comments on commit d23b53a

Please sign in to comment.