[llvm-branch-commits] [flang] [Flang][OpenMP] Add frontend support for -fopenmp-targets (PR #100155)
@@ -894,6 +894,7 @@ static bool parseDiagArgs(CompilerInvocation &res, llvm::opt::ArgList &args, /// options accordingly. Returns false if new errors are generated. static bool parseDialectArgs(CompilerInvocation &res, llvm::opt::ArgList &args, banach-space wrote: Thanks 🙏🏻 https://github.com/llvm/llvm-project/pull/100155 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [flang] [Flang][OpenMP] Add frontend support for -fopenmp-targets (PR #100155)
https://github.com/banach-space approved this pull request. https://github.com/llvm/llvm-project/pull/100155 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (PR #119975)
https://github.com/banach-space approved this pull request. The non-GPU changes LGTM. The CHECK lines removed in tests were just dead code, so thanks for the clean-up! The GPU parts look reasonable, but it might be worth waiting a few days in case someone more experienced wants to take a look. If there are no comments, I would just land this. https://github.com/llvm/llvm-project/pull/119975 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][LLVM] Delete `LLVMFixedVectorType` and `LLVMScalableVectorType` (PR #133286)
https://github.com/banach-space approved this pull request. Nice cleanup, thanks! LGTM https://github.com/llvm/llvm-project/pull/133286 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][LLVM] Delete `LLVMFixedVectorType` and `LLVMScalableVectorType` (PR #133286)
@@ -150,8 +150,7 @@ generatedTypePrinter(Type def, AsmPrinter &printer); bool LLVMArrayType::isValidElementType(Type type) { return !llvm::isa( - type); +LLVMFunctionType, LLVMTokenType>(type); banach-space wrote: That `LLVMScalableVectorType` was added long before SME: * https://reviews.llvm.org/D85663 But yes, "arrays of scalable vectors" are a thing and we rely on them. That said, I don't see any SME/SVE tests failing (I also check e2e locally), so this should be fine. https://github.com/llvm/llvm-project/pull/133286 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][LLVM] Delete `LLVMFixedVectorType` and `LLVMScalableVectorType` (PR #133286)
@@ -1033,7 +1033,7 @@ llvm.func @scalable_vector() -> i16 { llvm.func @scalable_llvm_vector() -> i16 { %0 = llvm.mlir.constant(1 : i32) : i32 // CHECK: llvm.alloca - %1 = llvm.alloca %0 x !llvm.vec : (i32) -> !llvm.ptr + %1 = llvm.alloca %0 x vector<[4] x !llvm.ppc_fp128> : (i32) -> !llvm.ptr banach-space wrote: The element type shouldn't matter, right? "Scalability" is a fairly abstract concept. https://github.com/llvm/llvm-project/pull/133286 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][LLVM] Delete `LLVMFixedVectorType` and `LLVMScalableVectorType` (PR #133286)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/133286 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/135634 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
https://github.com/banach-space approved this pull request. LGTM % nit Thanks! https://github.com/llvm/llvm-project/pull/135634 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla", "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; } +def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, +AllTypesMatch<["src1", "src2"]>, +AllTypesMatch<["acc", "dst"]>]> { banach-space wrote: This indentation is inconsistent with the other ops, but the existing indentation feels a bit ad-hoc. I like yours much more. Would you mind updating other definitions so that we do maintain consistency? Probably as a separate PR to keep the history clean. Updating this PR instead would also be fine with me. https://github.com/llvm/llvm-project/pull/135634 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/banach-space commented: Thanks! This one is a bit longer, so I may need to wait till Thursday before I can review. One high-level question - would sharing some code between NEON and SVE be possible? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/banach-space commented: Thanks Momchil - this is great! I skimmed through the pattern logic, and it's very neatly written. It's actually quite easy to follow, despite the underlying logic being a bit convoluted - well done! I've left a few minor suggestions, but nothing major. Also, it seems like we should be able to extend this fairly easily to support NEON as well. Worth thinking about 🙂 Now, overall this patch is quite large, and I’d suggest extracting the end-to-end / integration tests into a separate PR. Additionally, the remaining tests currently use `--convert-vector-to-llvm=`, which lowers all the way to LLVM (i.e., it exercises a lot of patterns). Instead, I’d recommend testing `LowerContractionToSVEI8MMPattern` in isolation and only verifying that the correct sequence of ArmSVE ops (plus some Vector ops) is generated - for example: ```mlir (...) %33 = arm_sve.smmla %23, %7, %15 : vector<[16]xi8> to vector<[4]xi32> %34 = arm_sve.smmla %24, %7, %16 : vector<[16]xi8> to vector<[4]xi32> %35 = arm_sve.smmla %31, %13, %15 : vector<[16]xi8> to vector<[4]xi32> %36 = arm_sve.smmla %32, %13, %16 : vector<[16]xi8> to vector<[4]xi32> ``` That way, we will: * reduce noise in the test output (by focusing on a single pattern), * simplify expected output (fewer ops to match), * avoid re-testing functionality already covered elsewhere (e.g., `arm_sve.smmla` → `arm_sve.intr.smmla` lowering). Btw, this is already looking great, and I know I’m asking for a bit of a rewrite (especially around the tests), but I really think it’ll help with long-term maintainability. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. banach-space wrote: Could you add a note that `vector.contract` needs to be accompanied by `arith.extsi` (or `arith.extui`) Ops? Also, is I8MM the official name? Shouldn't that be FEAT_I8MM? Basically, could we document a bit more? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern banach-space wrote: It's a very long pattern. Could you document the high-level logic? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at banach-space wrote: [nit] We shouldn't be concerned with MMT4D in this dialect - it's a much higher-level abstraction and this logic should be valid irrespective of how the input is generated. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- banach-space wrote: ```suggestion //===--===//``` https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); + +if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + +// M, N, and K are the conventional names for matrix dimensions in the +// context of matrix multiplication. +auto M = lhsType.getDimSize(0); +auto N = rhsType.getDimSize(0); +auto K = rhsType.getDimSize(1); + +if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || +N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + +// Check permutation maps. For now only accept +// lhs: (d0, d1, d2) -> (d0, d2) +// rhs: (d0, d1, d2) -> (d1, d2) +// acc: (d0, d1, d2) -> (d0, d1) +// Note: RHS is transposed. +if (op.getIndexingMapsArray()[0] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[1] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || +op.getIndexingMapsArray()[2] != +AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + +// Check iterator types for matrix multiplication. +auto itTypes = op.getIteratorTypesArray(); +if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || +itTypes[1] != vector::IteratorType
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || +rhsType.getRank() != 2) + return failure(); banach-space wrote: Could you use `notifyMatchFailure` with some descriptive error message instead? Thanks! Some comment for other instances of `failure`. https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at +// least 2. +if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || banach-space wrote: IIRC, inputs to `vector.contract` are required to be vectors, hence `lhsType.hasRank()` should always be true, no? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> banach-space wrote: Why not simple `isa(v.getDefinitionOp())` inside the function instead of this? That's more common from what I've seen (there's very little SFINAE in the Dialect code). https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] users/banach space/sme/remove ConvertIllegalShapeCastOpsToTransposes (PR #139706)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/139706 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][ArmSME] Remove `ConvertIllegalShapeCastOpsToTransposes` (PR #139706)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/139706 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] users/banach space/sme/remove ConvertIllegalShapeCastOpsToTransposes (PR #139706)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/139706 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)
@@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), +std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) +return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) +return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) +return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned,// ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: +return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: +// The accumulator comes transposed and the result will be transposed +// later, so all we have to do here is swap the operands. +return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, +PatternRewriter &rewriter) const override { + +Location loc = op.getLoc(); +mlir::VectorType lhsType = op.getLhsType(); +mlir::VectorType rhsType = op.getRhsType(); + +// For now handle LHS and RHS<8x[N]> - these are the types we +// eventually expect from MMT4D. M and N dimensions must be even and at banach-space wrote: Perhaps just expand this comment a bit (e.g. by noting that MMT4D is the main use-case ATM)? https://github.com/llvm/llvm-project/pull/135636 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] users/banach space/vector/update create write (PR #141567)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/141567 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] users/banach space/vector/update create write (PR #141567)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/141567 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC) (PR #141567)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/141567 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)
banach-space wrote: Thanks - great to finally be reaching this stage! I have a few high-level questions and suggestions: **1. Why is the scalable dimension always [4]?** From the current tests, it looks like the scalable dim is always `[4]`. Could you remind me why that value is chosen? **2. Reduce duplication in the 4x8x4 tests** The current tests differ only in terms of **input**/**output** and `extsi` vs `extui`. It should be possible to reduce duplication by extracting shared logic into helpers, and writing 4 separate entry points (set via `entry_point`) to isolate the differences. For example: ```mlir func.func @main_smmla() { // Init LHS, RHS, ACC // CHECK-LINES for LHS print(lhs); // CHECK-LINES for RHS print(rhs); arith.extsi (lhs) arith.extsi (rhs) vector.contract // CHECK-LINES for ACC print(acc); } ``` This would keep the test logic focused and easier to maintain. **3. Add checks for generated IR (LLVM dialect)** It would be good to verify that the lowered IR includes the correct SME MMLA intrinsics. For example: ```mlir // CHECK-COUNT-4: llvm.intr.smmla ``` This would help confirm both correctness and that the expected number of operations are emitted. **4. Consider toggling VL within tests** Have you considered toggling the scalable vector length (`VL`) within the test? That would allow verifying behaviour for multiple `VL` values. From what I can tell, this would only work if the inputs are generated inside a loop, similar to this example: https://github.com/llvm/llvm-project/blob/88f61f2c5c0ad9dad9c8df2fb86352629e7572c1/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir#L19-L37 That might be a nice validation of the "scalability" aspect. https://github.com/llvm/llvm-project/pull/140573 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
@@ -83,16 +84,48 @@ func.func @transfer_read_dims_mismatch_contiguous( return %res : vector<1x1x2x2xi8> } -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous( +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims( // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> -// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>> +// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>, vector<4xi8> // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8> // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8> -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous( +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims( +// CHECK-128B: memref.collapse_shape + +// - + +// The shape of the memref and the vector don't match, but the vector is a +// contiguous subset of the memref, so "flattenable" + +func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims( +%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : +memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x3x2xi8> + return %res : vector<2x3x2xi8> +} + +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims( +// CHECK-SAME:%[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> { +// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}> +// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]], %[[C0]]], %[[C0_I8]] {in_bounds = [true]} +// CHECK-SAME: : memref<5x24xi8, strided<[24, 1], offset: ?>>, vector<12xi8> +// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8> +// CHECK: return %[[VEC]] : vector<2x3x2xi8> banach-space wrote: > I don't understand the rationale behind having these in a particular order. The current ordering feels reversed to me. In my head, it's clearer to start with the most basic case - e.g., the one with no leading unit dims - and then progressively build on it by adding complexity, such as leading unit dims. Right now, the more complex case comes first, which makes the overall structure harder to follow. From another angle: the naming in * `@transfer_read_dims_mismatch_contiguous_non_unit_dims` is confusing, especially when compared to tests like * `@transfer_read_dims_match_contiguous_empty_stride`. Why is the absence of unit dims significant here? That may be obvious now, but it's not something I’m likely to remember when revisiting this file in the future. To improve readability and flow, I suggest: * Rename `@transfer_read_dims_mismatch_contiguous_non_unit_dims` -> `@transfer_read_dims_mismatch_contiguous` * Move it before `@transfer_read_dims_mismatch_contiguous_unit_dims` to preserve a "simple-to-complex" test progression. Thanks! https://github.com/llvm/llvm-project/pull/142422 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][AArch64] Add integration test for lowering of `vector.contract` to Neon FEAT_I8MM (PR #144699)
@@ -0,0 +1,336 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-neon enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \ +// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+neon,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] + +// +// Test the lowering of `vector.contract` using the `LowerContractionToNeonI8MMPattern` +// +// The operation that the `vector.contract` in this test performs is matrix +// multiplication with accumulate +// OUT = ACC + LHS * RHS +// of two 8-bit integer matrices LHS and RHS, and a 32-bit integer matrix ACC +// into a 32-bit integer matrix OUT. The LHS and RHS can be sign- or zero- extended, +// this test covers all the possible variants. +// +// Tested are calculations as well as that the relevant `ArmNeon` dialect +// operations ('arm_neon.smmla`, arm_neon.ummla`, etc) are emitted. +// +// That pattern above handles (therefore this test prepares) input/output vectors with +// specific shapes: +// * LHS: vector +// * RHS: vector +// * ACC, OUT: vector +// where the M and N are even and K is divisible by 8. +// Note that the RHS is transposed. +// This data layout makes it efficient to load data into SIMD +// registers in the layout expected by FEAT_I8MM instructions. +// Such a `vector.contract` is representative of the code we aim to generate +// by vectorisation of `linalg.mmt4d`. +// +// In this specific test we use M == 4, N == 4, and K == 8. banach-space wrote: Isn't K = 16 in the code below? https://github.com/llvm/llvm-project/pull/144699 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][AArch64] Add integration test for lowering of `vector.contract` to Neon FEAT_I8MM (PR #144699)
https://github.com/banach-space approved this pull request. Thanks, it's great to see more tests for `i8mm`. The documentation makes it relatively easy to follow (despite this being fairly complex!) - that's much appreciated! Overall LGTM, but I have one request. Could you unify the input data between SVE and NEON? I am happy for actual code to be duplicated. Btw, could you share how you generated the expected output? If that's some short numpy snippet, could you include it for future reference? (should these tests start to fail) Thank you! https://github.com/llvm/llvm-project/pull/144699 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. banach-space wrote: > // The collapsed dimensions (excluding the scalable one) of the vector and >// the memref must match What about dynamic dim sizes in the memref? If that's not supported, is there a test? https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR][AArch64] Add integration test for lowering of `vector.contract` to Neon FEAT_I8MM (PR #144699)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/144699 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/banach-space commented: Great work, Momchil - thank you! I've left a number of comments, but nothing major. My main high-level suggestion is to follow the guidance in [MLIR's Testing Guide](https://mlir.llvm.org/getting_started/TestingGuide/#contributor-guidelines) a bit more closely. It’s a relatively new (and long!) document, so I’ve included specific in-line suggestions to make it easier to see where things could align better. For additional context, this [RFC](https://discourse.llvm.org/t/rfc-should-we-aim-for-more-consistency-in-tests/) provides some of the rationale behind that approach. Also - what about memrefs with dynamic dimensions? https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
banach-space wrote: [nit] Avoid using the word `test` in test function names. It's just noise that doesn't add any new info. Instead, try to convey what makes a particular test case unique. See here for MLIR guidelines: https://mlir.llvm.org/getting_started/TestingGuide/#test-formatting-best-practices https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + +// Number of trailing dimensions to collapse, including the scalable +// dimension. Nothing to do if the single scalable dimension is already the +// last one. +const int64_t numCollapseDims = std::distance( +llvm::find(origScalableDims, true), origScalableDims.end()); +if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + +// We want a simple memref (not a tensor) with contiguous elements for at +// least all the trailing dimensions up to and including the scalable one. +auto memTy = dyn_cast(readOp.getBase().getType()); +if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + +// The collapsed dimensions (excluding the scalable one) of the vector and +// the memref must match and the corresponding indices must be in-bounds (it +// follows these indices would be zero). This guarantees that the operation +// transfers a contiguous block. +if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + +SmallVector origInBounds = readOp.getInBoundsValues(); +if (!llvm::all_of( +ArrayRef(origInBounds).take_back(numCollapseDims - 1), +[](bool v) { return v; })) + return rewriter.notifyMatchFailure(readOp, + "out-if-bounds index to collapse"); banach-space wrote: Note, it's not really index that's out-of-bounds, but the corresponding memory access. So, index could be in-bounds, but we might be reading "more" than there's available to read (starting at that index). For example: ```mlir vector.transfer_read %mem[5] : memref<7xi8>, vector<7xi8> ``` ```suggestion "out-of-bounds index to collapse"); ``` https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); banach-space wrote: Would supporting non-identity be a problem? It would be good to add a comment, either: * `TODO: We haven't required this, so leaving for later.` or * "Too complex because of , disabling". Any hint for future developers would be helpful. https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: banach-space wrote: Is it guaranteed that `%i` will be renamed as `arg0` after the transformation? AFAIK, no, but perhaps I am missing something? https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. +VectorType origVT = readOp.getVectorType(); +ArrayRef origScalableDims = origVT.getScalableDims(); +const int64_t origVRank = origVT.getRank(); +if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) banach-space wrote: [nit] [getNumScalableDims](https://github.com/banach-space/llvm-project/blob/c15e7dddaea765eab4f9ed73e79b762138dc4ac0/mlir/include/mlir/IR/BuiltinTypes.td#L1368-L1371) would be more canonical then `llvm::count` https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type. This is done by collapsing trailing +/// dimensions so we obtain a vector type with a single scalable dimension in +/// the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, +PatternRewriter &rewriter) const override { + +// Do not try to transform masked reads. For example, if we have a transfer +// to a `vector<[4]x4xi8>` we could have a mask like +//1 1 1 0 +//1 1 1 0 +//1 1 1 0 +//0 0 0 0 +// Flattening this mask would look like +//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 +// and we have not yet figured out an efficient way to build such a mask, +// neither from the mask operand, nor from the original `vector.create_mask` +// operation (if visible at all). +if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + +if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + +// We handle transfers of vectors with rank >= 2 and a single scalable +// dimension. banach-space wrote: [nit] It would be helpful to add _why_: * Don't need to worry about 1D, that's supported by default. * More than 1 scalable dims are tricky (how to collapse e.g. `vscale * vscale`?) https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @test_base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_using_strided_layout +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +#s0 = strided<[?, ?, 8, 1]> + +func.func @test_using_strided_layout(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// - + +// CHECK-LABEL: @test_3d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME:: memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]} +// CHECK-SAME:: memref>, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +#s1 = strided<[?, 16, 8, 1]> + +func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @test_4d_vector +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref> into +// CHECK-SAME: memref> +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]} +// CHECK-SAME: : memref>, vector<2x[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8> + +#s2 = strided<[?, 16, 8, 1]> + +func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref, vector<2x[4]x2x8xi8> + + return %A : vector<2x[4]x2x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_non_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// - + +// CHECK-LABEL: @negative_test_vector_legal_scalable_0 +// CHECK-NOT: memref.collapse + +func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op (PR #140572)
https://github.com/banach-space approved this pull request. LGTM, thanks! https://github.com/llvm/llvm-project/pull/140572 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir] Nominate MLIR Egress category maintainers (PR #149487)
https://github.com/banach-space approved this pull request. https://github.com/llvm/llvm-project/pull/149487 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits