================
@@ -219,200 +215,359 @@ class BoxedProcedurePass
inline mlir::ModuleOp getModule() { return getOperation(); }
void runOnOperation() override final {
- if (options.useThunks) {
+ if (useThunks) {
auto *context = &getContext();
mlir::IRRewriter rewriter(context);
BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
- getModule().walk([&](mlir::Operation *op) {
- bool opIsValid = true;
- typeConverter.setLocation(op->getLoc());
- if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
- mlir::Type ty = addr.getVal().getType();
- mlir::Type resTy = addr.getResult().getType();
- if (llvm::isa<mlir::FunctionType>(ty) ||
- llvm::isa<fir::BoxProcType>(ty)) {
- // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
- // or function type to be `fir.convert` ops.
- rewriter.setInsertionPoint(addr);
- rewriter.replaceOpWithNewOp<ConvertOp>(
- addr, typeConverter.convertType(addr.getType()),
addr.getVal());
- opIsValid = false;
- } else if (typeConverter.needsConversion(resTy)) {
- rewriter.startOpModification(op);
- op->getResult(0).setType(typeConverter.convertType(resTy));
- rewriter.finalizeOpModification(op);
- }
- } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
- mlir::FunctionType ty = func.getFunctionType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.startOpModification(func);
- auto toTy =
- mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
- if (!func.empty())
- for (auto e : llvm::enumerate(toTy.getInputs())) {
- unsigned i = e.index();
- auto &block = func.front();
- block.insertArgument(i, e.value(), func.getLoc());
- block.getArgument(i + 1).replaceAllUsesWith(
- block.getArgument(i));
- block.eraseArgument(i + 1);
- }
- func.setType(toTy);
- rewriter.finalizeOpModification(func);
+
+ // When using safe trampolines, we need to track handles per
+ // function so we can insert FreeTrampoline calls at each return.
+ // Process functions individually to manage this state.
+ if (useSafeTrampoline) {
+ getModule().walk([&](mlir::func::FuncOp funcOp) {
+ trampolineHandles.clear();
+ trampolineCallableMap.clear();
+ processFunction(funcOp, rewriter, typeConverter);
+ insertTrampolineFrees(funcOp, rewriter);
+ });
+ // Also process non-function ops at module level (globals, etc.)
+ processModuleLevelOps(rewriter, typeConverter);
+ } else {
+ getModule().walk([&](mlir::Operation *op) {
+ processOp(op, rewriter, typeConverter);
+ });
+ }
+ }
+ }
+
+private:
+ /// Trampoline handles collected while processing a function.
+ /// Each entry is a Value representing the opaque handle returned
+ /// by _FortranATrampolineInit, which must be freed before the
+ /// function returns.
+ llvm::SmallVector<mlir::Value> trampolineHandles;
+
+ /// Cache of trampoline callable addresses keyed by the func SSA value
+ /// of the emboxproc. This deduplicates trampolines when the same
+ /// internal procedure is emboxed multiple times in one host function.
+ llvm::DenseMap<mlir::Value, mlir::Value> trampolineCallableMap;
+
+ /// Process all ops within a function.
+ void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ funcOp.walk(
+ [&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
+ }
+
+ /// Process non-function ops at module level (globals, etc.)
+ void processModuleLevelOps(mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ for (auto &op : getModule().getBody()->getOperations())
+ if (!mlir::isa<mlir::func::FuncOp>(op))
+ processOp(&op, rewriter, typeConverter);
+ }
+
+ /// Insert _FortranATrampolineFree calls before every return in the function.
+ void insertTrampolineFrees(mlir::func::FuncOp funcOp,
+ mlir::IRRewriter &rewriter) {
+ if (trampolineHandles.empty())
+ return;
+
+ auto module{funcOp->getParentOfType<mlir::ModuleOp>()};
+ // Insert TrampolineFree calls before every func.return in this function.
+ // At this pass stage (after CFGConversion), func.return is the only
+ // terminator that exits the function. Other terminators are either
+ // intra-function branches (cf.br, cf.cond_br, fir.select*) or
+ // fir.unreachable (after STOP/ERROR STOP), which don't need cleanup
+ // since the process is terminating.
+ funcOp.walk([&](mlir::func::ReturnOp retOp) {
+ rewriter.setInsertionPoint(retOp);
+ FirOpBuilder builder(rewriter, module);
+ auto loc{retOp.getLoc()};
+ for (mlir::Value handle : trampolineHandles)
+ fir::runtime::genTrampolineFree(builder, loc, handle);
+ });
+ }
+
+ /// Process a single operation for boxproc type rewriting.
+ void processOp(mlir::Operation *op, mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ bool opIsValid{true};
+ typeConverter.setLocation(op->getLoc());
+ if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
+ mlir::Type ty{addr.getVal().getType()};
+ mlir::Type resTy{addr.getResult().getType()};
+ if (llvm::isa<mlir::FunctionType>(ty) ||
+ llvm::isa<fir::BoxProcType>(ty)) {
+ // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
+ // or function type to be `fir.convert` ops.
+ rewriter.setInsertionPoint(addr);
+ rewriter.replaceOpWithNewOp<ConvertOp>(
+ addr, typeConverter.convertType(addr.getType()), addr.getVal());
+ opIsValid = false;
+ } else if (typeConverter.needsConversion(resTy)) {
+ rewriter.startOpModification(op);
+ op->getResult(0).setType(typeConverter.convertType(resTy));
+ rewriter.finalizeOpModification(op);
+ }
+ } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
+ mlir::FunctionType ty{func.getFunctionType()};
+ if (typeConverter.needsConversion(ty)) {
+ rewriter.startOpModification(func);
+ auto toTy{
+ mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty))};
+ if (!func.empty())
+ for (auto e : llvm::enumerate(toTy.getInputs())) {
+ auto i{static_cast<unsigned>(e.index())};
+ auto &block{func.front()};
+ block.insertArgument(i, e.value(), func.getLoc());
+ block.getArgument(i + 1).replaceAllUsesWith(block.getArgument(i));
+ block.eraseArgument(i + 1);
}
- } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
- // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
- // as required.
- mlir::Type toTy = typeConverter.convertType(
- mlir::cast<BoxProcType>(embox.getType()).getEleTy());
- rewriter.setInsertionPoint(embox);
- if (embox.getHost()) {
- // Create the thunk.
- auto module = embox->getParentOfType<mlir::ModuleOp>();
- FirOpBuilder builder(rewriter, module);
- const auto triple{fir::getTargetTriple(module)};
- auto loc = embox.getLoc();
- mlir::Type i8Ty = builder.getI8Type();
- mlir::Type i8Ptr = builder.getRefType(i8Ty);
- // For PPC32 and PPC64, the thunk is populated by a call to
- // __trampoline_setup, which is defined in
- // compiler-rt/lib/builtins/trampoline_setup.c and requires the
- // thunk size greater than 32 bytes. For AArch64, RISCV and
x86_64,
- // the thunk setup doesn't go through __trampoline_setup and fits
in
- // 32 bytes.
- fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
- mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
- auto buffer = AllocaOp::create(builder, loc, buffTy);
- mlir::Value closure =
- builder.createConvert(loc, i8Ptr, embox.getHost());
- mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
- mlir::Value func =
- builder.createConvert(loc, i8Ptr, embox.getFunc());
- fir::CallOp::create(
- builder, loc, factory::getLlvmInitTrampoline(builder),
- llvm::ArrayRef<mlir::Value>{tramp, func, closure});
- auto adjustCall = fir::CallOp::create(
- builder, loc, factory::getLlvmAdjustTrampoline(builder),
- llvm::ArrayRef<mlir::Value>{tramp});
+ func.setType(toTy);
+ rewriter.finalizeOpModification(func);
+ }
+ } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
+ // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
+ // as required.
+ mlir::Type toTy{typeConverter.convertType(
+ mlir::cast<BoxProcType>(embox.getType()).getEleTy())};
+ rewriter.setInsertionPoint(embox);
+ if (embox.getHost()) {
+ auto module{embox->getParentOfType<mlir::ModuleOp>()};
+ auto loc{embox.getLoc()};
+
+ if (useSafeTrampoline) {
+ // Runtime trampoline pool path (W^X compliant).
+ // Insert Init/Adjust in the function's entry block so the
+ // handle dominates all func.return ops where TrampolineFree
+ // is emitted. This is necessary because fir.emboxproc may
+ // appear inside control flow branches. A cache avoids
+ // creating duplicate trampolines for the same internal
+ // procedure within a single host function.
+ mlir::Value funcVal{embox.getFunc()};
+ auto cacheIt{trampolineCallableMap.find(funcVal)};
+ if (cacheIt != trampolineCallableMap.end()) {
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
- adjustCall.getResult(0));
- opIsValid = false;
+ cacheIt->second);
} else {
- // Just forward the function as a pointer.
- rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
- embox.getFunc());
- opIsValid = false;
- }
- } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
- auto ty = global.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.startOpModification(global);
- auto toTy = typeConverter.convertType(ty);
- global.setType(toTy);
- rewriter.finalizeOpModification(global);
- }
- } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
- auto ty = mem.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.setInsertionPoint(mem);
- auto toTy = typeConverter.convertType(unwrapRefType(ty));
- bool isPinned = mem.getPinned();
- llvm::StringRef uniqName =
- mem.getUniqName().value_or(llvm::StringRef());
- llvm::StringRef bindcName =
- mem.getBindcName().value_or(llvm::StringRef());
- rewriter.replaceOpWithNewOp<AllocaOp>(
- mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
- mem.getShape());
- opIsValid = false;
- }
- } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
- auto ty = mem.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.setInsertionPoint(mem);
- auto toTy = typeConverter.convertType(unwrapRefType(ty));
- llvm::StringRef uniqName =
- mem.getUniqName().value_or(llvm::StringRef());
- llvm::StringRef bindcName =
- mem.getBindcName().value_or(llvm::StringRef());
- rewriter.replaceOpWithNewOp<AllocMemOp>(
- mem, toTy, uniqName, bindcName, mem.getTypeparams(),
- mem.getShape());
- opIsValid = false;
- }
- } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
- auto ty = coor.getType();
- mlir::Type baseTy = coor.getBaseType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(baseTy)) {
- rewriter.setInsertionPoint(coor);
- auto toTy = typeConverter.convertType(ty);
- auto toBaseTy = typeConverter.convertType(baseTy);
- rewriter.replaceOpWithNewOp<CoordinateOp>(
- coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy,
- coor.getFieldIndicesAttr());
- opIsValid = false;
- }
- } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
- auto ty = index.getType();
- mlir::Type onTy = index.getOnType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(onTy)) {
- rewriter.setInsertionPoint(index);
- auto toTy = typeConverter.convertType(ty);
- auto toOnTy = typeConverter.convertType(onTy);
- rewriter.replaceOpWithNewOp<FieldIndexOp>(
- index, toTy, index.getFieldId(), toOnTy,
index.getTypeparams());
- opIsValid = false;
- }
- } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
- auto ty = index.getType();
- mlir::Type onTy = index.getOnType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(onTy)) {
- rewriter.setInsertionPoint(index);
- auto toTy = typeConverter.convertType(ty);
- auto toOnTy = typeConverter.convertType(onTy);
- rewriter.replaceOpWithNewOp<LenParamIndexOp>(
- index, toTy, index.getFieldId(), toOnTy,
index.getTypeparams());
- opIsValid = false;
- }
- } else {
- rewriter.startOpModification(op);
- // Convert the operands if needed
- for (auto i : llvm::enumerate(op->getResultTypes()))
- if (typeConverter.needsConversion(i.value())) {
- auto toTy = typeConverter.convertType(i.value());
- op->getResult(i.index()).setType(toTy);
+ auto parentFunc{embox->getParentOfType<mlir::func::FuncOp>()};
+ auto &entryBlock{parentFunc.front()};
+
+ auto savedIP{rewriter.saveInsertionPoint()};
+
+ // Find the right insertion point in the entry block.
+ // Walk up from the emboxproc to find its top-level
+ // ancestor in the entry block. For an emboxproc directly
+ // in the entry block, this is the emboxproc itself.
+ // For one inside a structured op (fir.if, fir.do_loop),
+ // this is that structured op. For one inside an explicit
+ // branch target (cf.cond_br → ^bb1), we fall back to the
+ // entry block terminator.
+ mlir::Operation *entryAncestor{embox.getOperation()};
+ while (entryAncestor->getBlock() != &entryBlock) {
+ entryAncestor = entryAncestor->getParentOp();
+ if (!entryAncestor ||
+ mlir::isa<mlir::func::FuncOp>(entryAncestor))
+ break;
}
+ bool ancestorInEntry{
+ entryAncestor &&
+ !mlir::isa<mlir::func::FuncOp>(entryAncestor) &&
+ entryAncestor->getBlock() == &entryBlock};
- // Convert the type attributes if needed
- for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
- if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
- if (typeConverter.needsConversion(tyAttr.getValue())) {
- auto toTy = typeConverter.convertType(tyAttr.getValue());
- op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
+ // If the func value is not in the entry block (e.g.,
+ // address_of generated inside a structured fir.if),
+ // clone it into the entry block.
+ mlir::Value funcValInEntry{funcVal};
+ if (auto *funcDef{funcVal.getDefiningOp()}) {
+ if (funcDef->getBlock() != &entryBlock) {
+ if (ancestorInEntry)
+ rewriter.setInsertionPoint(entryAncestor);
+ else
+ rewriter.setInsertionPoint(entryBlock.getTerminator());
+ auto *cloned{rewriter.clone(*funcDef)};
+ funcValInEntry = cloned->getResult(0);
}
- rewriter.finalizeOpModification(op);
+ }
+
+ // Similarly clone the host value if not in entry block.
+ mlir::Value hostValInEntry{embox.getHost()};
+ if (auto *hostDef{embox.getHost().getDefiningOp()}) {
+ if (hostDef->getBlock() != &entryBlock) {
+ if (ancestorInEntry)
+ rewriter.setInsertionPoint(entryAncestor);
+ else
+ rewriter.setInsertionPoint(entryBlock.getTerminator());
+ auto *cloned{rewriter.clone(*hostDef)};
+ hostValInEntry = cloned->getResult(0);
+ }
----------------
Saieiei wrote:
I replaced the host-link “clone to entry” logic with a hard error. If the
host-link defining op isn’t in the entry block, we now `mlir::emitError` with a
clear explanation: cloning only that defining op would skip its initializing
stores and could silently produce wrong code. In practice this should never
trigger, since the host link is expected to be either a function entry block
argument (`!fir.ref<tuple<...>>`) or an `fir.alloca` created at function entry,
so the invariant holds for valid IR. I kept the existing `funcDef` cloning
behavior unchanged, because `fir.address_of` can legitimately appear under
structured ops and cloning it is safe (no side-effecting stores to miss).
https://github.com/llvm/llvm-project/compare/448886e8394650220d7e1dcf6fb984205d043fcc..a866872e85a28dc20bacdd1d071c96564489a54c
https://github.com/llvm/llvm-project/pull/183108
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits