[llvm-branch-commits] [clang] [AArch64][ARM] Add a release note about _BitInt (PR #101521)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/101521 None >From 3079b62127b9fd2878f9a1bf8ffeb2a5be90ab5b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 1 Aug 2024 17:58:18 +0100 Subject: [PATCH] [AArch64][ARM] Add a release note about _BitInt --- clang/docs/ReleaseNotes.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index b4ef1e9672a5d..0bdc3cebbb068 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -1203,6 +1203,8 @@ Arm and AArch64 Support * Arm Neoverse-N3 (neoverse-n3). * Arm Neoverse-V3 (neoverse-v3). * Arm Neoverse-V3AE (neoverse-v3ae). +- The C23 ``_BitInt`` implementation was brought into compliance with AAPCS32 + and AAPCS64 Android Support ^^^ ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [clang] [AArch64] Add some release notes items (PR #79983)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/79983 None >From d647f6a4754807648dd11480b3e942571a9e1e25 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 30 Jan 2024 11:13:42 + Subject: [PATCH] [AArch64] Add some release notes items --- clang/docs/ReleaseNotes.rst | 5 + llvm/docs/ReleaseNotes.rst | 8 2 files changed, 13 insertions(+) diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index 060bc7669b72a..ab5fb5aee61a8 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -1164,6 +1164,11 @@ Arm and AArch64 Support * Cortex-A720 (cortex-a720). * Cortex-X4 (cortex-x4). +- Alpha support has been added for SVE2.1 intrinsics. + +- Support has been added for `-fstack-clash-protection` and `-mstack-probe-size` + command line options. + Android Support ^^^ diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst index 7b6a3f10d6377..990f5a5f73e84 100644 --- a/llvm/docs/ReleaseNotes.rst +++ b/llvm/docs/ReleaseNotes.rst @@ -105,6 +105,14 @@ Changes to the AArch64 Backend Armv9.0a has the same features enabled as Armv8.5a, with the exception of crypto. +* Assembler/disassembler support has been added for 2023 architecture + extensions. + +* Support has been added for Stack Clash Protection. During function frame + creation and dynamic stack allocations, the compiler will issue memory + accesses at reguilar intervals so that a guard area at the top of the stack + can't be skipped over. + Changes to the AMDGPU Backend - ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [clang] [AArch64] Add some release notes items (PR #79983)
https://github.com/momchil-velikov milestoned https://github.com/llvm/llvm-project/pull/79983 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [clang] release/20.x: [AArch64] Add MSVC mangling for the __mfp8 type (#124968) (PR #125066)
https://github.com/momchil-velikov approved this pull request. https://github.com/llvm/llvm-project/pull/125066 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { momchil-velikov wrote: It's inconsistent only if others were consistent with each other. I've updated some, but really with a lack of a policy (or automated process) it's rather futile to try to guess what one or another reviewer would like. https://github.com/llvm/llvm-project/pull/135634 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/135636 Supersedes https://github.com/llvm/llvm-project/pull/135359 >From 2e61d3ee7b9ac88ae1be8ca248dad1a0880ccff4 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/135634 Supersedes https://github.com/llvm/llvm-project/pull/135358 >From 71e2f13ad5922bf93961c5d81fd9d1f5899c80b0 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 10 Apr 2025 14:38:27 + Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 32 +++ .../Transforms/LegalizeForLLVMExport.cpp | 4 +++ .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++ 5 files changed, 71 insertions(+) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 1a59062ccc93d..da2a8f89b4cfd 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ +USMMLA: Unsigned by signed integer matrix multiply-accumulate. + +The unsigned by signed integer matrix multiply-accumulate operation +multiplies the 2×8 matrix of unsigned 8-bit integer values held +the first source vector by the 8×2 matrix of signed 8-bit integer +values in the second source vector. The resulting 2×2 widened 32-bit +integer matrix product is then added to the 32-bit integer matrix +accumulator. + +Source: +https://developer.arm.com/documentation/100987/ + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = +"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, @@ -568,6 +596,10 @@ def SmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"smmla">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; +def UsmmlaIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"usmmla">, + Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; + def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index fe13ed03356b2..b1846e15196fc 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern; using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; +using UsmmlaOpLowering = OneToOneConvertToLLVMPattern; using DupQLaneLowering = OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = @@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( SmmlaOpLowering, UdotOpLowering, UmmlaOpLowering, + UsmmlaOpLowering, DupQLaneLowering, ScalableMaskedAddIOpLowering, ScalableMaskedAddFOpLowering, @@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget( SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp, +UsmmlaIntrOp, DupQLaneIntrOp, ScalableMaskedAddIIntrOp, ScalableMaskedAddFIntrOp, @@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget( SmmlaOp, UdotOp, UmmlaOp, + UsmmlaOp, DupQLaneOp, ScalableMaskedAddIOp, ScalableMaskedAddFOp, diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 5d044517e0ea8..47587aa26506c 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -48,6 +48,18 @@ f
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634 >From 5e91c2eb411cba43794fa7db918e88099885849e Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 10 Apr 2025 14:38:27 + Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 95 +++ .../Transforms/LegalizeForLLVMExport.cpp | 4 + .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++ 5 files changed, 96 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 3a990f8464ef8..7385bb73b449a 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -147,11 +147,9 @@ class ScalableMaskedIOp, - AllTypesMatch<["acc", "dst"]>, - ]> { +def SdotOp : ArmSVE_Op<"sdot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ SDOT: Signed integer addition of dot product. @@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def SmmlaOp : ArmSVE_Op<"smmla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def SmmlaOp : ArmSVE_Op<"smmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ SMMLA: Signed integer matrix multiply-accumulate. @@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UdotOp : ArmSVE_Op<"udot", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def UdotOp : ArmSVE_Op<"udot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ UDOT: Unsigned integer addition of dot product. @@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UmmlaOp : ArmSVE_Op<"ummla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def UmmlaOp : ArmSVE_Op<"ummla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ UMMLA: Unsigned integer matrix multiply-accumulate. @@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ +USMMLA: Unsigned by signed integer matrix multiply-accumulate. + +The unsigned by signed integer matrix multiply-accumulate operation +multiplies the 2×8 matrix of unsigned 8-bit integer values held +the first source vector by the 8×2 matrix of signed 8-bit integer +values in the second source vector. The resulting 2×2 widened 32-bit +integer matrix product is then added to the 32-bit integer matrix +accumulator. + +Source: +https://developer.arm.com/documentation/100987/ + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = +"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", -[Pure, SvboolTypeConstraint<"result", "source">]> -{ +
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634 >From 5e91c2eb411cba43794fa7db918e88099885849e Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 10 Apr 2025 14:38:27 + Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 95 +++ .../Transforms/LegalizeForLLVMExport.cpp | 4 + .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++ 5 files changed, 96 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 3a990f8464ef8..7385bb73b449a 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -147,11 +147,9 @@ class ScalableMaskedIOp, - AllTypesMatch<["acc", "dst"]>, - ]> { +def SdotOp : ArmSVE_Op<"sdot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ SDOT: Signed integer addition of dot product. @@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def SmmlaOp : ArmSVE_Op<"smmla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def SmmlaOp : ArmSVE_Op<"smmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ SMMLA: Signed integer matrix multiply-accumulate. @@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UdotOp : ArmSVE_Op<"udot", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def UdotOp : ArmSVE_Op<"udot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ UDOT: Unsigned integer addition of dot product. @@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UmmlaOp : ArmSVE_Op<"ummla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def UmmlaOp : ArmSVE_Op<"ummla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ UMMLA: Unsigned integer matrix multiply-accumulate. @@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ +USMMLA: Unsigned by signed integer matrix multiply-accumulate. + +The unsigned by signed integer matrix multiply-accumulate operation +multiplies the 2×8 matrix of unsigned 8-bit integer values held +the first source vector by the 8×2 matrix of signed 8-bit integer +values in the second source vector. The resulting 2×2 widened 32-bit +integer matrix product is then added to the 32-bit integer matrix +accumulator. + +Source: +https://developer.arm.com/documentation/100987/ + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = +"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", -[Pure, SvboolTypeConstraint<"result", "source">]> -{ +
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(pa
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(pa
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
momchil-velikov wrote: > One high-level question - would sharing some code between NEON and SVE be > possible? No, I can't see it happening and resulting in less, or simpler, or easier to maintain code. However, it might be possible to add Neon lowering to this patch and see if the result is any good. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern momchil-velikov wrote: Done. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at momchil-velikov wrote: Done (in the top-level description). https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634 Rate limit · GitHub body { background-color: #f6f8fa; color: #24292e; font-family: -apple-system,BlinkMacSystemFont,Segoe UI,Helvetica,Arial,sans-serif,Apple Color Emoji,Segoe UI Emoji,Segoe UI Symbol; font-size: 14px; line-height: 1.5; margin: 0; } .container { margin: 50px auto; max-width: 600px; text-align: center; padding: 0 24px; } a { color: #0366d6; text-decoration: none; } a:hover { text-decoration: underline; } h1 { line-height: 60px; font-size: 48px; font-weight: 300; margin: 0px; text-shadow: 0 1px 0 #fff; } p { color: rgba(0, 0, 0, 0.5); margin: 20px 0 40px; } ul { list-style: none; margin: 25px 0; padding: 0; } li { display: table-cell; font-weight: bold; width: 1%; } .logo { display: inline-block; margin-top: 35px; } .logo-img-2x { display: none; } @media only screen and (-webkit-min-device-pixel-ratio: 2), only screen and ( min--moz-device-pixel-ratio: 2), only screen and ( -o-min-device-pixel-ratio: 2/1), only screen and (min-device-pixel-ratio: 2), only screen and (min-resolution: 192dpi), only screen and (min-resolution: 2dppx) { .logo-img-1x { display: none; } .logo-img-2x { display: inline-block; } } #suggestions { margin-top: 35px; color: #ccc; } #suggestions a { color: #66; font-weight: 200; font-size: 14px; margin: 0 10px; } Whoa there! You have exceeded a secondary rate limit. Please wait a few minutes before you try again; in some cases this may take up to an hour. https://support.github.com/contact";>Contact Support — https://githubstatus.com";>GitHub Status — https://twitter.com/githubstatus";>@githubstatus ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From f397467bc167d94a28a919a45c009a8f08b6351b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/2] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634 >From 528237309c0bfd7bbb51a8fea37b54e07f21ad1d Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 10 Apr 2025 14:38:27 + Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 95 +++ .../Transforms/LegalizeForLLVMExport.cpp | 4 + .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++ 5 files changed, 96 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 3a990f8464ef8..7385bb73b449a 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -147,11 +147,9 @@ class ScalableMaskedIOp, - AllTypesMatch<["acc", "dst"]>, - ]> { +def SdotOp : ArmSVE_Op<"sdot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ SDOT: Signed integer addition of dot product. @@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def SmmlaOp : ArmSVE_Op<"smmla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def SmmlaOp : ArmSVE_Op<"smmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ SMMLA: Signed integer matrix multiply-accumulate. @@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UdotOp : ArmSVE_Op<"udot", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def UdotOp : ArmSVE_Op<"udot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ UDOT: Unsigned integer addition of dot product. @@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UmmlaOp : ArmSVE_Op<"ummla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def UmmlaOp : ArmSVE_Op<"ummla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ UMMLA: Unsigned integer matrix multiply-accumulate. @@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ +USMMLA: Unsigned by signed integer matrix multiply-accumulate. + +The unsigned by signed integer matrix multiply-accumulate operation +multiplies the 2×8 matrix of unsigned 8-bit integer values held +the first source vector by the 8×2 matrix of signed 8-bit integer +values in the second source vector. The resulting 2×2 widened 32-bit +integer matrix product is then added to the 32-bit integer matrix +accumulator. + +Source: +https://developer.arm.com/documentation/100987/ + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = +"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", -[Pure, SvboolTypeConstraint<"result", "source">]> -{ +
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From f397467bc167d94a28a919a45c009a8f08b6351b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/2] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at momchil-velikov wrote: There's no dependency on MMT4D - this comment is merely a rationale why we have chosen these particular operand shapes. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || momchil-velikov wrote: Done https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. momchil-velikov wrote: The `vector.contract` implicitly sign-extends its operands, so it does not need to by accompanied by explicit extend operations. I'll add code to handle this case too. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> momchil-velikov wrote: Because it does not need to be checked at runtime. But fair enough, it should be a `static_assert`, as there's no ambiguity to resolve, and such use is quite common across the code base, e.g. ``` static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); ``` https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op (PR #140572)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/140572 None >From 251f93ea5b87acefac1fbcd6951c3b7870eff83c Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Fri, 16 May 2025 15:47:36 + Subject: [PATCH] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op --- .../mlir/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.h | 31 +++ .../TransformOps/ArmSVEVectorTransformOps.td | 26 ++ .../ArmSVE/TransformOps/CMakeLists.txt| 6 + mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.cpp | 54 .../ArmSVE/TransformOps/CMakeLists.txt| 19 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 263 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 123 .../Vector/CPU/ArmSVE/vector-ummla.mlir | 138 + .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 138 + 12 files changed, 512 insertions(+), 290 deletions(-) create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h new file mode 100644 index 0..7f22cd1fe6435 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSVEVectorTransformOps.h - Vector transform ops *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// + +#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===--===// +// ArmSVE Vector Transform Operations +//===--===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sve { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td new file mode 100644 index 0..81b59340f3b0d --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -0,0 +1,26 @@ +//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +#ifndef ARMSVE_VECTOR_TRANSFORM_OPS +#define ARMSVE_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +def ApplyArmSVELowerContractionPatternsOp +: Op]> { + let description = [{ +Indicates that vector contraction-like operations should be lowered to +finer-grained vector primitives using the ArmSVE dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0..ce8d8fea7f188 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td) +mlir_tablegen(ArmSVEVectorTrans
[llvm-branch-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/140573 None >From 194c1c7737ea7baa74971666f93312a071f5703d Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 19 May 2025 14:50:45 + Subject: [PATCH] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM --- .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 + .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 ++ .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 + .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 + .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 + 5 files changed, 630 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir new file mode 100644 index 0..88534dd2aab1e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir @@ -0,0 +1,117 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17], +
[llvm-branch-commits] [mlir] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op (PR #140572)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140572 >From df54d59d29e8afc04740e86281bce6be5dd157da Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Fri, 16 May 2025 15:47:36 + Subject: [PATCH 1/2] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op --- .../mlir/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.h | 31 +++ .../TransformOps/ArmSVEVectorTransformOps.td | 26 ++ .../ArmSVE/TransformOps/CMakeLists.txt| 6 + mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.cpp | 54 .../ArmSVE/TransformOps/CMakeLists.txt| 19 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 263 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 123 .../Vector/CPU/ArmSVE/vector-ummla.mlir | 138 + .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 138 + 12 files changed, 512 insertions(+), 290 deletions(-) create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h new file mode 100644 index 0..7f22cd1fe6435 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSVEVectorTransformOps.h - Vector transform ops *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// + +#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===--===// +// ArmSVE Vector Transform Operations +//===--===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sve { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td new file mode 100644 index 0..81b59340f3b0d --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -0,0 +1,26 @@ +//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +#ifndef ARMSVE_VECTOR_TRANSFORM_OPS +#define ARMSVE_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +def ApplyArmSVELowerContractionPatternsOp +: Op]> { + let description = [{ +Indicates that vector contraction-like operations should be lowered to +finer-grained vector primitives using the ArmSVE dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0..ce8d8fea7f188 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td) +mlir_tablegen(ArmSVEVectorTransfo
[llvm-branch-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140573 >From 87f29647c650bdff25f93cef6b1d3ccc63eca63b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 19 May 2025 14:50:45 + Subject: [PATCH] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM --- .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 + .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 ++ .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 + .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 + .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 + 5 files changed, 630 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir new file mode 100644 index 0..88534dd2aab1e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir @@ -0,0 +1,117 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17], +
[llvm-branch-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140573 >From 87f29647c650bdff25f93cef6b1d3ccc63eca63b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 19 May 2025 14:50:45 + Subject: [PATCH] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM --- .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 + .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 ++ .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 + .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 + .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 + 5 files changed, 630 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir new file mode 100644 index 0..88534dd2aab1e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir @@ -0,0 +1,117 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17], +
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634 >From e60ca5aadf1043a0cb59d50da5f3dbf68bd50c51 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 10 Apr 2025 14:38:27 + Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 95 +++ .../Transforms/LegalizeForLLVMExport.cpp | 4 + .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++ 5 files changed, 96 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 3a990f8464ef8..7385bb73b449a 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -147,11 +147,9 @@ class ScalableMaskedIOp, - AllTypesMatch<["acc", "dst"]>, - ]> { +def SdotOp : ArmSVE_Op<"sdot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ SDOT: Signed integer addition of dot product. @@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def SmmlaOp : ArmSVE_Op<"smmla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def SmmlaOp : ArmSVE_Op<"smmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ SMMLA: Signed integer matrix multiply-accumulate. @@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UdotOp : ArmSVE_Op<"udot", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def UdotOp : ArmSVE_Op<"udot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ UDOT: Unsigned integer addition of dot product. @@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UmmlaOp : ArmSVE_Op<"ummla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def UmmlaOp : ArmSVE_Op<"ummla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ UMMLA: Unsigned integer matrix multiply-accumulate. @@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ +USMMLA: Unsigned by signed integer matrix multiply-accumulate. + +The unsigned by signed integer matrix multiply-accumulate operation +multiplies the 2×8 matrix of unsigned 8-bit integer values held +the first source vector by the 8×2 matrix of signed 8-bit integer +values in the second source vector. The resulting 2×2 widened 32-bit +integer matrix product is then added to the 32-bit integer matrix +accumulator. + +Source: +https://developer.arm.com/documentation/100987/ + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = +"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", -[Pure, SvboolTypeConstraint<"result", "source">]> -{ +
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From aa8a667f206874af3b26811ec04d58be12ad43de Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/3] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634 >From e60ca5aadf1043a0cb59d50da5f3dbf68bd50c51 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 10 Apr 2025 14:38:27 + Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` --- mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 95 +++ .../Transforms/LegalizeForLLVMExport.cpp | 4 + .../Dialect/ArmSVE/legalize-for-llvm.mlir | 12 +++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 11 +++ mlir/test/Target/LLVMIR/arm-sve.mlir | 12 +++ 5 files changed, 96 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 3a990f8464ef8..7385bb73b449a 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -147,11 +147,9 @@ class ScalableMaskedIOp, - AllTypesMatch<["acc", "dst"]>, - ]> { +def SdotOp : ArmSVE_Op<"sdot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ SDOT: Signed integer addition of dot product. @@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def SmmlaOp : ArmSVE_Op<"smmla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def SmmlaOp : ArmSVE_Op<"smmla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ SMMLA: Signed integer matrix multiply-accumulate. @@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UdotOp : ArmSVE_Op<"udot", - [Pure, - AllTypesMatch<["src1", "src2"]>, - AllTypesMatch<["acc", "dst"]>, - ]> { +def UdotOp : ArmSVE_Op<"udot", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { let summary = "Vector-vector dot product and accumulate op"; let description = [{ UDOT: Unsigned integer addition of dot product. @@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } -def UmmlaOp : ArmSVE_Op<"ummla", -[Pure, -AllTypesMatch<["src1", "src2"]>, -AllTypesMatch<["acc", "dst"]>, - ]> { +def UmmlaOp : ArmSVE_Op<"ummla", [Pure, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>]> { let summary = "Matrix-matrix multiply and accumulate op"; let description = [{ UMMLA: Unsigned integer matrix multiply-accumulate. @@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { + let summary = "Matrix-matrix multiply and accumulate op"; + let description = [{ +USMMLA: Unsigned by signed integer matrix multiply-accumulate. + +The unsigned by signed integer matrix multiply-accumulate operation +multiplies the 2×8 matrix of unsigned 8-bit integer values held +the first source vector by the 8×2 matrix of signed 8-bit integer +values in the second source vector. The resulting 2×2 widened 32-bit +integer matrix product is then added to the 32-bit integer matrix +accumulator. + +Source: +https://developer.arm.com/documentation/100987/ + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = +"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + class SvboolTypeConstraint : TypesMatchWith< "expected corresponding svbool type widened to [16]xi1", lhsArg, rhsArg, "VectorType(VectorType::Builder(::llvm::cast($_self)).setDim(::llvm::cast($_self).getRank() - 1, 16))">; def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool", -[Pure, SvboolTypeConstraint<"result", "source">]> -{ +
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636 >From aa8a667f206874af3b26811ec04d58be12ad43de Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 + Subject: [PATCH 1/3] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td| 4 + .../Dialect/ArmSVE/Transforms/Transforms.h| 3 + .../Conversion/VectorToLLVM/CMakeLists.txt| 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp| 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 + .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 + .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..b496ee0114910 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, +Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( +RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0ee6dce9ee94b..293e01a5bf4d4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPattern
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 2eb6c95955dc22b6b59eb4e5ba269e4744bbdd2a Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/3] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. momchil-velikov wrote: Comment added. https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) momchil-velikov wrote: Done. https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. momchil-velikov wrote: This part wasn't tested at all. Test cases added. https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/143146 >From 198ed819841270aeec7159fe2a9a4c092b8d8af7 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 14 May 2025 09:03:49 + Subject: [PATCH 1/4] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. --- .../Transforms/LegalizeVectorStorage.cpp | 110 - .../ArmSVE/legalize-transfer-read.mlir| 226 ++ .../transfer-read-scalable-not-rightmost.mlir | 72 ++ 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..f16d33c004fec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewriten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); + +// Collapse the trailing dimensions of the memref. +SmallVector reassoc; +for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i) + reassoc.push_back({i}); +for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank(); + ++i) + reassoc.
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); momchil-velikov wrote: Done. https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); momchil-velikov wrote: Fixed. https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][AArch64] Add integration test for lowering of `vector.contract` to Neon FEAT_I8MM (PR #144699)
@@ -0,0 +1,336 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-neon enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \ +// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+neon,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] + +// +// Test the lowering of `vector.contract` using the `LowerContractionToNeonI8MMPattern` +// +// The operation that the `vector.contract` in this test performs is matrix +// multiplication with accumulate +// OUT = ACC + LHS * RHS +// of two 8-bit integer matrices LHS and RHS, and a 32-bit integer matrix ACC +// into a 32-bit integer matrix OUT. The LHS and RHS can be sign- or zero- extended, +// this test covers all the possible variants. +// +// Tested are calculations as well as that the relevant `ArmNeon` dialect +// operations ('arm_neon.smmla`, arm_neon.ummla`, etc) are emitted. +// +// That pattern above handles (therefore this test prepares) input/output vectors with +// specific shapes: +// * LHS: vector +// * RHS: vector +// * ACC, OUT: vector +// where the M and N are even and K is divisible by 8. +// Note that the RHS is transposed. +// This data layout makes it efficient to load data into SIMD +// registers in the layout expected by FEAT_I8MM instructions. +// Such a `vector.contract` is representative of the code we aim to generate +// by vectorisation of `linalg.mmt4d`. +// +// In this specific test we use M == 4, N == 4, and K == 8. momchil-velikov wrote: Ah, yes it is. Comment fixed. https://github.com/llvm/llvm-project/pull/144699 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov edited https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From b950757c234900db941ed950ea3469b520d2e28a Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/8] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov edited https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov edited https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
momchil-velikov wrote: Commit message updated. https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/143146 >From 493955781f28b8b6caaeff1b45f7b7a06b43086c Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 14 May 2025 09:03:49 + Subject: [PATCH 1/3] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. --- .../Transforms/LegalizeVectorStorage.cpp | 110 - .../ArmSVE/legalize-transfer-read.mlir| 226 ++ .../transfer-read-scalable-not-rightmost.mlir | 72 ++ 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..f16d33c004fec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewriten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); + +// Collapse the trailing dimensions of the memref. +SmallVector reassoc; +for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i) + reassoc.push_back({i}); +for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank(); + ++i) + reassoc.
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -83,16 +84,48 @@ func.func @transfer_read_dims_mismatch_contiguous( return %res : vector<1x1x2x2xi8> } -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous( +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims( // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> -// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>> +// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>, vector<4xi8> // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8> // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8> -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous( +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims( +// CHECK-128B: memref.collapse_shape + +// - + +// The shape of the memref and the vector don't match, but the vector is a +// contiguous subset of the memref, so "flattenable" + +func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims( +%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : +memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x3x2xi8> + return %res : vector<2x3x2xi8> +} + +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims( +// CHECK-SAME:%[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> { +// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}> +// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]], %[[C0]]], %[[C0_I8]] {in_bounds = [true]} +// CHECK-SAME: : memref<5x24xi8, strided<[24, 1], offset: ?>>, vector<12xi8> +// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8> +// CHECK: return %[[VEC]] : vector<2x3x2xi8> momchil-velikov wrote: I don't understand the rationale behind having these in a particular order. https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From b950757c234900db941ed950ea3469b520d2e28a Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/9] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -630,7 +639,10 @@ class FlattenContiguousRowMajorTransferReadPattern if (transferReadOp.getMask()) return failure(); -int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); momchil-velikov wrote: With leading unit dimensions of a vector, the memref might not even be contiguous on that many dimensions. Example: `memref<2x2x2xi8, strided<[8, 4, 1]>` and `vector<1x1x2xi8>` Another consideration, in principle a memref can be collapsed in several different ways for given memref and vector types, anywhere from losing just one dimension to just one dimension remaining. I've chosen to do maximum collapsing, as it seems simpler to do, and the resulting IR is arguably simpler. https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/143146 THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. >From 62ad29ddc4d8c1ffa7c5af5dbadd9bb0647964ea Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 14 May 2025 09:03:49 + Subject: [PATCH] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. --- .../Transforms/LegalizeVectorStorage.cpp | 110 - .../ArmSVE/legalize-transfer-read.mlir| 226 ++ .../transfer-read-scalable-not-rightmost.mlir | 72 ++ 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..f16d33c004fec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewriten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); + +// Collapse the trailing dimension
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From b950757c234900db941ed950ea3469b520d2e28a Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/6] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From b950757c234900db941ed950ea3469b520d2e28a Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/6] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/143146 >From 6a6d6037b6da51b2da474c99751433542cf35603 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 14 May 2025 09:03:49 + Subject: [PATCH 1/2] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. --- .../Transforms/LegalizeVectorStorage.cpp | 110 - .../ArmSVE/legalize-transfer-read.mlir| 226 ++ .../transfer-read-scalable-not-rightmost.mlir | 72 ++ 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..f16d33c004fec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewriten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); + +// Collapse the trailing dimensions of the memref. +SmallVector reassoc; +for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i) + reassoc.push_back({i}); +for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank(); + ++i) + reassoc.
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/143146 >From 6a6d6037b6da51b2da474c99751433542cf35603 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 14 May 2025 09:03:49 + Subject: [PATCH 1/2] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. --- .../Transforms/LegalizeVectorStorage.cpp | 110 - .../ArmSVE/legalize-transfer-read.mlir| 226 ++ .../transfer-read-scalable-not-rightmost.mlir | 72 ++ 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..f16d33c004fec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewriten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); + +// Collapse the trailing dimensions of the memref. +SmallVector reassoc; +for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i) + reassoc.push_back({i}); +for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank(); + ++i) + reassoc.
[llvm-branch-commits] [mlir] [MLIR][AArch64] Add integration test for lowering of `vector.contract` to Neon FEAT_I8MM (PR #144699)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/144699 >From 3240b5d582f6fe5d430f792f0f7d208d80bf7f48 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 18 Jun 2025 13:04:45 + Subject: [PATCH 1/3] [MLIR][AArch64] Add integration test for lowering of `vector.contract` to Nein FEAT_I8MM --- .../CPU/ArmNeon/vector-contract-i8mm.mlir | 336 ++ 1 file changed, 336 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir new file mode 100644 index 0..337429e0a9688 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir @@ -0,0 +1,336 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-neon enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \ +// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+neon,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] + +// +// Test the lowering of `vector.contract` using the `LowerContractionToSMMLAPattern` +// +// The operation that the `vector.contract` in this test performs is matrix +// multiplication with accumulate +// OUT = ACC + LHS * RHS +// of two 8-bit integer matrices LHS and RHS, and a 32-bit integer matrix ACC +// into a 32-bit integer matrix OUT. The LHS and RHS can be sign- or zero- extended, +// this test covers all the possible variants. +// +// Tested are calculations as well as that the relevant `ArmNeon` dialect +// operations ('arm_neon.smmla`, arm_neon.ummla`, etc) are emitted. +// +// That pattern above handles (therefore this test prepares) input/output vectors with +// specific shapes: +// * LHS: vector +// * RHS: vector +// * ACC, OUT: vector +// where the M and N are even and K is divisible by 8. +// Note that the RHS is transposed. +// This data layout makes it efficient to load data into SIMD +// registers in the layout expected by FEAT_I8MM instructions. +// Such a `vector.contract` is representative of the code we aim to generate +// by vectorisation of `linalg.mmt4d`. +// +// In this specific test we use M == 4, N == 4, and K == 8. +// + +// Test the operation where both LHS and RHS are interpreted as signed, hence +// we ultimately emit and execute the `smmla` instruction. + +// CHECK-IR-LABEL: llvm.func @test_smmla +// CHECK-IR-COUNT-4: arm_neon.intr.smmla +func.func @test_smmla() { + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + + // Accumulator test data + %acc_cst = arith.constant dense<[[ -1, -9, -4, 0], + [ 6, 5, 7, 2], + [ -8, -7, 9, -10], + [ 9, 4, -4, 0]]> : vector<4x4xi32> + + %acc_mem = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : memref<4x4xi32>, vector<4x4xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[ -4, -4, -4, -6, 0, 1, 6, 2, -1, 4, 5, -8, 9, 5, 4, 9], + [ -1, 6, 0, 7, -7, 8, 5, 8, -7, 6, -2, 1, 1, 5, -4, -4], + [ 4, -10, 10, -3, 5, 3, 2, 3, -7, 9, -9, -10, 7, -8, -5, -2], + [ 9, 5, 8, 9, 6, -3, -9, 7, -4, -7, -2, 7, -8, 2, 8, 7]]> : vector<4x16xi8> + + %lhs_mem = memref.alloca() : memref<4x16xi8> + vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] : vector<4x16xi8>, memref<4x16xi8> + %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<4x16xi8>, vector<4x16xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[ 1, 2, -3, 5, 10, 8, 10, -2, 1, 10, -5, 2, 4, 3, -9, 4], + [ -3, -3, -3, 4, 6, -1, 0, -5, 6, 3, -1, 9, -3, 3, -2, 4], + [ 1, 9, -1, 1, -5, 4, 9, -10, -1, -7,
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 1590fe674c00e07171c8842805a06907ed68f242 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/2] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 1590fe674c00e07171c8842805a06907ed68f242 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/2] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 887f383aa07cca3fe023cd64b3b119cbf013c17b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/3] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 887f383aa07cca3fe023cd64b3b119cbf013c17b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/3] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op (PR #140572)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140572 >From 0d2dca54b02ab76d4b847eed764a5284b74fc5f3 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Fri, 16 May 2025 15:47:36 + Subject: [PATCH 1/3] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op --- .../mlir/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.h | 31 +++ .../TransformOps/ArmSVEVectorTransformOps.td | 26 ++ .../ArmSVE/TransformOps/CMakeLists.txt| 6 + mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.cpp | 54 .../ArmSVE/TransformOps/CMakeLists.txt| 19 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 263 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 123 .../Vector/CPU/ArmSVE/vector-ummla.mlir | 138 + .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 138 + 12 files changed, 512 insertions(+), 290 deletions(-) create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h new file mode 100644 index 0..7f22cd1fe6435 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSVEVectorTransformOps.h - Vector transform ops *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// + +#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===--===// +// ArmSVE Vector Transform Operations +//===--===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sve { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td new file mode 100644 index 0..81b59340f3b0d --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -0,0 +1,26 @@ +//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +#ifndef ARMSVE_VECTOR_TRANSFORM_OPS +#define ARMSVE_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +def ApplyArmSVELowerContractionPatternsOp +: Op]> { + let description = [{ +Indicates that vector contraction-like operations should be lowered to +finer-grained vector primitives using the ArmSVE dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0..ce8d8fea7f188 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td) +mlir_tablegen(ArmSVEVectorTransfo
[llvm-branch-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140573 >From b5865b5daacc46f53e948bcd6347f4eafc8d2938 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 19 May 2025 14:50:45 + Subject: [PATCH 1/2] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM --- .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 + .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 ++ .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 + .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 + .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 + 5 files changed, 630 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir new file mode 100644 index 0..88534dd2aab1e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir @@ -0,0 +1,117 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17], +
[llvm-branch-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140573 >From b5865b5daacc46f53e948bcd6347f4eafc8d2938 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 19 May 2025 14:50:45 + Subject: [PATCH 1/2] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM --- .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 + .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 ++ .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 + .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 + .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 + 5 files changed, 630 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir new file mode 100644 index 0..88534dd2aab1e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir @@ -0,0 +1,117 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17], +
[llvm-branch-commits] [mlir] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op (PR #140572)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/140572 >From 0d2dca54b02ab76d4b847eed764a5284b74fc5f3 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Fri, 16 May 2025 15:47:36 + Subject: [PATCH 1/3] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op --- .../mlir/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.h | 31 +++ .../TransformOps/ArmSVEVectorTransformOps.td | 26 ++ .../ArmSVE/TransformOps/CMakeLists.txt| 6 + mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Dialect/ArmSVE/CMakeLists.txt| 1 + .../TransformOps/ArmSVEVectorTransformOps.cpp | 54 .../ArmSVE/TransformOps/CMakeLists.txt| 19 ++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 263 ++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 123 .../Vector/CPU/ArmSVE/vector-ummla.mlir | 138 + .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 138 + 12 files changed, 512 insertions(+), 290 deletions(-) create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h new file mode 100644 index 0..7f22cd1fe6435 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSVEVectorTransformOps.h - Vector transform ops *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// + +#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===--===// +// ArmSVE Vector Transform Operations +//===--===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sve { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td new file mode 100644 index 0..81b59340f3b0d --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -0,0 +1,26 @@ +//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +#ifndef ARMSVE_VECTOR_TRANSFORM_OPS +#define ARMSVE_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +def ApplyArmSVELowerContractionPatternsOp +: Op]> { + let description = [{ +Indicates that vector contraction-like operations should be lowered to +finer-grained vector primitives using the ArmSVE dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0..ce8d8fea7f188 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td) +mlir_tablegen(ArmSVEVectorTransfo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/142422 Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such cases, only the trailing `n` dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. >From d9a470e098553dbac74e81f98e0077718f6d9ed1 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Utils/IndexingUtils.h| 3 +- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - mlir/lib/Dialect/Utils/IndexingUtils.cpp | 6 +- .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 6 files changed, 126 insertions(+), 78 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index 99218f491ddef..a1c8ec2db056a 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -40,7 +40,8 @@ class ArrayAttr; /// Assuming `sizes` is `[s0, .. sn]`, return the vector /// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. /// -/// `sizes` elements are asserted to be non-negative. +/// `sizes` element `s0` is asserted to be kDynamic or non-negative. +/// `sizes` elements `s1` to `sn` are asserted to be non-negative. /// /// Return an empty vector if `sizes` is empty. SmallVector computeSuffixProduct(ArrayRef sizes); diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -582,6 +582,15 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, namespace { +/// Helper functon to return the index of the last dynamic dimension in `shape`. momchil-velikov wrote: `std::distance` returns a `ptrdiff_t`, there's a chance it's not the same as `int64_t` (e.g. `long long` vs. `long`). https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -203,21 +206,21 @@ func.func @transfer_read_dynamic_dim_to_flatten( return %res : vector<1x2x6xi32> } -// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> +// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> // CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten // CHECK-SAME:%[[IDX_1:arg0]] // CHECK-SAME:%[[IDX_2:arg1]] // CHECK-SAME:%[[MEM:arg2]] -// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 momchil-velikov wrote: I don't think it maters, I wouldn't bother changing `*` to `+` unless I'm modifying the line anyway, in which case the `+` is "more correct" as we're in fact expecting a non-empty variable name. https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 4c1207b4f72314683eed7d5f8672c6d83db1ef54 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/5] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 4c1207b4f72314683eed7d5f8672c6d83db1ef54 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/5] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 8f9a4002820dcd3de2a5986d53749386a2507eab Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/4] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 8f9a4002820dcd3de2a5986d53749386a2507eab Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/4] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/143146 >From 4d13aa2c48a24e5a76618e78adde41713115f895 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 14 May 2025 09:03:49 + Subject: [PATCH] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. --- .../Transforms/LegalizeVectorStorage.cpp | 110 - .../ArmSVE/legalize-transfer-read.mlir| 226 ++ .../transfer-read-scalable-not-rightmost.mlir | 72 ++ 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..f16d33c004fec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewriten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); + +// Collapse the trailing dimensions of the memref. +SmallVector reassoc; +for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i) + reassoc.push_back({i}); +for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank(); + ++i) + reassoc.back
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/143146 >From 4d13aa2c48a24e5a76618e78adde41713115f895 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Wed, 14 May 2025 09:03:49 + Subject: [PATCH] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position. --- .../Transforms/LegalizeVectorStorage.cpp | 110 - .../ArmSVE/legalize-transfer-read.mlir| 226 ++ .../transfer-read-scalable-not-rightmost.mlir | 72 ++ 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..f16d33c004fec 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewriten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); + +// Collapse the trailing dimensions of the memref. +SmallVector reassoc; +for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i) + reassoc.push_back({i}); +for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank(); + ++i) + reassoc.back
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 937afe27b55fb6a61ccef6252eeb33159bca3e99 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/5] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 937afe27b55fb6a61ccef6252eeb33159bca3e99 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/5] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From 2eb6c95955dc22b6b59eb4e5ba269e4744bbdd2a Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/3] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -630,7 +639,10 @@ class FlattenContiguousRowMajorTransferReadPattern if (transferReadOp.getMask()) return failure(); -int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); momchil-velikov wrote: > > I've chosen to do maximum collapsing, as it seems simpler to do, and the > > resulting IR is arguably simpler. > Ok. Maybe it is simpler in these 2 senses. I guess minimising change > (compared to pre-PR and to pre-pass) could be another metric. On a second thought, perhaps _minimum_ collapsing would be more appropriate, since it discards less information. https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -630,7 +639,10 @@ class FlattenContiguousRowMajorTransferReadPattern if (transferReadOp.getMask()) return failure(); -int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); momchil-velikov wrote: > when might > lastDynIndex(sourceType.getShape()) be larger than sourceType.getRank() - > sourceType.getNumContiguousTrailingDims()) ? For memrefs with dynamic dimensions and no strides or maps, e.g. `memref<2x?x2xi8>` https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> momchil-velikov wrote: These examples contain redundancies. I've checked the cases and came up with * Ex.1 contiguous slice, perfect match `vector<4x3x2xi32>` from `memref<5x4x3x2xi32>` tested by: `transfer_read_dims_match_contiguous` `transfer_write_dims_match_contiguous` * Ex.2 contiguous slice, the leading dim does not match (2 != 4) `vector<2x3x2xi32>` from `memref<5x4x3x2xi32>` tested by: `transfer_read_dims_mismatch_contiguous_non_unit_dims` `transfer_write_dims_mismatch_contiguous_non_unit_dims` * Ex.5. contiguous slice, leading two unit dims of the vector ignored, 2 != 3 (allowed) `vector<1x1x2x2xi32>` from `memref<5x4x3x2xi32>` tested by: `transfer_read_dims_mismatch_contiguous_unit_dims` `transfer_write_dims_mismatch_contiguous_unit_dims` * Ex.7 contiguous slice, memref needs to be contiguous only in the last dimension `vector<1x1x2xi32>` from `memref<2x2x2xi32, strided<[8, 4, 1]>>` tested by: `transfer_read_dims_mismatch_non_contiguous_non_zero_indices` `transfer_write_dims_mismatch_non_contiguous_non_zero_indices` * Ex.6. non-contiguous slice, 2 != 3 `vector<2x1x2x2xi32>` from `memref<5x4x3x2xi32>` tested by: `transfer_read_dims_mismatch_non_contiguous_slice` `transfer_write_dims_mismatch_non_contiguous_slice` https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422 >From b950757c234900db941ed950ea3469b520d2e28a Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Mon, 2 Jun 2025 15:13:13 + Subject: [PATCH 1/7] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 - .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +- 4 files changed, 120 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -///Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -///Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -///Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -///Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +///2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +///dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +///two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguo
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -630,7 +639,10 @@ class FlattenContiguousRowMajorTransferReadPattern if (transferReadOp.getMask()) return failure(); -int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); momchil-velikov wrote: So, I changed it like this. https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits