[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
https://github.com/aaupov edited https://github.com/llvm/llvm-project/pull/119652 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
https://github.com/aaupov updated https://github.com/llvm/llvm-project/pull/119652 >From cc7cea0276fef36e896ecef149ef680b66bb9c1f Mon Sep 17 00:00:00 2001 From: Amir Ayupov Date: Wed, 11 Dec 2024 19:11:07 -0800 Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?= =?UTF-8?q?itial=20version?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 --- bolt/test/merge-fdata.test | 8 ++ bolt/tools/merge-fdata/merge-fdata.cpp | 39 ++ 2 files changed, 36 insertions(+), 11 deletions(-) create mode 100644 bolt/test/merge-fdata.test diff --git a/bolt/test/merge-fdata.test b/bolt/test/merge-fdata.test new file mode 100644 index 00..0ea0f2cb4398b6 --- /dev/null +++ b/bolt/test/merge-fdata.test @@ -0,0 +1,8 @@ +## Reproduces an issue where counts were not accumulated by merge-fdata + +# RUN: split-file %s %t +# RUN: merge-fdata %t/fdata | FileCheck %s +# CHECK: 1 main 10 1 main 1a 1 20 +#--- fdata +1 main 10 1 main 1a 0 10 +1 main 10 1 main 1a 1 10 diff --git a/bolt/tools/merge-fdata/merge-fdata.cpp b/bolt/tools/merge-fdata/merge-fdata.cpp index 89ca46c1c0a8fa..a74e9099df9059 100644 --- a/bolt/tools/merge-fdata/merge-fdata.cpp +++ b/bolt/tools/merge-fdata/merge-fdata.cpp @@ -19,6 +19,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Regex.h" #include "llvm/Support/Signals.h" #include "llvm/Support/ThreadPool.h" #include @@ -266,7 +267,19 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { errs() << "Using legacy profile format.\n"; std::optional BoltedCollection; std::mutex BoltedCollectionMutex; - typedef StringMap ProfileTy; + constexpr static const char *const FdataCountersPattern = + "(.*) ([0-9]+) ([0-9]+)"; + Regex FdataRegex(FdataCountersPattern); + struct CounterTy { +uint64_t Count; +uint64_t MispredCount; +CounterTy &operator+(const CounterTy &O) { + Count += O.Count; + MispredCount += O.MispredCount; + return *this; +} + }; + typedef StringMap ProfileTy; auto ParseProfile = [&](const std::string &Filename, auto &Profiles) { const llvm::thread::id tid = llvm::this_thread::get_id(); @@ -304,15 +317,19 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { SmallVector Lines; SplitString(Buf, Lines, "\n"); for (StringRef Line : Lines) { - size_t Pos = Line.rfind(" "); - if (Pos == StringRef::npos) + CounterTy CurrCount; + SmallVector Fields; + if (!FdataRegex.match(Line, &Fields)) report_error(Filename, "Malformed / corrupted profile"); - StringRef Signature = Line.substr(0, Pos); - uint64_t Count; - if (Line.substr(Pos + 1, Line.size() - Pos).getAsInteger(10, Count)) -report_error(Filename, "Malformed / corrupted profile counter"); - Count += Profile->lookup(Signature); - Profile->insert_or_assign(Signature, Count); + StringRef Signature = Fields[1]; + if (Fields[2].getAsInteger(10, CurrCount.MispredCount)) +report_error(Filename, "Malformed / corrupted execution count"); + if (Fields[3].getAsInteger(10, CurrCount.Count)) +report_error(Filename, "Malformed / corrupted misprediction count"); + + CounterTy Counter = Profile->lookup(Signature); + Counter = Counter + CurrCount; + Profile->insert_or_assign(Signature, Counter); } }; @@ -330,14 +347,14 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { ProfileTy MergedProfile; for (const auto &[Thread, Profile] : ParsedProfiles) for (const auto &[Key, Value] : Profile) { - uint64_t Count = MergedProfile.lookup(Key) + Value; + auto Count = MergedProfile.lookup(Key) + Value; MergedProfile.insert_or_assign(Key, Count); } if (BoltedCollection.value_or(false)) output() << "boltedcollection\n"; for (const auto &[Key, Value] : MergedProfile) -output() << Key << " " << Value << "\n"; +output() << Key << " " << Value.MispredCount << " " << Value.Count << "\n"; errs() << "Profile from " << Filenames.size() << " files merged.\n"; } >From 040deec056483f8268d134ce826a20f1e511dcd4 Mon Sep 17 00:00:00 2001 From: Amir Ayupov Date: Wed, 11 Dec 2024 19:22:23 -0800 Subject: [PATCH 2/2] simplify Created using spr 1.3.4 --- bolt/tools/merge-fdata/merge-fdata.cpp | 24 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/bolt/tools/merge-fdata/merge-fdata.cpp b/bolt/tools/merge-fdata/merge-fdata.cpp index a74e9099df9059..83616e1de99254 100644 --- a/bolt/tools/merge-fdata/merge-fdata.cpp +++ b/bolt/tools/merge-fdata/merge-fdata.cpp @@ -271,13 +271,14 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { "(.*) ([0-9]+) ([0-9]+)"; Regex FdataRegex(FdataCountersPattern); struct CounterTy {
[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
https://github.com/aaupov updated https://github.com/llvm/llvm-project/pull/119652 >From cc7cea0276fef36e896ecef149ef680b66bb9c1f Mon Sep 17 00:00:00 2001 From: Amir Ayupov Date: Wed, 11 Dec 2024 19:11:07 -0800 Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?= =?UTF-8?q?itial=20version?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 --- bolt/test/merge-fdata.test | 8 ++ bolt/tools/merge-fdata/merge-fdata.cpp | 39 ++ 2 files changed, 36 insertions(+), 11 deletions(-) create mode 100644 bolt/test/merge-fdata.test diff --git a/bolt/test/merge-fdata.test b/bolt/test/merge-fdata.test new file mode 100644 index 00..0ea0f2cb4398b6 --- /dev/null +++ b/bolt/test/merge-fdata.test @@ -0,0 +1,8 @@ +## Reproduces an issue where counts were not accumulated by merge-fdata + +# RUN: split-file %s %t +# RUN: merge-fdata %t/fdata | FileCheck %s +# CHECK: 1 main 10 1 main 1a 1 20 +#--- fdata +1 main 10 1 main 1a 0 10 +1 main 10 1 main 1a 1 10 diff --git a/bolt/tools/merge-fdata/merge-fdata.cpp b/bolt/tools/merge-fdata/merge-fdata.cpp index 89ca46c1c0a8fa..a74e9099df9059 100644 --- a/bolt/tools/merge-fdata/merge-fdata.cpp +++ b/bolt/tools/merge-fdata/merge-fdata.cpp @@ -19,6 +19,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Regex.h" #include "llvm/Support/Signals.h" #include "llvm/Support/ThreadPool.h" #include @@ -266,7 +267,19 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { errs() << "Using legacy profile format.\n"; std::optional BoltedCollection; std::mutex BoltedCollectionMutex; - typedef StringMap ProfileTy; + constexpr static const char *const FdataCountersPattern = + "(.*) ([0-9]+) ([0-9]+)"; + Regex FdataRegex(FdataCountersPattern); + struct CounterTy { +uint64_t Count; +uint64_t MispredCount; +CounterTy &operator+(const CounterTy &O) { + Count += O.Count; + MispredCount += O.MispredCount; + return *this; +} + }; + typedef StringMap ProfileTy; auto ParseProfile = [&](const std::string &Filename, auto &Profiles) { const llvm::thread::id tid = llvm::this_thread::get_id(); @@ -304,15 +317,19 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { SmallVector Lines; SplitString(Buf, Lines, "\n"); for (StringRef Line : Lines) { - size_t Pos = Line.rfind(" "); - if (Pos == StringRef::npos) + CounterTy CurrCount; + SmallVector Fields; + if (!FdataRegex.match(Line, &Fields)) report_error(Filename, "Malformed / corrupted profile"); - StringRef Signature = Line.substr(0, Pos); - uint64_t Count; - if (Line.substr(Pos + 1, Line.size() - Pos).getAsInteger(10, Count)) -report_error(Filename, "Malformed / corrupted profile counter"); - Count += Profile->lookup(Signature); - Profile->insert_or_assign(Signature, Count); + StringRef Signature = Fields[1]; + if (Fields[2].getAsInteger(10, CurrCount.MispredCount)) +report_error(Filename, "Malformed / corrupted execution count"); + if (Fields[3].getAsInteger(10, CurrCount.Count)) +report_error(Filename, "Malformed / corrupted misprediction count"); + + CounterTy Counter = Profile->lookup(Signature); + Counter = Counter + CurrCount; + Profile->insert_or_assign(Signature, Counter); } }; @@ -330,14 +347,14 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { ProfileTy MergedProfile; for (const auto &[Thread, Profile] : ParsedProfiles) for (const auto &[Key, Value] : Profile) { - uint64_t Count = MergedProfile.lookup(Key) + Value; + auto Count = MergedProfile.lookup(Key) + Value; MergedProfile.insert_or_assign(Key, Count); } if (BoltedCollection.value_or(false)) output() << "boltedcollection\n"; for (const auto &[Key, Value] : MergedProfile) -output() << Key << " " << Value << "\n"; +output() << Key << " " << Value.MispredCount << " " << Value.Count << "\n"; errs() << "Profile from " << Filenames.size() << " files merged.\n"; } >From 040deec056483f8268d134ce826a20f1e511dcd4 Mon Sep 17 00:00:00 2001 From: Amir Ayupov Date: Wed, 11 Dec 2024 19:22:23 -0800 Subject: [PATCH 2/2] simplify Created using spr 1.3.4 --- bolt/tools/merge-fdata/merge-fdata.cpp | 24 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/bolt/tools/merge-fdata/merge-fdata.cpp b/bolt/tools/merge-fdata/merge-fdata.cpp index a74e9099df9059..83616e1de99254 100644 --- a/bolt/tools/merge-fdata/merge-fdata.cpp +++ b/bolt/tools/merge-fdata/merge-fdata.cpp @@ -271,13 +271,14 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { "(.*) ([0-9]+) ([0-9]+)"; Regex FdataRegex(FdataCountersPattern); struct CounterTy {
[llvm-branch-commits] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)
https://github.com/matthias-springer edited https://github.com/llvm/llvm-project/pull/116524 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)
https://github.com/matthias-springer edited https://github.com/llvm/llvm-project/pull/116524 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (PR #119975)
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/119975 Clean up `populateVectorToLLVMConversionPatterns` so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite. The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions. Depends on #119973. >From e5926b63835eec731c513d91ce9c451c429ca572 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 14 Dec 2024 17:19:18 +0100 Subject: [PATCH] [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` --- .../Vector/Transforms/LoweringPatterns.h | 4 +++ .../GPUCommon/GPUToLLVMConversion.cpp | 12 + .../VectorToLLVM/ConvertVectorToLLVM.cpp | 27 ++- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 6 - .../Conversion/GPUCommon/lower-vector.mlir| 4 +-- .../VectorToLLVM/vector-to-llvm.mlir | 5 6 files changed, 37 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 3d643c96b45008..c507b23c6d4de6 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); +/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where +/// n > 1. +void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 1497d662dcdbdd..2fe3b1302e5e5b 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -32,10 +32,12 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Error.h" @@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) void GpuToLLVMConversionPass::runOnOperation() { MLIRContext *context = &getContext(); + + // Perform progressive lowering of vector transfer operations. + { +RewritePatternSet patterns(&getContext()); +// Vector transfer ops with rank > 1 should be lowered with VectorToSCF. +vector::populateVectorTransferLoweringPatterns(patterns, + /*maxTransferRank=*/1); +(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + LowerToLLVMOptions options(context); options.useBarePtrCallConv = hostBarePtrCallConv; RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a9a07c323c7358..577b74bb7e0c26 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). /// Non-scalable versions of this operation are handled in Vector Transforms. -class VectorCreateMaskOpRewritePattern -: public OpRewritePattern { +class VectorCreateMaskOpConversion +: public OpConversionPattern { public: - explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + explicit VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt) - : OpRewritePattern(context), + : OpConversionPattern(context), force32BitVectorIndices(enableIndexOpt) {} - LogicalResult matchAndRewrite(vector::CreateMaskOp op, -PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor, +ConversionPatternRewriter &rewriter) const override { auto dstType = op.getType(); if (dstType.getRank() != 1 || !cast(dstType).isScalable()) return failure(); @@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
[llvm-branch-commits] [mlir] [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (PR #119975)
llvmbot wrote: @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) Changes Clean up `populateVectorToLLVMConversionPatterns` so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite. The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions. Depends on #119973. --- Full diff: https://github.com/llvm/llvm-project/pull/119975.diff 6 Files Affected: - (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+4) - (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+12) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14-13) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+5-1) - (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+2-2) - (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (-5) ``diff diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 3d643c96b45008..c507b23c6d4de6 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); +/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where +/// n > 1. +void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 1497d662dcdbdd..2fe3b1302e5e5b 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -32,10 +32,12 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Error.h" @@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) void GpuToLLVMConversionPass::runOnOperation() { MLIRContext *context = &getContext(); + + // Perform progressive lowering of vector transfer operations. + { +RewritePatternSet patterns(&getContext()); +// Vector transfer ops with rank > 1 should be lowered with VectorToSCF. +vector::populateVectorTransferLoweringPatterns(patterns, + /*maxTransferRank=*/1); +(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + LowerToLLVMOptions options(context); options.useBarePtrCallConv = hostBarePtrCallConv; RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a9a07c323c7358..577b74bb7e0c26 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). /// Non-scalable versions of this operation are handled in Vector Transforms. -class VectorCreateMaskOpRewritePattern -: public OpRewritePattern { +class VectorCreateMaskOpConversion +: public OpConversionPattern { public: - explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + explicit VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt) - : OpRewritePattern(context), + : OpConversionPattern(context), force32BitVectorIndices(enableIndexOpt) {} - LogicalResult matchAndRewrite(vector::CreateMaskOp op, -PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor, +ConversionPatternRewriter &rewriter) const override { auto dstType = op.getType(); if (dstType.getRank() != 1 || !cast(dstType).isScalable()) return failure(); @@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern loc, LLVM::getVectorType(idxType, dstType.getShape()[0], /*isScalable=*/true)); auto bound = g
[llvm-branch-commits] [mlir] [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (PR #119975)
llvmbot wrote: @llvm/pr-subscribers-mlir-vector Author: Matthias Springer (matthias-springer) Changes Clean up `populateVectorToLLVMConversionPatterns` so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite. The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions. Depends on #119973. --- Full diff: https://github.com/llvm/llvm-project/pull/119975.diff 6 Files Affected: - (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+4) - (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+12) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14-13) - (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+5-1) - (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+2-2) - (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (-5) ``diff diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 3d643c96b45008..c507b23c6d4de6 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); +/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where +/// n > 1. +void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 1497d662dcdbdd..2fe3b1302e5e5b 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -32,10 +32,12 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Error.h" @@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) void GpuToLLVMConversionPass::runOnOperation() { MLIRContext *context = &getContext(); + + // Perform progressive lowering of vector transfer operations. + { +RewritePatternSet patterns(&getContext()); +// Vector transfer ops with rank > 1 should be lowered with VectorToSCF. +vector::populateVectorTransferLoweringPatterns(patterns, + /*maxTransferRank=*/1); +(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + LowerToLLVMOptions options(context); options.useBarePtrCallConv = hostBarePtrCallConv; RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a9a07c323c7358..577b74bb7e0c26 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). /// Non-scalable versions of this operation are handled in Vector Transforms. -class VectorCreateMaskOpRewritePattern -: public OpRewritePattern { +class VectorCreateMaskOpConversion +: public OpConversionPattern { public: - explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, + explicit VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt) - : OpRewritePattern(context), + : OpConversionPattern(context), force32BitVectorIndices(enableIndexOpt) {} - LogicalResult matchAndRewrite(vector::CreateMaskOp op, -PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor, +ConversionPatternRewriter &rewriter) const override { auto dstType = op.getType(); if (dstType.getRank() != 1 || !cast(dstType).isScalable()) return failure(); @@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern loc, LLVM::getVectorType(idxType, dstType.getShape()[0], /*isScalable=*/true)); auto bo
[llvm-branch-commits] [mlir] [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (PR #119975)
github-actions[bot] wrote: :warning: C/C++ code formatter, clang-format found issues in your code. :warning: You can test this locally with the following command: ``bash git-clang-format --diff bf0d13553b2bc2124a266e398976ba80a1114580 e5926b63835eec731c513d91ce9c451c429ca572 --extensions cpp,h -- mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp `` View the diff from clang-format here. ``diff diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 577b74bb7e..9657f583c3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1479,12 +1479,13 @@ class VectorCreateMaskOpConversion : public OpConversionPattern { public: explicit VectorCreateMaskOpConversion(MLIRContext *context, -bool enableIndexOpt) +bool enableIndexOpt) : OpConversionPattern(context), force32BitVectorIndices(enableIndexOpt) {} - LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor, -ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto dstType = op.getType(); if (dstType.getRank() != 1 || !cast(dstType).isScalable()) return failure(); `` https://github.com/llvm/llvm-project/pull/119975 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (PR #119975)
https://github.com/banach-space approved this pull request. The non-GPU changes LGTM. The CHECK lines removed in tests were just dead code, so thanks for the clean-up! The GPU parts look reasonable, but it might be worth waiting a few days in case someone more experienced wants to take a look. If there are no comments, I would just land this. https://github.com/llvm/llvm-project/pull/119975 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [DirectX] Split resource info into type and binding info. NFC (PR #119773)
https://github.com/hekota edited https://github.com/llvm/llvm-project/pull/119773 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [DirectX] Split resource info into type and binding info. NFC (PR #119773)
https://github.com/hekota edited https://github.com/llvm/llvm-project/pull/119773 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [DirectX] Split resource info into type and binding info. NFC (PR #119773)
@@ -303,44 +289,157 @@ class ResourceInfo { dxil::SamplerFeedbackType getFeedbackType() const; uint32_t getMultiSampleCount() const; - StringRef getName() const { -// TODO: Get the name from the symbol once we include one here. -return ""; - } dxil::ResourceClass getResourceClass() const { return RC; } dxil::ResourceKind getResourceKind() const { return Kind; } + bool operator==(const ResourceTypeInfo &RHS) const; + bool operator!=(const ResourceTypeInfo &RHS) const { return !(*this == RHS); } + bool operator<(const ResourceTypeInfo &RHS) const; + + void print(raw_ostream &OS, const DataLayout &DL) const; +}; + +//===--===// + +class ResourceBindingInfo { +public: + struct ResourceBinding { +uint32_t RecordID; +uint32_t Space; +uint32_t LowerBound; +uint32_t Size; + +bool operator==(const ResourceBinding &RHS) const { + return std::tie(RecordID, Space, LowerBound, Size) == + std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size); +} +bool operator!=(const ResourceBinding &RHS) const { + return !(*this == RHS); +} +bool operator<(const ResourceBinding &RHS) const { + return std::tie(RecordID, Space, LowerBound, Size) < + std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size); +} + }; + +private: + ResourceBinding Binding; + TargetExtType *HandleTy; + +public: + ResourceBindingInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound, + uint32_t Size, TargetExtType *HandleTy) + : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy) {} + void setBindingID(unsigned ID) { Binding.RecordID = ID; } const ResourceBinding &getBinding() const { return Binding; } + TargetExtType *getHandleTy() const { return HandleTy; } + const StringRef getName() const { +// TODO: Get the name from the symbol once we include one here. +return ""; + } - MDTuple *getAsMetadata(Module &M) const; - std::pair getAnnotateProps(Module &M) const; + MDTuple *getAsMetadata(Module &M, DXILResourceTypeMap &DRTM) const; + MDTuple *getAsMetadata(Module &M, dxil::ResourceTypeInfo RTI) const; - bool operator==(const ResourceInfo &RHS) const; - bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); } - bool operator<(const ResourceInfo &RHS) const; + std::pair + getAnnotateProps(Module &M, DXILResourceTypeMap &DRTM) const; + std::pair + getAnnotateProps(Module &M, dxil::ResourceTypeInfo RTI) const; - void print(raw_ostream &OS, const DataLayout &DL) const; + bool operator==(const ResourceBindingInfo &RHS) const { +return std::tie(Binding, HandleTy) == std::tie(RHS.Binding, RHS.HandleTy); + } + bool operator!=(const ResourceBindingInfo &RHS) const { +return !(*this == RHS); + } + bool operator<(const ResourceBindingInfo &RHS) const { +return Binding < RHS.Binding; + } + + void print(raw_ostream &OS, DXILResourceTypeMap &DRTM, + const DataLayout &DL) const; + void print(raw_ostream &OS, dxil::ResourceTypeInfo RTI, + const DataLayout &DL) const; }; } // namespace dxil //===--===// -class DXILResourceMap { - SmallVector Infos; +class DXILResourceTypeMap { + struct Info { +dxil::ResourceClass RC; +dxil::ResourceKind Kind; +bool GloballyCoherent; +bool HasCounter; + }; + DenseMap Infos; + +public: + bool invalidate(Module &M, const PreservedAnalyses &PA, + ModuleAnalysisManager::Invalidator &Inv); + + dxil::ResourceTypeInfo operator[](TargetExtType *Ty) { +Info I = Infos[Ty]; +return dxil::ResourceTypeInfo(Ty, I.RC, I.Kind, I.GloballyCoherent, + I.HasCounter); + } + + void setGloballyCoherent(TargetExtType *Ty, bool GloballyCoherent) { +Infos[Ty].GloballyCoherent = GloballyCoherent; + } + + void setHasCounter(TargetExtType *Ty, bool HasCounter) { +Infos[Ty].HasCounter = HasCounter; + } +}; + +class DXILResourceTypeAnalysis +: public AnalysisInfoMixin { + friend AnalysisInfoMixin; + + static AnalysisKey Key; + +public: + using Result = DXILResourceTypeMap; + + DXILResourceTypeMap run(Module &M, ModuleAnalysisManager &AM) { +return Result(); + } hekota wrote: It would be nice to add a comment on this class explaining how will the DXILResourceTypeMap get populated (that it happens in the DXILResourceBindingAnalysis, if I understand it correctly). https://github.com/llvm/llvm-project/pull/119773 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [DirectX] Split resource info into type and binding info. NFC (PR #119773)
https://github.com/hekota edited https://github.com/llvm/llvm-project/pull/119773 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [DirectX] Split resource info into type and binding info. NFC (PR #119773)
https://github.com/hekota edited https://github.com/llvm/llvm-project/pull/119773 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [DirectX] Create symbols for resource handles (PR #119775)
@@ -0,0 +1,48 @@ +; RUN: opt -S -passes=dxil-translate-metadata %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.6-compute" + +%struct.S = type { <4 x float>, <4 x i32> } + +define void @test() { + ; Buffer + %float4 = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0) + @llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false) + ; CHECK: %TypedBuffer = type { <4 x float> } + + ; Buffer + %int = call target("dx.TypedBuffer", i32, 0, 0, 1) + @llvm.dx.handle.fromBinding(i32 0, i32 1, i32 1, i32 0, i1 false) + ; CHECK: %TypedBuffer.0 = type { i32 } + + ; Buffer + %uint3 = call target("dx.TypedBuffer", <3 x i32>, 0, 0, 0) + @llvm.dx.handle.fromBinding(i32 0, i32 2, i32 1, i32 0, i1 false) + ; CHECK: %TypedBuffer.1 = type { <3 x i32> } + + ; StructuredBuffer + %struct0 = call target("dx.RawBuffer", %struct.S, 0, 0) + @llvm.dx.handle.fromBinding(i32 0, i32 10, i32 1, i32 0, i1 true) + ; CHECK: %StructuredBuffer = type { %struct.S } + + ; ByteAddressBuffer + %byteaddr = call target("dx.RawBuffer", i8, 0, 0) + @llvm.dx.handle.fromBinding(i32 0, i32 20, i32 1, i32 0, i1 false) + ; CHECK: %ByteAddressBuffer = type { i32 } + + ret void +} + +; CHECK: @[[T0:.*]] = external constant %TypedBuffer +; CHECK-NEXT: @[[T1:.*]] = external constant %TypedBuffer.0 +; CHECK-NEXT: @[[T2:.*]] = external constant %TypedBuffer.1 +; CHECK-NEXT: @[[S0:.*]] = external constant %StructuredBuffer +; CHECK-NEXT: @[[B0:.*]] = external constant %ByteAddressBuffer + +; CHECK: !{i32 0, ptr @[[T0]], !"" hekota wrote: Is the name of the symbol expected to be empty string here? https://github.com/llvm/llvm-project/pull/119775 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [DirectX] Create symbols for resource handles (PR #119775)
hekota wrote: If there already is a global variable for the resource in the module, shouldn't we be using that instead of creating a new symbol? https://github.com/llvm/llvm-project/pull/119775 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [NFC][BoundsChecking] Add TrapBB local variable (PR #119983)
llvmbot wrote: @llvm/pr-subscribers-llvm-transforms Author: Vitaly Buka (vitalybuka) Changes --- Full diff: https://github.com/llvm/llvm-project/pull/119983.diff 1 Files Affected: - (modified) llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp (+4-2) ``diff diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index c86d967716a5a0..c4511d574f2185 100644 --- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -126,16 +126,18 @@ static void insertBoundsCheck(Value *Or, BuilderTy &IRB, GetTrapBBT GetTrapBB) { BasicBlock *Cont = OldBB->splitBasicBlock(SplitI); OldBB->getTerminator()->eraseFromParent(); + BasicBlock * TrapBB = GetTrapBB(IRB); + if (C) { // If we have a constant zero, unconditionally branch. // FIXME: We should really handle this differently to bypass the splitting // the block. -BranchInst::Create(GetTrapBB(IRB), OldBB); +BranchInst::Create(TrapBB, OldBB); return; } // Create the conditional branch. - BranchInst::Create(GetTrapBB(IRB), Cont, Or, OldBB); + BranchInst::Create(TrapBB, Cont, Or, OldBB); } static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, `` https://github.com/llvm/llvm-project/pull/119983 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [NFC][BoundsChecking] Add TrapBB local variable (PR #119983)
https://github.com/vitalybuka converted_to_draft https://github.com/llvm/llvm-project/pull/119983 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [NFC][BoundsChecking] Add TrapBB local variable (PR #119983)
github-actions[bot] wrote: :warning: C/C++ code formatter, clang-format found issues in your code. :warning: You can test this locally with the following command: ``bash git-clang-format --diff 8bfb53a797a0d1cb2290f188ef23418c90b7d254 daf365e3e7fdc383e18e8171f1aa89a6ce8f72ec --extensions cpp -- llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp `` View the diff from clang-format here. ``diff diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index c4511d574f..769fbe92b0 100644 --- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -126,7 +126,7 @@ static void insertBoundsCheck(Value *Or, BuilderTy &IRB, GetTrapBBT GetTrapBB) { BasicBlock *Cont = OldBB->splitBasicBlock(SplitI); OldBB->getTerminator()->eraseFromParent(); - BasicBlock * TrapBB = GetTrapBB(IRB); + BasicBlock *TrapBB = GetTrapBB(IRB); if (C) { // If we have a constant zero, unconditionally branch. `` https://github.com/llvm/llvm-project/pull/119983 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [NFC][BoundsChecking] Add TrapBB local variable (PR #119983)
https://github.com/vitalybuka created https://github.com/llvm/llvm-project/pull/119983 None ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
@@ -330,14 +347,14 @@ void mergeLegacyProfiles(const SmallVectorImpl &Filenames) { ProfileTy MergedProfile; for (const auto &[Thread, Profile] : ParsedProfiles) for (const auto &[Key, Value] : Profile) { - uint64_t Count = MergedProfile.lookup(Key) + Value; + auto Count = MergedProfile.lookup(Key) + Value; aaupov wrote: The type of `Count` is changed, made explicit in updated version https://github.com/llvm/llvm-project/pull/119652 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
https://github.com/aaupov edited https://github.com/llvm/llvm-project/pull/119652 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
aaupov wrote: > Since you are switching to regex, are there any changes in parsing > performance? With updated version (on top of #119942), it's neutral/slightly faster than current merge-fdata > 1.04 ± 0.07 times faster https://github.com/llvm/llvm-project/pull/119652 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
https://github.com/aaupov edited https://github.com/llvm/llvm-project/pull/119652 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [llvm] [BOLT] Fix counts aggregation in merge-fdata (PR #119652)
https://github.com/maksfb approved this pull request. LGTM https://github.com/llvm/llvm-project/pull/119652 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits