https://github.com/adams381 created 
https://github.com/llvm/llvm-project/pull/192119

CIR-specific concrete subclass of the `ABIRewriteContext` interface introduced 
in #190661.  Rewrites CIR FuncOps and CallOps to match ABI-lowered signatures.

This first PR handles the scalar cases:

- Direct passthrough and scalar coercion (bitcast)
- Extend (integer widening with signext/zeroext attrs)
- Ignore (void returns, empty-struct arg erasure)
- Call-site rewrites for all of the above

Struct coercion (sret, byval, multi-register flattening) comes next.

11 C++ unit tests — each constructs a `FunctionClassification` by hand and 
verifies the rewritten IR, so no ABI classifier dependency.

Depends on #190661.

Made with [Cursor](https://cursor.com)

>From 23ef4674ee8ac04d7777d5442071a781b8a09b7b Mon Sep 17 00:00:00 2001
From: Adam Smith <[email protected]>
Date: Mon, 6 Apr 2026 12:14:06 -0700
Subject: [PATCH 1/3] [mlir][ABI] Add ABITypeMapper and ABIRewriteContext

Dialect-agnostic layer bridging MLIR and the LLVM ABI library.
ABITypeMapper handles built-in MLIR types; ABIRewriteContext is
the interface dialects implement for ABI rewrites (see
clang/docs/ClangIRABILowering.md Section 4).

Made-with: Cursor
---
 mlir/include/mlir/ABI/ABIRewriteContext.h    | 159 +++++++++++++++++
 mlir/include/mlir/ABI/ABITypeMapper.h        |  66 +++++++
 mlir/lib/ABI/ABITypeMapper.cpp               | 102 +++++++++++
 mlir/lib/ABI/CMakeLists.txt                  |  14 ++
 mlir/lib/CMakeLists.txt                      |   1 +
 mlir/unittests/ABI/ABIRewriteContextTest.cpp |  99 +++++++++++
 mlir/unittests/ABI/ABITypeMapperTest.cpp     | 173 +++++++++++++++++++
 mlir/unittests/ABI/CMakeLists.txt            |  12 ++
 mlir/unittests/CMakeLists.txt                |   1 +
 9 files changed, 627 insertions(+)
 create mode 100644 mlir/include/mlir/ABI/ABIRewriteContext.h
 create mode 100644 mlir/include/mlir/ABI/ABITypeMapper.h
 create mode 100644 mlir/lib/ABI/ABITypeMapper.cpp
 create mode 100644 mlir/lib/ABI/CMakeLists.txt
 create mode 100644 mlir/unittests/ABI/ABIRewriteContextTest.cpp
 create mode 100644 mlir/unittests/ABI/ABITypeMapperTest.cpp
 create mode 100644 mlir/unittests/ABI/CMakeLists.txt

diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h 
b/mlir/include/mlir/ABI/ABIRewriteContext.h
new file mode 100644
index 0000000000000..71d5a56b599b7
--- /dev/null
+++ b/mlir/include/mlir/ABI/ABIRewriteContext.h
@@ -0,0 +1,159 @@
+//===- ABIRewriteContext.h - Dialect-specific ABI rewriting -----*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines ABIRewriteContext, the abstract interface for dialect-
+// specific ABI lowering rewrites.  Each MLIR dialect that wants ABI lowering
+// (CIR, FIR, etc.) provides a concrete subclass.
+//
+// ABIRewriteContext consumes ABI classification results and drives the
+// creation of lowered function signatures, argument coercions, and call
+// site rewrites using dialect-specific operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ABI_ABIREWRITECONTEXT_H
+#define MLIR_ABI_ABIREWRITECONTEXT_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "llvm/Support/Alignment.h"
+
+namespace mlir {
+namespace abi {
+
+/// Classification of how a single argument or return value should be
+/// passed at the ABI level.
+///
+/// This is a dialect-agnostic representation.  It mirrors the kinds
+/// found in the LLVM ABI library and in CIR's ABIArgInfo, but does
+/// not depend on either.
+enum class ArgKind : uint8_t {
+  /// Pass directly in registers, possibly coerced to a different type.
+  Direct,
+
+  /// Like Direct, but with a sign/zero extension attribute.
+  Extend,
+
+  /// Pass indirectly via a pointer (sret for returns, byval for args).
+  Indirect,
+
+  /// Ignore (void return, empty struct).
+  Ignore,
+
+  /// Expand an aggregate into its constituent scalar fields.
+  Expand,
+};
+
+/// Describes how a single argument or return value is passed after ABI
+/// lowering.
+struct ArgClassification {
+  ArgKind Kind = ArgKind::Direct;
+
+  /// The ABI-coerced type, if different from the original.  Null means
+  /// use the original type.
+  Type CoercedType = nullptr;
+
+  /// For Indirect: alignment of the pointed-to object.
+  llvm::Align IndirectAlign = llvm::Align(1);
+
+  /// For Extend: whether to sign-extend (true) or zero-extend (false).
+  bool SignExtend = false;
+
+  /// For Direct: whether a struct coercion can be flattened into
+  /// individual register-width arguments.
+  bool CanFlatten = true;
+
+  /// For Indirect: whether the callee gets ownership (byval).
+  bool ByVal = false;
+
+  static ArgClassification getDirect(Type coerced = nullptr) {
+    ArgClassification c;
+    c.Kind = ArgKind::Direct;
+    c.CoercedType = coerced;
+    return c;
+  }
+
+  static ArgClassification getIgnore() {
+    ArgClassification c;
+    c.Kind = ArgKind::Ignore;
+    return c;
+  }
+
+  static ArgClassification getIndirect(llvm::Align align, bool byVal = true) {
+    ArgClassification c;
+    c.Kind = ArgKind::Indirect;
+    c.IndirectAlign = align;
+    c.ByVal = byVal;
+    return c;
+  }
+
+  static ArgClassification getExtend(Type coerced, bool signExt) {
+    ArgClassification c;
+    c.Kind = ArgKind::Extend;
+    c.CoercedType = coerced;
+    c.SignExtend = signExt;
+    return c;
+  }
+};
+
+/// Holds the full ABI classification for a function: return type and
+/// all arguments.
+struct FunctionClassification {
+  ArgClassification ReturnInfo;
+  SmallVector<ArgClassification> ArgInfos;
+};
+
+/// ABIRewriteContext is the abstract interface that each dialect
+/// implements to perform ABI-specific rewrites on its operations.
+///
+/// The pass orchestrator calls these methods after ABI classification
+/// to rewrite function definitions and call sites.
+class ABIRewriteContext {
+public:
+  virtual ~ABIRewriteContext() = default;
+
+  /// Rewrite a function definition to use ABI-lowered types.
+  ///
+  /// This creates a new function with the lowered signature, rewrites
+  /// the function body to adapt between the ABI types and the
+  /// original high-level types, and replaces the original function.
+  ///
+  /// \param funcOp  The function to rewrite (via FunctionOpInterface).
+  /// \param fc      The ABI classification for this function.
+  /// \param rewriter  The pattern rewriter to use for modifications.
+  /// \returns success() if the function was rewritten.
+  virtual LogicalResult
+  rewriteFunctionDefinition(FunctionOpInterface funcOp,
+                            const FunctionClassification &fc,
+                            OpBuilder &rewriter) = 0;
+
+  /// Rewrite a call operation to match the callee's ABI-lowered
+  /// signature.
+  ///
+  /// This coerces arguments, handles indirect returns (sret), and
+  /// adapts the call result back to the original high-level type.
+  ///
+  /// \param callOp  The call operation to rewrite.
+  /// \param fc      The ABI classification for the callee.
+  /// \param rewriter  The pattern rewriter to use for modifications.
+  /// \returns success() if the call was rewritten.
+  virtual LogicalResult rewriteCallSite(Operation *callOp,
+                                        const FunctionClassification &fc,
+                                        OpBuilder &rewriter) = 0;
+
+  /// Return the dialect namespace this context handles (e.g. "cir").
+  virtual StringRef getDialectNamespace() const = 0;
+};
+
+} // namespace abi
+} // namespace mlir
+
+#endif // MLIR_ABI_ABIREWRITECONTEXT_H
diff --git a/mlir/include/mlir/ABI/ABITypeMapper.h 
b/mlir/include/mlir/ABI/ABITypeMapper.h
new file mode 100644
index 0000000000000..2180c9c8a918d
--- /dev/null
+++ b/mlir/include/mlir/ABI/ABITypeMapper.h
@@ -0,0 +1,66 @@
+//===- ABITypeMapper.h - Map MLIR types to ABI types -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines ABITypeMapper, which translates mlir::Type instances into
+// the llvm::abi::Type hierarchy defined in llvm/ABI/Types.h.  Dialect-specific
+// types are handled via MLIR's DataLayoutTypeInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ABI_ABITYPEMAPPER_H
+#define MLIR_ABI_ABITYPEMAPPER_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "llvm/ABI/Types.h"
+#include "llvm/Support/Allocator.h"
+
+namespace mlir {
+namespace abi {
+
+/// ABITypeMapper translates mlir::Type values into the llvm::abi::Type
+/// hierarchy used by the LLVM ABI Lowering Library.
+///
+/// Standard MLIR types (IntegerType, FloatType, IndexType, VectorType,
+/// MemRefType) are mapped directly.  Dialect-specific types are mapped
+/// by querying the MLIR DataLayout for size and alignment.
+///
+/// Callers must supply a DataLayout (typically from the enclosing module)
+/// so the mapper can determine sizes and alignments.
+///
+/// The mapper owns a BumpPtrAllocator; all returned abi::Type pointers
+/// are valid for the lifetime of the mapper.
+class ABITypeMapper {
+public:
+  explicit ABITypeMapper(const DataLayout &dl);
+
+  /// Map an MLIR type to its ABI type representation.  Returns nullptr
+  /// if the type cannot be mapped.
+  const llvm::abi::Type *map(mlir::Type type);
+
+  /// Access the underlying TypeBuilder for advanced use.
+  llvm::abi::TypeBuilder &getTypeBuilder() { return Builder; }
+
+private:
+  const llvm::abi::Type *mapIntegerType(mlir::IntegerType type);
+  const llvm::abi::Type *mapFloatType(mlir::FloatType type);
+  const llvm::abi::Type *mapIndexType(mlir::IndexType type);
+  const llvm::abi::Type *mapVectorType(mlir::VectorType type);
+  const llvm::abi::Type *mapMemRefType(mlir::MemRefType type);
+  const llvm::abi::Type *mapNoneType(mlir::NoneType type);
+
+  const DataLayout &DL;
+  llvm::BumpPtrAllocator Allocator;
+  llvm::abi::TypeBuilder Builder;
+};
+
+} // namespace abi
+} // namespace mlir
+
+#endif // MLIR_ABI_ABITYPEMAPPER_H
diff --git a/mlir/lib/ABI/ABITypeMapper.cpp b/mlir/lib/ABI/ABITypeMapper.cpp
new file mode 100644
index 0000000000000..c7a69780bbe64
--- /dev/null
+++ b/mlir/lib/ABI/ABITypeMapper.cpp
@@ -0,0 +1,102 @@
+//===- ABITypeMapper.cpp - Map MLIR types to ABI types 
--------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABITypeMapper.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/Alignment.h"
+
+using namespace mlir;
+using namespace mlir::abi;
+
+ABITypeMapper::ABITypeMapper(const DataLayout &dl)
+    : DL(dl), Builder(Allocator) {}
+
+const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) {
+  if (auto intTy = dyn_cast<mlir::IntegerType>(type))
+    return mapIntegerType(intTy);
+
+  if (auto floatTy = dyn_cast<mlir::FloatType>(type))
+    return mapFloatType(floatTy);
+
+  if (auto indexTy = dyn_cast<mlir::IndexType>(type))
+    return mapIndexType(indexTy);
+
+  if (auto vecTy = dyn_cast<mlir::VectorType>(type))
+    return mapVectorType(vecTy);
+
+  if (auto memRefTy = dyn_cast<mlir::MemRefType>(type))
+    return mapMemRefType(memRefTy);
+
+  if (auto noneTy = dyn_cast<mlir::NoneType>(type))
+    return mapNoneType(noneTy);
+
+  // For dialect-specific types, fall back to DataLayout queries.
+  // The type must implement DataLayoutTypeInterface for this to work.
+  llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  return Builder.getIntegerType(sizeInBits.getFixedValue(),
+                                llvm::Align(abiAlign),
+                                /*Signed=*/false);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapIntegerType(mlir::IntegerType type) {
+  uint64_t width = type.getWidth();
+  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  // MLIR signless integers are treated as signed for ABI purposes.
+  // Most C/C++ integer types are signless in MLIR but behave as
+  // signed for ABI classification (sign extension, etc.).
+  bool isSigned = type.isSigned() || type.isSignless();
+  return Builder.getIntegerType(width, llvm::Align(abiAlign), isSigned);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapFloatType(mlir::FloatType type) {
+  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  const llvm::fltSemantics &semantics = type.getFloatSemantics();
+  return Builder.getFloatType(semantics, llvm::Align(abiAlign));
+}
+
+const llvm::abi::Type *ABITypeMapper::mapIndexType(mlir::IndexType type) {
+  llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  return Builder.getIntegerType(sizeInBits.getFixedValue(),
+                                llvm::Align(abiAlign),
+                                /*Signed=*/false);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapVectorType(mlir::VectorType type) {
+  const llvm::abi::Type *elementTy = map(type.getElementType());
+  if (!elementTy)
+    return nullptr;
+
+  auto shape = type.getShape();
+  // MLIR VectorType is always fixed-length and can be multi-dimensional.
+  // Flatten to a single dimension for ABI purposes.
+  uint64_t totalElements = 1;
+  for (int64_t dim : shape)
+    totalElements *= dim;
+
+  llvm::ElementCount ec = llvm::ElementCount::getFixed(totalElements);
+  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  return Builder.getVectorType(elementTy, ec, llvm::Align(abiAlign));
+}
+
+const llvm::abi::Type *ABITypeMapper::mapMemRefType(mlir::MemRefType type) {
+  // MemRef is pointer-like for ABI purposes.
+  llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  unsigned addrSpace = 0;
+  if (auto as = type.getMemorySpace())
+    if (auto intAttr = dyn_cast<IntegerAttr>(as))
+      addrSpace = intAttr.getInt();
+  return Builder.getPointerType(sizeInBits.getFixedValue(),
+                                llvm::Align(abiAlign), addrSpace);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapNoneType(mlir::NoneType type) {
+  return Builder.getVoidType();
+}
diff --git a/mlir/lib/ABI/CMakeLists.txt b/mlir/lib/ABI/CMakeLists.txt
new file mode 100644
index 0000000000000..eb434d25dd390
--- /dev/null
+++ b/mlir/lib/ABI/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_library(MLIRABI
+  ABITypeMapper.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/ABI
+
+  LINK_COMPONENTS
+  ABI
+  Support
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRDataLayoutInterfaces
+  )
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index 91ed05f6548d7..d7a6e28d98586 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Enable errors for any global constructors.
 add_flag_if_supported("-Werror=global-constructors" WERROR_GLOBAL_CONSTRUCTOR)
 
+add_subdirectory(ABI)
 add_subdirectory(Analysis)
 add_subdirectory(AsmParser)
 add_subdirectory(Bytecode)
diff --git a/mlir/unittests/ABI/ABIRewriteContextTest.cpp 
b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
new file mode 100644
index 0000000000000..04c28991cc752
--- /dev/null
+++ b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
@@ -0,0 +1,99 @@
+//===- ABIRewriteContextTest.cpp - Unit tests for ABIRewriteContext 
-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class MockRewriteContext : public ABIRewriteContext {
+public:
+  LogicalResult rewriteFunctionDefinition(FunctionOpInterface,
+                                          const FunctionClassification &,
+                                          OpBuilder &) override {
+    return success();
+  }
+
+  LogicalResult rewriteCallSite(Operation *, const FunctionClassification &,
+                                OpBuilder &) override {
+    return success();
+  }
+
+  StringRef getDialectNamespace() const override { return "mock"; }
+};
+
+TEST(ABIRewriteContextTest, MockCanBeConstructedAndDestroyed) {
+  MockRewriteContext ctx;
+  EXPECT_EQ(ctx.getDialectNamespace(), "mock");
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationDirect) {
+  auto c = ArgClassification::getDirect();
+  EXPECT_EQ(c.Kind, ArgKind::Direct);
+  EXPECT_EQ(c.CoercedType, nullptr);
+  EXPECT_TRUE(c.CanFlatten);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationDirectWithType) {
+  MLIRContext mlirCtx;
+  auto i32 = IntegerType::get(&mlirCtx, 32);
+  auto c = ArgClassification::getDirect(i32);
+  EXPECT_EQ(c.Kind, ArgKind::Direct);
+  EXPECT_EQ(c.CoercedType, i32);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIgnore) {
+  auto c = ArgClassification::getIgnore();
+  EXPECT_EQ(c.Kind, ArgKind::Ignore);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIndirect) {
+  auto c = ArgClassification::getIndirect(llvm::Align(8), true);
+  EXPECT_EQ(c.Kind, ArgKind::Indirect);
+  EXPECT_EQ(c.IndirectAlign, llvm::Align(8));
+  EXPECT_TRUE(c.ByVal);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIndirectNoByVal) {
+  auto c = ArgClassification::getIndirect(llvm::Align(16), false);
+  EXPECT_EQ(c.Kind, ArgKind::Indirect);
+  EXPECT_EQ(c.IndirectAlign, llvm::Align(16));
+  EXPECT_FALSE(c.ByVal);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationExtend) {
+  MLIRContext mlirCtx;
+  auto i8 = IntegerType::get(&mlirCtx, 8);
+
+  auto signExt = ArgClassification::getExtend(i8, true);
+  EXPECT_EQ(signExt.Kind, ArgKind::Extend);
+  EXPECT_TRUE(signExt.SignExtend);
+
+  auto zeroExt = ArgClassification::getExtend(i8, false);
+  EXPECT_EQ(zeroExt.Kind, ArgKind::Extend);
+  EXPECT_FALSE(zeroExt.SignExtend);
+}
+
+TEST(ABIRewriteContextTest, FunctionClassificationHoldsReturnAndArgs) {
+  FunctionClassification fc;
+  fc.ReturnInfo = ArgClassification::getDirect();
+  fc.ArgInfos.push_back(ArgClassification::getDirect());
+  fc.ArgInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true));
+  fc.ArgInfos.push_back(ArgClassification::getIgnore());
+
+  EXPECT_EQ(fc.ReturnInfo.Kind, ArgKind::Direct);
+  EXPECT_EQ(fc.ArgInfos.size(), 3u);
+  EXPECT_EQ(fc.ArgInfos[0].Kind, ArgKind::Direct);
+  EXPECT_EQ(fc.ArgInfos[1].Kind, ArgKind::Indirect);
+  EXPECT_EQ(fc.ArgInfos[2].Kind, ArgKind::Ignore);
+}
+
+} // namespace
diff --git a/mlir/unittests/ABI/ABITypeMapperTest.cpp 
b/mlir/unittests/ABI/ABITypeMapperTest.cpp
new file mode 100644
index 0000000000000..4a7989298a149
--- /dev/null
+++ b/mlir/unittests/ABI/ABITypeMapperTest.cpp
@@ -0,0 +1,173 @@
+//===- ABITypeMapperTest.cpp - Unit tests for ABITypeMapper 
---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABITypeMapper.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "llvm/ABI/Types.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class ABITypeMapperTest : public ::testing::Test {
+protected:
+  void SetUp() override {
+    ctx.loadDialect<DLTIDialect>();
+    module = ModuleOp::create(UnknownLoc::get(&ctx));
+  }
+
+  void TearDown() override { module->destroy(); }
+
+  MLIRContext ctx;
+  ModuleOp module;
+};
+
+TEST_F(ABITypeMapperTest, MapI32) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto i32 = IntegerType::get(&ctx, 32);
+  const llvm::abi::Type *result = mapper.map(i32);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isInteger());
+
+  auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+  EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 32u);
+}
+
+TEST_F(ABITypeMapperTest, MapI1) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto i1 = IntegerType::get(&ctx, 1);
+  const llvm::abi::Type *result = mapper.map(i1);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isInteger());
+
+  auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+  EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 1u);
+}
+
+TEST_F(ABITypeMapperTest, MapI64) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto i64 = IntegerType::get(&ctx, 64);
+  const llvm::abi::Type *result = mapper.map(i64);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isInteger());
+
+  auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+  EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 64u);
+}
+
+TEST_F(ABITypeMapperTest, MapF32) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto f32 = Float32Type::get(&ctx);
+  const llvm::abi::Type *result = mapper.map(f32);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isFloat());
+
+  auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+  EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 32u);
+}
+
+TEST_F(ABITypeMapperTest, MapF64) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto f64 = Float64Type::get(&ctx);
+  const llvm::abi::Type *result = mapper.map(f64);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isFloat());
+
+  auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+  EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 64u);
+}
+
+TEST_F(ABITypeMapperTest, MapF16) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto f16 = Float16Type::get(&ctx);
+  const llvm::abi::Type *result = mapper.map(f16);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isFloat());
+
+  auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+  EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 16u);
+}
+
+TEST_F(ABITypeMapperTest, MapNoneType) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto none = NoneType::get(&ctx);
+  const llvm::abi::Type *result = mapper.map(none);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isVoid());
+}
+
+TEST_F(ABITypeMapperTest, MapVectorOf4xF32) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto f32 = Float32Type::get(&ctx);
+  auto vec = VectorType::get({4}, f32);
+  const llvm::abi::Type *result = mapper.map(vec);
+
+  ASSERT_NE(result, nullptr);
+  EXPECT_TRUE(result->isVector());
+
+  auto *vecTy = llvm::cast<llvm::abi::VectorType>(result);
+  EXPECT_EQ(vecTy->getNumElements().getFixedValue(), 4u);
+  EXPECT_TRUE(vecTy->getElementType()->isFloat());
+}
+
+TEST_F(ABITypeMapperTest, MapSignedI32) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto si32 = IntegerType::get(&ctx, 32, IntegerType::Signed);
+  const llvm::abi::Type *result = mapper.map(si32);
+
+  ASSERT_NE(result, nullptr);
+  auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+  EXPECT_TRUE(intTy->isSigned());
+}
+
+TEST_F(ABITypeMapperTest, MapUnsignedI32) {
+  DataLayout dl(module);
+  ABITypeMapper mapper(dl);
+
+  auto ui32 = IntegerType::get(&ctx, 32, IntegerType::Unsigned);
+  const llvm::abi::Type *result = mapper.map(ui32);
+
+  ASSERT_NE(result, nullptr);
+  auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+  EXPECT_FALSE(intTy->isSigned());
+}
+
+} // namespace
diff --git a/mlir/unittests/ABI/CMakeLists.txt 
b/mlir/unittests/ABI/CMakeLists.txt
new file mode 100644
index 0000000000000..39f955a8efea6
--- /dev/null
+++ b/mlir/unittests/ABI/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_unittest(MLIRABITests
+  ABIRewriteContextTest.cpp
+  ABITypeMapperTest.cpp
+)
+
+mlir_target_link_libraries(MLIRABITests
+  PRIVATE
+  MLIRABI
+  MLIRDataLayoutInterfaces
+  MLIRDLTIDialect
+  MLIRIR
+)
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 89332bce5fe05..654ec44d90b04 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
   add_unittest(MLIRUnitTests ${test_dirname} ${ARGN})
 endfunction()
 
+add_subdirectory(ABI)
 add_subdirectory(Analysis)
 add_subdirectory(Bytecode)
 add_subdirectory(Conversion)

>From ca2c452c7fd71bda991e72c00d7248f17e3a7a62 Mon Sep 17 00:00:00 2001
From: Adam Smith <[email protected]>
Date: Mon, 13 Apr 2026 15:07:34 -0700
Subject: [PATCH 2/3] [mlir][ABI][NFC] Rename member variables to camelCase

Rename struct and class members from PascalCase to camelCase
per the MLIR style guide.  Applies to ArgClassification,
FunctionClassification, and ABITypeMapper members.

Addresses review feedback from @andykaylor.

Made-with: Cursor
---
 mlir/include/mlir/ABI/ABIRewriteContext.h    | 34 ++++++-------
 mlir/include/mlir/ABI/ABITypeMapper.h        |  8 +--
 mlir/lib/ABI/ABITypeMapper.cpp               | 42 +++++++---------
 mlir/unittests/ABI/ABIRewriteContextTest.cpp | 52 ++++++++++----------
 4 files changed, 65 insertions(+), 71 deletions(-)

diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h 
b/mlir/include/mlir/ABI/ABIRewriteContext.h
index 71d5a56b599b7..7c48c626207bb 100644
--- a/mlir/include/mlir/ABI/ABIRewriteContext.h
+++ b/mlir/include/mlir/ABI/ABIRewriteContext.h
@@ -55,51 +55,51 @@ enum class ArgKind : uint8_t {
 /// Describes how a single argument or return value is passed after ABI
 /// lowering.
 struct ArgClassification {
-  ArgKind Kind = ArgKind::Direct;
+  ArgKind kind = ArgKind::Direct;
 
   /// The ABI-coerced type, if different from the original.  Null means
   /// use the original type.
-  Type CoercedType = nullptr;
+  Type coercedType = nullptr;
 
   /// For Indirect: alignment of the pointed-to object.
-  llvm::Align IndirectAlign = llvm::Align(1);
+  llvm::Align indirectAlign = llvm::Align(1);
 
   /// For Extend: whether to sign-extend (true) or zero-extend (false).
-  bool SignExtend = false;
+  bool signExtend = false;
 
   /// For Direct: whether a struct coercion can be flattened into
   /// individual register-width arguments.
-  bool CanFlatten = true;
+  bool canFlatten = true;
 
   /// For Indirect: whether the callee gets ownership (byval).
-  bool ByVal = false;
+  bool byVal = false;
 
   static ArgClassification getDirect(Type coerced = nullptr) {
     ArgClassification c;
-    c.Kind = ArgKind::Direct;
-    c.CoercedType = coerced;
+    c.kind = ArgKind::Direct;
+    c.coercedType = coerced;
     return c;
   }
 
   static ArgClassification getIgnore() {
     ArgClassification c;
-    c.Kind = ArgKind::Ignore;
+    c.kind = ArgKind::Ignore;
     return c;
   }
 
   static ArgClassification getIndirect(llvm::Align align, bool byVal = true) {
     ArgClassification c;
-    c.Kind = ArgKind::Indirect;
-    c.IndirectAlign = align;
-    c.ByVal = byVal;
+    c.kind = ArgKind::Indirect;
+    c.indirectAlign = align;
+    c.byVal = byVal;
     return c;
   }
 
   static ArgClassification getExtend(Type coerced, bool signExt) {
     ArgClassification c;
-    c.Kind = ArgKind::Extend;
-    c.CoercedType = coerced;
-    c.SignExtend = signExt;
+    c.kind = ArgKind::Extend;
+    c.coercedType = coerced;
+    c.signExtend = signExt;
     return c;
   }
 };
@@ -107,8 +107,8 @@ struct ArgClassification {
 /// Holds the full ABI classification for a function: return type and
 /// all arguments.
 struct FunctionClassification {
-  ArgClassification ReturnInfo;
-  SmallVector<ArgClassification> ArgInfos;
+  ArgClassification returnInfo;
+  SmallVector<ArgClassification> argInfos;
 };
 
 /// ABIRewriteContext is the abstract interface that each dialect
diff --git a/mlir/include/mlir/ABI/ABITypeMapper.h 
b/mlir/include/mlir/ABI/ABITypeMapper.h
index 2180c9c8a918d..2499e910ca797 100644
--- a/mlir/include/mlir/ABI/ABITypeMapper.h
+++ b/mlir/include/mlir/ABI/ABITypeMapper.h
@@ -45,7 +45,7 @@ class ABITypeMapper {
   const llvm::abi::Type *map(mlir::Type type);
 
   /// Access the underlying TypeBuilder for advanced use.
-  llvm::abi::TypeBuilder &getTypeBuilder() { return Builder; }
+  llvm::abi::TypeBuilder &getTypeBuilder() { return builder; }
 
 private:
   const llvm::abi::Type *mapIntegerType(mlir::IntegerType type);
@@ -55,9 +55,9 @@ class ABITypeMapper {
   const llvm::abi::Type *mapMemRefType(mlir::MemRefType type);
   const llvm::abi::Type *mapNoneType(mlir::NoneType type);
 
-  const DataLayout &DL;
-  llvm::BumpPtrAllocator Allocator;
-  llvm::abi::TypeBuilder Builder;
+  const DataLayout &dl;
+  llvm::BumpPtrAllocator allocator;
+  llvm::abi::TypeBuilder builder;
 };
 
 } // namespace abi
diff --git a/mlir/lib/ABI/ABITypeMapper.cpp b/mlir/lib/ABI/ABITypeMapper.cpp
index c7a69780bbe64..83dc6990ec5bb 100644
--- a/mlir/lib/ABI/ABITypeMapper.cpp
+++ b/mlir/lib/ABI/ABITypeMapper.cpp
@@ -13,8 +13,8 @@
 using namespace mlir;
 using namespace mlir::abi;
 
-ABITypeMapper::ABITypeMapper(const DataLayout &dl)
-    : DL(dl), Builder(Allocator) {}
+ABITypeMapper::ABITypeMapper(const DataLayout &dataLayout)
+    : dl(dataLayout), builder(allocator) {}
 
 const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) {
   if (auto intTy = dyn_cast<mlir::IntegerType>(type))
@@ -37,33 +37,30 @@ const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) {
 
   // For dialect-specific types, fall back to DataLayout queries.
   // The type must implement DataLayoutTypeInterface for this to work.
-  llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
-  uint64_t abiAlign = DL.getTypeABIAlignment(type);
-  return Builder.getIntegerType(sizeInBits.getFixedValue(),
+  llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type);
+  uint64_t abiAlign = dl.getTypeABIAlignment(type);
+  return builder.getIntegerType(sizeInBits.getFixedValue(),
                                 llvm::Align(abiAlign),
                                 /*Signed=*/false);
 }
 
 const llvm::abi::Type *ABITypeMapper::mapIntegerType(mlir::IntegerType type) {
   uint64_t width = type.getWidth();
-  uint64_t abiAlign = DL.getTypeABIAlignment(type);
-  // MLIR signless integers are treated as signed for ABI purposes.
-  // Most C/C++ integer types are signless in MLIR but behave as
-  // signed for ABI classification (sign extension, etc.).
+  uint64_t abiAlign = dl.getTypeABIAlignment(type);
   bool isSigned = type.isSigned() || type.isSignless();
-  return Builder.getIntegerType(width, llvm::Align(abiAlign), isSigned);
+  return builder.getIntegerType(width, llvm::Align(abiAlign), isSigned);
 }
 
 const llvm::abi::Type *ABITypeMapper::mapFloatType(mlir::FloatType type) {
-  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  uint64_t abiAlign = dl.getTypeABIAlignment(type);
   const llvm::fltSemantics &semantics = type.getFloatSemantics();
-  return Builder.getFloatType(semantics, llvm::Align(abiAlign));
+  return builder.getFloatType(semantics, llvm::Align(abiAlign));
 }
 
 const llvm::abi::Type *ABITypeMapper::mapIndexType(mlir::IndexType type) {
-  llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
-  uint64_t abiAlign = DL.getTypeABIAlignment(type);
-  return Builder.getIntegerType(sizeInBits.getFixedValue(),
+  llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type);
+  uint64_t abiAlign = dl.getTypeABIAlignment(type);
+  return builder.getIntegerType(sizeInBits.getFixedValue(),
                                 llvm::Align(abiAlign),
                                 /*Signed=*/false);
 }
@@ -74,29 +71,26 @@ const llvm::abi::Type 
*ABITypeMapper::mapVectorType(mlir::VectorType type) {
     return nullptr;
 
   auto shape = type.getShape();
-  // MLIR VectorType is always fixed-length and can be multi-dimensional.
-  // Flatten to a single dimension for ABI purposes.
   uint64_t totalElements = 1;
   for (int64_t dim : shape)
     totalElements *= dim;
 
   llvm::ElementCount ec = llvm::ElementCount::getFixed(totalElements);
-  uint64_t abiAlign = DL.getTypeABIAlignment(type);
-  return Builder.getVectorType(elementTy, ec, llvm::Align(abiAlign));
+  uint64_t abiAlign = dl.getTypeABIAlignment(type);
+  return builder.getVectorType(elementTy, ec, llvm::Align(abiAlign));
 }
 
 const llvm::abi::Type *ABITypeMapper::mapMemRefType(mlir::MemRefType type) {
-  // MemRef is pointer-like for ABI purposes.
-  llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
-  uint64_t abiAlign = DL.getTypeABIAlignment(type);
+  llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type);
+  uint64_t abiAlign = dl.getTypeABIAlignment(type);
   unsigned addrSpace = 0;
   if (auto as = type.getMemorySpace())
     if (auto intAttr = dyn_cast<IntegerAttr>(as))
       addrSpace = intAttr.getInt();
-  return Builder.getPointerType(sizeInBits.getFixedValue(),
+  return builder.getPointerType(sizeInBits.getFixedValue(),
                                 llvm::Align(abiAlign), addrSpace);
 }
 
 const llvm::abi::Type *ABITypeMapper::mapNoneType(mlir::NoneType type) {
-  return Builder.getVoidType();
+  return builder.getVoidType();
 }
diff --git a/mlir/unittests/ABI/ABIRewriteContextTest.cpp 
b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
index 04c28991cc752..59a5307225e0d 100644
--- a/mlir/unittests/ABI/ABIRewriteContextTest.cpp
+++ b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
@@ -37,36 +37,36 @@ TEST(ABIRewriteContextTest, 
MockCanBeConstructedAndDestroyed) {
 
 TEST(ABIRewriteContextTest, ArgClassificationDirect) {
   auto c = ArgClassification::getDirect();
-  EXPECT_EQ(c.Kind, ArgKind::Direct);
-  EXPECT_EQ(c.CoercedType, nullptr);
-  EXPECT_TRUE(c.CanFlatten);
+  EXPECT_EQ(c.kind, ArgKind::Direct);
+  EXPECT_EQ(c.coercedType, nullptr);
+  EXPECT_TRUE(c.canFlatten);
 }
 
 TEST(ABIRewriteContextTest, ArgClassificationDirectWithType) {
   MLIRContext mlirCtx;
   auto i32 = IntegerType::get(&mlirCtx, 32);
   auto c = ArgClassification::getDirect(i32);
-  EXPECT_EQ(c.Kind, ArgKind::Direct);
-  EXPECT_EQ(c.CoercedType, i32);
+  EXPECT_EQ(c.kind, ArgKind::Direct);
+  EXPECT_EQ(c.coercedType, i32);
 }
 
 TEST(ABIRewriteContextTest, ArgClassificationIgnore) {
   auto c = ArgClassification::getIgnore();
-  EXPECT_EQ(c.Kind, ArgKind::Ignore);
+  EXPECT_EQ(c.kind, ArgKind::Ignore);
 }
 
 TEST(ABIRewriteContextTest, ArgClassificationIndirect) {
   auto c = ArgClassification::getIndirect(llvm::Align(8), true);
-  EXPECT_EQ(c.Kind, ArgKind::Indirect);
-  EXPECT_EQ(c.IndirectAlign, llvm::Align(8));
-  EXPECT_TRUE(c.ByVal);
+  EXPECT_EQ(c.kind, ArgKind::Indirect);
+  EXPECT_EQ(c.indirectAlign, llvm::Align(8));
+  EXPECT_TRUE(c.byVal);
 }
 
 TEST(ABIRewriteContextTest, ArgClassificationIndirectNoByVal) {
   auto c = ArgClassification::getIndirect(llvm::Align(16), false);
-  EXPECT_EQ(c.Kind, ArgKind::Indirect);
-  EXPECT_EQ(c.IndirectAlign, llvm::Align(16));
-  EXPECT_FALSE(c.ByVal);
+  EXPECT_EQ(c.kind, ArgKind::Indirect);
+  EXPECT_EQ(c.indirectAlign, llvm::Align(16));
+  EXPECT_FALSE(c.byVal);
 }
 
 TEST(ABIRewriteContextTest, ArgClassificationExtend) {
@@ -74,26 +74,26 @@ TEST(ABIRewriteContextTest, ArgClassificationExtend) {
   auto i8 = IntegerType::get(&mlirCtx, 8);
 
   auto signExt = ArgClassification::getExtend(i8, true);
-  EXPECT_EQ(signExt.Kind, ArgKind::Extend);
-  EXPECT_TRUE(signExt.SignExtend);
+  EXPECT_EQ(signExt.kind, ArgKind::Extend);
+  EXPECT_TRUE(signExt.signExtend);
 
   auto zeroExt = ArgClassification::getExtend(i8, false);
-  EXPECT_EQ(zeroExt.Kind, ArgKind::Extend);
-  EXPECT_FALSE(zeroExt.SignExtend);
+  EXPECT_EQ(zeroExt.kind, ArgKind::Extend);
+  EXPECT_FALSE(zeroExt.signExtend);
 }
 
 TEST(ABIRewriteContextTest, FunctionClassificationHoldsReturnAndArgs) {
   FunctionClassification fc;
-  fc.ReturnInfo = ArgClassification::getDirect();
-  fc.ArgInfos.push_back(ArgClassification::getDirect());
-  fc.ArgInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true));
-  fc.ArgInfos.push_back(ArgClassification::getIgnore());
-
-  EXPECT_EQ(fc.ReturnInfo.Kind, ArgKind::Direct);
-  EXPECT_EQ(fc.ArgInfos.size(), 3u);
-  EXPECT_EQ(fc.ArgInfos[0].Kind, ArgKind::Direct);
-  EXPECT_EQ(fc.ArgInfos[1].Kind, ArgKind::Indirect);
-  EXPECT_EQ(fc.ArgInfos[2].Kind, ArgKind::Ignore);
+  fc.returnInfo = ArgClassification::getDirect();
+  fc.argInfos.push_back(ArgClassification::getDirect());
+  fc.argInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true));
+  fc.argInfos.push_back(ArgClassification::getIgnore());
+
+  EXPECT_EQ(fc.returnInfo.kind, ArgKind::Direct);
+  EXPECT_EQ(fc.argInfos.size(), 3u);
+  EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Direct);
+  EXPECT_EQ(fc.argInfos[1].kind, ArgKind::Indirect);
+  EXPECT_EQ(fc.argInfos[2].kind, ArgKind::Ignore);
 }
 
 } // namespace

>From 8ce99db064c122d3b5026e540d4345bb8ea932b6 Mon Sep 17 00:00:00 2001
From: Adam Smith <[email protected]>
Date: Tue, 14 Apr 2026 12:50:29 -0700
Subject: [PATCH 3/3] [CIR] Add CIRABIRewriteContext for ABI function/call
 rewriting

Add CIRABIRewriteContext, the CIR dialect's concrete implementation
of the shared ABIRewriteContext interface from #190661.  This class
rewrites CIR FuncOps and CallOps to match ABI-lowered signatures.

This initial PR covers Direct, Extend, and Ignore argument/return
kinds with 11 unit tests.  Struct coercion (sret, byval, multi-
register flattening) will follow in a subsequent PR.

Depends on #190661.

Made-with: Cursor
---
 .../TargetLowering/CIRABIRewriteContext.cpp   | 469 ++++++++++++++++++
 .../TargetLowering/CIRABIRewriteContext.h     |  50 ++
 .../Transforms/TargetLowering/CMakeLists.txt  |   2 +
 .../CIR/CIRABIRewriteContextTest.cpp          | 406 +++++++++++++++
 clang/unittests/CIR/CMakeLists.txt            |   5 +
 5 files changed, 932 insertions(+)
 create mode 100644 
clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
 create mode 100644 
clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
 create mode 100644 clang/unittests/CIR/CIRABIRewriteContextTest.cpp

diff --git 
a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
new file mode 100644
index 0000000000000..cab6faf44eda4
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
@@ -0,0 +1,469 @@
+//===- CIRABIRewriteContext.cpp - CIR-specific ABI rewriting 
--------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "CIRABIRewriteContext.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Types.h"
+#include "clang/CIR/Dialect/IR/CIRAttrs.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
+
+using namespace cir;
+using namespace mlir;
+using namespace mlir::abi;
+
+/// Emit a value coercion between two types.  For scalar-to-scalar
+/// (e.g. integer sign extension), a direct cir.cast is sufficient.
+/// When one of the types is a record (struct), LLVM IR's bitcast
+/// cannot reinterpret between aggregate and scalar types, so we go
+/// through memory: alloca srcTy -> store src -> bitcast ptr -> load
+/// dstTy.
+static Value emitCoercion(OpBuilder &rewriter, Location loc, Type dstTy,
+                          Value src) {
+  Type srcTy = src.getType();
+  if (srcTy == dstTy)
+    return src;
+
+  bool needsMemory =
+      mlir::isa<cir::RecordType, cir::ComplexType>(srcTy) ||
+      mlir::isa<cir::RecordType, cir::ComplexType>(dstTy) ||
+      (mlir::isa<cir::VectorType>(srcTy) != mlir::isa<cir::VectorType>(dstTy));
+
+  if (!needsMemory)
+    return cir::CastOp::create(rewriter, loc, dstTy, cir::CastKind::bitcast,
+                               src);
+
+  auto srcPtrTy = cir::PointerType::get(srcTy);
+  auto dstPtrTy = cir::PointerType::get(dstTy);
+
+  auto alloca =
+      cir::AllocaOp::create(rewriter, loc, srcPtrTy, srcTy,
+                            /*name=*/rewriter.getStringAttr("coerce"),
+                            /*alignment=*/rewriter.getI64IntegerAttr(8));
+
+  cir::StoreOp::create(rewriter, loc, src, alloca,
+                       /*isVolatile=*/mlir::UnitAttr(),
+                       /*alignment=*/mlir::IntegerAttr(),
+                       /*sync_scope=*/cir::SyncScopeKindAttr(),
+                       /*mem_order=*/cir::MemOrderAttr());
+
+  auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy,
+                                     cir::CastKind::bitcast, alloca);
+
+  return cir::LoadOp::create(rewriter, loc, dstTy, ptrCast,
+                             /*isDeref=*/mlir::UnitAttr(),
+                             /*isVolatile=*/mlir::UnitAttr(),
+                             /*alignment=*/mlir::IntegerAttr(),
+                             /*sync_scope=*/cir::SyncScopeKindAttr(),
+                             /*mem_order=*/cir::MemOrderAttr());
+}
+
+/// Insert coercion before each cir.return to coerce the return value
+/// from the original type to the ABI type.
+static void insertReturnCoercion(FunctionOpInterface funcOp, Type origRetTy,
+                                 Type coercedRetTy, OpBuilder &rewriter) {
+  SmallVector<cir::ReturnOp> returnOps;
+  funcOp->walk([&](cir::ReturnOp retOp) { returnOps.push_back(retOp); });
+
+  for (cir::ReturnOp retOp : returnOps) {
+    if (retOp.getInput().empty())
+      continue;
+
+    Value origVal = retOp.getInput()[0];
+    if (origVal.getType() == coercedRetTy)
+      continue;
+
+    rewriter.setInsertionPoint(retOp);
+    Value coerced =
+        emitCoercion(rewriter, retOp.getLoc(), coercedRetTy, origVal);
+    retOp->setOperand(0, coerced);
+  }
+}
+
+/// For each argument that requires ABI coercion (Extend or Direct
+/// with a coerced type), insert a cast at the function entry and
+/// replace all uses of the block argument with the cast result.
+static void insertArgAdaptation(FunctionOpInterface funcOp,
+                                const FunctionClassification &fc,
+                                OpBuilder &rewriter) {
+  Region &body = funcOp->getRegion(0);
+  if (body.empty())
+    return;
+
+  Block &entryBlock = body.front();
+  Operation *lastInserted = nullptr;
+
+  for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+    if (!argClass.coercedType)
+      continue;
+
+    if (argClass.kind != ArgKind::Extend && argClass.kind != ArgKind::Direct)
+      continue;
+
+    BlockArgument blockArg = entryBlock.getArgument(idx);
+    Type oldArgTy = blockArg.getType();
+    Type newArgTy = argClass.coercedType;
+
+    if (oldArgTy == newArgTy)
+      continue;
+
+    blockArg.setType(newArgTy);
+
+    if (lastInserted)
+      rewriter.setInsertionPointAfter(lastInserted);
+    else
+      rewriter.setInsertionPointToStart(&entryBlock);
+
+    Value adapted;
+    SmallPtrSet<Operation *, 4> coercionOps;
+
+    if (argClass.kind == ArgKind::Extend) {
+      auto cast = cir::CastOp::create(rewriter, funcOp.getLoc(), oldArgTy,
+                                      cir::CastKind::integral, blockArg);
+      adapted = cast;
+      coercionOps.insert(cast.getOperation());
+    } else {
+      auto srcPtrTy = cir::PointerType::get(newArgTy);
+      auto dstPtrTy = cir::PointerType::get(oldArgTy);
+      Location loc = funcOp.getLoc();
+
+      auto alloca =
+          cir::AllocaOp::create(rewriter, loc, srcPtrTy, newArgTy,
+                                /*name=*/rewriter.getStringAttr("coerce"),
+                                /*alignment=*/rewriter.getI64IntegerAttr(8));
+
+      auto store = cir::StoreOp::create(rewriter, loc, blockArg, alloca,
+                                        /*isVolatile=*/mlir::UnitAttr(),
+                                        /*alignment=*/mlir::IntegerAttr(),
+                                        
/*sync_scope=*/cir::SyncScopeKindAttr(),
+                                        /*mem_order=*/cir::MemOrderAttr());
+
+      auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy,
+                                         cir::CastKind::bitcast, alloca);
+
+      auto load = cir::LoadOp::create(rewriter, loc, oldArgTy, ptrCast,
+                                      /*isDeref=*/mlir::UnitAttr(),
+                                      /*isVolatile=*/mlir::UnitAttr(),
+                                      /*alignment=*/mlir::IntegerAttr(),
+                                      /*sync_scope=*/cir::SyncScopeKindAttr(),
+                                      /*mem_order=*/cir::MemOrderAttr());
+
+      adapted = load;
+      coercionOps.insert(alloca.getOperation());
+      coercionOps.insert(store.getOperation());
+      coercionOps.insert(ptrCast.getOperation());
+      coercionOps.insert(load.getOperation());
+    }
+    lastInserted = adapted.getDefiningOp();
+
+    blockArg.replaceAllUsesExcept(adapted, coercionOps);
+  }
+}
+
+LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
+    FunctionOpInterface funcOp, const FunctionClassification &fc,
+    OpBuilder &rewriter) {
+  ArrayRef<Type> oldArgTypes = funcOp.getArgumentTypes();
+  ArrayRef<Type> oldResultTypes = funcOp.getResultTypes();
+  bool isDecl = funcOp.isDeclaration();
+
+  bool returnCoerced = false;
+  bool hasArgChanges = false;
+  SmallVector<unsigned> ignoredArgIndices;
+
+  // Compute new argument types.
+  SmallVector<Type> newArgTypes;
+
+  for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+    Type origTy = oldArgTypes[idx];
+    switch (argClass.kind) {
+    case ArgKind::Direct:
+    case ArgKind::Extend:
+      newArgTypes.push_back(argClass.coercedType ? argClass.coercedType
+                                                 : origTy);
+      if (argClass.coercedType && argClass.coercedType != origTy)
+        hasArgChanges = true;
+      break;
+    case ArgKind::Ignore:
+      ignoredArgIndices.push_back(idx);
+      hasArgChanges = true;
+      break;
+    case ArgKind::Indirect:
+    case ArgKind::Expand:
+      newArgTypes.push_back(origTy);
+      break;
+    }
+  }
+
+  // Compute new result type.  CIR's FuncType::clone expects exactly
+  // one result type (VoidType for void-returning functions).
+  auto voidTy = cir::VoidType::get(funcOp->getContext());
+  Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0];
+  Type newRetTy = origRetTy;
+
+  if (fc.returnInfo.kind == ArgKind::Direct ||
+      fc.returnInfo.kind == ArgKind::Extend) {
+    if (fc.returnInfo.coercedType && !oldResultTypes.empty() &&
+        fc.returnInfo.coercedType != oldResultTypes[0]) {
+      newRetTy = fc.returnInfo.coercedType;
+      returnCoerced = true;
+    }
+  } else if (fc.returnInfo.kind == ArgKind::Ignore) {
+    newRetTy = voidTy;
+  }
+
+  SmallVector<Type> newResultTypes = {newRetTy};
+
+  // If nothing changed, skip the rewrite -- unless we have
+  // Extend args/returns that need signext/zeroext attrs.
+  bool hasExtend = fc.returnInfo.kind == ArgKind::Extend;
+  for (auto &argClass : fc.argInfos)
+    if (argClass.kind == ArgKind::Extend)
+      hasExtend = true;
+  if (!hasArgChanges && !hasExtend && !returnCoerced && newRetTy == origRetTy 
&&
+      newArgTypes == SmallVector<Type>(oldArgTypes))
+    return success();
+
+  // Body modifications only apply to definitions.
+  if (!isDecl) {
+    if (hasArgChanges)
+      insertArgAdaptation(funcOp, fc, rewriter);
+
+    // Erase block arguments for Ignore'd args (in reverse to keep
+    // indices valid).  Replace any remaining uses with undef first.
+    if (!ignoredArgIndices.empty()) {
+      Region &body = funcOp->getRegion(0);
+      if (!body.empty()) {
+        Block &entry = body.front();
+        for (int i = ignoredArgIndices.size() - 1; i >= 0; --i) {
+          unsigned blockIdx = ignoredArgIndices[i];
+          if (blockIdx < entry.getNumArguments()) {
+            BlockArgument arg = entry.getArgument(blockIdx);
+            if (!arg.use_empty()) {
+              rewriter.setInsertionPointToStart(&entry);
+              auto ptrTy = cir::PointerType::get(arg.getType());
+              auto alloca = cir::AllocaOp::create(
+                  rewriter, funcOp.getLoc(), ptrTy, arg.getType(),
+                  /*name=*/rewriter.getStringAttr("ignored"),
+                  /*alignment=*/rewriter.getI64IntegerAttr(1));
+              auto load = cir::LoadOp::create(
+                  rewriter, funcOp.getLoc(), arg.getType(), alloca,
+                  /*isDeref=*/mlir::UnitAttr(),
+                  /*isVolatile=*/mlir::UnitAttr(),
+                  /*alignment=*/mlir::IntegerAttr(),
+                  /*sync_scope=*/cir::SyncScopeKindAttr(),
+                  /*mem_order=*/cir::MemOrderAttr());
+              arg.replaceAllUsesWith(load);
+            }
+            entry.eraseArgument(blockIdx);
+          }
+        }
+      }
+    }
+
+    if (returnCoerced)
+      insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType,
+                           rewriter);
+
+    // When the return type is Ignore (empty struct), rewrite all
+    // return ops to drop their operand so they return void.
+    if (fc.returnInfo.kind == ArgKind::Ignore && !oldResultTypes.empty()) {
+      funcOp.walk([&](cir::ReturnOp retOp) {
+        if (retOp.getNumOperands() > 0) {
+          rewriter.setInsertionPoint(retOp);
+          cir::ReturnOp::create(rewriter, retOp.getLoc());
+          retOp->erase();
+        }
+      });
+    }
+  }
+
+  Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes);
+  funcOp.setFunctionTypeAttr(TypeAttr::get(newFnTy));
+
+  // Attach signext/zeroext attributes for Extend args and returns.
+  {
+    MLIRContext *ctx = funcOp->getContext();
+    unsigned numArgs = newArgTypes.size();
+    bool needsArgAttrs = false;
+    bool hasIgnoredArgs = !ignoredArgIndices.empty();
+    for (auto &argClass : fc.argInfos)
+      if (argClass.kind == ArgKind::Extend)
+        needsArgAttrs = true;
+    if (hasIgnoredArgs && funcOp->hasAttr("arg_attrs"))
+      needsArgAttrs = true;
+
+    if (needsArgAttrs) {
+      SmallVector<Attribute> argAttrDicts(numArgs, DictionaryAttr::get(ctx));
+
+      // Preserve existing arg_attrs, skipping Ignore'd args.
+      if (auto existingAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs")) {
+        unsigned newIdx = 0;
+        for (unsigned oldIdx = 0; oldIdx < existingAttrs.size(); ++oldIdx) {
+          if (oldIdx < fc.argInfos.size() &&
+              fc.argInfos[oldIdx].kind == ArgKind::Ignore)
+            continue;
+          if (newIdx < numArgs)
+            argAttrDicts[newIdx] = existingAttrs[oldIdx];
+          ++newIdx;
+        }
+      }
+
+      for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+        if (argClass.kind != ArgKind::Extend)
+          continue;
+        if (idx >= numArgs)
+          continue;
+        auto existing = mlir::cast<DictionaryAttr>(argAttrDicts[idx]);
+        SmallVector<NamedAttribute> attrs(existing.begin(), existing.end());
+        StringRef attrName =
+            argClass.signExtend ? "llvm.signext" : "llvm.zeroext";
+        attrs.push_back(
+            rewriter.getNamedAttr(attrName, rewriter.getUnitAttr()));
+        argAttrDicts[idx] = DictionaryAttr::get(ctx, attrs);
+      }
+
+      funcOp->setAttr("arg_attrs", ArrayAttr::get(ctx, argAttrDicts));
+    }
+
+    // Add signext/zeroext to return value for Extend returns.
+    if (fc.returnInfo.kind == ArgKind::Extend) {
+      SmallVector<NamedAttribute> retAttrs;
+      if (auto existing = funcOp->getAttrOfType<ArrayAttr>("res_attrs"))
+        if (existing.size() > 0)
+          for (auto attr : mlir::cast<DictionaryAttr>(existing[0]))
+            retAttrs.push_back(attr);
+      StringRef attrName =
+          fc.returnInfo.signExtend ? "llvm.signext" : "llvm.zeroext";
+      retAttrs.push_back(
+          rewriter.getNamedAttr(attrName, rewriter.getUnitAttr()));
+      SmallVector<Attribute> resAttrDicts;
+      resAttrDicts.push_back(DictionaryAttr::get(ctx, retAttrs));
+      funcOp->setAttr("res_attrs", ArrayAttr::get(ctx, resAttrDicts));
+    }
+  }
+
+  return success();
+}
+
+LogicalResult CIRABIRewriteContext::rewriteCallSite(
+    Operation *callOp, const FunctionClassification &fc, OpBuilder &rewriter) {
+  auto call = cast<cir::CallOp>(callOp);
+
+  SmallVector<Value> newArgs;
+  bool argsChanged = false;
+  auto argOperands = call.getArgOperands();
+
+  for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+    if (idx >= argOperands.size())
+      break;
+
+    Value arg = argOperands[idx];
+
+    if (argClass.kind == ArgKind::Ignore) {
+      argsChanged = true;
+      continue;
+    }
+
+    if ((argClass.kind == ArgKind::Extend ||
+         argClass.kind == ArgKind::Direct) &&
+        argClass.coercedType && arg.getType() != argClass.coercedType) {
+      rewriter.setInsertionPoint(call);
+      Value coerced;
+      if (argClass.kind == ArgKind::Extend)
+        coerced =
+            cir::CastOp::create(rewriter, call.getLoc(), argClass.coercedType,
+                                cir::CastKind::integral, arg);
+      else
+        coerced =
+            emitCoercion(rewriter, call.getLoc(), argClass.coercedType, arg);
+      newArgs.push_back(coerced);
+      argsChanged = true;
+    } else {
+      newArgs.push_back(arg);
+    }
+  }
+
+  // Pass through any extra operands beyond classified args.
+  for (unsigned i = fc.argInfos.size(); i < argOperands.size(); ++i)
+    newArgs.push_back(argOperands[i]);
+
+  // Handle direct return coercion.
+  bool returnCoerced = false;
+  Type coercedRetTy;
+  if ((fc.returnInfo.kind == ArgKind::Direct ||
+       fc.returnInfo.kind == ArgKind::Extend) &&
+      fc.returnInfo.coercedType) {
+    returnCoerced = true;
+    coercedRetTy = fc.returnInfo.coercedType;
+  }
+
+  // Handle Ignore return: replace with void call.
+  if (fc.returnInfo.kind == ArgKind::Ignore && call.getNumResults() > 0) {
+    rewriter.setInsertionPoint(call);
+    auto voidTy = cir::VoidType::get(call.getContext());
+    auto newCall = cir::CallOp::create(rewriter, call.getLoc(),
+                                       call.getCalleeAttr(), voidTy, newArgs);
+    for (NamedAttribute attr : call->getAttrs())
+      if (!newCall->hasAttr(attr.getName()))
+        newCall->setAttr(attr.getName(), attr.getValue());
+
+    if (!call.getResult().use_empty()) {
+      rewriter.setInsertionPointAfter(newCall);
+      Type origRetTy = call.getResult().getType();
+      auto ptrTy = cir::PointerType::get(origRetTy);
+      auto alloca =
+          cir::AllocaOp::create(rewriter, call.getLoc(), ptrTy, origRetTy,
+                                /*name=*/rewriter.getStringAttr("ignored"),
+                                /*alignment=*/rewriter.getI64IntegerAttr(1));
+      auto load =
+          cir::LoadOp::create(rewriter, call.getLoc(), origRetTy, alloca,
+                              /*isDeref=*/mlir::UnitAttr(),
+                              /*isVolatile=*/mlir::UnitAttr(),
+                              /*alignment=*/mlir::IntegerAttr(),
+                              /*sync_scope=*/cir::SyncScopeKindAttr(),
+                              /*mem_order=*/cir::MemOrderAttr());
+      call.getResult().replaceAllUsesWith(load);
+    }
+    call->erase();
+    return success();
+  }
+
+  if (!returnCoerced && !argsChanged)
+    return success();
+
+  Type callRetTy;
+  Type origRetTy;
+  bool hasResult = call.getNumResults() > 0;
+
+  if (hasResult) {
+    origRetTy = call.getResult().getType();
+    callRetTy = returnCoerced ? coercedRetTy : origRetTy;
+  } else {
+    callRetTy = cir::VoidType::get(call.getContext());
+  }
+
+  rewriter.setInsertionPoint(call);
+  auto newCall = cir::CallOp::create(rewriter, call.getLoc(),
+                                     call.getCalleeAttr(), callRetTy, newArgs);
+  for (NamedAttribute attr : call->getAttrs())
+    if (!newCall->hasAttr(attr.getName()))
+      newCall->setAttr(attr.getName(), attr.getValue());
+
+  if (hasResult && returnCoerced && origRetTy != coercedRetTy) {
+    rewriter.setInsertionPointAfter(newCall);
+    Value castBack =
+        emitCoercion(rewriter, call.getLoc(), origRetTy, newCall.getResult());
+    call.getResult().replaceAllUsesWith(castBack);
+  } else if (hasResult) {
+    call.getResult().replaceAllUsesWith(newCall.getResult());
+  }
+
+  call->erase();
+  return success();
+}
diff --git 
a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
new file mode 100644
index 0000000000000..93d968c9db123
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
@@ -0,0 +1,50 @@
+//===- CIRABIRewriteContext.h - CIR-specific ABI rewriting ------*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines CIRABIRewriteContext, the CIR dialect's implementation of
+// the shared ABIRewriteContext interface.  It rewrites CIR function 
definitions
+// and call sites to match ABI-lowered signatures.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
+#define CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+
+namespace cir {
+
+/// CIR-specific implementation of the ABIRewriteContext interface.
+///
+/// This class knows how to rewrite CIR FuncOps and CallOps to match
+/// ABI-lowered signatures, using CIR operations for coercion (alloca,
+/// load, store, cast, etc.).
+class CIRABIRewriteContext : public mlir::abi::ABIRewriteContext {
+  mlir::ModuleOp module;
+
+public:
+  explicit CIRABIRewriteContext(mlir::ModuleOp module) : module(module) {}
+
+  mlir::LogicalResult
+  rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp,
+                            const mlir::abi::FunctionClassification &fc,
+                            mlir::OpBuilder &rewriter) override;
+
+  mlir::LogicalResult
+  rewriteCallSite(mlir::Operation *callOp,
+                  const mlir::abi::FunctionClassification &fc,
+                  mlir::OpBuilder &rewriter) override;
+
+  mlir::StringRef getDialectNamespace() const override { return "cir"; }
+};
+
+} // namespace cir
+
+#endif // 
CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
index 86502b7f5dd4e..9833952623708 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_clang_library(MLIRCIRTargetLowering
+  CIRABIRewriteContext.cpp
   CIRCXXABI.cpp
   LowerModule.cpp
   LowerItaniumCXXABI.cpp
@@ -15,6 +16,7 @@ add_clang_library(MLIRCIRTargetLowering
   LINK_LIBS PUBLIC
 
   clangBasic
+  MLIRABI
   MLIRIR
   MLIRPass
   MLIRDLTIDialect
diff --git a/clang/unittests/CIR/CIRABIRewriteContextTest.cpp 
b/clang/unittests/CIR/CIRABIRewriteContextTest.cpp
new file mode 100644
index 0000000000000..5cf8637346ae3
--- /dev/null
+++ b/clang/unittests/CIR/CIRABIRewriteContextTest.cpp
@@ -0,0 +1,406 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Unit tests for CIRABIRewriteContext, the CIR dialect's concrete
+// implementation of the shared ABIRewriteContext interface.  Each test
+// constructs a FunctionClassification manually (no ABI library needed)
+// and verifies the resulting IR after rewriting.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "clang/CIR/Dialect/IR/CIRAttrs.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/IR/CIRTypes.h"
+#include "gtest/gtest.h"
+
+// The header is private to the Transforms library, so we include it
+// via the path relative to the source tree.  The CMakeLists arranges
+// the include directories.
+#include 
"../../lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h"
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class CIRABIRewriteTest : public ::testing::Test {
+protected:
+  CIRABIRewriteTest() : builder(&context), loc(UnknownLoc::get(&context)) {
+    context.loadDialect<cir::CIRDialect>();
+  }
+
+  MLIRContext context;
+  OpBuilder builder;
+  Location loc;
+
+  /// Create a ModuleOp containing a single CIR FuncOp with the given
+  /// argument types and return type.  If \p addBody is true, the
+  /// function gets an entry block with a cir.return (returning its
+  /// first result-typed block arg if non-void, or void otherwise).
+  std::pair<ModuleOp, cir::FuncOp> createFunc(StringRef name,
+                                              ArrayRef<Type> argTypes,
+                                              Type retType,
+                                              bool addBody = true) {
+    auto module = ModuleOp::create(loc);
+    builder.setInsertionPointToEnd(module.getBody());
+
+    auto funcTy = cir::FuncType::get(argTypes, retType);
+    auto funcOp = cir::FuncOp::create(builder, loc, name, funcTy);
+
+    if (addBody) {
+      Block *entry = funcOp.addEntryBlock();
+      builder.setInsertionPointToEnd(entry);
+      if (isa<cir::VoidType>(retType))
+        cir::ReturnOp::create(builder, loc);
+      else
+        cir::ReturnOp::create(builder, loc,
+                              mlir::ValueRange{entry->getArgument(0)});
+    }
+
+    return {module, funcOp};
+  }
+
+  /// Create a ModuleOp containing a caller function that calls a
+  /// callee.  The caller passes its own block arguments to the callee.
+  struct CallFixture {
+    ModuleOp module;
+    cir::FuncOp callee;
+    cir::FuncOp caller;
+    cir::CallOp callOp;
+  };
+
+  CallFixture createCallPair(StringRef calleeName, ArrayRef<Type> argTypes,
+                             Type retType) {
+    auto module = ModuleOp::create(loc);
+    builder.setInsertionPointToEnd(module.getBody());
+
+    auto funcTy = cir::FuncType::get(argTypes, retType);
+
+    // Callee (declaration only).
+    auto callee = cir::FuncOp::create(builder, loc, calleeName, funcTy);
+
+    // Caller with a body that calls the callee.
+    auto caller = cir::FuncOp::create(builder, loc, "caller", funcTy);
+    Block *entry = caller.addEntryBlock();
+    builder.setInsertionPointToEnd(entry);
+
+    SmallVector<Value> args;
+    for (unsigned i = 0; i < argTypes.size(); ++i)
+      args.push_back(entry->getArgument(i));
+
+    cir::CallOp call;
+    if (isa<cir::VoidType>(retType)) {
+      auto voidTy = cir::VoidType::get(&context);
+      call = cir::CallOp::create(
+          builder, loc, mlir::FlatSymbolRefAttr::get(&context, calleeName),
+          voidTy, args);
+      cir::ReturnOp::create(builder, loc);
+    } else {
+      call = cir::CallOp::create(
+          builder, loc, mlir::FlatSymbolRefAttr::get(&context, calleeName),
+          retType, args);
+      cir::ReturnOp::create(builder, loc, mlir::ValueRange{call.getResult()});
+    }
+
+    return {module, callee, caller, call};
+  }
+};
+
+// ---- rewriteFunctionDefinition tests ----
+
+TEST_F(CIRABIRewriteTest, DirectPassthrough) {
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getDirect();
+  fc.argInfos.push_back(ArgClassification::getDirect());
+
+  cir::CIRABIRewriteContext rewriteCtx(module);
+  OpBuilder rewriter(funcOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+  auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+  EXPECT_EQ(fnTy.getInputs().size(), 1u);
+  EXPECT_EQ(fnTy.getInputs()[0], i32Ty);
+  EXPECT_EQ(fnTy.getReturnType(), i32Ty);
+
+  module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, DirectReturnCoercion) {
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto i64Ty = cir::IntType::get(&context, 64, false);
+  auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getDirect(i64Ty);
+  fc.argInfos.push_back(ArgClassification::getDirect());
+
+  cir::CIRABIRewriteContext rewriteCtx(module);
+  OpBuilder rewriter(funcOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+  auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+  EXPECT_EQ(fnTy.getReturnType(), i64Ty);
+
+  module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, ExtendArg) {
+  auto i8Ty = cir::IntType::get(&context, 8, true);
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto voidTy = cir::VoidType::get(&context);
+  auto [module, funcOp] = createFunc("f", {i8Ty}, voidTy);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getDirect();
+  fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true));
+
+  cir::CIRABIRewriteContext rewriteCtx(module);
+  OpBuilder rewriter(funcOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+  auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+  EXPECT_EQ(fnTy.getInputs().size(), 1u);
+  EXPECT_EQ(fnTy.getInputs()[0], i32Ty);
+
+  // Verify signext attribute was attached.
+  auto argAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
+  ASSERT_TRUE(argAttrs != nullptr);
+  ASSERT_EQ(argAttrs.size(), 1u);
+  auto dict = cast<DictionaryAttr>(argAttrs[0]);
+  EXPECT_TRUE(dict.get("llvm.signext") != nullptr);
+
+  // Verify the entry block has a cir.cast (integral) to adapt i32
+  // back to i8 for body uses.
+  Block &entry = funcOp->getRegion(0).front();
+  bool foundCast = false;
+  for (Operation &op : entry) {
+    if (auto cast = dyn_cast<cir::CastOp>(op)) {
+      if (cast.getKind() == cir::CastKind::integral) {
+        EXPECT_EQ(cast.getResult().getType(), i8Ty);
+        foundCast = true;
+      }
+    }
+  }
+  EXPECT_TRUE(foundCast);
+
+  module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, ExtendReturn) {
+  auto i8Ty = cir::IntType::get(&context, 8, true);
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto [module, funcOp] = createFunc("f", {i8Ty}, i8Ty);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getExtend(i32Ty, false);
+  fc.argInfos.push_back(ArgClassification::getDirect());
+
+  cir::CIRABIRewriteContext rewriteCtx(module);
+  OpBuilder rewriter(funcOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+  auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+  EXPECT_EQ(fnTy.getReturnType(), i32Ty);
+
+  // Verify zeroext attribute on return.
+  auto resAttrs = funcOp->getAttrOfType<ArrayAttr>("res_attrs");
+  ASSERT_TRUE(resAttrs != nullptr);
+  ASSERT_EQ(resAttrs.size(), 1u);
+  auto dict = cast<DictionaryAttr>(resAttrs[0]);
+  EXPECT_TRUE(dict.get("llvm.zeroext") != nullptr);
+
+  module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, IgnoreReturn) {
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getIgnore();
+  fc.argInfos.push_back(ArgClassification::getDirect());
+
+  cir::CIRABIRewriteContext rewriteCtx(module);
+  OpBuilder rewriter(funcOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+  auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+  auto voidTy = cir::VoidType::get(&context);
+  EXPECT_EQ(fnTy.getReturnType(), voidTy);
+
+  module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, IgnoreArg) {
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto voidTy = cir::VoidType::get(&context);
+  auto [module, funcOp] = createFunc("f", {i32Ty}, voidTy);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getDirect();
+  fc.argInfos.push_back(ArgClassification::getIgnore());
+
+  cir::CIRABIRewriteContext rewriteCtx(module);
+  OpBuilder rewriter(funcOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+  auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+  EXPECT_EQ(fnTy.getInputs().size(), 0u);
+
+  Block &entry = funcOp->getRegion(0).front();
+  EXPECT_EQ(entry.getNumArguments(), 0u);
+
+  module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, DeclarationRewrite) {
+  auto i8Ty = cir::IntType::get(&context, 8, true);
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto [module, funcOp] = createFunc("f", {i8Ty}, i8Ty, /*addBody=*/false);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getExtend(i32Ty, true);
+  fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true));
+
+  cir::CIRABIRewriteContext rewriteCtx(module);
+  OpBuilder rewriter(funcOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+  auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+  EXPECT_EQ(fnTy.getInputs()[0], i32Ty);
+  EXPECT_EQ(fnTy.getReturnType(), i32Ty);
+
+  // Verify both signext attributes.
+  auto argAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
+  ASSERT_TRUE(argAttrs != nullptr);
+  auto dict = cast<DictionaryAttr>(argAttrs[0]);
+  EXPECT_TRUE(dict.get("llvm.signext") != nullptr);
+
+  auto resAttrs = funcOp->getAttrOfType<ArrayAttr>("res_attrs");
+  ASSERT_TRUE(resAttrs != nullptr);
+  auto rdict = cast<DictionaryAttr>(resAttrs[0]);
+  EXPECT_TRUE(rdict.get("llvm.signext") != nullptr);
+
+  module->erase();
+}
+
+// ---- rewriteCallSite tests ----
+
+TEST_F(CIRABIRewriteTest, CallSiteDirectPassthrough) {
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto fixture = createCallPair("callee", {i32Ty}, i32Ty);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getDirect();
+  fc.argInfos.push_back(ArgClassification::getDirect());
+
+  cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+  OpBuilder rewriter(fixture.callOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+  // The original call should still be there (no changes needed).
+  EXPECT_EQ(fixture.callOp->getNumResults(), 1u);
+  EXPECT_EQ(fixture.callOp->getResult(0).getType(), i32Ty);
+
+  fixture.module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, CallSiteExtendArg) {
+  auto i8Ty = cir::IntType::get(&context, 8, true);
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto voidTy = cir::VoidType::get(&context);
+  auto fixture = createCallPair("callee", {i8Ty}, voidTy);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getDirect();
+  fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true));
+
+  cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+  OpBuilder rewriter(fixture.callOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+  // The old call was erased and replaced.  Look for a CallOp whose
+  // argument is i32 (the extended type).
+  Block &callerEntry = fixture.caller->getRegion(0).front();
+  cir::CallOp newCall;
+  for (Operation &op : callerEntry)
+    if (auto c = dyn_cast<cir::CallOp>(op))
+      newCall = c;
+  ASSERT_TRUE(newCall != nullptr);
+  EXPECT_EQ(newCall.getArgOperands()[0].getType(), i32Ty);
+
+  fixture.module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, CallSiteIgnoreReturn) {
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto fixture = createCallPair("callee", {i32Ty}, i32Ty);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getIgnore();
+  fc.argInfos.push_back(ArgClassification::getDirect());
+
+  cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+  OpBuilder rewriter(fixture.callOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+  // Find the replacement void call.
+  Block &callerEntry = fixture.caller->getRegion(0).front();
+  cir::CallOp newCall;
+  for (Operation &op : callerEntry)
+    if (auto c = dyn_cast<cir::CallOp>(op))
+      newCall = c;
+  ASSERT_TRUE(newCall != nullptr);
+  EXPECT_EQ(newCall.getNumResults(), 0u);
+
+  fixture.module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, CallSiteIgnoreArg) {
+  auto i32Ty = cir::IntType::get(&context, 32, true);
+  auto voidTy = cir::VoidType::get(&context);
+  auto fixture = createCallPair("callee", {i32Ty}, voidTy);
+
+  FunctionClassification fc;
+  fc.returnInfo = ArgClassification::getDirect();
+  fc.argInfos.push_back(ArgClassification::getIgnore());
+
+  cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+  OpBuilder rewriter(fixture.callOp);
+  ASSERT_TRUE(
+      succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+  // Find the replacement call -- it should have zero args.
+  Block &callerEntry = fixture.caller->getRegion(0).front();
+  cir::CallOp newCall;
+  for (Operation &op : callerEntry)
+    if (auto c = dyn_cast<cir::CallOp>(op))
+      newCall = c;
+  ASSERT_TRUE(newCall != nullptr);
+  EXPECT_EQ(newCall.getArgOperands().size(), 0u);
+
+  fixture.module->erase();
+}
+
+} // namespace
diff --git a/clang/unittests/CIR/CMakeLists.txt 
b/clang/unittests/CIR/CMakeLists.txt
index 650fde38c48a9..d318810d33fe5 100644
--- a/clang/unittests/CIR/CMakeLists.txt
+++ b/clang/unittests/CIR/CMakeLists.txt
@@ -1,15 +1,20 @@
 set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include )
 set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include)
+set(CLANG_CIR_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../lib/CIR)
 include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
 include_directories(${MLIR_TABLEGEN_OUTPUT_DIR})
+include_directories(${CLANG_CIR_SRC_DIR})
 
 add_distinct_clang_unittest(CIRUnitTests
+  CIRABIRewriteContextTest.cpp
   PointerLikeTest.cpp
   LLVM_COMPONENTS
   Core
 
   LINK_LIBS
+  MLIRABI
   MLIRCIR
+  MLIRCIRTargetLowering
   CIROpenACCSupport
   MLIRIR
   MLIROpenACCDialect

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to