Author: Benjamin Kramer Date: 2020-12-10T10:49:25+01:00 New Revision: 1d00508c5bf0d43203e11765ce84cdd6cf257856
URL: https://github.com/llvm/llvm-project/commit/1d00508c5bf0d43203e11765ce84cdd6cf257856 DIFF: https://github.com/llvm/llvm-project/commit/1d00508c5bf0d43203e11765ce84cdd6cf257856.diff LOG: [mlir][Shape] Make sure tensor_cast(constant_shape) folding uses the correct type This is still subtle, but I think the test cases are sufficient to show that it works. Differential Revision: https://reviews.llvm.org/D92927 Added: Modified: mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td mlir/test/Dialect/Shape/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td index 43c670a8582e..4e6d062a232f 100644 --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -32,5 +32,7 @@ def SizeToIndexToSizeCanonicalization : Pat< (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)), (replaceWithValue $arg)>; +// Fold tensor_cast(const_shape) to const_shape. This changes the type of +// const_shape to the destination type of the cast. def TensorCastConstShape : Pat < - (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>; + (TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 9cb01da75901..aa43f515f753 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -872,13 +872,24 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) { // ----- -// Fold tensor_cast of a const_shape to const_shape -// CHECK-LABEL: @fold_tensor_cast_of_const_shape -func @fold_tensor_cast_of_const_shape(%arg: tensor<?xindex>) { +// Verify that tensor_cast folding uses the correct type +// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned +func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { + // CHECK: constant dense<2> : tensor<1xindex> // CHECK-NOT: tensor_cast %0 = shape.const_shape [2] : tensor<?xindex> %1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex> - %2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor<?xindex> - "consume.witness"(%2) : (!shape.witness) -> () - return + return %1 : tensor<1xindex> +} + +// ----- + +// Verify that tensor_cast folding uses the correct type +// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic +func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> { + // CHECK: shape.const_shape [2] : tensor<?xindex> + // CHECK-NOT: tensor_cast + %0 = shape.const_shape [2] : tensor<1xindex> + %1 = tensor_cast %0 : tensor<1xindex> to tensor<?xindex> + return %1 : tensor<?xindex> } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits