Author: Andy Kaylor
Date: 2025-11-17T09:09:46-08:00
New Revision: 0c8464330a510e0c3b629883ed1acd81da17da5d

URL: 
https://github.com/llvm/llvm-project/commit/0c8464330a510e0c3b629883ed1acd81da17da5d
DIFF: 
https://github.com/llvm/llvm-project/commit/0c8464330a510e0c3b629883ed1acd81da17da5d.diff

LOG: [CIR] Upstream handling for BaseToDerived casts (#167769)

Upstream handling for BaseToDerived casts, adding the
cir.base_class_addr operation and lowering to LLVM IR.

Added: 
    clang/test/CIR/CodeGen/base-to-derived.cpp

Modified: 
    clang/include/clang/CIR/Dialect/IR/CIROps.td
    clang/lib/CIR/CodeGen/CIRGenBuilder.h
    clang/lib/CIR/CodeGen/CIRGenClass.cpp
    clang/lib/CIR/CodeGen/CIRGenExpr.cpp
    clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
    clang/lib/CIR/CodeGen/CIRGenFunction.h
    clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 2124b1dc62a81..7b987ea49bf97 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -3386,6 +3386,10 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
     cannot be known by the operation, and that information affects how the
     operation is lowered.
 
+    The validity of the relationship of derived and base cannot yet be 
verified.
+    If the target class is not a valid base class for the object, the behavior
+    is undefined.
+
     Example:
     ```c++
     struct Base { };
@@ -3399,8 +3403,6 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
     ```
   }];
 
-  // The validity of the relationship of derived and base cannot yet be
-  // verified, currently not worth adding a verifier.
   let arguments = (ins
     Arg<CIR_PointerType, "derived class pointer", [MemRead]>:$derived_addr,
     IndexAttr:$offset, UnitAttr:$assume_not_null);
@@ -3414,6 +3416,56 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// DerivedClassAddrOp
+//===----------------------------------------------------------------------===//
+
+def CIR_DerivedClassAddrOp : CIR_Op<"derived_class_addr"> {
+  let summary = "Get the derived class address for a class/struct";
+  let description = [{
+    The `cir.derived_class_addr` operaration gets the address of a particular
+    derived class given a non-virtual base class pointer. The offset in bytes
+    of the base class must be passed in, similar to `cir.base_class_addr`, but
+    going into the other direction. This means lowering to a negative offset.
+
+    The operation contains a flag for whether or not the operand may be 
nullptr.
+    That depends on the context and cannot be known by the operation, and that
+    information affects how the operation is lowered.
+
+    The validity of the relationship of derived and base cannot yet be 
verified.
+    If the target class is not a valid derived class for the object, the
+    behavior is undefined.
+
+    Example:
+    ```c++
+    class A {};
+    class B : public A {};
+
+    B *getAsB(A *a) {
+      return static_cast<B*>(a);
+    }
+    ```
+
+    leads to
+    ```mlir
+      %2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
+      %3 = cir.base_class_addr %2 : !cir.ptr<!rec_B> [0] -> !cir.ptr<!rec_A>
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<CIR_PointerType, "base class pointer", [MemRead]>:$base_addr,
+    IndexAttr:$offset, UnitAttr:$assume_not_null);
+
+  let results = (outs Res<CIR_PointerType, "">:$derived_addr);
+
+  let assemblyFormat = [{
+      $base_addr `:` qualified(type($base_addr))
+      (`nonnull` $assume_not_null^)?
+      ` ` `[` $offset `]` `->` qualified(type($derived_addr)) attr-dict
+  }];
+}
+
 
//===----------------------------------------------------------------------===//
 // ComplexCreateOp
 
//===----------------------------------------------------------------------===//

diff  --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h 
b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
index a391d7e70ace7..5ab1d0e05cf8a 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h
+++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
@@ -405,6 +405,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
     return Address(baseAddr, destType, addr.getAlignment());
   }
 
+  Address createDerivedClassAddr(mlir::Location loc, Address addr,
+                                 mlir::Type destType, unsigned offset,
+                                 bool assumeNotNull) {
+    if (destType == addr.getElementType())
+      return addr;
+
+    cir::PointerType ptrTy = getPointerTo(destType);
+    auto derivedAddr =
+        cir::DerivedClassAddrOp::create(*this, loc, ptrTy, addr.getPointer(),
+                                        mlir::APInt(64, offset), 
assumeNotNull);
+    return Address(derivedAddr, destType, addr.getAlignment());
+  }
+
   mlir::Value createVTTAddrPoint(mlir::Location loc, mlir::Type retTy,
                                  mlir::Value addr, uint64_t offset) {
     return cir::VTTAddrPointOp::create(*this, loc, retTy,

diff  --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp 
b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
index a8296782ebc40..89c4696b9da94 100644
--- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
@@ -1110,6 +1110,25 @@ mlir::Value CIRGenFunction::getVTTParameter(GlobalDecl 
gd, bool forVirtualBase,
   }
 }
 
+Address CIRGenFunction::getAddressOfDerivedClass(
+    mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
+    llvm::iterator_range<CastExpr::path_const_iterator> path,
+    bool nullCheckValue) {
+  assert(!path.empty() && "Base path should not be empty!");
+
+  QualType derivedTy = getContext().getCanonicalTagType(derived);
+  mlir::Type derivedValueTy = convertType(derivedTy);
+  CharUnits nonVirtualOffset =
+      cgm.computeNonVirtualBaseClassOffset(derived, path);
+
+  // Note that in OG, no offset (nonVirtualOffset.getQuantity() == 0) means it
+  // just gives the address back. In CIR a `cir.derived_class` is created and
+  // made into a nop later on during lowering.
+  return builder.createDerivedClassAddr(loc, baseAddr, derivedValueTy,
+                                        nonVirtualOffset.getQuantity(),
+                                        /*assumeNotNull=*/!nullCheckValue);
+}
+
 Address CIRGenFunction::getAddressOfBaseClass(
     Address value, const CXXRecordDecl *derived,
     llvm::iterator_range<CastExpr::path_const_iterator> path,

diff  --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp 
b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index d35bb0af0de14..681a801cd7d81 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -1301,7 +1301,6 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
   case CK_NonAtomicToAtomic:
   case CK_AtomicToNonAtomic:
   case CK_ToUnion:
-  case CK_BaseToDerived:
   case CK_ObjCObjectLValueCast:
   case CK_VectorSplat:
   case CK_ConstructorConversion:
@@ -1336,6 +1335,7 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
                                   lv.getAddress().getAlignment()),
                           e->getType(), lv.getBaseInfo());
   }
+
   case CK_LValueBitCast: {
     // This must be a reinterpret_cast (or c-style equivalent).
     const auto *ce = cast<ExplicitCastExpr>(e);
@@ -1387,6 +1387,22 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) 
{
     return makeAddrLValue(baseAddr, e->getType(), lv.getBaseInfo());
   }
 
+  case CK_BaseToDerived: {
+    const auto *derivedClassDecl = e->getType()->castAsCXXRecordDecl();
+    LValue lv = emitLValue(e->getSubExpr());
+
+    // Perform the base-to-derived conversion
+    Address derived = getAddressOfDerivedClass(
+        getLoc(e->getSourceRange()), lv.getAddress(), derivedClassDecl,
+        e->path(), /*NullCheckValue=*/false);
+    // C++11 [expr.static.cast]p2: Behavior is undefined if a downcast is
+    // performed and the object is not of the derived type.
+    assert(!cir::MissingFeatures::sanitizers());
+
+    assert(!cir::MissingFeatures::opTBAA());
+    return makeAddrLValue(derived, e->getType(), lv.getBaseInfo());
+  }
+
   case CK_ZeroToOCLOpaqueType:
     llvm_unreachable("NULL to OpenCL opaque type lvalue cast is not valid");
   }

diff  --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp 
b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index ce95607bd468d..3b0977d213325 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1972,6 +1972,20 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr 
*ce) {
     return builder.createIntToPtr(middleVal, destCIRTy);
   }
 
+  case CK_BaseToDerived: {
+    const CXXRecordDecl *derivedClassDecl = destTy->getPointeeCXXRecordDecl();
+    assert(derivedClassDecl && "BaseToDerived arg isn't a C++ object 
pointer!");
+    Address base = cgf.emitPointerWithAlignment(subExpr);
+    Address derived = cgf.getAddressOfDerivedClass(
+        cgf.getLoc(ce->getSourceRange()), base, derivedClassDecl, ce->path(),
+        cgf.shouldNullCheckClassCastValue(ce));
+
+    // C++11 [expr.static.cast]p11: Behavior is undefined if a downcast is
+    // performed and the object is not of the derived type.
+    assert(!cir::MissingFeatures::sanitizers());
+
+    return cgf.getAsNaturalPointerTo(derived, ce->getType()->getPointeeType());
+  }
   case CK_UncheckedDerivedToBase:
   case CK_DerivedToBase: {
     // The EmitPointerWithAlignment path does this fine; just discard
@@ -1979,7 +1993,6 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr 
*ce) {
     return cgf.getAsNaturalPointerTo(cgf.emitPointerWithAlignment(ce),
                                      ce->getType()->getPointeeType());
   }
-
   case CK_Dynamic: {
     Address v = cgf.emitPointerWithAlignment(subExpr);
     const auto *dce = cast<CXXDynamicCastExpr>(ce);

diff  --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h 
b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 2dddf26981105..b22bf2d87fc10 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -823,6 +823,11 @@ class CIRGenFunction : public CIRGenTypeCache {
       llvm::iterator_range<CastExpr::path_const_iterator> path,
       bool nullCheckValue, SourceLocation loc);
 
+  Address getAddressOfDerivedClass(
+      mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
+      llvm::iterator_range<CastExpr::path_const_iterator> path,
+      bool nullCheckValue);
+
   /// Return the VTT parameter that should be passed to a base
   /// constructor/destructor with virtual bases.
   /// FIXME: VTTs are Itanium ABI-specific, so the definition should move

diff  --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index d88a4ad76f27b..92434d730eb31 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1360,6 +1360,41 @@ mlir::LogicalResult 
CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMDerivedClassAddrOpLowering::matchAndRewrite(
+    cir::DerivedClassAddrOp derivedClassOp, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  const mlir::Type resultType =
+      getTypeConverter()->convertType(derivedClassOp.getType());
+  mlir::Value baseAddr = adaptor.getBaseAddr();
+  // The offset is set in the operation as an unsigned value, but it must be
+  // applied as a negative offset.
+  int64_t offsetVal = -(adaptor.getOffset().getZExtValue());
+  if (offsetVal == 0) {
+    // If the offset is zero, we can just return the base address,
+    rewriter.replaceOp(derivedClassOp, baseAddr);
+    return mlir::success();
+  }
+  llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {offsetVal};
+  mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
+                                               mlir::IntegerType::Signless);
+  if (derivedClassOp.getAssumeNotNull()) {
+    rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
+        derivedClassOp, resultType, byteType, baseAddr, offset,
+        mlir::LLVM::GEPNoWrapFlags::inbounds);
+  } else {
+    mlir::Location loc = derivedClassOp.getLoc();
+    mlir::Value isNull = mlir::LLVM::ICmpOp::create(
+        rewriter, loc, mlir::LLVM::ICmpPredicate::eq, baseAddr,
+        mlir::LLVM::ZeroOp::create(rewriter, loc, baseAddr.getType()));
+    mlir::Value adjusted =
+        mlir::LLVM::GEPOp::create(rewriter, loc, resultType, byteType, 
baseAddr,
+                                  offset, 
mlir::LLVM::GEPNoWrapFlags::inbounds);
+    rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(derivedClassOp, isNull,
+                                                      baseAddr, adjusted);
+  }
+  return mlir::success();
+}
+
 mlir::LogicalResult CIRToLLVMATanOpLowering::matchAndRewrite(
     cir::ATanOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {

diff  --git a/clang/test/CIR/CodeGen/base-to-derived.cpp 
b/clang/test/CIR/CodeGen/base-to-derived.cpp
new file mode 100644
index 0000000000000..af9aa0ffd19c1
--- /dev/null
+++ b/clang/test/CIR/CodeGen/base-to-derived.cpp
@@ -0,0 +1,97 @@
+// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-cir %s 
-o %t.cir
+// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-llvm 
%s -o %t-cir.ll
+// RUN: FileCheck --check-prefix=LLVM --input-file=%t-cir.ll %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -emit-llvm %s -o %t.ll
+// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s
+
+class A {
+    int a;
+};
+
+class B {
+    int b;
+public:
+    A *getAsA();
+};
+
+class X : public A, public B {
+    int x;
+};
+
+X *castAtoX(A *a) {
+  return static_cast<X*>(a);
+}
+
+// CIR: cir.func {{.*}} @_Z8castAtoXP1A(%[[ARG0:.*]]: !cir.ptr<!rec_A> {{.*}})
+// CIR:   %[[A_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A>, 
!cir.ptr<!cir.ptr<!rec_A>>, ["a", init]
+// CIR:   cir.store %[[ARG0]], %[[A_ADDR]] : !cir.ptr<!rec_A>, 
!cir.ptr<!cir.ptr<!rec_A>>
+// CIR:   %[[A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, 
!cir.ptr<!rec_A>
+// CIR:   %[[X:.*]] = cir.derived_class_addr %[[A]] : !cir.ptr<!rec_A> [0] -> 
!cir.ptr<!rec_X>
+
+// Note: Because the offset is 0, a null check is not needed.
+
+// LLVM: define {{.*}} ptr @_Z8castAtoXP1A(ptr %[[ARG0:.*]])
+// LLVM:   %[[A_ADDR:.*]] = alloca ptr
+// LLVM:   store ptr %[[ARG0]], ptr %[[A_ADDR]]
+// LLVM:   %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
+
+// OGCG: define {{.*}} ptr @_Z8castAtoXP1A(ptr {{.*}} %[[ARG0:.*]])
+// OGCG:   %[[A_ADDR:.*]] = alloca ptr
+// OGCG:   store ptr %[[ARG0]], ptr %[[A_ADDR]]
+// OGCG:   %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
+
+X *castBtoX(B *b) {
+  return static_cast<X*>(b);
+}
+
+// CIR: cir.func {{.*}} @_Z8castBtoXP1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> {{.*}})
+// CIR:   %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, 
!cir.ptr<!cir.ptr<!rec_B>>, ["b", init]
+// CIR:   cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, 
!cir.ptr<!cir.ptr<!rec_B>>
+// CIR:   %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, 
!cir.ptr<!rec_B>
+// CIR:   %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> [4] -> 
!cir.ptr<!rec_X>
+
+// LLVM: define {{.*}} ptr @_Z8castBtoXP1B(ptr %[[ARG0:.*]])
+// LLVM:   %[[B_ADDR:.*]] = alloca ptr, i64 1, align 8
+// LLVM:   store ptr %[[ARG0]], ptr %[[B_ADDR]], align 8
+// LLVM:   %[[B:.*]] = load ptr, ptr %[[B_ADDR]], align 8
+// LLVM:   %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
+// LLVM:   %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
+// LLVM:   %[[X:.*]] = select i1 %[[IS_NULL]], ptr %[[B]], ptr %[[B_NON_NULL]]
+
+// OGCG: define {{.*}} ptr @_Z8castBtoXP1B(ptr {{.*}} %[[ARG0:.*]])
+// OGCG: entry:
+// OGCG:   %[[B_ADDR:.*]] = alloca ptr
+// OGCG:   store ptr %[[ARG0]], ptr %[[B_ADDR]]
+// OGCG:   %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
+// OGCG:   %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
+// OGCG:   br i1 %[[IS_NULL]], label %[[LABEL_NULL:.*]], label 
%[[LABEL_NOTNULL:.*]]
+// OGCG: [[LABEL_NOTNULL]]:
+// OGCG:   %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4
+// OGCG:   br label %[[LABEL_END:.*]]
+// OGCG: [[LABEL_NULL]]:
+// OGCG:   br label %[[LABEL_END:.*]]
+// OGCG: [[LABEL_END]]:
+// OGCG:   %[[X:.*]] = phi ptr [ %[[B_NON_NULL]], %[[LABEL_NOTNULL]] ], [ 
null, %[[LABEL_NULL]] ]
+
+X &castBReftoXRef(B &b) {
+  return static_cast<X&>(b);
+}
+
+// CIR: cir.func {{.*}} @_Z14castBReftoXRefR1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> 
{{.*}})
+// CIR:   %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, 
!cir.ptr<!cir.ptr<!rec_B>>, ["b", init, const]
+// CIR:   cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, 
!cir.ptr<!cir.ptr<!rec_B>>
+// CIR:   %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, 
!cir.ptr<!rec_B>
+// CIR:   %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> nonnull 
[4] -> !cir.ptr<!rec_X>
+
+// LLVM: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr %[[ARG0:.*]])
+// LLVM:   %[[B_ADDR:.*]] = alloca ptr
+// LLVM:   store ptr %[[ARG0]], ptr %[[B_ADDR]]
+// LLVM:   %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
+// LLVM:   %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
+
+// OGCG: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr {{.*}} %[[ARG0:.*]])
+// OGCG:   %[[B_ADDR:.*]] = alloca ptr
+// OGCG:   store ptr %[[ARG0]], ptr %[[B_ADDR]]
+// OGCG:   %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
+// OGCG:   %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to