Author: Eugene Zhulenev
Date: 2020-12-08T10:30:14-08:00
New Revision: 94e645f9cce8fba26b4aec069103794f1779065f

URL: 
https://github.com/llvm/llvm-project/commit/94e645f9cce8fba26b4aec069103794f1779065f
DIFF: 
https://github.com/llvm/llvm-project/commit/94e645f9cce8fba26b4aec069103794f1779065f.diff

LOG: [mlir] Async: Add numWorkerThreads argument to createAsyncParallelForPass

Add an option to pass the number of worker threads to select the number of 
async regions for parallel for transformation.
```
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass(int 
numWorkerThreads);
```

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D92835

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h 
b/mlir/include/mlir/Dialect/Async/Passes.h
index 9716bde76593..ab5abdc28611 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -19,6 +19,9 @@ namespace mlir {
 
 std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
 
+std::unique_ptr<OperationPass<FuncOp>>
+createAsyncParallelForPass(int numWorkerThreads);
+
 std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
 
 std::unique_ptr<OperationPass<FuncOp>> 
createAsyncRefCountingOptimizationPass();

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp 
b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index c6508610c796..d6553974bc38 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -96,6 +96,10 @@ struct AsyncParallelForRewrite : public 
OpRewritePattern<scf::ParallelOp> {
 struct AsyncParallelForPass
     : public AsyncParallelForBase<AsyncParallelForPass> {
   AsyncParallelForPass() = default;
+  AsyncParallelForPass(int numWorkerThreads) {
+    assert(numWorkerThreads >= 1);
+    numConcurrentAsyncExecute = numWorkerThreads;
+  }
   void runOnFunction() override;
 };
 
@@ -276,3 +280,8 @@ void AsyncParallelForPass::runOnFunction() {
 std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
   return std::make_unique<AsyncParallelForPass>();
 }
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncParallelForPass(int numWorkerThreads) {
+  return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
+}


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to