================
@@ -219,200 +215,359 @@ class BoxedProcedurePass
   inline mlir::ModuleOp getModule() { return getOperation(); }
 
   void runOnOperation() override final {
-    if (options.useThunks) {
+    if (useThunks) {
       auto *context = &getContext();
       mlir::IRRewriter rewriter(context);
       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
-      getModule().walk([&](mlir::Operation *op) {
-        bool opIsValid = true;
-        typeConverter.setLocation(op->getLoc());
-        if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
-          mlir::Type ty = addr.getVal().getType();
-          mlir::Type resTy = addr.getResult().getType();
-          if (llvm::isa<mlir::FunctionType>(ty) ||
-              llvm::isa<fir::BoxProcType>(ty)) {
-            // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
-            // or function type to be `fir.convert` ops.
-            rewriter.setInsertionPoint(addr);
-            rewriter.replaceOpWithNewOp<ConvertOp>(
-                addr, typeConverter.convertType(addr.getType()), 
addr.getVal());
-            opIsValid = false;
-          } else if (typeConverter.needsConversion(resTy)) {
-            rewriter.startOpModification(op);
-            op->getResult(0).setType(typeConverter.convertType(resTy));
-            rewriter.finalizeOpModification(op);
-          }
-        } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
-          mlir::FunctionType ty = func.getFunctionType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(func);
-            auto toTy =
-                mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
-            if (!func.empty())
-              for (auto e : llvm::enumerate(toTy.getInputs())) {
-                unsigned i = e.index();
-                auto &block = func.front();
-                block.insertArgument(i, e.value(), func.getLoc());
-                block.getArgument(i + 1).replaceAllUsesWith(
-                    block.getArgument(i));
-                block.eraseArgument(i + 1);
-              }
-            func.setType(toTy);
-            rewriter.finalizeOpModification(func);
+
+      // When using safe trampolines, we need to track handles per
+      // function so we can insert FreeTrampoline calls at each return.
+      // Process functions individually to manage this state.
+      if (useSafeTrampoline) {
+        getModule().walk([&](mlir::func::FuncOp funcOp) {
+          trampolineHandles.clear();
+          trampolineCallableMap.clear();
+          processFunction(funcOp, rewriter, typeConverter);
+          insertTrampolineFrees(funcOp, rewriter);
+        });
+        // Also process non-function ops at module level (globals, etc.)
+        processModuleLevelOps(rewriter, typeConverter);
+      } else {
+        getModule().walk([&](mlir::Operation *op) {
+          processOp(op, rewriter, typeConverter);
+        });
+      }
+    }
+  }
+
+private:
+  /// Trampoline handles collected while processing a function.
+  /// Each entry is a Value representing the opaque handle returned
+  /// by _FortranATrampolineInit, which must be freed before the
+  /// function returns.
+  llvm::SmallVector<mlir::Value> trampolineHandles;
+
+  /// Cache of trampoline callable addresses keyed by the func SSA value
+  /// of the emboxproc. This deduplicates trampolines when the same
+  /// internal procedure is emboxed multiple times in one host function.
+  llvm::DenseMap<mlir::Value, mlir::Value> trampolineCallableMap;
+
+  /// Process all ops within a function.
+  void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
+                       BoxprocTypeRewriter &typeConverter) {
+    funcOp.walk(
+        [&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
+  }
+
+  /// Process non-function ops at module level (globals, etc.)
+  void processModuleLevelOps(mlir::IRRewriter &rewriter,
+                             BoxprocTypeRewriter &typeConverter) {
+    for (auto &op : getModule().getBody()->getOperations())
+      if (!mlir::isa<mlir::func::FuncOp>(op))
+        processOp(&op, rewriter, typeConverter);
+  }
+
+  /// Insert _FortranATrampolineFree calls before every return in the function.
+  void insertTrampolineFrees(mlir::func::FuncOp funcOp,
+                             mlir::IRRewriter &rewriter) {
+    if (trampolineHandles.empty())
+      return;
+
+    auto module{funcOp->getParentOfType<mlir::ModuleOp>()};
+    // Insert TrampolineFree calls before every func.return in this function.
+    // At this pass stage (after CFGConversion), func.return is the only
+    // terminator that exits the function. Other terminators are either
+    // intra-function branches (cf.br, cf.cond_br, fir.select*) or
+    // fir.unreachable (after STOP/ERROR STOP), which don't need cleanup
+    // since the process is terminating.
+    funcOp.walk([&](mlir::func::ReturnOp retOp) {
+      rewriter.setInsertionPoint(retOp);
+      FirOpBuilder builder(rewriter, module);
+      auto loc{retOp.getLoc()};
+      for (mlir::Value handle : trampolineHandles)
+        fir::runtime::genTrampolineFree(builder, loc, handle);
+    });
+  }
+
+  /// Process a single operation for boxproc type rewriting.
+  void processOp(mlir::Operation *op, mlir::IRRewriter &rewriter,
+                 BoxprocTypeRewriter &typeConverter) {
+    bool opIsValid{true};
+    typeConverter.setLocation(op->getLoc());
+    if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
+      mlir::Type ty{addr.getVal().getType()};
+      mlir::Type resTy{addr.getResult().getType()};
+      if (llvm::isa<mlir::FunctionType>(ty) ||
+          llvm::isa<fir::BoxProcType>(ty)) {
+        // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
+        // or function type to be `fir.convert` ops.
+        rewriter.setInsertionPoint(addr);
+        rewriter.replaceOpWithNewOp<ConvertOp>(
+            addr, typeConverter.convertType(addr.getType()), addr.getVal());
+        opIsValid = false;
+      } else if (typeConverter.needsConversion(resTy)) {
+        rewriter.startOpModification(op);
+        op->getResult(0).setType(typeConverter.convertType(resTy));
+        rewriter.finalizeOpModification(op);
+      }
+    } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
+      mlir::FunctionType ty{func.getFunctionType()};
+      if (typeConverter.needsConversion(ty)) {
+        rewriter.startOpModification(func);
+        auto toTy{
+            mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty))};
+        if (!func.empty())
+          for (auto e : llvm::enumerate(toTy.getInputs())) {
+            auto i{static_cast<unsigned>(e.index())};
+            auto &block{func.front()};
+            block.insertArgument(i, e.value(), func.getLoc());
+            block.getArgument(i + 1).replaceAllUsesWith(block.getArgument(i));
+            block.eraseArgument(i + 1);
           }
-        } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
-          // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
-          // as required.
-          mlir::Type toTy = typeConverter.convertType(
-              mlir::cast<BoxProcType>(embox.getType()).getEleTy());
-          rewriter.setInsertionPoint(embox);
-          if (embox.getHost()) {
-            // Create the thunk.
-            auto module = embox->getParentOfType<mlir::ModuleOp>();
-            FirOpBuilder builder(rewriter, module);
-            const auto triple{fir::getTargetTriple(module)};
-            auto loc = embox.getLoc();
-            mlir::Type i8Ty = builder.getI8Type();
-            mlir::Type i8Ptr = builder.getRefType(i8Ty);
-            // For PPC32 and PPC64, the thunk is populated by a call to
-            // __trampoline_setup, which is defined in
-            // compiler-rt/lib/builtins/trampoline_setup.c and requires the
-            // thunk size greater than 32 bytes.  For AArch64, RISCV and 
x86_64,
-            // the thunk setup doesn't go through __trampoline_setup and fits 
in
-            // 32 bytes.
-            fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
-            mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
-            auto buffer = AllocaOp::create(builder, loc, buffTy);
-            mlir::Value closure =
-                builder.createConvert(loc, i8Ptr, embox.getHost());
-            mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
-            mlir::Value func =
-                builder.createConvert(loc, i8Ptr, embox.getFunc());
-            fir::CallOp::create(
-                builder, loc, factory::getLlvmInitTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp, func, closure});
-            auto adjustCall = fir::CallOp::create(
-                builder, loc, factory::getLlvmAdjustTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp});
+        func.setType(toTy);
+        rewriter.finalizeOpModification(func);
+      }
+    } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
+      // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
+      // as required.
+      mlir::Type toTy{typeConverter.convertType(
+          mlir::cast<BoxProcType>(embox.getType()).getEleTy())};
+      rewriter.setInsertionPoint(embox);
+      if (embox.getHost()) {
+        auto module{embox->getParentOfType<mlir::ModuleOp>()};
+        auto loc{embox.getLoc()};
+
+        if (useSafeTrampoline) {
+          // Runtime trampoline pool path (W^X compliant).
+          // Insert Init/Adjust in the function's entry block so the
+          // handle dominates all func.return ops where TrampolineFree
+          // is emitted. This is necessary because fir.emboxproc may
+          // appear inside control flow branches. A cache avoids
+          // creating duplicate trampolines for the same internal
+          // procedure within a single host function.
+          mlir::Value funcVal{embox.getFunc()};
+          auto cacheIt{trampolineCallableMap.find(funcVal)};
+          if (cacheIt != trampolineCallableMap.end()) {
             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   adjustCall.getResult(0));
-            opIsValid = false;
+                                                   cacheIt->second);
           } else {
-            // Just forward the function as a pointer.
-            rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   embox.getFunc());
-            opIsValid = false;
-          }
-        } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
-          auto ty = global.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(global);
-            auto toTy = typeConverter.convertType(ty);
-            global.setType(toTy);
-            rewriter.finalizeOpModification(global);
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            bool isPinned = mem.getPinned();
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocaOp>(
-                mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocMemOp>(
-                mem, toTy, uniqName, bindcName, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
-          auto ty = coor.getType();
-          mlir::Type baseTy = coor.getBaseType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(baseTy)) {
-            rewriter.setInsertionPoint(coor);
-            auto toTy = typeConverter.convertType(ty);
-            auto toBaseTy = typeConverter.convertType(baseTy);
-            rewriter.replaceOpWithNewOp<CoordinateOp>(
-                coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy,
-                coor.getFieldIndicesAttr());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<FieldIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<LenParamIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
-          }
-        } else {
-          rewriter.startOpModification(op);
-          // Convert the operands if needed
-          for (auto i : llvm::enumerate(op->getResultTypes()))
-            if (typeConverter.needsConversion(i.value())) {
-              auto toTy = typeConverter.convertType(i.value());
-              op->getResult(i.index()).setType(toTy);
+            auto parentFunc{embox->getParentOfType<mlir::func::FuncOp>()};
+            auto &entryBlock{parentFunc.front()};
+
+            auto savedIP{rewriter.saveInsertionPoint()};
+
+            // Find the right insertion point in the entry block.
+            // Walk up from the emboxproc to find its top-level
+            // ancestor in the entry block. For an emboxproc directly
+            // in the entry block, this is the emboxproc itself.
+            // For one inside a structured op (fir.if, fir.do_loop),
+            // this is that structured op. For one inside an explicit
+            // branch target (cf.cond_br → ^bb1), we fall back to the
+            // entry block terminator.
+            mlir::Operation *entryAncestor{embox.getOperation()};
+            while (entryAncestor->getBlock() != &entryBlock) {
+              entryAncestor = entryAncestor->getParentOp();
+              if (!entryAncestor ||
+                  mlir::isa<mlir::func::FuncOp>(entryAncestor))
+                break;
             }
+            bool ancestorInEntry{
+                entryAncestor &&
+                !mlir::isa<mlir::func::FuncOp>(entryAncestor) &&
+                entryAncestor->getBlock() == &entryBlock};
 
-          // Convert the type attributes if needed
-          for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
-            if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
-              if (typeConverter.needsConversion(tyAttr.getValue())) {
-                auto toTy = typeConverter.convertType(tyAttr.getValue());
-                op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
+            // If the func value is not in the entry block (e.g.,
+            // address_of generated inside a structured fir.if),
+            // clone it into the entry block.
+            mlir::Value funcValInEntry{funcVal};
+            if (auto *funcDef{funcVal.getDefiningOp()}) {
+              if (funcDef->getBlock() != &entryBlock) {
+                if (ancestorInEntry)
+                  rewriter.setInsertionPoint(entryAncestor);
+                else
+                  rewriter.setInsertionPoint(entryBlock.getTerminator());
+                auto *cloned{rewriter.clone(*funcDef)};
+                funcValInEntry = cloned->getResult(0);
               }
-          rewriter.finalizeOpModification(op);
+            }
+
+            // Similarly clone the host value if not in entry block.
+            mlir::Value hostValInEntry{embox.getHost()};
+            if (auto *hostDef{embox.getHost().getDefiningOp()}) {
+              if (hostDef->getBlock() != &entryBlock) {
+                if (ancestorInEntry)
+                  rewriter.setInsertionPoint(entryAncestor);
+                else
+                  rewriter.setInsertionPoint(entryBlock.getTerminator());
+                auto *cloned{rewriter.clone(*hostDef)};
+                hostValInEntry = cloned->getResult(0);
+              }
----------------
jeanPerier wrote:

Thanks, this looks good to me. Thinking more about it, the invariant will be 
broken if the MLIR inliner is ran before this and a host procedure is inlined 
into some other function body, potentially not in the entry block.

I do not like the idea of trying to clone the IR for the host link creation to 
solve this. It may not be possible if it depends on some calls anyway. I think 
the best solution to this is probably to add an operation in lowering to deal 
with the trampoline cleanups when needed (with some mechanism to connect the 
fir.embox_proc to it).
Since MLIR inlining is experimental in flang, I think this is OK to proceed 
with your patch, but we should be aware that the current approach risk hitting 
a compilation error when combined with the `-mllvm -inline-all` developer 
option.

https://github.com/llvm/llvm-project/pull/183108
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to