From b006eb21388dc1f0c13e7f4584847805767f95aa Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 31 Aug 2024 17:30:32 -0500 Subject: [PATCH] Fix forward mode ldg (#2065) * Fix forward mode ldg * fix --- enzyme/Enzyme/GradientUtils.cpp | 2 +- enzyme/test/Enzyme/ForwardMode/nvvm_ldg.ll | 29 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/nvvm_ldg.ll diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index b76b0557292..f652f0e295f 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -6078,7 +6078,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, return applyChainRule( II->getType(), bb, [&](Value *ptr) { - Value *args[] = {ptr}; + Value *args[] = {ptr, getNewFromOriginal(II->getArgOperand(1))}; auto li = bb.CreateCall(II->getCalledFunction(), args); llvm::SmallVector ToCopy2(MD_ToCopy); ToCopy2.push_back(LLVMContext::MD_noalias); diff --git a/enzyme/test/Enzyme/ForwardMode/nvvm_ldg.ll b/enzyme/test/Enzyme/ForwardMode/nvvm_ldg.ll new file mode 100644 index 00000000000..2d97c294a1c --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/nvvm_ldg.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-ni:10:11:12:13" +target triple = "nvptx64-nvidia-cuda" + +declare float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* nocapture, i32) + +define float @vmul(float addrspace(1)* %inp) { +top: + %ld = call float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* %inp, i32 4) + ret float %ld +} + + +define float @test_derivative(float addrspace(1)* %inp, float addrspace(1)* %dinp) { +entry: + %0 = tail call float (float (float addrspace(1)*)*, ...) @__enzyme_fwddiff(float (float addrspace(1)*)* nonnull @vmul, float addrspace(1)* %inp, float addrspace(1)* %dinp) + ret float %0 +} + +; Function Attrs: nounwind +declare float @__enzyme_fwddiff(float (float addrspace(1)*)*, ...) + +; CHECK: define internal float @fwddiffevmul(float addrspace(1)* %inp, float addrspace(1)* %"inp'") +; CHECK-NEXT: top: +; CHECK-NEXT: %[[res:.+]] = call{{( fast)?}} float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* %"inp'", i32 4) +; CHECK-NEXT: ret float %[[res]] +; CHECK-NEXT: } \ No newline at end of file