Author: Qinkun Bao Date: 2025-06-24T10:09:20-04:00 New Revision: b8672c3278bf3ee83e8c44053d03558632ba46e0
URL: https://github.com/llvm/llvm-project/commit/b8672c3278bf3ee83e8c44053d03558632ba46e0 DIFF: https://github.com/llvm/llvm-project/commit/b8672c3278bf3ee83e8c44053d03558632ba46e0.diff LOG: Revert "[mlir][mesh] adding option for traversal order in sharding propagatio…" This reverts commit 43e1a5a411d972fe06a1afb86ffd5ba21fd2a376. Added: Modified: mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp Removed: mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir ################################################################################ diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index c4d512b60bc51..1dc178586e918 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -206,6 +206,9 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding); // Use newShardOp if it is not null. Otherwise create a new one. // May insert resharding if required. // Potentially updates newShardOp. +void maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, OpBuilder &builder, + ShardOp &newShardOp); void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder); void maybeInsertSourceShardingAnnotation(MeshSharding sharding, diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h index a2424d43a8ba9..83399d10beaae 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h @@ -19,18 +19,6 @@ class FuncOp; namespace mesh { -/// This enum controls the traversal order for the sharding propagation. -enum class TraversalOrder { - /// Forward traversal. - Forward, - /// Backward traversal. - Backward, - /// Forward then backward traversal. - ForwardBackward, - /// Backward then forward traversal. - BackwardForward -}; - //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td index 11ec7e78cd5e6..06ebf151e7d64 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td @@ -24,21 +24,6 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO operation, and the operations themselves are added with sharding option attributes. }]; - let options = [ - Option<"traversal", "traversal", - "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward", - "Traversal order to use for sharding propagation:", - [{::llvm::cl::values( - clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward", - "Forward only traversal."), - clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward", - "backward only traversal."), - clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward", - "forward-backward traversal."), - clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward", - "backward-forward traversal.") - )}]>, - ]; let dependentDialects = [ "mesh::MeshDialect" ]; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index b8cc91da722f0..0a01aaf776e7d 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -298,12 +298,13 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { return type; } -static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, - Value &operandValue, - Operation *operandOp, - OpBuilder &builder, - ShardOp &newShardOp) { +void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, + OpOperand &operand, + OpBuilder &builder, + ShardOp &newShardOp) { OpBuilder::InsertionGuard insertionGuard(builder); + Value operandValue = operand.get(); + Operation *operandOp = operand.getOwner(); builder.setInsertionPointAfterValue(operandValue); ShardOp shardOp = dyn_cast<ShardOp>(operandOp); if (shardOp && sharding == shardOp.getSharding() && @@ -322,8 +323,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp, /*annotate_for_users*/ false); } - operandValue.replaceUsesWithIf( - newShardOp, [operandOp, operandValue](OpOperand &use) { + IRRewriter rewriter(builder); + rewriter.replaceUsesWithIf( + operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { return use.getOwner() == operandOp && use.get() == operandValue; }); @@ -334,20 +336,15 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp, newShardOp.getSharding(), /*annotate_for_users*/ true); - newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2); + rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); } void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder) { ShardOp newShardOp; - SmallVector<std::pair<Value, Operation *>> uses; - for (auto &use : result.getUses()) { - uses.emplace_back(use.get(), use.getOwner()); - } - for (auto &[operandValue, operandOp] : uses) { - maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp, - builder, newShardOp); + for (auto &use : llvm::make_early_inc_range(result.getUses())) { + maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); } } diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp index 6751fafaf1776..4452dd65fce9d 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp @@ -362,9 +362,6 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { //===----------------------------------------------------------------------===// struct ShardingPropagation : public mesh::impl::ShardingPropagationBase<ShardingPropagation> { - - using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase; - void runOnOperation() override { FunctionOpInterface funcOp = getOperation(); MLIRContext *ctx = funcOp.getContext(); @@ -385,31 +382,18 @@ struct ShardingPropagation shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); }); - auto traverse = [&](auto &&range, OpBuilder &builder, - const char *order) -> bool { - for (Operation &op : range) { - if (failed(visitOp(&op, builder))) { - signalPassFailure(); - return true; - } - } - LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n" - << funcOp << "\n"); - LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp)))); - return false; - }; - - // 1. Propagate in reversed order. - if (traversal == TraversalOrder::Backward || - traversal == TraversalOrder::BackwardForward) - traverse(llvm::reverse(block), builder, "backward"); - - // 2. Propagate in original order. - if (traversal != TraversalOrder::Backward) - traverse(block, builder, "forward"); - - // 3. Propagate in backward order if needed. - if (traversal == TraversalOrder::ForwardBackward) - traverse(llvm::reverse(block), builder, "backward"); + // 1. propagate in reversed order + for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block))) + if (failed(visitOp(&op, builder))) + return signalPassFailure(); + + LLVM_DEBUG(DBGS() << "After reversed order propagation:\n" + << funcOp << "\n"); + LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp)))); + + // 2. propagate in original order + for (Operation &op : llvm::make_early_inc_range(block)) + if (failed(visitOp(&op, builder))) + return signalPassFailure(); } }; diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir deleted file mode 100644 index 4223d01d65111..0000000000000 --- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -module { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} - func.func @test_forward() -> tensor<6x6xi32> { - %c1_i32 = arith.constant 1 : i32 - // CHECK: tensor.empty() - %0 = tensor.empty() : tensor<6x6xi32> - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - // CHECK-COUNT-2: mesh.shard - %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32> - %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32> - // CHECK: tensor.empty() - // CHECK-NOT: mesh.shard @ - %2 = tensor.empty() : tensor<6x6xi32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1 - : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { - ^bb0(%in: i32, %in_2: i32, %out: i32): - %9 = arith.addi %in, %in_2 : i32 - linalg.yield %9 : i32 - } -> tensor<6x6xi32> - // CHECK: return - return %3 : tensor<6x6xi32> - } -} diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir deleted file mode 100644 index dd2eee2f7def8..0000000000000 --- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -module { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} - func.func @test_forward() -> tensor<6x6xi32> { - %c1_i32 = arith.constant 1 : i32 - // CHECK: tensor.empty() - %0 = tensor.empty() : tensor<6x6xi32> - // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]] - %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32> - %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32> - %2 = tensor.empty() : tensor<6x6xi32> - // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]] - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1 - : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { - ^bb0(%in: i32, %in_2: i32, %out: i32): - %9 = arith.addi %in, %in_2 : i32 - linalg.yield %9 : i32 - } -> tensor<6x6xi32> - %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding - %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32> - // CHECK: return - return %annotated_col : tensor<6x6xi32> - } -} diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir deleted file mode 100644 index 98e9931b8de94..0000000000000 --- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} { - mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} - func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} { - %c1_i32 = arith.constant 1 : i32 - // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32> - %0 = tensor.empty() : tensor<6x6xi32> - // CHECK: [[v1:%.*]] = linalg.fill ins - // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32> - %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32> - %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding - %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32> - // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32> - // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32> - %3 = tensor.empty() : tensor<6x6xi32> - // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32> - // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) { - // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding - // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32> - %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated - : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) { - ^bb0(%in: i32, %in_2: i32, %out: i32): - %9 = arith.addi %in, %in_2 : i32 - linalg.yield %9 : i32 - } -> tensor<6x6xi32> - %c0_i32 = arith.constant 0 : i32 - %6 = tensor.empty() : tensor<i32> - %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32> - // CHECK: [[vreduced:%.*]] = linalg.reduce ins - // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding - // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32> - %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] - (%in: i32, %init: i32) { - %9 = arith.addi %in, %init : i32 - linalg.yield %9 : i32 - } - // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding - %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding - // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32> - %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32> - return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32> - } -} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits