Author: Eugene Zhulenev Date: 2020-12-08T10:30:14-08:00 New Revision: 94e645f9cce8fba26b4aec069103794f1779065f
URL: https://github.com/llvm/llvm-project/commit/94e645f9cce8fba26b4aec069103794f1779065f DIFF: https://github.com/llvm/llvm-project/commit/94e645f9cce8fba26b4aec069103794f1779065f.diff LOG: [mlir] Async: Add numWorkerThreads argument to createAsyncParallelForPass Add an option to pass the number of worker threads to select the number of async regions for parallel for transformation. ``` std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass(int numWorkerThreads); ``` Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D92835 Added: Modified: mlir/include/mlir/Dialect/Async/Passes.h mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h index 9716bde76593..ab5abdc28611 100644 --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -19,6 +19,9 @@ namespace mlir { std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass(); +std::unique_ptr<OperationPass<FuncOp>> +createAsyncParallelForPass(int numWorkerThreads); + std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass(); std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index c6508610c796..d6553974bc38 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -96,6 +96,10 @@ struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { struct AsyncParallelForPass : public AsyncParallelForBase<AsyncParallelForPass> { AsyncParallelForPass() = default; + AsyncParallelForPass(int numWorkerThreads) { + assert(numWorkerThreads >= 1); + numConcurrentAsyncExecute = numWorkerThreads; + } void runOnFunction() override; }; @@ -276,3 +280,8 @@ void AsyncParallelForPass::runOnFunction() { std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() { return std::make_unique<AsyncParallelForPass>(); } + +std::unique_ptr<OperationPass<FuncOp>> +mlir::createAsyncParallelForPass(int numWorkerThreads) { + return std::make_unique<AsyncParallelForPass>(numWorkerThreads); +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits