https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/76003
>From 860a2f794bdf12ff1f08d4802570757e805264b0 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Mon, 18 Dec 2023 15:53:41 -0600 Subject: [PATCH 1/7] [mlir][Linalg] Support dynamic sizes in `lower_pack` transform --- .../Linalg/TransformOps/LinalgTransformOps.td | 3 +- .../Dialect/Linalg/Transforms/Transforms.h | 2 +- .../Dialect/Linalg/Transforms/Transforms.cpp | 69 +++++++++++++------ .../Dialect/Linalg/transform-lower-pack.mlir | 20 ++++++ 4 files changed, 70 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 77ed9db5e71bd1..4abd3740b57105 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -498,7 +498,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [ let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target); let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op, - Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op, + Type<Or<[Transform_ConcreteOpType<"tensor.expand_shape">.predicate, + Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op); let assemblyFormat = [{ $target attr-dict `:` functional-type(operands, results) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a848d12fbbb50e..344e801835ccc9 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op, struct LowerPackResult { tensor::PadOp padOp; - tensor::ExpandShapeOp expandShapeOp; + Operation *expandShapeOp; linalg::TransposeOp transposeOp; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 9d230e2c2e5749..359274866748fc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -218,21 +218,11 @@ struct PackedOperandsDimList { FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, tensor::PackOp packOp) { - // 1. Filter out NYI cases. - auto packedTensorType = - cast<RankedTensorType>(packOp->getResultTypes().front()); - if (llvm::any_of(packOp.getStaticInnerTiles(), - [](int64_t size) { return ShapedType::isDynamic(size); })) { - return rewriter.notifyMatchFailure( - packOp, - "non-static shape NYI, needs a more powerful tensor.expand_shape op"); - } - Location loc = packOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(packOp); - // 2. Compute the permutation vector to shuffle packed shape into the shape + // 1. Compute the permutation vector to shuffle packed shape into the shape. // before any outer or inner permutations have been applied. The permutation // can be obtained from two permutations: // a) Compute the permutation vector to move the last `numPackedDims` into @@ -240,6 +230,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, // b) Compute the permutation vector to move outer dims if the pack op // has outer_dims_perm. // Apply (b) permutation on (a) permutation to get the final permutation. + auto packedTensorType = + cast<RankedTensorType>(packOp->getResultTypes().front()); int64_t numPackedDims = packOp.getInnerDimsPos().size(); int64_t packedRank = packedTensorType.getRank(); auto lastDims = llvm::to_vector( @@ -259,12 +251,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm; applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm); - // 3. Compute the stripMinedShape: this is the packed shape before any outer + // 2. Compute the stripMinedShape: this is the packed shape before any outer. // or inner permutations have been applied. SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); - // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. + // 3. Pad the source of packOp to a shape we can expand into stripMinedShape. SmallVector<OpFoldResult> lows(packOp.getSourceRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> highs(packOp.getSourceRank(), @@ -351,24 +343,57 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, /*transposeOp=*/nullptr}; } } - // 5. Expand from the padded result to the stripMinedShape. - auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>( - loc, - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), - padOp.getResult(), packingMetadata.reassociations); - // 6. Transpose stripMinedShape to packedShape. + RankedTensorType expandSourceType = padOp.getResult().getType().cast<RankedTensorType>(); + RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); + + // Dynamic dim is factorable only if the expanded version has at most one dynamic dim + bool isFactorable = true; + for (const auto &[i, rIndcs] : llvm::enumerate(packingMetadata.reassociations)) { + if (!expandSourceType.isDynamicDim(i)) + continue; + int64_t numDyn = 0; + for (auto j : rIndcs) { + if ((stripMinedShape[j] == ShapedType::kDynamic) && (++numDyn > 1)) { + isFactorable = false; + break; + } + } + } + + // 4. Expand from the padded result to the stripMinedShape. SmallVector<int64_t> transpPerm = invertPermutationVector(packedToStripMinedShapePerm); + Operation *reshapeOp; + if (!isFactorable) { + SmallVector<OpFoldResult> sizes = + tensor::getMixedSizes(rewriter, loc, packOp.getDest()); + applyPermutationToVector(sizes, transpPerm); + Value shapeInitTensor = + rewriter.create<tensor::EmptyOp>(loc, RankedTensorType::get({expandDestType.getRank()}, rewriter.getIndexType()), ValueRange{}); + Value shapeTensor = shapeInitTensor; + for (const auto &[i, size] : llvm::enumerate(sizes)) { + Value dim = (expandDestType.isDynamicDim(i)) ? cast<Value>(size) : rewriter.create<arith::ConstantIndexOp>(loc, getConstantIntValue(size).value()).getResult(); + shapeTensor = rewriter.create<tensor::InsertOp>(loc, dim, shapeTensor, SmallVector<Value>({rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()})); + } + reshapeOp = rewriter.create<tensor::ReshapeOp>(loc, expandDestType, padOp.getResult(), shapeTensor); + } else { + reshapeOp = rewriter.create<tensor::ExpandShapeOp>( + loc, + expandDestType, + padOp.getResult(), packingMetadata.reassociations); + } + + // 5. Transpose stripMinedShape to packedShape. auto transposeOp = rewriter.create<linalg::TransposeOp>( - loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); + loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "reshape op: " << reshapeOp; DBGSNL(); + DBGS() << "reshape op: " << &reshapeOp; DBGSNL(); llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); - // 7. Replace packOp by transposeOp. + // 6. Replace packOp by transposeOp. rewriter.replaceOp(packOp, transposeOp->getResults()); return LowerPackResult{padOp, reshapeOp, transposeOp}; diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 316df431a9c0c8..6a203dab91e58b 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -61,6 +61,26 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @pack_all_dyn( +func.func @pack_all_dyn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 + : tensor<?x?xf32> -> tensor<?x?x?x?xf32> + + return %pack : tensor<?x?x?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @pack_as_pad( func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 >From c8db4ac07c017dbdfbd8f91d47f32015ca9dce67 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Tue, 19 Dec 2023 19:11:22 -0600 Subject: [PATCH 2/7] Refactor --- .../Dialect/Linalg/Transforms/Transforms.cpp | 54 ++++++++++--------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 359274866748fc..21446d07b784a9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -344,44 +344,46 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, } } - RankedTensorType expandSourceType = padOp.getResult().getType().cast<RankedTensorType>(); - RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); - - // Dynamic dim is factorable only if the expanded version has at most one dynamic dim - bool isFactorable = true; - for (const auto &[i, rIndcs] : llvm::enumerate(packingMetadata.reassociations)) { - if (!expandSourceType.isDynamicDim(i)) - continue; - int64_t numDyn = 0; - for (auto j : rIndcs) { - if ((stripMinedShape[j] == ShapedType::kDynamic) && (++numDyn > 1)) { - isFactorable = false; - break; - } - } - } - // 4. Expand from the padded result to the stripMinedShape. + // Check if any dims are not factorable. A dim is factorable if the expansion + // requires at most dynamnic dim + RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); SmallVector<int64_t> transpPerm = invertPermutationVector(packedToStripMinedShapePerm); Operation *reshapeOp; - if (!isFactorable) { + if (llvm::any_of(packingMetadata.reassociations, + [&](const auto &rAssoc) -> bool { + return llvm::count_if(rAssoc, [&](int64_t r) { + return stripMinedShape[r] == ShapedType::kDynamic; + }) > 1; + })) { SmallVector<OpFoldResult> sizes = tensor::getMixedSizes(rewriter, loc, packOp.getDest()); applyPermutationToVector(sizes, transpPerm); - Value shapeInitTensor = - rewriter.create<tensor::EmptyOp>(loc, RankedTensorType::get({expandDestType.getRank()}, rewriter.getIndexType()), ValueRange{}); + // Create a `tensor` of `index` types for the `shape` operand of `tensor.reshape` + Value shapeInitTensor = rewriter.create<tensor::EmptyOp>( + loc, + RankedTensorType::get({expandDestType.getRank()}, + rewriter.getIndexType()), + ValueRange{}); Value shapeTensor = shapeInitTensor; for (const auto &[i, size] : llvm::enumerate(sizes)) { - Value dim = (expandDestType.isDynamicDim(i)) ? cast<Value>(size) : rewriter.create<arith::ConstantIndexOp>(loc, getConstantIntValue(size).value()).getResult(); - shapeTensor = rewriter.create<tensor::InsertOp>(loc, dim, shapeTensor, SmallVector<Value>({rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()})); + Value dim = (expandDestType.isDynamicDim(i)) + ? cast<Value>(size) + : rewriter + .create<arith::ConstantIndexOp>( + loc, getConstantIntValue(size).value()) + .getResult(); + shapeTensor = rewriter.create<tensor::InsertOp>( + loc, dim, shapeTensor, + SmallVector<Value>( + {rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()})); } - reshapeOp = rewriter.create<tensor::ReshapeOp>(loc, expandDestType, padOp.getResult(), shapeTensor); + reshapeOp = rewriter.create<tensor::ReshapeOp>( + loc, expandDestType, padOp.getResult(), shapeTensor); } else { reshapeOp = rewriter.create<tensor::ExpandShapeOp>( - loc, - expandDestType, - padOp.getResult(), packingMetadata.reassociations); + loc, expandDestType, padOp.getResult(), packingMetadata.reassociations); } // 5. Transpose stripMinedShape to packedShape. >From e68b32e372de420b2e6ece98e574836920014c54 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Tue, 19 Dec 2023 21:49:38 -0600 Subject: [PATCH 3/7] Add regression test --- .../Dialect/Linalg/transform-lower-pack.mlir | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 6a203dab91e58b..13d74cbe433264 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -61,11 +61,37 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func.func @pack_all_dyn( -func.func @pack_all_dyn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { - %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 - : tensor<?x?xf32> -> tensor<?x?x?x?xf32> - +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)> +// CHECK: func.func @pack_dyn_tiles( +// CHECK-SAME: %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]] +// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>, +// CHECK-SAME: %[[TILE0:.*]]: index, +// CHECK-SAME: %[[TILE1:.*]]: index +func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<?x?x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<?x?x?x?xf32> { +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]] +// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]] +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]] +// CHECK-NEXT: ^bb0 +// CHECK-NEXT: tensor.yield %[[CST]] : f32 +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]] +// CHECK-NEXT: %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex> +// CHECK-NEXT: %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]] +// CHECK-NEXT: %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]] +// CHECK-NEXT: %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]] +// CHECK-NEXT: %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]] +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]]) +// CHECK-NEXT: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3] + %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1 + : tensor<64x128xf32> -> tensor<?x?x?x?xf32> return %pack : tensor<?x?x?x?xf32> } >From 0975552abe2d404388af48eafc39b464f69a4834 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Tue, 19 Dec 2023 21:53:42 -0600 Subject: [PATCH 4/7] Fix comment --- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 21446d07b784a9..1f63d0ab706cdb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -345,12 +345,14 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, } // 4. Expand from the padded result to the stripMinedShape. - // Check if any dims are not factorable. A dim is factorable if the expansion - // requires at most dynamnic dim - RankedTensorType expandDestType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); + RankedTensorType expandDestType = + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); SmallVector<int64_t> transpPerm = invertPermutationVector(packedToStripMinedShapePerm); Operation *reshapeOp; + // Check if any dims are not factorable and thus need a `tensor.reshape` + // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion + // requires at most dynamnic dim if (llvm::any_of(packingMetadata.reassociations, [&](const auto &rAssoc) -> bool { return llvm::count_if(rAssoc, [&](int64_t r) { @@ -360,7 +362,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, SmallVector<OpFoldResult> sizes = tensor::getMixedSizes(rewriter, loc, packOp.getDest()); applyPermutationToVector(sizes, transpPerm); - // Create a `tensor` of `index` types for the `shape` operand of `tensor.reshape` + // Create a `tensor` of `index` types for the `shape` operand of + // `tensor.reshape` Value shapeInitTensor = rewriter.create<tensor::EmptyOp>( loc, RankedTensorType::get({expandDestType.getRank()}, >From 48deca06d650959ba3727df9697566a0fd6a6cd2 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Tue, 19 Dec 2023 22:31:12 -0600 Subject: [PATCH 5/7] Properly check optional value --- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 1f63d0ab706cdb..2a1c72942df0bb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -371,12 +371,15 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, ValueRange{}); Value shapeTensor = shapeInitTensor; for (const auto &[i, size] : llvm::enumerate(sizes)) { - Value dim = (expandDestType.isDynamicDim(i)) - ? cast<Value>(size) - : rewriter - .create<arith::ConstantIndexOp>( - loc, getConstantIntValue(size).value()) - .getResult(); + auto maybeConstInt = getConstantIntValue(size); + assert(maybeConstInt.has_value() || + expandDestType.isDynamicDim(i) && "expected dynamic dim"); + Value dim = + (maybeConstInt.has_value()) + ? rewriter + .create<arith::ConstantIndexOp>(loc, maybeConstInt.value()) + .getResult() + : cast<Value>(size); shapeTensor = rewriter.create<tensor::InsertOp>( loc, dim, shapeTensor, SmallVector<Value>( >From 194f8194659908f8127b99a807033192e1477def Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Tue, 19 Dec 2023 22:37:10 -0600 Subject: [PATCH 6/7] Revert accidental change --- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 2a1c72942df0bb..6018d58b94eb72 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -372,8 +372,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, Value shapeTensor = shapeInitTensor; for (const auto &[i, size] : llvm::enumerate(sizes)) { auto maybeConstInt = getConstantIntValue(size); - assert(maybeConstInt.has_value() || - expandDestType.isDynamicDim(i) && "expected dynamic dim"); + assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) && + "expected dynamic dim"); Value dim = (maybeConstInt.has_value()) ? rewriter @@ -397,7 +397,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "reshape op: " << &reshapeOp; DBGSNL(); + DBGS() << "reshape op: " << reshapeOp; DBGSNL(); llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); >From cf0cb00740031db55929dfc058a87d363680cf5a Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Tue, 19 Dec 2023 23:00:44 -0600 Subject: [PATCH 7/7] Add clarifying comment --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 344e801835ccc9..06e8586f4288b4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op, struct LowerPackResult { tensor::PadOp padOp; - Operation *expandShapeOp; + Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp` linalg::TransposeOp transposeOp; }; _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits