https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/72801
>From 8abbf36f741c8363155e0f3cbf2450ff7f1f0801 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Sun, 19 Nov 2023 18:31:38 +0100 Subject: [PATCH 1/3] [mlir][async] Avoid crash when not using `func.func` --- .../Async/Transforms/AsyncParallelFor.cpp | 4 ++++ .../Async/async-parallel-for-compute-fn.mlir | 19 +++++++++++++++++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 2 ++ 3 files changed, 25 insertions(+) diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index 12a28c2e23b221a..639bc7f9ec7f112 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -102,6 +102,10 @@ struct AsyncParallelForPass : public impl::AsyncParallelForBase<AsyncParallelForPass> { AsyncParallelForPass() = default; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<async::AsyncDialect, func::FuncDialect>(); + } + AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, int32_t minTaskSize) { this->asyncDispatch = asyncDispatch; diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir index 2115b1881fa6d66..fa3b53dd839c6c6 100644 --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -69,6 +69,25 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) { // ----- +// Smoke test that parallel for doesn't crash when func dialect is not loaded. + +// CHECK-LABEL: llvm.func @without_func_dialect() +llvm.func @without_func_dialect() { + %cst = arith.constant 0.0 : f32 + + %c0 = arith.constant 0 : index + %c22 = arith.constant 22 : index + %c1 = arith.constant 1 : index + %54 = memref.alloc() : memref<22xf32> + %alloc_4 = memref.alloc() : memref<22xf32> + scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) { + memref.store %cst, %alloc_4[%arg0] : memref<22xf32> + } + llvm.return +} + +// ----- + // Check that for statically known inner loop bound block size is aligned and // inner loop uses statically known loop trip counts. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 842964b853d084d..963c52fd4191657 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() { const char *const getAttrName = R"( assert(index < {0} && "invalid attribute index"); assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new " + "op created without having initialized its dialect?"); return name.getAttributeNames()[index]; )"; method->body() << formatv(getAttrName, attributes.size()); >From eb09cc895d7d1c08f745df22345cd0fae5432c7a Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Mon, 20 Nov 2023 19:23:49 +0100 Subject: [PATCH 2/3] Declare dependentDialects in `Passes.td` --- mlir/include/mlir/Dialect/Async/Passes.td | 1 + mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td index c7ee4ba39aecdf0..f0ef83ca3fd4f1a 100644 --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -36,6 +36,7 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> { let dependentDialects = [ "arith::ArithDialect", "async::AsyncDialect", + "func::FuncDialect", "scf::SCFDialect" ]; } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index 639bc7f9ec7f112..12a28c2e23b221a 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -102,10 +102,6 @@ struct AsyncParallelForPass : public impl::AsyncParallelForBase<AsyncParallelForPass> { AsyncParallelForPass() = default; - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert<async::AsyncDialect, func::FuncDialect>(); - } - AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, int32_t minTaskSize) { this->asyncDispatch = asyncDispatch; >From 77ba982eba8f7511543e9e06864a15c839feece8 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Mon, 20 Nov 2023 21:19:37 +0100 Subject: [PATCH 3/3] Update assertion --- mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir | 2 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir index fa3b53dd839c6c6..6f068c0e8d74cc7 100644 --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -69,7 +69,7 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) { // ----- -// Smoke test that parallel for doesn't crash when func dialect is not loaded. +// Smoke test that parallel for doesn't crash when func dialect is not used. // CHECK-LABEL: llvm.func @without_func_dialect() llvm.func @without_func_dialect() { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 963c52fd4191657..57392434285ff89 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1143,8 +1143,8 @@ void OpEmitter::genAttrNameGetters() { const char *const getAttrName = R"( assert(index < {0} && "invalid attribute index"); assert(name.getStringRef() == getOperationName() && "invalid operation name"); - assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new " - "op created without having initialized its dialect?"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); return name.getAttributeNames()[index]; )"; method->body() << formatv(getAttrName, attributes.size()); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits