llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-vector Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Clean up `populateVectorToLLVMConversionPatterns` so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite. The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions. Depends on #<!-- -->119973. --- Full diff: https://github.com/llvm/llvm-project/pull/119975.diff 6 Files Affected: - (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+4) - (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+12) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14-13) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+5-1) - (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+2-2) - (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (-5) ``````````diff diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 3d643c96b45008..c507b23c6d4de6 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); +/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where +/// n > 1. +void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 1497d662dcdbdd..2fe3b1302e5e5b 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -32,10 +32,12 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Error.h" @@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) void GpuToLLVMConversionPass::runOnOperation() { MLIRContext *context = &getContext(); + + // Perform progressive lowering of vector transfer operations. + { + RewritePatternSet patterns(&getContext()); + // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. + vector::populateVectorTransferLoweringPatterns(patterns, + /*maxTransferRank=*/1); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + LowerToLLVMOptions options(context); options.useBarePtrCallConv = hostBarePtrCallConv; RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a9a07c323c7358..577b74bb7e0c26 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). /// Non-scalable versions of this operation are handled in Vector Transforms. -class VectorCreateMaskOpRewritePattern - : public OpRewritePattern<vector::CreateMaskOp> { +class VectorCreateMaskOpConversion + : public OpConversionPattern<vector::CreateMaskOp> { public: - explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + explicit VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt) - : OpRewritePattern<vector::CreateMaskOp>(context), + : OpConversionPattern<vector::CreateMaskOp>(context), force32BitVectorIndices(enableIndexOpt) {} - LogicalResult matchAndRewrite(vector::CreateMaskOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto dstType = op.getType(); if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) return failure(); @@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern loc, LLVM::getVectorType(idxType, dstType.getShape()[0], /*isScalable=*/true)); auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, - op.getOperand(0)); + adaptor.getOperands()[0]); Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, bounds); @@ -1896,16 +1896,19 @@ struct VectorScalableStepOpLowering } // namespace +void mlir::vector::populateVectorRankReducingFMAPattern( + RewritePatternSet &patterns) { + patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext()); +} + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions, bool force32BitVectorIndices) { + // This function populates only ConversionPatterns, not RewritePatterns. MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.add<VectorFMAOpNDRewritePattern>(ctx); - populateVectorInsertExtractStridedSliceTransforms(patterns); - populateVectorStepLoweringPatterns(patterns); patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); - patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); + patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, VectorExtractElementOpConversion, VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertElementOpConversion, @@ -1922,8 +1925,6 @@ void mlir::populateVectorToLLVMConversionPatterns( MaskedReductionOpConversion, VectorInterleaveOpLowering, VectorDeinterleaveOpLowering, VectorFromElementsLowering, VectorScalableStepOpLowering>(converter); - // Transfer ops with rank > 1 are handled by VectorToSCF. - populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 64a9ad8e9bade0..2d94c2f2e85a08 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass void ConvertVectorToLLVMPass::runOnOperation() { // Perform progressive lowering of operations on slices and all contraction - // operations. Also materializes masks, applies folding and DCE. + // operations. Also materializes masks, lowers vector.step, rank-reduces FMA, + // applies folding and DCE. { RewritePatternSet patterns(&getContext()); populateVectorToVectorCanonicalizationPatterns(patterns); @@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); + populateVectorInsertExtractStridedSliceTransforms(patterns); + populateVectorStepLoweringPatterns(patterns); + populateVectorRankReducingFMAPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir index 44deb45cd752b4..532a2383cea9ef 100644 --- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s module { - func.func @func(%arg: vector<11xf32>) { + func.func @func(%arg: vector<11xf32>) -> vector<11xf32> { %cst_41 = arith.constant dense<true> : vector<11xi1> // CHECK: vector.mask // CHECK-SAME: vector.yield %arg0 %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32> - return + return %127 : vector<11xf32> } } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index ea88fece9e662d..f95e943250bd44 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32 // CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable( // CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>) // CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>> -// CHECK: %[[T2:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32> // CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>> // CHECK: %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>> @@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4 return %0 : vector<4x4x4xf32> } // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>> // CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>> // ----- @@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, % return %0 : vector<4x4x[4]xf32> } // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>> // CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>> // ----- @@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c return %0 : vector<4x4x4xindex> } // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>> // CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>> // ----- @@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4] return %0 : vector<4x4x[4]xindex> } // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>> // CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>> // ----- `````````` </details> https://github.com/llvm/llvm-project/pull/119975 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits