https://github.com/adams381 created https://github.com/llvm/llvm-project/pull/192119
CIR-specific concrete subclass of the `ABIRewriteContext` interface introduced in #190661. Rewrites CIR FuncOps and CallOps to match ABI-lowered signatures. This first PR handles the scalar cases: - Direct passthrough and scalar coercion (bitcast) - Extend (integer widening with signext/zeroext attrs) - Ignore (void returns, empty-struct arg erasure) - Call-site rewrites for all of the above Struct coercion (sret, byval, multi-register flattening) comes next. 11 C++ unit tests — each constructs a `FunctionClassification` by hand and verifies the rewritten IR, so no ABI classifier dependency. Depends on #190661. Made with [Cursor](https://cursor.com) >From 23ef4674ee8ac04d7777d5442071a781b8a09b7b Mon Sep 17 00:00:00 2001 From: Adam Smith <[email protected]> Date: Mon, 6 Apr 2026 12:14:06 -0700 Subject: [PATCH 1/3] [mlir][ABI] Add ABITypeMapper and ABIRewriteContext Dialect-agnostic layer bridging MLIR and the LLVM ABI library. ABITypeMapper handles built-in MLIR types; ABIRewriteContext is the interface dialects implement for ABI rewrites (see clang/docs/ClangIRABILowering.md Section 4). Made-with: Cursor --- mlir/include/mlir/ABI/ABIRewriteContext.h | 159 +++++++++++++++++ mlir/include/mlir/ABI/ABITypeMapper.h | 66 +++++++ mlir/lib/ABI/ABITypeMapper.cpp | 102 +++++++++++ mlir/lib/ABI/CMakeLists.txt | 14 ++ mlir/lib/CMakeLists.txt | 1 + mlir/unittests/ABI/ABIRewriteContextTest.cpp | 99 +++++++++++ mlir/unittests/ABI/ABITypeMapperTest.cpp | 173 +++++++++++++++++++ mlir/unittests/ABI/CMakeLists.txt | 12 ++ mlir/unittests/CMakeLists.txt | 1 + 9 files changed, 627 insertions(+) create mode 100644 mlir/include/mlir/ABI/ABIRewriteContext.h create mode 100644 mlir/include/mlir/ABI/ABITypeMapper.h create mode 100644 mlir/lib/ABI/ABITypeMapper.cpp create mode 100644 mlir/lib/ABI/CMakeLists.txt create mode 100644 mlir/unittests/ABI/ABIRewriteContextTest.cpp create mode 100644 mlir/unittests/ABI/ABITypeMapperTest.cpp create mode 100644 mlir/unittests/ABI/CMakeLists.txt diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h b/mlir/include/mlir/ABI/ABIRewriteContext.h new file mode 100644 index 0000000000000..71d5a56b599b7 --- /dev/null +++ b/mlir/include/mlir/ABI/ABIRewriteContext.h @@ -0,0 +1,159 @@ +//===- ABIRewriteContext.h - Dialect-specific ABI rewriting -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines ABIRewriteContext, the abstract interface for dialect- +// specific ABI lowering rewrites. Each MLIR dialect that wants ABI lowering +// (CIR, FIR, etc.) provides a concrete subclass. +// +// ABIRewriteContext consumes ABI classification results and drives the +// creation of lowered function signatures, argument coercions, and call +// site rewrites using dialect-specific operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ABI_ABIREWRITECONTEXT_H +#define MLIR_ABI_ABIREWRITECONTEXT_H + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/Support/Alignment.h" + +namespace mlir { +namespace abi { + +/// Classification of how a single argument or return value should be +/// passed at the ABI level. +/// +/// This is a dialect-agnostic representation. It mirrors the kinds +/// found in the LLVM ABI library and in CIR's ABIArgInfo, but does +/// not depend on either. +enum class ArgKind : uint8_t { + /// Pass directly in registers, possibly coerced to a different type. + Direct, + + /// Like Direct, but with a sign/zero extension attribute. + Extend, + + /// Pass indirectly via a pointer (sret for returns, byval for args). + Indirect, + + /// Ignore (void return, empty struct). + Ignore, + + /// Expand an aggregate into its constituent scalar fields. + Expand, +}; + +/// Describes how a single argument or return value is passed after ABI +/// lowering. +struct ArgClassification { + ArgKind Kind = ArgKind::Direct; + + /// The ABI-coerced type, if different from the original. Null means + /// use the original type. + Type CoercedType = nullptr; + + /// For Indirect: alignment of the pointed-to object. + llvm::Align IndirectAlign = llvm::Align(1); + + /// For Extend: whether to sign-extend (true) or zero-extend (false). + bool SignExtend = false; + + /// For Direct: whether a struct coercion can be flattened into + /// individual register-width arguments. + bool CanFlatten = true; + + /// For Indirect: whether the callee gets ownership (byval). + bool ByVal = false; + + static ArgClassification getDirect(Type coerced = nullptr) { + ArgClassification c; + c.Kind = ArgKind::Direct; + c.CoercedType = coerced; + return c; + } + + static ArgClassification getIgnore() { + ArgClassification c; + c.Kind = ArgKind::Ignore; + return c; + } + + static ArgClassification getIndirect(llvm::Align align, bool byVal = true) { + ArgClassification c; + c.Kind = ArgKind::Indirect; + c.IndirectAlign = align; + c.ByVal = byVal; + return c; + } + + static ArgClassification getExtend(Type coerced, bool signExt) { + ArgClassification c; + c.Kind = ArgKind::Extend; + c.CoercedType = coerced; + c.SignExtend = signExt; + return c; + } +}; + +/// Holds the full ABI classification for a function: return type and +/// all arguments. +struct FunctionClassification { + ArgClassification ReturnInfo; + SmallVector<ArgClassification> ArgInfos; +}; + +/// ABIRewriteContext is the abstract interface that each dialect +/// implements to perform ABI-specific rewrites on its operations. +/// +/// The pass orchestrator calls these methods after ABI classification +/// to rewrite function definitions and call sites. +class ABIRewriteContext { +public: + virtual ~ABIRewriteContext() = default; + + /// Rewrite a function definition to use ABI-lowered types. + /// + /// This creates a new function with the lowered signature, rewrites + /// the function body to adapt between the ABI types and the + /// original high-level types, and replaces the original function. + /// + /// \param funcOp The function to rewrite (via FunctionOpInterface). + /// \param fc The ABI classification for this function. + /// \param rewriter The pattern rewriter to use for modifications. + /// \returns success() if the function was rewritten. + virtual LogicalResult + rewriteFunctionDefinition(FunctionOpInterface funcOp, + const FunctionClassification &fc, + OpBuilder &rewriter) = 0; + + /// Rewrite a call operation to match the callee's ABI-lowered + /// signature. + /// + /// This coerces arguments, handles indirect returns (sret), and + /// adapts the call result back to the original high-level type. + /// + /// \param callOp The call operation to rewrite. + /// \param fc The ABI classification for the callee. + /// \param rewriter The pattern rewriter to use for modifications. + /// \returns success() if the call was rewritten. + virtual LogicalResult rewriteCallSite(Operation *callOp, + const FunctionClassification &fc, + OpBuilder &rewriter) = 0; + + /// Return the dialect namespace this context handles (e.g. "cir"). + virtual StringRef getDialectNamespace() const = 0; +}; + +} // namespace abi +} // namespace mlir + +#endif // MLIR_ABI_ABIREWRITECONTEXT_H diff --git a/mlir/include/mlir/ABI/ABITypeMapper.h b/mlir/include/mlir/ABI/ABITypeMapper.h new file mode 100644 index 0000000000000..2180c9c8a918d --- /dev/null +++ b/mlir/include/mlir/ABI/ABITypeMapper.h @@ -0,0 +1,66 @@ +//===- ABITypeMapper.h - Map MLIR types to ABI types -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines ABITypeMapper, which translates mlir::Type instances into +// the llvm::abi::Type hierarchy defined in llvm/ABI/Types.h. Dialect-specific +// types are handled via MLIR's DataLayoutTypeInterface. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ABI_ABITYPEMAPPER_H +#define MLIR_ABI_ABITYPEMAPPER_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "llvm/ABI/Types.h" +#include "llvm/Support/Allocator.h" + +namespace mlir { +namespace abi { + +/// ABITypeMapper translates mlir::Type values into the llvm::abi::Type +/// hierarchy used by the LLVM ABI Lowering Library. +/// +/// Standard MLIR types (IntegerType, FloatType, IndexType, VectorType, +/// MemRefType) are mapped directly. Dialect-specific types are mapped +/// by querying the MLIR DataLayout for size and alignment. +/// +/// Callers must supply a DataLayout (typically from the enclosing module) +/// so the mapper can determine sizes and alignments. +/// +/// The mapper owns a BumpPtrAllocator; all returned abi::Type pointers +/// are valid for the lifetime of the mapper. +class ABITypeMapper { +public: + explicit ABITypeMapper(const DataLayout &dl); + + /// Map an MLIR type to its ABI type representation. Returns nullptr + /// if the type cannot be mapped. + const llvm::abi::Type *map(mlir::Type type); + + /// Access the underlying TypeBuilder for advanced use. + llvm::abi::TypeBuilder &getTypeBuilder() { return Builder; } + +private: + const llvm::abi::Type *mapIntegerType(mlir::IntegerType type); + const llvm::abi::Type *mapFloatType(mlir::FloatType type); + const llvm::abi::Type *mapIndexType(mlir::IndexType type); + const llvm::abi::Type *mapVectorType(mlir::VectorType type); + const llvm::abi::Type *mapMemRefType(mlir::MemRefType type); + const llvm::abi::Type *mapNoneType(mlir::NoneType type); + + const DataLayout &DL; + llvm::BumpPtrAllocator Allocator; + llvm::abi::TypeBuilder Builder; +}; + +} // namespace abi +} // namespace mlir + +#endif // MLIR_ABI_ABITYPEMAPPER_H diff --git a/mlir/lib/ABI/ABITypeMapper.cpp b/mlir/lib/ABI/ABITypeMapper.cpp new file mode 100644 index 0000000000000..c7a69780bbe64 --- /dev/null +++ b/mlir/lib/ABI/ABITypeMapper.cpp @@ -0,0 +1,102 @@ +//===- ABITypeMapper.cpp - Map MLIR types to ABI types --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/ABI/ABITypeMapper.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/Support/Alignment.h" + +using namespace mlir; +using namespace mlir::abi; + +ABITypeMapper::ABITypeMapper(const DataLayout &dl) + : DL(dl), Builder(Allocator) {} + +const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) { + if (auto intTy = dyn_cast<mlir::IntegerType>(type)) + return mapIntegerType(intTy); + + if (auto floatTy = dyn_cast<mlir::FloatType>(type)) + return mapFloatType(floatTy); + + if (auto indexTy = dyn_cast<mlir::IndexType>(type)) + return mapIndexType(indexTy); + + if (auto vecTy = dyn_cast<mlir::VectorType>(type)) + return mapVectorType(vecTy); + + if (auto memRefTy = dyn_cast<mlir::MemRefType>(type)) + return mapMemRefType(memRefTy); + + if (auto noneTy = dyn_cast<mlir::NoneType>(type)) + return mapNoneType(noneTy); + + // For dialect-specific types, fall back to DataLayout queries. + // The type must implement DataLayoutTypeInterface for this to work. + llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type); + uint64_t abiAlign = DL.getTypeABIAlignment(type); + return Builder.getIntegerType(sizeInBits.getFixedValue(), + llvm::Align(abiAlign), + /*Signed=*/false); +} + +const llvm::abi::Type *ABITypeMapper::mapIntegerType(mlir::IntegerType type) { + uint64_t width = type.getWidth(); + uint64_t abiAlign = DL.getTypeABIAlignment(type); + // MLIR signless integers are treated as signed for ABI purposes. + // Most C/C++ integer types are signless in MLIR but behave as + // signed for ABI classification (sign extension, etc.). + bool isSigned = type.isSigned() || type.isSignless(); + return Builder.getIntegerType(width, llvm::Align(abiAlign), isSigned); +} + +const llvm::abi::Type *ABITypeMapper::mapFloatType(mlir::FloatType type) { + uint64_t abiAlign = DL.getTypeABIAlignment(type); + const llvm::fltSemantics &semantics = type.getFloatSemantics(); + return Builder.getFloatType(semantics, llvm::Align(abiAlign)); +} + +const llvm::abi::Type *ABITypeMapper::mapIndexType(mlir::IndexType type) { + llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type); + uint64_t abiAlign = DL.getTypeABIAlignment(type); + return Builder.getIntegerType(sizeInBits.getFixedValue(), + llvm::Align(abiAlign), + /*Signed=*/false); +} + +const llvm::abi::Type *ABITypeMapper::mapVectorType(mlir::VectorType type) { + const llvm::abi::Type *elementTy = map(type.getElementType()); + if (!elementTy) + return nullptr; + + auto shape = type.getShape(); + // MLIR VectorType is always fixed-length and can be multi-dimensional. + // Flatten to a single dimension for ABI purposes. + uint64_t totalElements = 1; + for (int64_t dim : shape) + totalElements *= dim; + + llvm::ElementCount ec = llvm::ElementCount::getFixed(totalElements); + uint64_t abiAlign = DL.getTypeABIAlignment(type); + return Builder.getVectorType(elementTy, ec, llvm::Align(abiAlign)); +} + +const llvm::abi::Type *ABITypeMapper::mapMemRefType(mlir::MemRefType type) { + // MemRef is pointer-like for ABI purposes. + llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type); + uint64_t abiAlign = DL.getTypeABIAlignment(type); + unsigned addrSpace = 0; + if (auto as = type.getMemorySpace()) + if (auto intAttr = dyn_cast<IntegerAttr>(as)) + addrSpace = intAttr.getInt(); + return Builder.getPointerType(sizeInBits.getFixedValue(), + llvm::Align(abiAlign), addrSpace); +} + +const llvm::abi::Type *ABITypeMapper::mapNoneType(mlir::NoneType type) { + return Builder.getVoidType(); +} diff --git a/mlir/lib/ABI/CMakeLists.txt b/mlir/lib/ABI/CMakeLists.txt new file mode 100644 index 0000000000000..eb434d25dd390 --- /dev/null +++ b/mlir/lib/ABI/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_library(MLIRABI + ABITypeMapper.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/ABI + + LINK_COMPONENTS + ABI + Support + + LINK_LIBS PUBLIC + MLIRIR + MLIRDataLayoutInterfaces + ) diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index 91ed05f6548d7..d7a6e28d98586 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -1,6 +1,7 @@ # Enable errors for any global constructors. add_flag_if_supported("-Werror=global-constructors" WERROR_GLOBAL_CONSTRUCTOR) +add_subdirectory(ABI) add_subdirectory(Analysis) add_subdirectory(AsmParser) add_subdirectory(Bytecode) diff --git a/mlir/unittests/ABI/ABIRewriteContextTest.cpp b/mlir/unittests/ABI/ABIRewriteContextTest.cpp new file mode 100644 index 0000000000000..04c28991cc752 --- /dev/null +++ b/mlir/unittests/ABI/ABIRewriteContextTest.cpp @@ -0,0 +1,99 @@ +//===- ABIRewriteContextTest.cpp - Unit tests for ABIRewriteContext -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/ABI/ABIRewriteContext.h" +#include <gtest/gtest.h> + +using namespace mlir; +using namespace mlir::abi; + +namespace { + +class MockRewriteContext : public ABIRewriteContext { +public: + LogicalResult rewriteFunctionDefinition(FunctionOpInterface, + const FunctionClassification &, + OpBuilder &) override { + return success(); + } + + LogicalResult rewriteCallSite(Operation *, const FunctionClassification &, + OpBuilder &) override { + return success(); + } + + StringRef getDialectNamespace() const override { return "mock"; } +}; + +TEST(ABIRewriteContextTest, MockCanBeConstructedAndDestroyed) { + MockRewriteContext ctx; + EXPECT_EQ(ctx.getDialectNamespace(), "mock"); +} + +TEST(ABIRewriteContextTest, ArgClassificationDirect) { + auto c = ArgClassification::getDirect(); + EXPECT_EQ(c.Kind, ArgKind::Direct); + EXPECT_EQ(c.CoercedType, nullptr); + EXPECT_TRUE(c.CanFlatten); +} + +TEST(ABIRewriteContextTest, ArgClassificationDirectWithType) { + MLIRContext mlirCtx; + auto i32 = IntegerType::get(&mlirCtx, 32); + auto c = ArgClassification::getDirect(i32); + EXPECT_EQ(c.Kind, ArgKind::Direct); + EXPECT_EQ(c.CoercedType, i32); +} + +TEST(ABIRewriteContextTest, ArgClassificationIgnore) { + auto c = ArgClassification::getIgnore(); + EXPECT_EQ(c.Kind, ArgKind::Ignore); +} + +TEST(ABIRewriteContextTest, ArgClassificationIndirect) { + auto c = ArgClassification::getIndirect(llvm::Align(8), true); + EXPECT_EQ(c.Kind, ArgKind::Indirect); + EXPECT_EQ(c.IndirectAlign, llvm::Align(8)); + EXPECT_TRUE(c.ByVal); +} + +TEST(ABIRewriteContextTest, ArgClassificationIndirectNoByVal) { + auto c = ArgClassification::getIndirect(llvm::Align(16), false); + EXPECT_EQ(c.Kind, ArgKind::Indirect); + EXPECT_EQ(c.IndirectAlign, llvm::Align(16)); + EXPECT_FALSE(c.ByVal); +} + +TEST(ABIRewriteContextTest, ArgClassificationExtend) { + MLIRContext mlirCtx; + auto i8 = IntegerType::get(&mlirCtx, 8); + + auto signExt = ArgClassification::getExtend(i8, true); + EXPECT_EQ(signExt.Kind, ArgKind::Extend); + EXPECT_TRUE(signExt.SignExtend); + + auto zeroExt = ArgClassification::getExtend(i8, false); + EXPECT_EQ(zeroExt.Kind, ArgKind::Extend); + EXPECT_FALSE(zeroExt.SignExtend); +} + +TEST(ABIRewriteContextTest, FunctionClassificationHoldsReturnAndArgs) { + FunctionClassification fc; + fc.ReturnInfo = ArgClassification::getDirect(); + fc.ArgInfos.push_back(ArgClassification::getDirect()); + fc.ArgInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true)); + fc.ArgInfos.push_back(ArgClassification::getIgnore()); + + EXPECT_EQ(fc.ReturnInfo.Kind, ArgKind::Direct); + EXPECT_EQ(fc.ArgInfos.size(), 3u); + EXPECT_EQ(fc.ArgInfos[0].Kind, ArgKind::Direct); + EXPECT_EQ(fc.ArgInfos[1].Kind, ArgKind::Indirect); + EXPECT_EQ(fc.ArgInfos[2].Kind, ArgKind::Ignore); +} + +} // namespace diff --git a/mlir/unittests/ABI/ABITypeMapperTest.cpp b/mlir/unittests/ABI/ABITypeMapperTest.cpp new file mode 100644 index 0000000000000..4a7989298a149 --- /dev/null +++ b/mlir/unittests/ABI/ABITypeMapperTest.cpp @@ -0,0 +1,173 @@ +//===- ABITypeMapperTest.cpp - Unit tests for ABITypeMapper ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/ABI/ABITypeMapper.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "llvm/ABI/Types.h" + +#include <gtest/gtest.h> + +using namespace mlir; +using namespace mlir::abi; + +namespace { + +class ABITypeMapperTest : public ::testing::Test { +protected: + void SetUp() override { + ctx.loadDialect<DLTIDialect>(); + module = ModuleOp::create(UnknownLoc::get(&ctx)); + } + + void TearDown() override { module->destroy(); } + + MLIRContext ctx; + ModuleOp module; +}; + +TEST_F(ABITypeMapperTest, MapI32) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto i32 = IntegerType::get(&ctx, 32); + const llvm::abi::Type *result = mapper.map(i32); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isInteger()); + + auto *intTy = llvm::cast<llvm::abi::IntegerType>(result); + EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 32u); +} + +TEST_F(ABITypeMapperTest, MapI1) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto i1 = IntegerType::get(&ctx, 1); + const llvm::abi::Type *result = mapper.map(i1); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isInteger()); + + auto *intTy = llvm::cast<llvm::abi::IntegerType>(result); + EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 1u); +} + +TEST_F(ABITypeMapperTest, MapI64) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto i64 = IntegerType::get(&ctx, 64); + const llvm::abi::Type *result = mapper.map(i64); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isInteger()); + + auto *intTy = llvm::cast<llvm::abi::IntegerType>(result); + EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 64u); +} + +TEST_F(ABITypeMapperTest, MapF32) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto f32 = Float32Type::get(&ctx); + const llvm::abi::Type *result = mapper.map(f32); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isFloat()); + + auto *floatTy = llvm::cast<llvm::abi::FloatType>(result); + EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 32u); +} + +TEST_F(ABITypeMapperTest, MapF64) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto f64 = Float64Type::get(&ctx); + const llvm::abi::Type *result = mapper.map(f64); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isFloat()); + + auto *floatTy = llvm::cast<llvm::abi::FloatType>(result); + EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 64u); +} + +TEST_F(ABITypeMapperTest, MapF16) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto f16 = Float16Type::get(&ctx); + const llvm::abi::Type *result = mapper.map(f16); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isFloat()); + + auto *floatTy = llvm::cast<llvm::abi::FloatType>(result); + EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 16u); +} + +TEST_F(ABITypeMapperTest, MapNoneType) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto none = NoneType::get(&ctx); + const llvm::abi::Type *result = mapper.map(none); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isVoid()); +} + +TEST_F(ABITypeMapperTest, MapVectorOf4xF32) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto f32 = Float32Type::get(&ctx); + auto vec = VectorType::get({4}, f32); + const llvm::abi::Type *result = mapper.map(vec); + + ASSERT_NE(result, nullptr); + EXPECT_TRUE(result->isVector()); + + auto *vecTy = llvm::cast<llvm::abi::VectorType>(result); + EXPECT_EQ(vecTy->getNumElements().getFixedValue(), 4u); + EXPECT_TRUE(vecTy->getElementType()->isFloat()); +} + +TEST_F(ABITypeMapperTest, MapSignedI32) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto si32 = IntegerType::get(&ctx, 32, IntegerType::Signed); + const llvm::abi::Type *result = mapper.map(si32); + + ASSERT_NE(result, nullptr); + auto *intTy = llvm::cast<llvm::abi::IntegerType>(result); + EXPECT_TRUE(intTy->isSigned()); +} + +TEST_F(ABITypeMapperTest, MapUnsignedI32) { + DataLayout dl(module); + ABITypeMapper mapper(dl); + + auto ui32 = IntegerType::get(&ctx, 32, IntegerType::Unsigned); + const llvm::abi::Type *result = mapper.map(ui32); + + ASSERT_NE(result, nullptr); + auto *intTy = llvm::cast<llvm::abi::IntegerType>(result); + EXPECT_FALSE(intTy->isSigned()); +} + +} // namespace diff --git a/mlir/unittests/ABI/CMakeLists.txt b/mlir/unittests/ABI/CMakeLists.txt new file mode 100644 index 0000000000000..39f955a8efea6 --- /dev/null +++ b/mlir/unittests/ABI/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_unittest(MLIRABITests + ABIRewriteContextTest.cpp + ABITypeMapperTest.cpp +) + +mlir_target_link_libraries(MLIRABITests + PRIVATE + MLIRABI + MLIRDataLayoutInterfaces + MLIRDLTIDialect + MLIRIR +) diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index 89332bce5fe05..654ec44d90b04 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname) add_unittest(MLIRUnitTests ${test_dirname} ${ARGN}) endfunction() +add_subdirectory(ABI) add_subdirectory(Analysis) add_subdirectory(Bytecode) add_subdirectory(Conversion) >From ca2c452c7fd71bda991e72c00d7248f17e3a7a62 Mon Sep 17 00:00:00 2001 From: Adam Smith <[email protected]> Date: Mon, 13 Apr 2026 15:07:34 -0700 Subject: [PATCH 2/3] [mlir][ABI][NFC] Rename member variables to camelCase Rename struct and class members from PascalCase to camelCase per the MLIR style guide. Applies to ArgClassification, FunctionClassification, and ABITypeMapper members. Addresses review feedback from @andykaylor. Made-with: Cursor --- mlir/include/mlir/ABI/ABIRewriteContext.h | 34 ++++++------- mlir/include/mlir/ABI/ABITypeMapper.h | 8 +-- mlir/lib/ABI/ABITypeMapper.cpp | 42 +++++++--------- mlir/unittests/ABI/ABIRewriteContextTest.cpp | 52 ++++++++++---------- 4 files changed, 65 insertions(+), 71 deletions(-) diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h b/mlir/include/mlir/ABI/ABIRewriteContext.h index 71d5a56b599b7..7c48c626207bb 100644 --- a/mlir/include/mlir/ABI/ABIRewriteContext.h +++ b/mlir/include/mlir/ABI/ABIRewriteContext.h @@ -55,51 +55,51 @@ enum class ArgKind : uint8_t { /// Describes how a single argument or return value is passed after ABI /// lowering. struct ArgClassification { - ArgKind Kind = ArgKind::Direct; + ArgKind kind = ArgKind::Direct; /// The ABI-coerced type, if different from the original. Null means /// use the original type. - Type CoercedType = nullptr; + Type coercedType = nullptr; /// For Indirect: alignment of the pointed-to object. - llvm::Align IndirectAlign = llvm::Align(1); + llvm::Align indirectAlign = llvm::Align(1); /// For Extend: whether to sign-extend (true) or zero-extend (false). - bool SignExtend = false; + bool signExtend = false; /// For Direct: whether a struct coercion can be flattened into /// individual register-width arguments. - bool CanFlatten = true; + bool canFlatten = true; /// For Indirect: whether the callee gets ownership (byval). - bool ByVal = false; + bool byVal = false; static ArgClassification getDirect(Type coerced = nullptr) { ArgClassification c; - c.Kind = ArgKind::Direct; - c.CoercedType = coerced; + c.kind = ArgKind::Direct; + c.coercedType = coerced; return c; } static ArgClassification getIgnore() { ArgClassification c; - c.Kind = ArgKind::Ignore; + c.kind = ArgKind::Ignore; return c; } static ArgClassification getIndirect(llvm::Align align, bool byVal = true) { ArgClassification c; - c.Kind = ArgKind::Indirect; - c.IndirectAlign = align; - c.ByVal = byVal; + c.kind = ArgKind::Indirect; + c.indirectAlign = align; + c.byVal = byVal; return c; } static ArgClassification getExtend(Type coerced, bool signExt) { ArgClassification c; - c.Kind = ArgKind::Extend; - c.CoercedType = coerced; - c.SignExtend = signExt; + c.kind = ArgKind::Extend; + c.coercedType = coerced; + c.signExtend = signExt; return c; } }; @@ -107,8 +107,8 @@ struct ArgClassification { /// Holds the full ABI classification for a function: return type and /// all arguments. struct FunctionClassification { - ArgClassification ReturnInfo; - SmallVector<ArgClassification> ArgInfos; + ArgClassification returnInfo; + SmallVector<ArgClassification> argInfos; }; /// ABIRewriteContext is the abstract interface that each dialect diff --git a/mlir/include/mlir/ABI/ABITypeMapper.h b/mlir/include/mlir/ABI/ABITypeMapper.h index 2180c9c8a918d..2499e910ca797 100644 --- a/mlir/include/mlir/ABI/ABITypeMapper.h +++ b/mlir/include/mlir/ABI/ABITypeMapper.h @@ -45,7 +45,7 @@ class ABITypeMapper { const llvm::abi::Type *map(mlir::Type type); /// Access the underlying TypeBuilder for advanced use. - llvm::abi::TypeBuilder &getTypeBuilder() { return Builder; } + llvm::abi::TypeBuilder &getTypeBuilder() { return builder; } private: const llvm::abi::Type *mapIntegerType(mlir::IntegerType type); @@ -55,9 +55,9 @@ class ABITypeMapper { const llvm::abi::Type *mapMemRefType(mlir::MemRefType type); const llvm::abi::Type *mapNoneType(mlir::NoneType type); - const DataLayout &DL; - llvm::BumpPtrAllocator Allocator; - llvm::abi::TypeBuilder Builder; + const DataLayout &dl; + llvm::BumpPtrAllocator allocator; + llvm::abi::TypeBuilder builder; }; } // namespace abi diff --git a/mlir/lib/ABI/ABITypeMapper.cpp b/mlir/lib/ABI/ABITypeMapper.cpp index c7a69780bbe64..83dc6990ec5bb 100644 --- a/mlir/lib/ABI/ABITypeMapper.cpp +++ b/mlir/lib/ABI/ABITypeMapper.cpp @@ -13,8 +13,8 @@ using namespace mlir; using namespace mlir::abi; -ABITypeMapper::ABITypeMapper(const DataLayout &dl) - : DL(dl), Builder(Allocator) {} +ABITypeMapper::ABITypeMapper(const DataLayout &dataLayout) + : dl(dataLayout), builder(allocator) {} const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) { if (auto intTy = dyn_cast<mlir::IntegerType>(type)) @@ -37,33 +37,30 @@ const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) { // For dialect-specific types, fall back to DataLayout queries. // The type must implement DataLayoutTypeInterface for this to work. - llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type); - uint64_t abiAlign = DL.getTypeABIAlignment(type); - return Builder.getIntegerType(sizeInBits.getFixedValue(), + llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type); + uint64_t abiAlign = dl.getTypeABIAlignment(type); + return builder.getIntegerType(sizeInBits.getFixedValue(), llvm::Align(abiAlign), /*Signed=*/false); } const llvm::abi::Type *ABITypeMapper::mapIntegerType(mlir::IntegerType type) { uint64_t width = type.getWidth(); - uint64_t abiAlign = DL.getTypeABIAlignment(type); - // MLIR signless integers are treated as signed for ABI purposes. - // Most C/C++ integer types are signless in MLIR but behave as - // signed for ABI classification (sign extension, etc.). + uint64_t abiAlign = dl.getTypeABIAlignment(type); bool isSigned = type.isSigned() || type.isSignless(); - return Builder.getIntegerType(width, llvm::Align(abiAlign), isSigned); + return builder.getIntegerType(width, llvm::Align(abiAlign), isSigned); } const llvm::abi::Type *ABITypeMapper::mapFloatType(mlir::FloatType type) { - uint64_t abiAlign = DL.getTypeABIAlignment(type); + uint64_t abiAlign = dl.getTypeABIAlignment(type); const llvm::fltSemantics &semantics = type.getFloatSemantics(); - return Builder.getFloatType(semantics, llvm::Align(abiAlign)); + return builder.getFloatType(semantics, llvm::Align(abiAlign)); } const llvm::abi::Type *ABITypeMapper::mapIndexType(mlir::IndexType type) { - llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type); - uint64_t abiAlign = DL.getTypeABIAlignment(type); - return Builder.getIntegerType(sizeInBits.getFixedValue(), + llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type); + uint64_t abiAlign = dl.getTypeABIAlignment(type); + return builder.getIntegerType(sizeInBits.getFixedValue(), llvm::Align(abiAlign), /*Signed=*/false); } @@ -74,29 +71,26 @@ const llvm::abi::Type *ABITypeMapper::mapVectorType(mlir::VectorType type) { return nullptr; auto shape = type.getShape(); - // MLIR VectorType is always fixed-length and can be multi-dimensional. - // Flatten to a single dimension for ABI purposes. uint64_t totalElements = 1; for (int64_t dim : shape) totalElements *= dim; llvm::ElementCount ec = llvm::ElementCount::getFixed(totalElements); - uint64_t abiAlign = DL.getTypeABIAlignment(type); - return Builder.getVectorType(elementTy, ec, llvm::Align(abiAlign)); + uint64_t abiAlign = dl.getTypeABIAlignment(type); + return builder.getVectorType(elementTy, ec, llvm::Align(abiAlign)); } const llvm::abi::Type *ABITypeMapper::mapMemRefType(mlir::MemRefType type) { - // MemRef is pointer-like for ABI purposes. - llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type); - uint64_t abiAlign = DL.getTypeABIAlignment(type); + llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type); + uint64_t abiAlign = dl.getTypeABIAlignment(type); unsigned addrSpace = 0; if (auto as = type.getMemorySpace()) if (auto intAttr = dyn_cast<IntegerAttr>(as)) addrSpace = intAttr.getInt(); - return Builder.getPointerType(sizeInBits.getFixedValue(), + return builder.getPointerType(sizeInBits.getFixedValue(), llvm::Align(abiAlign), addrSpace); } const llvm::abi::Type *ABITypeMapper::mapNoneType(mlir::NoneType type) { - return Builder.getVoidType(); + return builder.getVoidType(); } diff --git a/mlir/unittests/ABI/ABIRewriteContextTest.cpp b/mlir/unittests/ABI/ABIRewriteContextTest.cpp index 04c28991cc752..59a5307225e0d 100644 --- a/mlir/unittests/ABI/ABIRewriteContextTest.cpp +++ b/mlir/unittests/ABI/ABIRewriteContextTest.cpp @@ -37,36 +37,36 @@ TEST(ABIRewriteContextTest, MockCanBeConstructedAndDestroyed) { TEST(ABIRewriteContextTest, ArgClassificationDirect) { auto c = ArgClassification::getDirect(); - EXPECT_EQ(c.Kind, ArgKind::Direct); - EXPECT_EQ(c.CoercedType, nullptr); - EXPECT_TRUE(c.CanFlatten); + EXPECT_EQ(c.kind, ArgKind::Direct); + EXPECT_EQ(c.coercedType, nullptr); + EXPECT_TRUE(c.canFlatten); } TEST(ABIRewriteContextTest, ArgClassificationDirectWithType) { MLIRContext mlirCtx; auto i32 = IntegerType::get(&mlirCtx, 32); auto c = ArgClassification::getDirect(i32); - EXPECT_EQ(c.Kind, ArgKind::Direct); - EXPECT_EQ(c.CoercedType, i32); + EXPECT_EQ(c.kind, ArgKind::Direct); + EXPECT_EQ(c.coercedType, i32); } TEST(ABIRewriteContextTest, ArgClassificationIgnore) { auto c = ArgClassification::getIgnore(); - EXPECT_EQ(c.Kind, ArgKind::Ignore); + EXPECT_EQ(c.kind, ArgKind::Ignore); } TEST(ABIRewriteContextTest, ArgClassificationIndirect) { auto c = ArgClassification::getIndirect(llvm::Align(8), true); - EXPECT_EQ(c.Kind, ArgKind::Indirect); - EXPECT_EQ(c.IndirectAlign, llvm::Align(8)); - EXPECT_TRUE(c.ByVal); + EXPECT_EQ(c.kind, ArgKind::Indirect); + EXPECT_EQ(c.indirectAlign, llvm::Align(8)); + EXPECT_TRUE(c.byVal); } TEST(ABIRewriteContextTest, ArgClassificationIndirectNoByVal) { auto c = ArgClassification::getIndirect(llvm::Align(16), false); - EXPECT_EQ(c.Kind, ArgKind::Indirect); - EXPECT_EQ(c.IndirectAlign, llvm::Align(16)); - EXPECT_FALSE(c.ByVal); + EXPECT_EQ(c.kind, ArgKind::Indirect); + EXPECT_EQ(c.indirectAlign, llvm::Align(16)); + EXPECT_FALSE(c.byVal); } TEST(ABIRewriteContextTest, ArgClassificationExtend) { @@ -74,26 +74,26 @@ TEST(ABIRewriteContextTest, ArgClassificationExtend) { auto i8 = IntegerType::get(&mlirCtx, 8); auto signExt = ArgClassification::getExtend(i8, true); - EXPECT_EQ(signExt.Kind, ArgKind::Extend); - EXPECT_TRUE(signExt.SignExtend); + EXPECT_EQ(signExt.kind, ArgKind::Extend); + EXPECT_TRUE(signExt.signExtend); auto zeroExt = ArgClassification::getExtend(i8, false); - EXPECT_EQ(zeroExt.Kind, ArgKind::Extend); - EXPECT_FALSE(zeroExt.SignExtend); + EXPECT_EQ(zeroExt.kind, ArgKind::Extend); + EXPECT_FALSE(zeroExt.signExtend); } TEST(ABIRewriteContextTest, FunctionClassificationHoldsReturnAndArgs) { FunctionClassification fc; - fc.ReturnInfo = ArgClassification::getDirect(); - fc.ArgInfos.push_back(ArgClassification::getDirect()); - fc.ArgInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true)); - fc.ArgInfos.push_back(ArgClassification::getIgnore()); - - EXPECT_EQ(fc.ReturnInfo.Kind, ArgKind::Direct); - EXPECT_EQ(fc.ArgInfos.size(), 3u); - EXPECT_EQ(fc.ArgInfos[0].Kind, ArgKind::Direct); - EXPECT_EQ(fc.ArgInfos[1].Kind, ArgKind::Indirect); - EXPECT_EQ(fc.ArgInfos[2].Kind, ArgKind::Ignore); + fc.returnInfo = ArgClassification::getDirect(); + fc.argInfos.push_back(ArgClassification::getDirect()); + fc.argInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true)); + fc.argInfos.push_back(ArgClassification::getIgnore()); + + EXPECT_EQ(fc.returnInfo.kind, ArgKind::Direct); + EXPECT_EQ(fc.argInfos.size(), 3u); + EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Direct); + EXPECT_EQ(fc.argInfos[1].kind, ArgKind::Indirect); + EXPECT_EQ(fc.argInfos[2].kind, ArgKind::Ignore); } } // namespace >From 8ce99db064c122d3b5026e540d4345bb8ea932b6 Mon Sep 17 00:00:00 2001 From: Adam Smith <[email protected]> Date: Tue, 14 Apr 2026 12:50:29 -0700 Subject: [PATCH 3/3] [CIR] Add CIRABIRewriteContext for ABI function/call rewriting Add CIRABIRewriteContext, the CIR dialect's concrete implementation of the shared ABIRewriteContext interface from #190661. This class rewrites CIR FuncOps and CallOps to match ABI-lowered signatures. This initial PR covers Direct, Extend, and Ignore argument/return kinds with 11 unit tests. Struct coercion (sret, byval, multi- register flattening) will follow in a subsequent PR. Depends on #190661. Made-with: Cursor --- .../TargetLowering/CIRABIRewriteContext.cpp | 469 ++++++++++++++++++ .../TargetLowering/CIRABIRewriteContext.h | 50 ++ .../Transforms/TargetLowering/CMakeLists.txt | 2 + .../CIR/CIRABIRewriteContextTest.cpp | 406 +++++++++++++++ clang/unittests/CIR/CMakeLists.txt | 5 + 5 files changed, 932 insertions(+) create mode 100644 clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp create mode 100644 clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h create mode 100644 clang/unittests/CIR/CIRABIRewriteContextTest.cpp diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp new file mode 100644 index 0000000000000..cab6faf44eda4 --- /dev/null +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp @@ -0,0 +1,469 @@ +//===- CIRABIRewriteContext.cpp - CIR-specific ABI rewriting --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CIRABIRewriteContext.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Types.h" +#include "clang/CIR/Dialect/IR/CIRAttrs.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/IR/CIROpsEnums.h" + +using namespace cir; +using namespace mlir; +using namespace mlir::abi; + +/// Emit a value coercion between two types. For scalar-to-scalar +/// (e.g. integer sign extension), a direct cir.cast is sufficient. +/// When one of the types is a record (struct), LLVM IR's bitcast +/// cannot reinterpret between aggregate and scalar types, so we go +/// through memory: alloca srcTy -> store src -> bitcast ptr -> load +/// dstTy. +static Value emitCoercion(OpBuilder &rewriter, Location loc, Type dstTy, + Value src) { + Type srcTy = src.getType(); + if (srcTy == dstTy) + return src; + + bool needsMemory = + mlir::isa<cir::RecordType, cir::ComplexType>(srcTy) || + mlir::isa<cir::RecordType, cir::ComplexType>(dstTy) || + (mlir::isa<cir::VectorType>(srcTy) != mlir::isa<cir::VectorType>(dstTy)); + + if (!needsMemory) + return cir::CastOp::create(rewriter, loc, dstTy, cir::CastKind::bitcast, + src); + + auto srcPtrTy = cir::PointerType::get(srcTy); + auto dstPtrTy = cir::PointerType::get(dstTy); + + auto alloca = + cir::AllocaOp::create(rewriter, loc, srcPtrTy, srcTy, + /*name=*/rewriter.getStringAttr("coerce"), + /*alignment=*/rewriter.getI64IntegerAttr(8)); + + cir::StoreOp::create(rewriter, loc, src, alloca, + /*isVolatile=*/mlir::UnitAttr(), + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr(), + /*mem_order=*/cir::MemOrderAttr()); + + auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy, + cir::CastKind::bitcast, alloca); + + return cir::LoadOp::create(rewriter, loc, dstTy, ptrCast, + /*isDeref=*/mlir::UnitAttr(), + /*isVolatile=*/mlir::UnitAttr(), + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr(), + /*mem_order=*/cir::MemOrderAttr()); +} + +/// Insert coercion before each cir.return to coerce the return value +/// from the original type to the ABI type. +static void insertReturnCoercion(FunctionOpInterface funcOp, Type origRetTy, + Type coercedRetTy, OpBuilder &rewriter) { + SmallVector<cir::ReturnOp> returnOps; + funcOp->walk([&](cir::ReturnOp retOp) { returnOps.push_back(retOp); }); + + for (cir::ReturnOp retOp : returnOps) { + if (retOp.getInput().empty()) + continue; + + Value origVal = retOp.getInput()[0]; + if (origVal.getType() == coercedRetTy) + continue; + + rewriter.setInsertionPoint(retOp); + Value coerced = + emitCoercion(rewriter, retOp.getLoc(), coercedRetTy, origVal); + retOp->setOperand(0, coerced); + } +} + +/// For each argument that requires ABI coercion (Extend or Direct +/// with a coerced type), insert a cast at the function entry and +/// replace all uses of the block argument with the cast result. +static void insertArgAdaptation(FunctionOpInterface funcOp, + const FunctionClassification &fc, + OpBuilder &rewriter) { + Region &body = funcOp->getRegion(0); + if (body.empty()) + return; + + Block &entryBlock = body.front(); + Operation *lastInserted = nullptr; + + for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) { + if (!argClass.coercedType) + continue; + + if (argClass.kind != ArgKind::Extend && argClass.kind != ArgKind::Direct) + continue; + + BlockArgument blockArg = entryBlock.getArgument(idx); + Type oldArgTy = blockArg.getType(); + Type newArgTy = argClass.coercedType; + + if (oldArgTy == newArgTy) + continue; + + blockArg.setType(newArgTy); + + if (lastInserted) + rewriter.setInsertionPointAfter(lastInserted); + else + rewriter.setInsertionPointToStart(&entryBlock); + + Value adapted; + SmallPtrSet<Operation *, 4> coercionOps; + + if (argClass.kind == ArgKind::Extend) { + auto cast = cir::CastOp::create(rewriter, funcOp.getLoc(), oldArgTy, + cir::CastKind::integral, blockArg); + adapted = cast; + coercionOps.insert(cast.getOperation()); + } else { + auto srcPtrTy = cir::PointerType::get(newArgTy); + auto dstPtrTy = cir::PointerType::get(oldArgTy); + Location loc = funcOp.getLoc(); + + auto alloca = + cir::AllocaOp::create(rewriter, loc, srcPtrTy, newArgTy, + /*name=*/rewriter.getStringAttr("coerce"), + /*alignment=*/rewriter.getI64IntegerAttr(8)); + + auto store = cir::StoreOp::create(rewriter, loc, blockArg, alloca, + /*isVolatile=*/mlir::UnitAttr(), + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr(), + /*mem_order=*/cir::MemOrderAttr()); + + auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy, + cir::CastKind::bitcast, alloca); + + auto load = cir::LoadOp::create(rewriter, loc, oldArgTy, ptrCast, + /*isDeref=*/mlir::UnitAttr(), + /*isVolatile=*/mlir::UnitAttr(), + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr(), + /*mem_order=*/cir::MemOrderAttr()); + + adapted = load; + coercionOps.insert(alloca.getOperation()); + coercionOps.insert(store.getOperation()); + coercionOps.insert(ptrCast.getOperation()); + coercionOps.insert(load.getOperation()); + } + lastInserted = adapted.getDefiningOp(); + + blockArg.replaceAllUsesExcept(adapted, coercionOps); + } +} + +LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition( + FunctionOpInterface funcOp, const FunctionClassification &fc, + OpBuilder &rewriter) { + ArrayRef<Type> oldArgTypes = funcOp.getArgumentTypes(); + ArrayRef<Type> oldResultTypes = funcOp.getResultTypes(); + bool isDecl = funcOp.isDeclaration(); + + bool returnCoerced = false; + bool hasArgChanges = false; + SmallVector<unsigned> ignoredArgIndices; + + // Compute new argument types. + SmallVector<Type> newArgTypes; + + for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) { + Type origTy = oldArgTypes[idx]; + switch (argClass.kind) { + case ArgKind::Direct: + case ArgKind::Extend: + newArgTypes.push_back(argClass.coercedType ? argClass.coercedType + : origTy); + if (argClass.coercedType && argClass.coercedType != origTy) + hasArgChanges = true; + break; + case ArgKind::Ignore: + ignoredArgIndices.push_back(idx); + hasArgChanges = true; + break; + case ArgKind::Indirect: + case ArgKind::Expand: + newArgTypes.push_back(origTy); + break; + } + } + + // Compute new result type. CIR's FuncType::clone expects exactly + // one result type (VoidType for void-returning functions). + auto voidTy = cir::VoidType::get(funcOp->getContext()); + Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0]; + Type newRetTy = origRetTy; + + if (fc.returnInfo.kind == ArgKind::Direct || + fc.returnInfo.kind == ArgKind::Extend) { + if (fc.returnInfo.coercedType && !oldResultTypes.empty() && + fc.returnInfo.coercedType != oldResultTypes[0]) { + newRetTy = fc.returnInfo.coercedType; + returnCoerced = true; + } + } else if (fc.returnInfo.kind == ArgKind::Ignore) { + newRetTy = voidTy; + } + + SmallVector<Type> newResultTypes = {newRetTy}; + + // If nothing changed, skip the rewrite -- unless we have + // Extend args/returns that need signext/zeroext attrs. + bool hasExtend = fc.returnInfo.kind == ArgKind::Extend; + for (auto &argClass : fc.argInfos) + if (argClass.kind == ArgKind::Extend) + hasExtend = true; + if (!hasArgChanges && !hasExtend && !returnCoerced && newRetTy == origRetTy && + newArgTypes == SmallVector<Type>(oldArgTypes)) + return success(); + + // Body modifications only apply to definitions. + if (!isDecl) { + if (hasArgChanges) + insertArgAdaptation(funcOp, fc, rewriter); + + // Erase block arguments for Ignore'd args (in reverse to keep + // indices valid). Replace any remaining uses with undef first. + if (!ignoredArgIndices.empty()) { + Region &body = funcOp->getRegion(0); + if (!body.empty()) { + Block &entry = body.front(); + for (int i = ignoredArgIndices.size() - 1; i >= 0; --i) { + unsigned blockIdx = ignoredArgIndices[i]; + if (blockIdx < entry.getNumArguments()) { + BlockArgument arg = entry.getArgument(blockIdx); + if (!arg.use_empty()) { + rewriter.setInsertionPointToStart(&entry); + auto ptrTy = cir::PointerType::get(arg.getType()); + auto alloca = cir::AllocaOp::create( + rewriter, funcOp.getLoc(), ptrTy, arg.getType(), + /*name=*/rewriter.getStringAttr("ignored"), + /*alignment=*/rewriter.getI64IntegerAttr(1)); + auto load = cir::LoadOp::create( + rewriter, funcOp.getLoc(), arg.getType(), alloca, + /*isDeref=*/mlir::UnitAttr(), + /*isVolatile=*/mlir::UnitAttr(), + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr(), + /*mem_order=*/cir::MemOrderAttr()); + arg.replaceAllUsesWith(load); + } + entry.eraseArgument(blockIdx); + } + } + } + } + + if (returnCoerced) + insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType, + rewriter); + + // When the return type is Ignore (empty struct), rewrite all + // return ops to drop their operand so they return void. + if (fc.returnInfo.kind == ArgKind::Ignore && !oldResultTypes.empty()) { + funcOp.walk([&](cir::ReturnOp retOp) { + if (retOp.getNumOperands() > 0) { + rewriter.setInsertionPoint(retOp); + cir::ReturnOp::create(rewriter, retOp.getLoc()); + retOp->erase(); + } + }); + } + } + + Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes); + funcOp.setFunctionTypeAttr(TypeAttr::get(newFnTy)); + + // Attach signext/zeroext attributes for Extend args and returns. + { + MLIRContext *ctx = funcOp->getContext(); + unsigned numArgs = newArgTypes.size(); + bool needsArgAttrs = false; + bool hasIgnoredArgs = !ignoredArgIndices.empty(); + for (auto &argClass : fc.argInfos) + if (argClass.kind == ArgKind::Extend) + needsArgAttrs = true; + if (hasIgnoredArgs && funcOp->hasAttr("arg_attrs")) + needsArgAttrs = true; + + if (needsArgAttrs) { + SmallVector<Attribute> argAttrDicts(numArgs, DictionaryAttr::get(ctx)); + + // Preserve existing arg_attrs, skipping Ignore'd args. + if (auto existingAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs")) { + unsigned newIdx = 0; + for (unsigned oldIdx = 0; oldIdx < existingAttrs.size(); ++oldIdx) { + if (oldIdx < fc.argInfos.size() && + fc.argInfos[oldIdx].kind == ArgKind::Ignore) + continue; + if (newIdx < numArgs) + argAttrDicts[newIdx] = existingAttrs[oldIdx]; + ++newIdx; + } + } + + for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) { + if (argClass.kind != ArgKind::Extend) + continue; + if (idx >= numArgs) + continue; + auto existing = mlir::cast<DictionaryAttr>(argAttrDicts[idx]); + SmallVector<NamedAttribute> attrs(existing.begin(), existing.end()); + StringRef attrName = + argClass.signExtend ? "llvm.signext" : "llvm.zeroext"; + attrs.push_back( + rewriter.getNamedAttr(attrName, rewriter.getUnitAttr())); + argAttrDicts[idx] = DictionaryAttr::get(ctx, attrs); + } + + funcOp->setAttr("arg_attrs", ArrayAttr::get(ctx, argAttrDicts)); + } + + // Add signext/zeroext to return value for Extend returns. + if (fc.returnInfo.kind == ArgKind::Extend) { + SmallVector<NamedAttribute> retAttrs; + if (auto existing = funcOp->getAttrOfType<ArrayAttr>("res_attrs")) + if (existing.size() > 0) + for (auto attr : mlir::cast<DictionaryAttr>(existing[0])) + retAttrs.push_back(attr); + StringRef attrName = + fc.returnInfo.signExtend ? "llvm.signext" : "llvm.zeroext"; + retAttrs.push_back( + rewriter.getNamedAttr(attrName, rewriter.getUnitAttr())); + SmallVector<Attribute> resAttrDicts; + resAttrDicts.push_back(DictionaryAttr::get(ctx, retAttrs)); + funcOp->setAttr("res_attrs", ArrayAttr::get(ctx, resAttrDicts)); + } + } + + return success(); +} + +LogicalResult CIRABIRewriteContext::rewriteCallSite( + Operation *callOp, const FunctionClassification &fc, OpBuilder &rewriter) { + auto call = cast<cir::CallOp>(callOp); + + SmallVector<Value> newArgs; + bool argsChanged = false; + auto argOperands = call.getArgOperands(); + + for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) { + if (idx >= argOperands.size()) + break; + + Value arg = argOperands[idx]; + + if (argClass.kind == ArgKind::Ignore) { + argsChanged = true; + continue; + } + + if ((argClass.kind == ArgKind::Extend || + argClass.kind == ArgKind::Direct) && + argClass.coercedType && arg.getType() != argClass.coercedType) { + rewriter.setInsertionPoint(call); + Value coerced; + if (argClass.kind == ArgKind::Extend) + coerced = + cir::CastOp::create(rewriter, call.getLoc(), argClass.coercedType, + cir::CastKind::integral, arg); + else + coerced = + emitCoercion(rewriter, call.getLoc(), argClass.coercedType, arg); + newArgs.push_back(coerced); + argsChanged = true; + } else { + newArgs.push_back(arg); + } + } + + // Pass through any extra operands beyond classified args. + for (unsigned i = fc.argInfos.size(); i < argOperands.size(); ++i) + newArgs.push_back(argOperands[i]); + + // Handle direct return coercion. + bool returnCoerced = false; + Type coercedRetTy; + if ((fc.returnInfo.kind == ArgKind::Direct || + fc.returnInfo.kind == ArgKind::Extend) && + fc.returnInfo.coercedType) { + returnCoerced = true; + coercedRetTy = fc.returnInfo.coercedType; + } + + // Handle Ignore return: replace with void call. + if (fc.returnInfo.kind == ArgKind::Ignore && call.getNumResults() > 0) { + rewriter.setInsertionPoint(call); + auto voidTy = cir::VoidType::get(call.getContext()); + auto newCall = cir::CallOp::create(rewriter, call.getLoc(), + call.getCalleeAttr(), voidTy, newArgs); + for (NamedAttribute attr : call->getAttrs()) + if (!newCall->hasAttr(attr.getName())) + newCall->setAttr(attr.getName(), attr.getValue()); + + if (!call.getResult().use_empty()) { + rewriter.setInsertionPointAfter(newCall); + Type origRetTy = call.getResult().getType(); + auto ptrTy = cir::PointerType::get(origRetTy); + auto alloca = + cir::AllocaOp::create(rewriter, call.getLoc(), ptrTy, origRetTy, + /*name=*/rewriter.getStringAttr("ignored"), + /*alignment=*/rewriter.getI64IntegerAttr(1)); + auto load = + cir::LoadOp::create(rewriter, call.getLoc(), origRetTy, alloca, + /*isDeref=*/mlir::UnitAttr(), + /*isVolatile=*/mlir::UnitAttr(), + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr(), + /*mem_order=*/cir::MemOrderAttr()); + call.getResult().replaceAllUsesWith(load); + } + call->erase(); + return success(); + } + + if (!returnCoerced && !argsChanged) + return success(); + + Type callRetTy; + Type origRetTy; + bool hasResult = call.getNumResults() > 0; + + if (hasResult) { + origRetTy = call.getResult().getType(); + callRetTy = returnCoerced ? coercedRetTy : origRetTy; + } else { + callRetTy = cir::VoidType::get(call.getContext()); + } + + rewriter.setInsertionPoint(call); + auto newCall = cir::CallOp::create(rewriter, call.getLoc(), + call.getCalleeAttr(), callRetTy, newArgs); + for (NamedAttribute attr : call->getAttrs()) + if (!newCall->hasAttr(attr.getName())) + newCall->setAttr(attr.getName(), attr.getValue()); + + if (hasResult && returnCoerced && origRetTy != coercedRetTy) { + rewriter.setInsertionPointAfter(newCall); + Value castBack = + emitCoercion(rewriter, call.getLoc(), origRetTy, newCall.getResult()); + call.getResult().replaceAllUsesWith(castBack); + } else if (hasResult) { + call.getResult().replaceAllUsesWith(newCall.getResult()); + } + + call->erase(); + return success(); +} diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h new file mode 100644 index 0000000000000..93d968c9db123 --- /dev/null +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h @@ -0,0 +1,50 @@ +//===- CIRABIRewriteContext.h - CIR-specific ABI rewriting ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines CIRABIRewriteContext, the CIR dialect's implementation of +// the shared ABIRewriteContext interface. It rewrites CIR function definitions +// and call sites to match ABI-lowered signatures. +// +//===----------------------------------------------------------------------===// + +#ifndef CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H +#define CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H + +#include "mlir/ABI/ABIRewriteContext.h" +#include "mlir/IR/BuiltinOps.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" + +namespace cir { + +/// CIR-specific implementation of the ABIRewriteContext interface. +/// +/// This class knows how to rewrite CIR FuncOps and CallOps to match +/// ABI-lowered signatures, using CIR operations for coercion (alloca, +/// load, store, cast, etc.). +class CIRABIRewriteContext : public mlir::abi::ABIRewriteContext { + mlir::ModuleOp module; + +public: + explicit CIRABIRewriteContext(mlir::ModuleOp module) : module(module) {} + + mlir::LogicalResult + rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp, + const mlir::abi::FunctionClassification &fc, + mlir::OpBuilder &rewriter) override; + + mlir::LogicalResult + rewriteCallSite(mlir::Operation *callOp, + const mlir::abi::FunctionClassification &fc, + mlir::OpBuilder &rewriter) override; + + mlir::StringRef getDialectNamespace() const override { return "cir"; } +}; + +} // namespace cir + +#endif // CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt index 86502b7f5dd4e..9833952623708 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt @@ -1,4 +1,5 @@ add_clang_library(MLIRCIRTargetLowering + CIRABIRewriteContext.cpp CIRCXXABI.cpp LowerModule.cpp LowerItaniumCXXABI.cpp @@ -15,6 +16,7 @@ add_clang_library(MLIRCIRTargetLowering LINK_LIBS PUBLIC clangBasic + MLIRABI MLIRIR MLIRPass MLIRDLTIDialect diff --git a/clang/unittests/CIR/CIRABIRewriteContextTest.cpp b/clang/unittests/CIR/CIRABIRewriteContextTest.cpp new file mode 100644 index 0000000000000..5cf8637346ae3 --- /dev/null +++ b/clang/unittests/CIR/CIRABIRewriteContextTest.cpp @@ -0,0 +1,406 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Unit tests for CIRABIRewriteContext, the CIR dialect's concrete +// implementation of the shared ABIRewriteContext interface. Each test +// constructs a FunctionClassification manually (no ABI library needed) +// and verifies the resulting IR after rewriting. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ABI/ABIRewriteContext.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "clang/CIR/Dialect/IR/CIRAttrs.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/IR/CIRTypes.h" +#include "gtest/gtest.h" + +// The header is private to the Transforms library, so we include it +// via the path relative to the source tree. The CMakeLists arranges +// the include directories. +#include "../../lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h" + +using namespace mlir; +using namespace mlir::abi; + +namespace { + +class CIRABIRewriteTest : public ::testing::Test { +protected: + CIRABIRewriteTest() : builder(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect<cir::CIRDialect>(); + } + + MLIRContext context; + OpBuilder builder; + Location loc; + + /// Create a ModuleOp containing a single CIR FuncOp with the given + /// argument types and return type. If \p addBody is true, the + /// function gets an entry block with a cir.return (returning its + /// first result-typed block arg if non-void, or void otherwise). + std::pair<ModuleOp, cir::FuncOp> createFunc(StringRef name, + ArrayRef<Type> argTypes, + Type retType, + bool addBody = true) { + auto module = ModuleOp::create(loc); + builder.setInsertionPointToEnd(module.getBody()); + + auto funcTy = cir::FuncType::get(argTypes, retType); + auto funcOp = cir::FuncOp::create(builder, loc, name, funcTy); + + if (addBody) { + Block *entry = funcOp.addEntryBlock(); + builder.setInsertionPointToEnd(entry); + if (isa<cir::VoidType>(retType)) + cir::ReturnOp::create(builder, loc); + else + cir::ReturnOp::create(builder, loc, + mlir::ValueRange{entry->getArgument(0)}); + } + + return {module, funcOp}; + } + + /// Create a ModuleOp containing a caller function that calls a + /// callee. The caller passes its own block arguments to the callee. + struct CallFixture { + ModuleOp module; + cir::FuncOp callee; + cir::FuncOp caller; + cir::CallOp callOp; + }; + + CallFixture createCallPair(StringRef calleeName, ArrayRef<Type> argTypes, + Type retType) { + auto module = ModuleOp::create(loc); + builder.setInsertionPointToEnd(module.getBody()); + + auto funcTy = cir::FuncType::get(argTypes, retType); + + // Callee (declaration only). + auto callee = cir::FuncOp::create(builder, loc, calleeName, funcTy); + + // Caller with a body that calls the callee. + auto caller = cir::FuncOp::create(builder, loc, "caller", funcTy); + Block *entry = caller.addEntryBlock(); + builder.setInsertionPointToEnd(entry); + + SmallVector<Value> args; + for (unsigned i = 0; i < argTypes.size(); ++i) + args.push_back(entry->getArgument(i)); + + cir::CallOp call; + if (isa<cir::VoidType>(retType)) { + auto voidTy = cir::VoidType::get(&context); + call = cir::CallOp::create( + builder, loc, mlir::FlatSymbolRefAttr::get(&context, calleeName), + voidTy, args); + cir::ReturnOp::create(builder, loc); + } else { + call = cir::CallOp::create( + builder, loc, mlir::FlatSymbolRefAttr::get(&context, calleeName), + retType, args); + cir::ReturnOp::create(builder, loc, mlir::ValueRange{call.getResult()}); + } + + return {module, callee, caller, call}; + } +}; + +// ---- rewriteFunctionDefinition tests ---- + +TEST_F(CIRABIRewriteTest, DirectPassthrough) { + auto i32Ty = cir::IntType::get(&context, 32, true); + auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getDirect(); + fc.argInfos.push_back(ArgClassification::getDirect()); + + cir::CIRABIRewriteContext rewriteCtx(module); + OpBuilder rewriter(funcOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter))); + + auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType()); + EXPECT_EQ(fnTy.getInputs().size(), 1u); + EXPECT_EQ(fnTy.getInputs()[0], i32Ty); + EXPECT_EQ(fnTy.getReturnType(), i32Ty); + + module->erase(); +} + +TEST_F(CIRABIRewriteTest, DirectReturnCoercion) { + auto i32Ty = cir::IntType::get(&context, 32, true); + auto i64Ty = cir::IntType::get(&context, 64, false); + auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getDirect(i64Ty); + fc.argInfos.push_back(ArgClassification::getDirect()); + + cir::CIRABIRewriteContext rewriteCtx(module); + OpBuilder rewriter(funcOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter))); + + auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType()); + EXPECT_EQ(fnTy.getReturnType(), i64Ty); + + module->erase(); +} + +TEST_F(CIRABIRewriteTest, ExtendArg) { + auto i8Ty = cir::IntType::get(&context, 8, true); + auto i32Ty = cir::IntType::get(&context, 32, true); + auto voidTy = cir::VoidType::get(&context); + auto [module, funcOp] = createFunc("f", {i8Ty}, voidTy); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getDirect(); + fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true)); + + cir::CIRABIRewriteContext rewriteCtx(module); + OpBuilder rewriter(funcOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter))); + + auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType()); + EXPECT_EQ(fnTy.getInputs().size(), 1u); + EXPECT_EQ(fnTy.getInputs()[0], i32Ty); + + // Verify signext attribute was attached. + auto argAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs"); + ASSERT_TRUE(argAttrs != nullptr); + ASSERT_EQ(argAttrs.size(), 1u); + auto dict = cast<DictionaryAttr>(argAttrs[0]); + EXPECT_TRUE(dict.get("llvm.signext") != nullptr); + + // Verify the entry block has a cir.cast (integral) to adapt i32 + // back to i8 for body uses. + Block &entry = funcOp->getRegion(0).front(); + bool foundCast = false; + for (Operation &op : entry) { + if (auto cast = dyn_cast<cir::CastOp>(op)) { + if (cast.getKind() == cir::CastKind::integral) { + EXPECT_EQ(cast.getResult().getType(), i8Ty); + foundCast = true; + } + } + } + EXPECT_TRUE(foundCast); + + module->erase(); +} + +TEST_F(CIRABIRewriteTest, ExtendReturn) { + auto i8Ty = cir::IntType::get(&context, 8, true); + auto i32Ty = cir::IntType::get(&context, 32, true); + auto [module, funcOp] = createFunc("f", {i8Ty}, i8Ty); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getExtend(i32Ty, false); + fc.argInfos.push_back(ArgClassification::getDirect()); + + cir::CIRABIRewriteContext rewriteCtx(module); + OpBuilder rewriter(funcOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter))); + + auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType()); + EXPECT_EQ(fnTy.getReturnType(), i32Ty); + + // Verify zeroext attribute on return. + auto resAttrs = funcOp->getAttrOfType<ArrayAttr>("res_attrs"); + ASSERT_TRUE(resAttrs != nullptr); + ASSERT_EQ(resAttrs.size(), 1u); + auto dict = cast<DictionaryAttr>(resAttrs[0]); + EXPECT_TRUE(dict.get("llvm.zeroext") != nullptr); + + module->erase(); +} + +TEST_F(CIRABIRewriteTest, IgnoreReturn) { + auto i32Ty = cir::IntType::get(&context, 32, true); + auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getIgnore(); + fc.argInfos.push_back(ArgClassification::getDirect()); + + cir::CIRABIRewriteContext rewriteCtx(module); + OpBuilder rewriter(funcOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter))); + + auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType()); + auto voidTy = cir::VoidType::get(&context); + EXPECT_EQ(fnTy.getReturnType(), voidTy); + + module->erase(); +} + +TEST_F(CIRABIRewriteTest, IgnoreArg) { + auto i32Ty = cir::IntType::get(&context, 32, true); + auto voidTy = cir::VoidType::get(&context); + auto [module, funcOp] = createFunc("f", {i32Ty}, voidTy); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getDirect(); + fc.argInfos.push_back(ArgClassification::getIgnore()); + + cir::CIRABIRewriteContext rewriteCtx(module); + OpBuilder rewriter(funcOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter))); + + auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType()); + EXPECT_EQ(fnTy.getInputs().size(), 0u); + + Block &entry = funcOp->getRegion(0).front(); + EXPECT_EQ(entry.getNumArguments(), 0u); + + module->erase(); +} + +TEST_F(CIRABIRewriteTest, DeclarationRewrite) { + auto i8Ty = cir::IntType::get(&context, 8, true); + auto i32Ty = cir::IntType::get(&context, 32, true); + auto [module, funcOp] = createFunc("f", {i8Ty}, i8Ty, /*addBody=*/false); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getExtend(i32Ty, true); + fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true)); + + cir::CIRABIRewriteContext rewriteCtx(module); + OpBuilder rewriter(funcOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter))); + + auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType()); + EXPECT_EQ(fnTy.getInputs()[0], i32Ty); + EXPECT_EQ(fnTy.getReturnType(), i32Ty); + + // Verify both signext attributes. + auto argAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs"); + ASSERT_TRUE(argAttrs != nullptr); + auto dict = cast<DictionaryAttr>(argAttrs[0]); + EXPECT_TRUE(dict.get("llvm.signext") != nullptr); + + auto resAttrs = funcOp->getAttrOfType<ArrayAttr>("res_attrs"); + ASSERT_TRUE(resAttrs != nullptr); + auto rdict = cast<DictionaryAttr>(resAttrs[0]); + EXPECT_TRUE(rdict.get("llvm.signext") != nullptr); + + module->erase(); +} + +// ---- rewriteCallSite tests ---- + +TEST_F(CIRABIRewriteTest, CallSiteDirectPassthrough) { + auto i32Ty = cir::IntType::get(&context, 32, true); + auto fixture = createCallPair("callee", {i32Ty}, i32Ty); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getDirect(); + fc.argInfos.push_back(ArgClassification::getDirect()); + + cir::CIRABIRewriteContext rewriteCtx(fixture.module); + OpBuilder rewriter(fixture.callOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter))); + + // The original call should still be there (no changes needed). + EXPECT_EQ(fixture.callOp->getNumResults(), 1u); + EXPECT_EQ(fixture.callOp->getResult(0).getType(), i32Ty); + + fixture.module->erase(); +} + +TEST_F(CIRABIRewriteTest, CallSiteExtendArg) { + auto i8Ty = cir::IntType::get(&context, 8, true); + auto i32Ty = cir::IntType::get(&context, 32, true); + auto voidTy = cir::VoidType::get(&context); + auto fixture = createCallPair("callee", {i8Ty}, voidTy); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getDirect(); + fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true)); + + cir::CIRABIRewriteContext rewriteCtx(fixture.module); + OpBuilder rewriter(fixture.callOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter))); + + // The old call was erased and replaced. Look for a CallOp whose + // argument is i32 (the extended type). + Block &callerEntry = fixture.caller->getRegion(0).front(); + cir::CallOp newCall; + for (Operation &op : callerEntry) + if (auto c = dyn_cast<cir::CallOp>(op)) + newCall = c; + ASSERT_TRUE(newCall != nullptr); + EXPECT_EQ(newCall.getArgOperands()[0].getType(), i32Ty); + + fixture.module->erase(); +} + +TEST_F(CIRABIRewriteTest, CallSiteIgnoreReturn) { + auto i32Ty = cir::IntType::get(&context, 32, true); + auto fixture = createCallPair("callee", {i32Ty}, i32Ty); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getIgnore(); + fc.argInfos.push_back(ArgClassification::getDirect()); + + cir::CIRABIRewriteContext rewriteCtx(fixture.module); + OpBuilder rewriter(fixture.callOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter))); + + // Find the replacement void call. + Block &callerEntry = fixture.caller->getRegion(0).front(); + cir::CallOp newCall; + for (Operation &op : callerEntry) + if (auto c = dyn_cast<cir::CallOp>(op)) + newCall = c; + ASSERT_TRUE(newCall != nullptr); + EXPECT_EQ(newCall.getNumResults(), 0u); + + fixture.module->erase(); +} + +TEST_F(CIRABIRewriteTest, CallSiteIgnoreArg) { + auto i32Ty = cir::IntType::get(&context, 32, true); + auto voidTy = cir::VoidType::get(&context); + auto fixture = createCallPair("callee", {i32Ty}, voidTy); + + FunctionClassification fc; + fc.returnInfo = ArgClassification::getDirect(); + fc.argInfos.push_back(ArgClassification::getIgnore()); + + cir::CIRABIRewriteContext rewriteCtx(fixture.module); + OpBuilder rewriter(fixture.callOp); + ASSERT_TRUE( + succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter))); + + // Find the replacement call -- it should have zero args. + Block &callerEntry = fixture.caller->getRegion(0).front(); + cir::CallOp newCall; + for (Operation &op : callerEntry) + if (auto c = dyn_cast<cir::CallOp>(op)) + newCall = c; + ASSERT_TRUE(newCall != nullptr); + EXPECT_EQ(newCall.getArgOperands().size(), 0u); + + fixture.module->erase(); +} + +} // namespace diff --git a/clang/unittests/CIR/CMakeLists.txt b/clang/unittests/CIR/CMakeLists.txt index 650fde38c48a9..d318810d33fe5 100644 --- a/clang/unittests/CIR/CMakeLists.txt +++ b/clang/unittests/CIR/CMakeLists.txt @@ -1,15 +1,20 @@ set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include ) set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include) +set(CLANG_CIR_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../lib/CIR) include_directories(SYSTEM ${MLIR_INCLUDE_DIR}) include_directories(${MLIR_TABLEGEN_OUTPUT_DIR}) +include_directories(${CLANG_CIR_SRC_DIR}) add_distinct_clang_unittest(CIRUnitTests + CIRABIRewriteContextTest.cpp PointerLikeTest.cpp LLVM_COMPONENTS Core LINK_LIBS + MLIRABI MLIRCIR + MLIRCIRTargetLowering CIROpenACCSupport MLIRIR MLIROpenACCDialect _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
