https://github.com/el-ev created https://github.com/llvm/llvm-project/pull/137501
Closes #136059 >From f27d8d2f2ecbf65479601f1d21206b1cb4d9ef6a Mon Sep 17 00:00:00 2001 From: Iris Shi <0...@owo.li> Date: Sun, 27 Apr 2025 15:16:19 +0800 Subject: [PATCH] [CIR] Upstream initial support for union type --- .../include/clang/CIR/Dialect/IR/CIRTypes.td | 1 + clang/lib/CIR/CodeGen/CIRGenExpr.cpp | 25 ++-- clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp | 2 +- .../CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp | 93 +++++++++++- clang/lib/CIR/Dialect/IR/CIRTypes.cpp | 39 +++-- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 15 +- clang/test/CIR/CodeGen/union.c | 135 +++++++++++++++++- 7 files changed, 275 insertions(+), 35 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index 0a821e152d353..afb7da1fa011d 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -494,6 +494,7 @@ def CIR_RecordType : CIR_Type<"Record", "record", bool isComplete() const { return !isIncomplete(); }; bool isIncomplete() const; + mlir::Type getLargestMember(const mlir::DataLayout &dataLayout) const; size_t getNumElements() const { return getMembers().size(); }; std::string getKindAsStr() { switch (getKind()) { diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 0a518c0fd935d..f3bcffe042280 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -317,20 +317,25 @@ LValue CIRGenFunction::emitLValueForField(LValue base, const FieldDecl *field) { } unsigned recordCVR = base.getVRQualifiers(); - if (rec->isUnion()) { - cgm.errorNYI(field->getSourceRange(), "emitLValueForField: union"); - return LValue(); - } - assert(!cir::MissingFeatures::preservedAccessIndexRegion()); llvm::StringRef fieldName = field->getName(); - const CIRGenRecordLayout &layout = - cgm.getTypes().getCIRGenRecordLayout(field->getParent()); - unsigned fieldIndex = layout.getCIRFieldNo(field); - assert(!cir::MissingFeatures::lambdaFieldToName()); + if (rec->isUnion()) { + unsigned fieldIndex = field->getFieldIndex(); + assert(!cir::MissingFeatures::lambdaFieldToName()); + addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex); - addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex); + } else { + assert(!cir::MissingFeatures::preservedAccessIndexRegion()); + + const CIRGenRecordLayout &layout = + cgm.getTypes().getCIRGenRecordLayout(field->getParent()); + unsigned fieldIndex = layout.getCIRFieldNo(field); + + assert(!cir::MissingFeatures::lambdaFieldToName()); + + addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex); + } // If this is a reference field, load the reference right now. if (fieldType->isReferenceType()) { diff --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp index 368a6cb27c0fd..e006a77c6e7d6 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp @@ -1,4 +1,4 @@ -//===--- CIRGenExprAgg.cpp - Emit CIR Code from Aggregate Expressions -----===// +//===- CIRGenExprAggregrate.cpp - Emit CIR Code from Aggregate Expressions ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp b/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp index 83aba256cd48e..778a178505684 100644 --- a/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp @@ -56,7 +56,7 @@ struct CIRRecordLowering final { }; // The constructor. CIRRecordLowering(CIRGenTypes &cirGenTypes, const RecordDecl *recordDecl, - bool isPacked); + bool packed); /// Constructs a MemberInfo instance from an offset and mlir::Type. MemberInfo makeStorageInfo(CharUnits offset, mlir::Type data) { @@ -64,6 +64,7 @@ struct CIRRecordLowering final { } void lower(); + void lowerUnion(); /// Determines if we need a packed llvm struct. void determinePacked(); @@ -83,6 +84,10 @@ struct CIRRecordLowering final { return CharUnits::fromQuantity(dataLayout.layout.getTypeABIAlignment(Ty)); } + bool isZeroInitializable(const FieldDecl *fd) { + return cirGenTypes.isZeroInitializable(fd->getType()); + } + /// Wraps cir::IntType with some implicit arguments. mlir::Type getUIntNType(uint64_t numBits) { unsigned alignedBits = llvm::PowerOf2Ceil(numBits); @@ -121,6 +126,13 @@ struct CIRRecordLowering final { /// Fills out the structures that are ultimately consumed. void fillOutputFields(); + void appendPaddingBytes(CharUnits size) { + if (!size.isZero()) { + fieldTypes.push_back(getByteArrayType(size)); + padded = true; + } + } + CIRGenTypes &cirGenTypes; CIRGenBuilderTy &builder; const ASTContext &astContext; @@ -136,6 +148,8 @@ struct CIRRecordLowering final { LLVM_PREFERRED_TYPE(bool) unsigned zeroInitializable : 1; LLVM_PREFERRED_TYPE(bool) + unsigned zeroInitializableAsBase : 1; + LLVM_PREFERRED_TYPE(bool) unsigned packed : 1; LLVM_PREFERRED_TYPE(bool) unsigned padded : 1; @@ -148,18 +162,19 @@ struct CIRRecordLowering final { CIRRecordLowering::CIRRecordLowering(CIRGenTypes &cirGenTypes, const RecordDecl *recordDecl, - bool isPacked) + bool packed) : cirGenTypes(cirGenTypes), builder(cirGenTypes.getBuilder()), astContext(cirGenTypes.getASTContext()), recordDecl(recordDecl), astRecordLayout( cirGenTypes.getASTContext().getASTRecordLayout(recordDecl)), dataLayout(cirGenTypes.getCGModule().getModule()), - zeroInitializable(true), packed(isPacked), padded(false) {} + zeroInitializable(true), zeroInitializableAsBase(true), packed(packed), + padded(false) {} void CIRRecordLowering::lower() { if (recordDecl->isUnion()) { - cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(), - "lower: union"); + lowerUnion(); + assert(!cir::MissingFeatures::bitfields()); return; } @@ -306,3 +321,71 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) { // TODO: implement verification return rl; } + +void CIRRecordLowering::lowerUnion() { + CharUnits layoutSize = astRecordLayout.getSize(); + mlir::Type storageType = nullptr; + bool seenNamedMember = false; + + // Iterate through the fields setting bitFieldInfo and the Fields array. Also + // locate the "most appropriate" storage type. The heuristic for finding the + // storage type isn't necessary, the first (non-0-length-bitfield) field's + // type would work fine and be simpler but would be different than what we've + // been doing and cause lit tests to change. + for (const FieldDecl *field : recordDecl->fields()) { + mlir::Type fieldType; + if (field->isBitField()) + cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(), + "bitfields in lowerUnion"); + else + fieldType = getStorageType(field); + + fields[field->getCanonicalDecl()] = 0; + + // Compute zero-initializable status. + // This union might not be zero initialized: it may contain a pointer to + // data member which might have some exotic initialization sequence. + // If this is the case, then we aught not to try and come up with a "better" + // type, it might not be very easy to come up with a Constant which + // correctly initializes it. + if (!seenNamedMember) { + seenNamedMember = field->getIdentifier(); + if (!seenNamedMember) + if (const RecordDecl *fieldRD = field->getType()->getAsRecordDecl()) + seenNamedMember = fieldRD->findFirstNamedDataMember(); + if (seenNamedMember && !isZeroInitializable(field)) { + zeroInitializable = zeroInitializableAsBase = false; + storageType = fieldType; + } + } + + // Because our union isn't zero initializable, we won't be getting a better + // storage type. + if (!zeroInitializable) + continue; + + // Conditionally update our storage type if we've got a new "better" one. + if (!storageType || getAlignment(fieldType) > getAlignment(storageType) || + (getAlignment(fieldType) == getAlignment(storageType) && + getSize(fieldType) > getSize(storageType))) + storageType = fieldType; + + // NOTE(cir): Track all union member's types, not just the largest one. It + // allows for proper type-checking and retain more info for analisys. + fieldTypes.push_back(fieldType); + } + + if (!storageType) + cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(), + "No-storage Union NYI"); + + if (layoutSize < getSize(storageType)) + storageType = getByteArrayType(layoutSize); + + // NOTE(cir): Defer padding calculations to the lowering process. + appendPaddingBytes(layoutSize - getSize(storageType)); + + // Set packed if we need it. + if (layoutSize % getAlignment(storageType)) + packed = true; +} diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index c6133b9a20e4f..38ac14f80c06c 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -230,6 +230,31 @@ void RecordType::complete(ArrayRef<Type> members, bool packed, bool padded) { llvm_unreachable("failed to complete record"); } +/// Return the largest member of in the type. +/// +/// Recurses into union members never returning a union as the largest member. +Type RecordType::getLargestMember(const ::mlir::DataLayout &dataLayout) const { + assert(isUnion() && "Only call getLargestMember on unions"); + Type largestMember; + unsigned largestMemberSize = 0; + unsigned numElements = getNumElements(); + auto members = getMembers(); + if (getPadded()) + numElements -= 1; // The last element is padding. + for (unsigned i = 0; i < numElements; ++i) { + Type ty = members[i]; + if (!largestMember || dataLayout.getTypeABIAlignment(ty) > + dataLayout.getTypeABIAlignment(largestMember) || + (dataLayout.getTypeABIAlignment(ty) == + dataLayout.getTypeABIAlignment(largestMember) && + dataLayout.getTypeSize(ty) > largestMemberSize)) { + largestMember = ty; + largestMemberSize = dataLayout.getTypeSize(largestMember); + } + } + return largestMember; +} + //===----------------------------------------------------------------------===// // Data Layout information for types //===----------------------------------------------------------------------===// @@ -237,10 +262,8 @@ void RecordType::complete(ArrayRef<Type> members, bool packed, bool padded) { llvm::TypeSize RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout, mlir::DataLayoutEntryListRef params) const { - if (isUnion()) { - // TODO(CIR): Implement union layout. - return llvm::TypeSize::getFixed(8); - } + if (isUnion()) + return dataLayout.getTypeSize(getLargestMember(dataLayout)); unsigned recordSize = computeStructSize(dataLayout); return llvm::TypeSize::getFixed(recordSize * 8); @@ -249,10 +272,8 @@ RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout, uint64_t RecordType::getABIAlignment(const ::mlir::DataLayout &dataLayout, ::mlir::DataLayoutEntryListRef params) const { - if (isUnion()) { - // TODO(CIR): Implement union layout. - return 8; - } + if (isUnion()) + return dataLayout.getTypeABIAlignment(getLargestMember(dataLayout)); // Packed structures always have an ABI alignment of 1. if (getPacked()) @@ -268,8 +289,6 @@ RecordType::computeStructSize(const mlir::DataLayout &dataLayout) const { unsigned recordSize = 0; uint64_t recordAlignment = 1; - // We can't use a range-based for loop here because we might be ignoring the - // last element. for (mlir::Type ty : getMembers()) { // This assumes that we're calculating size based on the ABI alignment, not // the preferred alignment for each type. diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 8bb27942d9646..0dba45a5aba35 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1431,7 +1431,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, break; // Unions are lowered as only the largest member. case cir::RecordType::Union: - llvm_unreachable("Lowering of unions is NYI"); + if (auto largestMember = type.getLargestMember(dataLayout)) + llvmMembers.push_back( + convertTypeForMemory(converter, dataLayout, largestMember)); + if (type.getPadded()) { + auto last = *type.getMembers().rbegin(); + llvmMembers.push_back( + convertTypeForMemory(converter, dataLayout, last)); + } break; } @@ -1604,7 +1611,11 @@ mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite( return mlir::success(); } case cir::RecordType::Union: - return op.emitError() << "NYI: union get_member lowering"; + // Union members share the address space, so we just need a bitcast to + // conform to type-checking. + rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy, + adaptor.getAddr()); + return mlir::success(); } } diff --git a/clang/test/CIR/CodeGen/union.c b/clang/test/CIR/CodeGen/union.c index c4db37f835add..71cb9c2b20ca1 100644 --- a/clang/test/CIR/CodeGen/union.c +++ b/clang/test/CIR/CodeGen/union.c @@ -5,25 +5,146 @@ // RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll // RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s +union U1 { + int n; + char c; +}; + +// CIR: !rec_U1 = !cir.record<union "U1" {!s32i, !s8i}> +// LLVM: %union.U1 = type { i32 } +// OGCG: %union.U1 = type { i32 } + +union U2 { + char b; + short s; + int i; + float f; + double d; +}; + +// CIR: !rec_U2 = !cir.record<union "U2" {!s8i, !s16i, !s32i, !cir.float, !cir.double}> +// LLVM: %union.U2 = type { double } +// OGCG: %union.U2 = type { double } + union IncompleteU *p; -// CIR: cir.global external @p = #cir.ptr<null> : !cir.ptr<!rec_IncompleteU> +// CIR: cir.global external @p = #cir.ptr<null> : !cir.ptr<!rec_IncompleteU> // LLVM: @p = dso_local global ptr null // OGCG: @p = global ptr null, align 8 -void f(void) { +void f1(void) { union IncompleteU *p; } -// CIR: cir.func @f() -// CIR-NEXT: cir.alloca !cir.ptr<!rec_IncompleteU>, !cir.ptr<!cir.ptr<!rec_IncompleteU>>, ["p"] -// CIR-NEXT: cir.return +// CIR: cir.func @f1() +// CIR-NEXT: cir.alloca !cir.ptr<!rec_IncompleteU>, !cir.ptr<!cir.ptr<!rec_IncompleteU>>, ["p"] +// CIR-NEXT: cir.return -// LLVM: define void @f() +// LLVM: define void @f1() // LLVM-NEXT: %[[P:.*]] = alloca ptr, i64 1, align 8 // LLVM-NEXT: ret void -// OGCG: define{{.*}} void @f() +// OGCG: define{{.*}} void @f1() // OGCG-NEXT: entry: // OGCG-NEXT: %[[P:.*]] = alloca ptr, align 8 // OGCG-NEXT: ret void + +int f2(void) { + union U1 u; + u.n = 42; + return u.n; +} + +// CIR: cir.func @f2() -> !s32i +// CIR-NEXT: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} +// CIR-NEXT: %1 = cir.alloca !rec_U1, !cir.ptr<!rec_U1>, ["u"] {alignment = 4 : i64} +// CIR-NEXT: %2 = cir.const #cir.int<42> : !s32i +// CIR-NEXT: %3 = cir.get_member %1[0] {name = "n"} : !cir.ptr<!rec_U1> -> !cir.ptr<!s32i> +// CIR-NEXT: cir.store %2, %3 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: %4 = cir.get_member %1[0] {name = "n"} : !cir.ptr<!rec_U1> -> !cir.ptr<!s32i> +// CIR-NEXT: %5 = cir.load %4 : !cir.ptr<!s32i>, !s32i +// CIR-NEXT: cir.store %5, %0 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: %6 = cir.load %0 : !cir.ptr<!s32i>, !s32i +// CIR-NEXT: cir.return %6 : !s32i + +// LLVM: define i32 @f2() +// LLVM-NEXT: %1 = alloca i32, i64 1, align 4 +// LLVM-NEXT: %2 = alloca %union.U1, i64 1, align 4 +// LLVM-NEXT: store i32 42, ptr %2, align 4 +// LLVM-NEXT: %3 = load i32, ptr %2, align 4 +// LLVM-NEXT: store i32 %3, ptr %1, align 4 +// LLVM-NEXT: %4 = load i32, ptr %1, align 4 +// LLVM-NEXT: ret i32 %4 + +// OGCG: define dso_local i32 @f2() +// OGCG-NEXT: entry: +// OGCG-NEXT: %u = alloca %union.U1, align 4 +// OGCG-NEXT: store i32 42, ptr %u, align 4 +// OGCG-NEXT: %0 = load i32, ptr %u, align 4 +// OGCG-NEXT: ret i32 %0 + + +void shouldGenerateUnionAccess(union U2 u) { + u.b = 0; + u.b; + u.i = 1; + u.i; + u.f = 0.1F; + u.f; + u.d = 0.1; + u.d; +} + +// CIR: cir.func @shouldGenerateUnionAccess(%arg0: !rec_U2 +// CIR-NEXT: %0 = cir.alloca !rec_U2, !cir.ptr<!rec_U2>, ["u", init] {alignment = 8 : i64} +// CIR-NEXT: cir.store %arg0, %0 : !rec_U2, !cir.ptr<!rec_U2> +// CIR-NEXT: %1 = cir.const #cir.int<0> : !s32i +// CIR-NEXT: %2 = cir.cast(integral, %1 : !s32i), !s8i +// CIR-NEXT: %3 = cir.get_member %0[0] {name = "b"} : !cir.ptr<!rec_U2> -> !cir.ptr<!s8i> +// CIR-NEXT: cir.store %2, %3 : !s8i, !cir.ptr<!s8i> +// CIR-NEXT: %4 = cir.get_member %0[0] {name = "b"} : !cir.ptr<!rec_U2> -> !cir.ptr<!s8i> +// CIR-NEXT: %5 = cir.load %4 : !cir.ptr<!s8i>, !s8i +// CIR-NEXT: %6 = cir.const #cir.int<1> : !s32i +// CIR-NEXT: %7 = cir.get_member %0[2] {name = "i"} : !cir.ptr<!rec_U2> -> !cir.ptr<!s32i> +// CIR-NEXT: cir.store %6, %7 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: %8 = cir.get_member %0[2] {name = "i"} : !cir.ptr<!rec_U2> -> !cir.ptr<!s32i> +// CIR-NEXT: %9 = cir.load %8 : !cir.ptr<!s32i>, !s32i +// CIR-NEXT: %10 = cir.const #cir.fp<1.000000e-01> : !cir.float +// CIR-NEXT: %11 = cir.get_member %0[3] {name = "f"} : !cir.ptr<!rec_U2> -> !cir.ptr<!cir.float> +// CIR-NEXT: cir.store %10, %11 : !cir.float, !cir.ptr<!cir.float> +// CIR-NEXT: %12 = cir.get_member %0[3] {name = "f"} : !cir.ptr<!rec_U2> -> !cir.ptr<!cir.float> +// CIR-NEXT: %13 = cir.load %12 : !cir.ptr<!cir.float>, !cir.float +// CIR-NEXT: %14 = cir.const #cir.fp<1.000000e-01> : !cir.double +// CIR-NEXT: %15 = cir.get_member %0[4] {name = "d"} : !cir.ptr<!rec_U2> -> !cir.ptr<!cir.double> +// CIR-NEXT: cir.store %14, %15 : !cir.double, !cir.ptr<!cir.double> +// CIR-NEXT: %16 = cir.get_member %0[4] {name = "d"} : !cir.ptr<!rec_U2> -> !cir.ptr<!cir.double> +// CIR-NEXT: %17 = cir.load %16 : !cir.ptr<!cir.double>, !cir.double +// CIR-NEXT: cir.return + +// LLVM: define void @shouldGenerateUnionAccess(%union.U2 %0) { +// LLVM-NEXT: %2 = alloca %union.U2, i64 1, align 8 +// LLVM-NEXT: store %union.U2 %0, ptr %2, align 8 +// LLVM-NEXT: store i8 0, ptr %2, align 1 +// LLVM-NEXT: %3 = load i8, ptr %2, align 1 +// LLVM-NEXT: store i32 1, ptr %2, align 4 +// LLVM-NEXT: %4 = load i32, ptr %2, align 4 +// LLVM-NEXT: store float 0x3FB99999A0000000, ptr %2, align 4 +// LLVM-NEXT: %5 = load float, ptr %2, align 4 +// LLVM-NEXT: store double 1.000000e-01, ptr %2, align 8 +// LLVM-NEXT: %6 = load double, ptr %2, align 8 +// LLVM-NEXT: ret void + +// OGCG: define dso_local void @shouldGenerateUnionAccess(i64 %u.coerce) #0 { +// OGCG-NEXT: entry: +// OGCG-NEXT: %u = alloca %union.U2, align 8 +// OGCG-NEXT: %coerce.dive = getelementptr inbounds nuw %union.U2, ptr %u, i32 0, i32 0 +// OGCG-NEXT: store i64 %u.coerce, ptr %coerce.dive, align 8 +// OGCG-NEXT: store i8 0, ptr %u, align 8 +// OGCG-NEXT: %0 = load i8, ptr %u, align 8 +// OGCG-NEXT: store i32 1, ptr %u, align 8 +// OGCG-NEXT: %1 = load i32, ptr %u, align 8 +// OGCG-NEXT: store float 0x3FB99999A0000000, ptr %u, align 8 +// OGCG-NEXT: %2 = load float, ptr %u, align 8 +// OGCG-NEXT: store double 1.000000e-01, ptr %u, align 8 +// OGCG-NEXT: %3 = load double, ptr %u, align 8 +// OGCG-NEXT: ret void \ No newline at end of file _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits