Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Support memory cleanup at a return statement #116304

Merged
merged 3 commits into from
Nov 15, 2024

Conversation

khaki3
Copy link
Contributor

@khaki3 khaki3 commented Nov 15, 2024

We generate cuf.free and func.return twice if a return statement exists at the end of program.

program test
  integer, device :: a(10)
  return
end
% flang -x cuda test.cuf -mmlir --mlir-print-ir-after-all
error: loc("/path/to/test.cuf":3:3): 'func.return' op must be the last operation in the parent block
// -----// IR Dump After Fortran::lower::VerifierPass Failed () //----- //

Dumped IR:

  "func.func"() <{function_type = () -> (), sym_name = "_QQmain"}> ({
...
    "cuf.free"(%5#1) <{data_attr = #cuf.cuda<device>}> : (!fir.ref<!fir.array<10xi32>>) -> ()
    "func.return"() : () -> ()
    "cuf.free"(%5#1) <{data_attr = #cuf.cuda<device>}> : (!fir.ref<!fir.array<10xi32>>) -> ()
    "func.return"() : () -> ()
}
...

The routine genExitRoutine in Bridge.cpp is guarded by blockIsUnterminated() to make sure that func.return is generated only at the end of a block. However, we redundantly run bridge.fctCtx().finalizeAndKeep() before genExitRoutine in this case, resulting in two pairs of cuf.free and func.return. This PR fixes Bridge.cpp by using blockIsUnterminated() to guard finalizeAndKeep as well.

@khaki3 khaki3 changed the title [flang][cuf] Support memory finalization at a final return statement [flang][cuf] Support memory cleanup at a return statement Nov 15, 2024
@khaki3 khaki3 changed the title [flang][cuf] Support memory cleanup at a return statement [flang][cuda] Support memory cleanup at a return statement Nov 15, 2024
Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to work on this Matsu! This look ok to me. Just adding Jean and Val to have there feedback. There might be a case we forgot.

@@ -79,23 +79,26 @@ class StatementContext {
}
}

/// Make cleanup calls. Retain the stack top list for a repeat call.
/// Make a cleanup call. Retain the stack top list for a repeat call.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can keep the original comment since we generate 0 to N cleanup calls.

@clementval
Copy link
Contributor

clementval commented Nov 15, 2024

Can you add a test with multiple return like:

program test
  integer, device :: a(10)
  logical :: l

  if (l) then
    return
  end if

  return
end

and maybe also a subroutine test

subroutine test(l)
  integer, device :: a(10)
  logical :: l

  if (l) then
    l = .false.
    return
  end if

  return
end

Thanks

Copy link
Contributor

@vdonaldson vdonaldson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an improvement to me as is.

I'm not familiar with OpenACC code. But for the record, if OpenACC regions are directives in the PFT, it might be possible to further simplify this code by subsuming the openAccCtx into the activeConstructStack, and/or possibly even integrating those with the fctCtx. That would be a little more ambitious though.

@khaki3 khaki3 merged commit ff7fca7 into llvm:main Nov 15, 2024
7 of 8 checks passed
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 16, 2024
@llvmbot
Copy link

llvmbot commented Nov 16, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (khaki3)

Changes

We generate cuf.free and func.return twice if a return statement exists at the end of program.

program test
  integer, device :: a(10)
  return
end
% flang -x cuda test.cuf -mmlir --mlir-print-ir-after-all
error: loc("/path/to/test.cuf":3:3): 'func.return' op must be the last operation in the parent block
// -----// IR Dump After Fortran::lower::VerifierPass Failed () //----- //

Dumped IR:

  "func.func"() &lt;{function_type = () -&gt; (), sym_name = "_QQmain"}&gt; ({
...
    "cuf.free"(%5#<!-- -->1) &lt;{data_attr = #cuf.cuda&lt;device&gt;}&gt; : (!fir.ref&lt;!fir.array&lt;10xi32&gt;&gt;) -&gt; ()
    "func.return"() : () -&gt; ()
    "cuf.free"(%5#<!-- -->1) &lt;{data_attr = #cuf.cuda&lt;device&gt;}&gt; : (!fir.ref&lt;!fir.array&lt;10xi32&gt;&gt;) -&gt; ()
    "func.return"() : () -&gt; ()
}
...

The routine genExitRoutine in Bridge.cpp is guarded by blockIsUnterminated() to make sure that func.return is generated only at the end of a block. However, we redundantly run bridge.fctCtx().finalizeAndKeep() before genExitRoutine in this case, resulting in two pairs of cuf.free and func.return. This PR fixes Bridge.cpp by using blockIsUnterminated() to guard finalizeAndKeep as well.


Full diff: https://github.com/llvm/llvm-project/pull/116304.diff

4 Files Affected:

  • (modified) flang/include/flang/Lower/StatementContext.h (+4-1)
  • (modified) flang/lib/Lower/Bridge.cpp (+17-20)
  • (added) flang/test/Lower/CUDA/cuda-return01.cuf (+14)
  • (added) flang/test/Lower/CUDA/cuda-return02.cuf (+48)
diff --git a/flang/include/flang/Lower/StatementContext.h b/flang/include/flang/Lower/StatementContext.h
index 7776edc93ed737..eef21d4bae5aab 100644
--- a/flang/include/flang/Lower/StatementContext.h
+++ b/flang/include/flang/Lower/StatementContext.h
@@ -92,10 +92,13 @@ class StatementContext {
     cufs.back().reset();
   }
 
+  /// Pop the stack top list.
+  void pop() { cufs.pop_back(); }
+
   /// Make cleanup calls. Pop the stack top list.
   void finalizeAndPop() {
     finalizeAndKeep();
-    cufs.pop_back();
+    pop();
   }
 
   bool hasCode() const {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index da53edf7e734b0..7f41742bf5e8b2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1621,13 +1621,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   // Termination of symbolically referenced execution units
   //===--------------------------------------------------------------------===//
 
-  /// END of program
+  /// Exit of a routine
   ///
-  /// Generate the cleanup block before the program exits
-  void genExitRoutine() {
-
-    if (blockIsUnterminated())
-      builder->create<mlir::func::ReturnOp>(toLocation());
+  /// Generate the cleanup block before the routine exits
+  void genExitRoutine(bool earlyReturn, mlir::ValueRange retval = {}) {
+    if (blockIsUnterminated()) {
+      bridge.openAccCtx().finalizeAndKeep();
+      bridge.fctCtx().finalizeAndKeep();
+      builder->create<mlir::func::ReturnOp>(toLocation(), retval);
+    }
+    if (!earlyReturn) {
+      bridge.openAccCtx().pop();
+      bridge.fctCtx().pop();
+    }
   }
 
   /// END of procedure-like constructs
@@ -1684,9 +1690,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
             resultRef = builder->createConvert(loc, resultRefType, resultRef);
           return builder->create<fir::LoadOp>(loc, resultRef);
         });
-    bridge.openAccCtx().finalizeAndPop();
-    bridge.fctCtx().finalizeAndPop();
-    builder->create<mlir::func::ReturnOp>(loc, resultVal);
+    genExitRoutine(false, resultVal);
   }
 
   /// Get the return value of a call to \p symbol, which is a subroutine entry
@@ -1712,13 +1716,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     } else if (Fortran::semantics::HasAlternateReturns(symbol)) {
       mlir::Value retval = builder->create<fir::LoadOp>(
           toLocation(), getAltReturnResult(symbol));
-      bridge.openAccCtx().finalizeAndPop();
-      bridge.fctCtx().finalizeAndPop();
-      builder->create<mlir::func::ReturnOp>(toLocation(), retval);
+      genExitRoutine(false, retval);
     } else {
-      bridge.openAccCtx().finalizeAndPop();
-      bridge.fctCtx().finalizeAndPop();
-      genExitRoutine();
+      genExitRoutine(false);
     }
   }
 
@@ -5018,8 +5018,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       it->stmtCtx.finalizeAndKeep();
     }
     if (funit->isMainProgram()) {
-      bridge.fctCtx().finalizeAndKeep();
-      genExitRoutine();
+      genExitRoutine(true);
       return;
     }
     mlir::Location loc = toLocation();
@@ -5478,9 +5477,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void endNewFunction(Fortran::lower::pft::FunctionLikeUnit &funit) {
     setCurrentPosition(Fortran::lower::pft::stmtSourceLoc(funit.endStmt));
     if (funit.isMainProgram()) {
-      bridge.openAccCtx().finalizeAndPop();
-      bridge.fctCtx().finalizeAndPop();
-      genExitRoutine();
+      genExitRoutine(false);
     } else {
       genFIRProcedureExit(funit, funit.getSubprogramSymbol());
     }
diff --git a/flang/test/Lower/CUDA/cuda-return01.cuf b/flang/test/Lower/CUDA/cuda-return01.cuf
new file mode 100644
index 00000000000000..c9f9a8b57ef041
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-return01.cuf
@@ -0,0 +1,14 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Check if finalization works with a return statement
+
+program main
+  integer, device :: a(10)
+  return
+end
+
+! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "main"} {
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare
+! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
+! CHECK-NEXT: return
+! CHECK-NEXT: }
diff --git a/flang/test/Lower/CUDA/cuda-return02.cuf b/flang/test/Lower/CUDA/cuda-return02.cuf
new file mode 100644
index 00000000000000..5d01f0a24b420b
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-return02.cuf
@@ -0,0 +1,48 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Check if finalization works with multiple return statements
+
+program test
+  integer, device :: a(10)
+  logical :: l
+
+  if (l) then
+    return
+  end if
+
+  return
+end
+
+! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "test"} {
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare
+! CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
+! CHECK-NEXT: ^bb1:
+! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
+! CHECK-NEXT: return
+! CHECK-NEXT: ^bb2:
+! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
+! CHECK-NEXT: return
+! CHECK-NEXT: }
+
+subroutine sub(l)
+  integer, device :: a(10)
+  logical :: l
+
+  if (l) then
+    l = .false.
+    return
+  end if
+
+  return
+end
+
+! CHECK: func.func @_QPsub(%arg0: !fir.ref<!fir.logical<4>> {fir.bindc_name = "l"}) {
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare
+! CHECK: cf.cond_br %6, ^bb1, ^bb2
+! CHECK: ^bb1:
+! CHECK: cf.br ^bb3
+! CHECK: ^bb2:
+! CHECK: cf.br ^bb3
+! CHECK: ^bb3:
+! CHECK: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
+! CHECK: }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants