Author: Andy Kaylor Date: 2025-08-15T15:14:51-07:00 New Revision: 0cd35e7afd91ba64bdb2fc11caf13d0826780865
URL: https://github.com/llvm/llvm-project/commit/0cd35e7afd91ba64bdb2fc11caf13d0826780865 DIFF: https://github.com/llvm/llvm-project/commit/0cd35e7afd91ba64bdb2fc11caf13d0826780865.diff LOG: [CIR] Add cir.vtable.get_vptr operation (#153630) This adds support for the cir.vtable.get_vptr operation and uses it to initialize the vptr member during constructors of dynamic classes. Added: Modified: clang/include/clang/CIR/Dialect/IR/CIROps.td clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td clang/include/clang/CIR/Dialect/IR/CIRTypes.td clang/lib/CIR/CodeGen/CIRGenBuilder.h clang/lib/CIR/CodeGen/CIRGenClass.cpp clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h clang/test/CIR/CodeGen/virtual-function-calls.cpp Removed: ################################################################################ diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index a77e9199cdc96..a181c95494eff 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1749,6 +1749,39 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [ }]; } +//===----------------------------------------------------------------------===// +// VTableGetVPtr +//===----------------------------------------------------------------------===// + +def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> { + let summary = "Get a the address of the vtable pointer for an object"; + let description = [{ + The `vtable.get_vptr` operation retrieves the address of the vptr for a + C++ object. This operation requires that the object pointer points to + the start of a complete object. (TODO: Describe how we get that). + The vptr will always be at offset zero in the object, but this operation + is more explicit about what is being retrieved than a direct bitcast. + + The return type is always `!cir.ptr<!cir.vptr>`. + + Example: + ```mlir + %2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C> + %3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr> + ``` + }]; + + let arguments = (ins + Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src + ); + + let results = (outs CIR_PtrToVPtr:$result); + + let assemblyFormat = [{ + $src `:` qualified(type($src)) `->` qualified(type($result)) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // SetBitfieldOp //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td index d7d55dfbc0654..82f6e1d33043e 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td @@ -289,6 +289,14 @@ def CIR_AnyFloatOrVecOfFloatType let cppFunctionName = "isFPOrVectorOfFPType"; } +//===----------------------------------------------------------------------===// +// VPtr type predicates +//===----------------------------------------------------------------------===// + +def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType", "vptr type">; + +def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>; + //===----------------------------------------------------------------------===// // Scalar Type predicates //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index a258df79a6184..312d0a9422673 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -296,10 +296,10 @@ def CIR_VPtrType : CIR_Type<"VPtr", "vptr", [ access to the vptr. This type will be the element type of the 'vptr' member of structures that - require a vtable pointer. A pointer to this type is returned by the - `cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this - pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to - get the address of a virtual function pointer. + require a vtable pointer. The `cir.vtable.address_point` operation returns + this type. The `cir.vtable.get_vptr` operations returns a pointer to this + type. This pointer may be passed to the `cir.vtable.get_virtual_fn_addr` + operation to get the address of a virtual function pointer. The pointer may also be cast to other pointer types in order to perform pointer arithmetic based on information encoded in the AST layout to get diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index 59d2adc15a01a..a7537a0480a23 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -84,6 +84,10 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { llvm_unreachable("Unsupported format for long double"); } + mlir::Type getPtrToVPtrType() { + return getPointerTo(cir::VPtrType::get(getContext())); + } + /// Get a CIR record kind from a AST declaration tag. cir::RecordType::RecordKind getRecordKind(const clang::TagTypeKind kind) { switch (kind) { diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp index 31c93cd00d083..a3947047de079 100644 --- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp @@ -289,7 +289,7 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc, } // Apply the offsets. - Address vtableField = loadCXXThisAddress(); + Address classAddr = loadCXXThisAddress(); if (!nonVirtualOffset.isZero() || virtualOffset) { cgm.errorNYI(loc, "initializeVTablePointer: non-virtual and virtual offset"); @@ -300,9 +300,9 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc, // vtable field is derived from `this` pointer, therefore they should be in // the same addr space. assert(!cir::MissingFeatures::addressSpace()); - // TODO(cir): This should be cir.vtable.get_vptr. - vtableField = builder.createElementBitCast(loc, vtableField, - vtableAddressPoint.getType()); + auto vtablePtr = cir::VTableGetVPtrOp::create( + builder, loc, builder.getPtrToVPtrType(), classAddr.getPointer()); + Address vtableField = Address(vtablePtr, classAddr.getAlignment()); builder.createStore(loc, vtableAddressPoint, vtableField); assert(!cir::MissingFeatures::opTBAA()); assert(!cir::MissingFeatures::createInvariantGroup()); @@ -657,6 +657,23 @@ Address CIRGenFunction::getAddressOfBaseClass( return value; } +mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr, + const CXXRecordDecl *rd) { + auto vtablePtr = cir::VTableGetVPtrOp::create( + builder, loc, builder.getPtrToVPtrType(), thisAddr.getPointer()); + Address vtablePtrAddr = Address(vtablePtr, thisAddr.getAlignment()); + + auto vtable = builder.createLoad(loc, vtablePtrAddr); + assert(!cir::MissingFeatures::opTBAA()); + + if (cgm.getCodeGenOpts().OptimizationLevel > 0 && + cgm.getCodeGenOpts().StrictVTablePointers) { + assert(!cir::MissingFeatures::createInvariantGroup()); + } + + return vtable; +} + void CIRGenFunction::emitCXXConstructorCall(const clang::CXXConstructorDecl *d, clang::CXXCtorType type, bool forVirtualBase, diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 1ea296a6887ef..9f7521db78bec 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2344,7 +2344,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMVecShuffleOpLowering, CIRToLLVMVecSplatOpLowering, CIRToLLVMVecTernaryOpLowering, - CIRToLLVMVTableAddrPointOpLowering + CIRToLLVMVTableAddrPointOpLowering, + CIRToLLVMVTableGetVPtrOpLowering // clang-format on >(converter, patterns.getContext()); @@ -2468,6 +2469,18 @@ mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite( + cir::VTableGetVPtrOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // cir.vtable.get_vptr is equivalent to a bitcast from the source object + // pointer to the vptr type. Since the LLVM dialect uses opaque pointers + // we can just replace uses of this operation with the original pointer. + mlir::Value srcVal = adaptor.getSrc(); + rewriter.replaceAllUsesWith(op, srcVal); + rewriter.eraseOp(op); + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite( cir::StackSaveOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index e32bf2d1bae0c..91e8505233379 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -467,6 +467,16 @@ class CIRToLLVMVTableAddrPointOpLowering mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMVTableGetVPtrOpLowering + : public mlir::OpConversionPattern<cir::VTableGetVPtrOp> { +public: + using mlir::OpConversionPattern<cir::VTableGetVPtrOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VTableGetVPtrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + class CIRToLLVMStackSaveOpLowering : public mlir::OpConversionPattern<cir::StackSaveOp> { public: diff --git a/clang/test/CIR/CodeGen/virtual-function-calls.cpp b/clang/test/CIR/CodeGen/virtual-function-calls.cpp index 004b6dab30563..4787d78aa0e35 100644 --- a/clang/test/CIR/CodeGen/virtual-function-calls.cpp +++ b/clang/test/CIR/CodeGen/virtual-function-calls.cpp @@ -27,8 +27,8 @@ A::A() {} // CIR: cir.store %arg0, %[[THIS_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>> // CIR: %[[THIS:.*]] = cir.load %[[THIS_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A> // CIR: %[[VPTR:.*]] = cir.vtable.address_point(@_ZTV1A, address_point = <index = 0, offset = 2>) : !cir.vptr -// CIR: %[[THIS_VPTR_PTR:.*]] = cir.cast(bitcast, %[[THIS]] : !cir.ptr<!rec_A>), !cir.ptr<!cir.vptr> -// CIR: cir.store align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr> +// CIR: %[[THIS_VPTR_PTR:.*]] = cir.vtable.get_vptr %[[THIS]] : !cir.ptr<!rec_A> -> !cir.ptr<!cir.vptr> +// CIR: cir.store{{.*}} align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr> // CIR: cir.return // LLVM: define{{.*}} void @_ZN1AC2Ev(ptr %[[ARG0:.*]]) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits