Author: adams381
Date: 2026-04-14T14:00:59-05:00
New Revision: a36e9d1d57b12de3674689a617ab7452ed43d9a2

URL: 
https://github.com/llvm/llvm-project/commit/a36e9d1d57b12de3674689a617ab7452ed43d9a2
DIFF: 
https://github.com/llvm/llvm-project/commit/a36e9d1d57b12de3674689a617ab7452ed43d9a2.diff

LOG: [CIR] Add musttail thunks and covariant return null-check (#191255)

Implement variadic thunk emission via musttail and null-check
pointer returns in covariant thunk adjustment, matching classic
codegen behavior.

Adds musttail UnitAttr to cir.call/cir.try_call with lowering
to LLVM::MustTail.

Made with [Cursor](https://cursor.com)

Added: 
    

Modified: 
    clang/include/clang/CIR/Dialect/IR/CIRDialect.td
    clang/include/clang/CIR/Dialect/IR/CIROps.td
    clang/lib/CIR/CodeGen/CIRGenVTables.cpp
    clang/lib/CIR/Dialect/IR/CIRDialect.cpp
    clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
    clang/test/CIR/CodeGen/thunks.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td 
b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index b57f874c34393..5b808ea92f470 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -77,6 +77,7 @@ def CIR_Dialect : Dialect {
     static llvm::StringRef getArgAttrsAttrName() { return "arg_attrs"; }
     static llvm::StringRef getRecordLayoutsAttrName() { return 
"cir.record_layouts"; }
     static llvm::StringRef getCUDABinaryHandleAttrName() { return 
"cir.cu.binary_handle"; }
+    static llvm::StringRef getMustTailAttrName() { return "musttail"; }
 
     static llvm::StringRef getAMDGPUCodeObjectVersionAttrName() { return 
"cir.amdhsa_code_object_version"; }
     static llvm::StringRef getAMDGPUPrintfKindAttrName() { return 
"cir.amdgpu_printf_kind"; }

diff  --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index e10b9102dae78..6f8db65acccc9 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -3921,6 +3921,7 @@ class CIR_CallOpBase<string mnemonic, list<Trait> 
extra_traits = []>
   dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
       Variadic<CIR_AnyType>:$args,
       UnitAttr:$nothrow,
+      UnitAttr:$musttail,
       DefaultValuedAttr<CIR_SideEffect, "SideEffect::All">:$side_effect,
       OptionalAttr<DictArrayAttr>:$arg_attrs,
       OptionalAttr<DictArrayAttr>:$res_attrs

diff  --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp 
b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
index 56839ca03dbb1..6e1a80926f679 100644
--- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
@@ -559,26 +559,41 @@ uint64_t 
CIRGenVTables::getSecondaryVirtualPointerIndex(const CXXRecordDecl *rd,
 
 static RValue performReturnAdjustment(CIRGenFunction &cgf, QualType resultType,
                                       RValue rv, const ThunkInfo &thunk) {
-  // Emit the return adjustment.
+  // Emit the return adjustment.  For non-reference pointer returns, match
+  // classic codegen: skip the adjustment when the returned pointer is null.
   bool nullCheckValue = !resultType->isReferenceType();
-
   mlir::Value returnValue = rv.getValue();
 
-  if (nullCheckValue)
-    cgf.cgm.errorNYI(
-        "return adjustment with null check for non-reference types");
-
   const CXXRecordDecl *classDecl =
       resultType->getPointeeType()->getAsCXXRecordDecl();
   CharUnits classAlign = cgf.cgm.getClassPointerAlignment(classDecl);
   mlir::Type pointeeType = cgf.convertTypeForMem(resultType->getPointeeType());
-  returnValue = cgf.cgm.getCXXABI().performReturnAdjustment(
-      cgf, Address(returnValue, pointeeType, classAlign), classDecl,
-      thunk.Return);
+  CIRGenBuilderTy &builder = cgf.getBuilder();
+  mlir::Location loc = returnValue.getLoc();
+
+  if (!nullCheckValue) {
+    returnValue = cgf.cgm.getCXXABI().performReturnAdjustment(
+        cgf, Address(returnValue, pointeeType, classAlign), classDecl,
+        thunk.Return);
+    return RValue::get(returnValue);
+  }
 
-  if (nullCheckValue)
-    cgf.cgm.errorNYI(
-        "return adjustment with null check for non-reference types");
+  mlir::Value isNotNull = builder.createPtrIsNotNull(returnValue);
+  returnValue =
+      cir::TernaryOp::create(
+          builder, loc, isNotNull,
+          [&](mlir::OpBuilder &, mlir::Location) {
+            mlir::Value adjusted = cgf.cgm.getCXXABI().performReturnAdjustment(
+                cgf, Address(returnValue, pointeeType, classAlign), classDecl,
+                thunk.Return);
+            builder.createYield(loc, adjusted);
+          },
+          [&](mlir::OpBuilder &, mlir::Location) {
+            mlir::Value nullVal =
+                builder.getNullPtr(returnValue.getType(), loc).getResult();
+            builder.createYield(loc, nullVal);
+          })
+          .getResult();
 
   return RValue::get(returnValue);
 }
@@ -743,8 +758,33 @@ void CIRGenFunction::emitCallAndReturnForThunk(cir::FuncOp 
callee,
 void CIRGenFunction::emitMustTailThunk(GlobalDecl gd,
                                        mlir::Value adjustedThisPtr,
                                        cir::FuncOp callee) {
-  assert(!cir::MissingFeatures::opCallMustTail());
-  cgm.errorNYI("musttail thunk");
+  // Forward all function arguments, replacing 'this' with the adjusted 
pointer.
+  // The call is marked musttail so varargs are forwarded correctly.
+  mlir::Block *entryBlock = getCurFunctionEntryBlock();
+  SmallVector<mlir::Value> args;
+  for (mlir::BlockArgument arg : entryBlock->getArguments())
+    args.push_back(arg);
+
+  // Replace the 'this' argument (first arg) with the adjusted pointer.
+  assert(!args.empty() && "thunk must have at least 'this' argument");
+  if (adjustedThisPtr.getType() != args[0].getType())
+    adjustedThisPtr = builder.createBitcast(adjustedThisPtr, 
args[0].getType());
+  args[0] = adjustedThisPtr;
+
+  mlir::Location loc = curFn->getLoc();
+  cir::FuncType calleeTy = callee.getFunctionType();
+  mlir::Type retTy = calleeTy.getReturnType();
+
+  cir::CallOp call = builder.createCallOp(loc, callee, args);
+  call->setAttr(cir::CIRDialect::getMustTailAttrName(),
+                mlir::UnitAttr::get(builder.getContext()));
+
+  if (isa<cir::VoidType>(retTy))
+    cir::ReturnOp::create(builder, loc);
+  else
+    cir::ReturnOp::create(builder, loc, call->getResult(0));
+
+  finishThunk();
 }
 
 void CIRGenFunction::generateThunk(cir::FuncOp fn,

diff  --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index d685a14ef263a..4514b04780746 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -918,6 +918,10 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser 
&parser,
     return ::mlir::failure();
   }
 
+  if (parser.parseOptionalKeyword("musttail").succeeded())
+    result.addAttribute(CIRDialect::getMustTailAttrName(),
+                        mlir::UnitAttr::get(parser.getContext()));
+
   if (parser.parseOptionalKeyword("nothrow").succeeded())
     result.addAttribute(CIRDialect::getNoThrowAttrName(),
                         mlir::UnitAttr::get(parser.getContext()));
@@ -1020,6 +1024,9 @@ printCallCommon(mlir::Operation *op, 
mlir::FlatSymbolRefAttr calleeSym,
     printer << tryCall.getUnwindDest();
   }
 
+  if (op->hasAttr(CIRDialect::getMustTailAttrName()))
+    printer << " musttail";
+
   if (isNothrow)
     printer << " nothrow";
 
@@ -1031,6 +1038,7 @@ printCallCommon(mlir::Operation *op, 
mlir::FlatSymbolRefAttr calleeSym,
 
   llvm::SmallVector<::llvm::StringRef> elidedAttrs = {
       CIRDialect::getCalleeAttrName(),
+      CIRDialect::getMustTailAttrName(),
       CIRDialect::getNoThrowAttrName(),
       CIRDialect::getSideEffectAttrName(),
       CIRDialect::getOperandSegmentSizesAttrName(),

diff  --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 032971410c64b..b7fd20715287a 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1663,7 +1663,8 @@ static void lowerCallAttributes(cir::CIRCallOpInterface 
op,
         attr.getName() == CIRDialect::getSideEffectAttrName() ||
         attr.getName() == CIRDialect::getNoThrowAttrName() ||
         attr.getName() == CIRDialect::getNoUnwindAttrName() ||
-        attr.getName() == CIRDialect::getNoReturnAttrName())
+        attr.getName() == CIRDialect::getNoReturnAttrName() ||
+        attr.getName() == CIRDialect::getMustTailAttrName())
       continue;
 
     assert(!cir::MissingFeatures::opFuncExtraAttrs());
@@ -1764,6 +1765,8 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange 
callOperands,
     newOp.setNoUnwind(noUnwind);
     newOp.setWillReturn(willReturn);
     newOp.setNoreturn(noReturn);
+    if (op->hasAttr(CIRDialect::getMustTailAttrName()))
+      newOp.setTailCallKind(mlir::LLVM::TailCallKind::MustTail);
   }
 
   return mlir::success();

diff  --git a/clang/test/CIR/CodeGen/thunks.cpp 
b/clang/test/CIR/CodeGen/thunks.cpp
index 15c4810738420..b36e8a9805516 100644
--- a/clang/test/CIR/CodeGen/thunks.cpp
+++ b/clang/test/CIR/CodeGen/thunks.cpp
@@ -91,6 +91,36 @@ void C::g(int x) {}
 
 } // namespace Test4
 
+namespace CovariantReturn {
+// Covariant return with virtual inheritance: return-adjusting thunks use a
+// null check for pointer returns (classic PerformReturnAdjustment).
+struct A {
+  virtual A *f();
+};
+struct B : virtual A {
+  virtual A *f();
+};
+struct C : B {
+  virtual C *f();
+};
+C *C::f() { return 0; }
+} // namespace CovariantReturn
+
+namespace VarargThunk {
+// Variadic this-adjusting thunk.  On x86_64, the thunk forwards arguments
+// via musttail (classic codegen) or direct argument forwarding (CIR).
+struct A {
+  virtual void f(int x, ...);
+};
+struct B {
+  virtual void f(int x, ...);
+};
+struct C : A, B {
+  void f(int x, ...) override;
+};
+void C::f(int x, ...) {}
+} // namespace VarargThunk
+
 // In CIR, all globals are emitted before functions.
 
 // Test1 vtable: C's vtable references the thunk for B's entry.
@@ -183,6 +213,23 @@ void C::g(int x) {}
 // CIR:   cir.call @_ZN5Test41C1gEi(%[[T4_RESULT]], %[[T4_ARG]])
 // CIR:   cir.return
 
+// --- CovariantReturn: return adjustment with null check on pointer return ---
+
+// CIR-LABEL: cir.func {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// CIR:       cir.call @_ZN15CovariantReturn1C1fEv
+// CIR:       cir.ternary
+
+// --- VarargThunk: variadic this-adjusting thunk ---
+
+// CIR: cir.func {{.*}} @_ZThn8_N11VarargThunk1C1fEiz(%arg0: !cir.ptr<
+// CIR:   %[[VT_THIS:.*]] = cir.load
+// CIR:   %[[VT_CAST:.*]] = cir.cast bitcast %[[VT_THIS]] : !cir.ptr<{{.*}}> 
-> !cir.ptr<!u8i>
+// CIR:   %[[VT_OFFSET:.*]] = cir.const #cir.int<-8> : !s64i
+// CIR:   %[[VT_ADJUSTED:.*]] = cir.ptr_stride %[[VT_CAST]], %[[VT_OFFSET]] : 
(!cir.ptr<!u8i>, !s64i) -> !cir.ptr<!u8i>
+// CIR:   %[[VT_RESULT:.*]] = cir.cast bitcast %[[VT_ADJUSTED]] : 
!cir.ptr<!u8i> -> !cir.ptr<
+// CIR:   cir.call @_ZN11VarargThunk1C1fEiz(%[[VT_RESULT]], %arg1) musttail
+// CIR:   cir.return
+
 // --- LLVM checks ---
 
 // LLVM: @_ZTVN5Test11CE = global { [3 x ptr], [3 x ptr] } {
@@ -231,6 +278,14 @@ void C::g(int x) {}
 // LLVM:   %[[L4_ARG:.*]] = load i32, ptr
 // LLVM:   call void @_ZN5Test41C1gEi(ptr{{.*}} %[[L4_ADJ]], i32{{.*}} 
%[[L4_ARG]])
 
+// LLVM-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// LLVM:       call {{.*}} @_ZN15CovariantReturn1C1fEv
+// LLVM:       phi ptr
+
+// LLVM-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, 
i32{{.*}}, ...)
+// LLVM:   getelementptr i8, ptr {{.*}}, i64 -8
+// LLVM:   musttail call void (ptr, i32, ...) 
@_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
+
 // --- OGCG checks ---
 
 // OGCG: @_ZTVN5Test11CE = unnamed_addr constant { [3 x ptr], [3 x ptr] } {
@@ -278,3 +333,11 @@ void C::g(int x) {}
 // OGCG:   %[[O4_ADJ:.*]] = getelementptr inbounds i8, ptr %[[O4_THIS]], i64 -8
 // OGCG:   %[[O4_ARG:.*]] = load i32, ptr
 // OGCG:   {{.*}}call void @_ZN5Test41C1gEi(ptr{{.*}} %[[O4_ADJ]], i32{{.*}} 
%[[O4_ARG]])
+
+// OGCG-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// OGCG:       {{.*}}call {{.*}} @_ZN15CovariantReturn1C1fEv
+// OGCG:       phi ptr
+
+// OGCG-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, 
i32{{.*}}, ...)
+// OGCG:   getelementptr inbounds i8, ptr {{.*}}, i64 -8
+// OGCG:   musttail call void (ptr, i32, ...) 
@_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to