Author: Sean Silva Date: 2021-01-19T13:49:25-08:00 New Revision: be7352c00d51f4358db3a23ed6a077f7cb48eafd
URL: https://github.com/llvm/llvm-project/commit/be7352c00d51f4358db3a23ed6a077f7cb48eafd DIFF: https://github.com/llvm/llvm-project/commit/be7352c00d51f4358db3a23ed6a077f7cb48eafd.diff LOG: [mlir][splitting std] move 2 more ops to `tensor` - DynamicTensorFromElementsOp - TensorFromElements Differential Revision: https://reviews.llvm.org/D94994 Added: Modified: mlir/include/mlir/Dialect/StandardOps/IR/Ops.td mlir/include/mlir/Dialect/Tensor/IR/Tensor.h mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp mlir/lib/Dialect/StandardOps/IR/Ops.cpp mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp mlir/lib/Dialect/Tensor/IR/CMakeLists.txt mlir/lib/Dialect/Tensor/IR/TensorOps.cpp mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt mlir/lib/Dialect/Tensor/Transforms/PassDetail.h mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir mlir/test/Dialect/Standard/bufferize.mlir mlir/test/Dialect/Standard/canonicalize.mlir mlir/test/Dialect/Standard/invalid.mlir mlir/test/Dialect/Standard/ops.mlir mlir/test/Dialect/Tensor/bufferize.mlir mlir/test/Dialect/Tensor/canonicalize.mlir mlir/test/Dialect/Tensor/invalid.mlir mlir/test/Dialect/Tensor/ops.mlir mlir/test/IR/core-ops.mlir mlir/test/IR/invalid-ops.mlir mlir/test/Transforms/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 6eabe1179234..8e3f1f1a7a85 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1591,47 +1591,6 @@ def DivFOp : FloatArithmeticOp<"divf"> { let summary = "floating point division operation"; } -//===----------------------------------------------------------------------===// -// DynamicTensorFromElementsOp -//===----------------------------------------------------------------------===// - -def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements", - [RecursiveSideEffects, SingleBlockImplicitTerminator<"YieldOp">]> { - string summary = "Creates a dynamically sized tensor from elements"; - string description = [{ - This operation creates a dynamically sized tensor with elements of any type. - It expects one index operand per dynamic extent of the result tensor. - - The body region defines the tensor's elements. It takes index operands as - its region arguments that span the index space. The element at the given - position is yielded with the `yield` operation (see `YieldOp`). There is - no defined ordering to the invocations of the body. It is conceptually - a "parallel map" operation. - - Example: - - ```mlir - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - ... - yield %elem : f32 - } : tensor<?x3x?f32> - ``` - }]; - - let arguments = (ins Variadic<Index>:$dynamicExtents); - let results = (outs AnyRankedTensor:$result); - let regions = (region SizedRegion<1>:$body); - - let builders = [ - // Build op and populate its body per callback function. - OpBuilderDAG<(ins "Type":$resultTy, "ValueRange":$dynamicExtents, - "function_ref<void(OpBuilder &, Location, ValueRange)>")>, - ]; - - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // ExpOp //===----------------------------------------------------------------------===// @@ -1672,46 +1631,6 @@ def Exp2Op : FloatUnaryOp<"exp2"> { let summary = "base-2 exponential of the specified value"; } -//===----------------------------------------------------------------------===// -// TensorFromElementsOp -//===----------------------------------------------------------------------===// - -def TensorFromElementsOp : Std_Op<"tensor_from_elements", [ - NoSideEffect, - TypesMatchWith<"operand types match result element type", - "result", "elements", "SmallVector<Type, 2>(" - "$_self.cast<ShapedType>().getDimSize(0), " - "$_self.cast<ShapedType>().getElementType())"> - ]> { - string summary = "tensor from elements operation."; - string description = [{ - Create a 1D tensor from a range of same-type arguments. - - Example: - - ```mlir - tensor_from_elements(i_1, ..., i_N) : tensor<Nxindex> - ``` - }]; - - let arguments = (ins Variadic<AnyType>:$elements); - let results = (outs 1DTensorOf<[AnyType]>:$result); - - let assemblyFormat = "$elements attr-dict `:` type($result)"; - - // This op is fully verified by its traits. - let verifier = ?; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilderDAG<(ins "Type":$elementType, "ValueRange":$elements)>, - // Special case builder for when `elements` has size >=1. - OpBuilderDAG<(ins "ValueRange":$elements)> - ]; - - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// @@ -3837,24 +3756,6 @@ def ViewOp : Std_Op<"view", [ let hasCanonicalizer = 1; } -//===----------------------------------------------------------------------===// -// YieldOp -//===----------------------------------------------------------------------===// - -def YieldOp : Std_Op<"yield", [NoSideEffect, ReturnLike, Terminator, - HasParent<"DynamicTensorFromElementsOp">]> { - let summary = "Yield a value from a region"; - let description = [{ - This operation is used to yield a single value from a within a region. It - is used to create dynamically sized tensors - (see `DynamicTensorFromElementsOp`). - }]; - - let arguments = (ins AnyType:$value); - let assemblyFormat = "$value attr-dict `:` type($value)"; - let verifier = ?; -} - //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h index 53980db64dc0..3a1a20835959 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -13,6 +13,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index e0500b8fcfa6..e7776c4e8a9b 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -10,6 +10,7 @@ #define TENSOR_OPS include "mlir/Dialect/Tensor/IR/TensorBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" class Tensor_Op<string mnemonic, list<OpTrait> traits = []> @@ -105,4 +106,109 @@ def Tensor_ExtractOp : Tensor_Op<"extract", let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// FromElementsOp +//===----------------------------------------------------------------------===// + +def Tensor_FromElementsOp : Tensor_Op<"from_elements", [ + NoSideEffect, + TypesMatchWith<"operand types match result element type", + "result", "elements", "SmallVector<Type, 2>(" + "$_self.cast<ShapedType>().getDimSize(0), " + "$_self.cast<ShapedType>().getElementType())"> + ]> { + string summary = "tensor from elements operation."; + string description = [{ + Create a 1D tensor from a range of same-type arguments. + + Example: + + ```mlir + tensor.from_elements(i_1, ..., i_N) : tensor<Nxindex> + ``` + }]; + + let arguments = (ins Variadic<AnyType>:$elements); + let results = (outs 1DTensorOf<[AnyType]>:$result); + + let assemblyFormat = "$elements attr-dict `:` type($result)"; + + // This op is fully verified by its traits. + let verifier = ?; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilderDAG<(ins "Type":$elementType, "ValueRange":$elements)>, + // Special case builder for when `elements` has size >=1. + OpBuilderDAG<(ins "ValueRange":$elements)> + ]; + + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// GenerateOp +//===----------------------------------------------------------------------===// + +def Tensor_GenerateOp : Tensor_Op<"generate", + [RecursiveSideEffects, + SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { + string summary = "Creates a dynamically sized tensor from elements"; + string description = [{ + This operation creates a dynamically sized tensor with elements of any type. + It expects one index operand per dynamic extent of the result tensor. + + The body region defines the tensor's elements. It takes index operands as + its region arguments that span the index space. The element at the given + position is yielded with the `yield` operation (see `YieldOp`). There is + no defined ordering to the invocations of the body. It is conceptually + a "parallel map" operation. + + Example: + + ```mlir + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + ... + yield %elem : f32 + } : tensor<?x3x?f32> + ``` + }]; + + let arguments = (ins Variadic<Index>:$dynamicExtents); + let results = (outs AnyRankedTensor:$result); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "$dynamicExtents $body attr-dict `:` type($result)"; + + let builders = [ + // Build op and populate its body per callback function. + OpBuilderDAG<(ins "Type":$resultTy, "ValueRange":$dynamicExtents, + "function_ref<void(OpBuilder &, Location, ValueRange)>")>, + ]; + + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +def Tensor_YieldOp : Tensor_Op<"yield", + [NoSideEffect, ReturnLike, Terminator, + HasParent<"::mlir::tensor::GenerateOp">]> { + let summary = "Yield a value from a region"; + let description = [{ + This operation is used to yield a single value from a within a region. It + is used to create dynamically sized tensors + (see `tensor.generate` op). + }]; + + let arguments = (ins AnyType:$value); + let assemblyFormat = "$value attr-dict `:` type($value)"; + // Dummy builder to appease code in templated ensureTerminator that + // GenerateOp's auto-generated parser calls. + let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>]; + let verifier = ?; +} + #endif // TENSOR_OPS diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td index 327c7499e0c8..7abb3daed2fe 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td @@ -14,6 +14,7 @@ include "mlir/Pass/PassBase.td" def TensorBufferize : FunctionPass<"tensor-bufferize"> { let summary = "Bufferize the `tensor` dialect"; let constructor = "mlir::createTensorBufferizePass()"; + let dependentDialects = ["scf::SCFDialect"]; } #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt index 25c835d97723..f65e9aec3142 100644 --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_conversion_library(MLIRShapeToStandard MLIREDSC MLIRIR MLIRShape + MLIRTensor MLIRPass MLIRSCF MLIRTransforms diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 0d87d4f10975..0eeea250f19f 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -113,7 +113,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( Value rankDiff = rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); - rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( + rewriter.replaceOpWithNewOp<tensor::GenerateOp>( op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value outputDimension = args[0]; @@ -151,7 +151,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( greaterRankOperandExtent); b.create<scf::YieldOp>(loc, broadcastedExtent); }); - b.create<mlir::YieldOp>(loc, ifOp.getResult(0)); + b.create<tensor::YieldOp>(loc, ifOp.getResult(0)); }); return success(); } @@ -184,7 +184,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite( } Type indexTy = rewriter.getIndexType(); Value tensor = - rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands); + rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); return success(); @@ -503,7 +503,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( if (op.getType().isa<ShapeType>()) return failure(); - // For ranked tensor arguments, lower to `tensor_from_elements`. + // For ranked tensor arguments, lower to `tensor.from_elements`. auto loc = op.getLoc(); ShapeOfOp::Adaptor transformed(operands); Value tensor = transformed.arg(); @@ -526,22 +526,22 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( } // Materialize extent tensor. - Value staticExtentTensor = rewriter.create<TensorFromElementsOp>( + Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( loc, rewriter.getIndexType(), extentValues); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), staticExtentTensor); return success(); } - // Lower to `dynamic_tensor_from_elements` otherwise. + // Lower to `tensor.generate` otherwise. auto *ctx = rewriter.getContext(); Value rank = rewriter.create<mlir::RankOp>(loc, tensor); - rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( + rewriter.replaceOpWithNewOp<tensor::GenerateOp>( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value dim = args.front(); Value extent = b.create<DimOp>(loc, tensor, dim); - b.create<mlir::YieldOp>(loc, extent); + b.create<tensor::YieldOp>(loc, extent); }); return success(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index c4a8a0155f50..e1be47f54798 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1392,9 +1392,8 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { return getResult(); } - // Fold dim to the operand of dynamic_tensor_from_elements. - if (auto fromElements = - dyn_cast_or_null<DynamicTensorFromElementsOp>(definingOp)) { + // Fold dim to the operand of tensor.generate. + if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { auto resultType = fromElements.getResult().getType().cast<RankedTensorType>(); // The case where the type encodes the size of the dimension is handled @@ -1734,258 +1733,6 @@ LogicalResult DmaWaitOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// DynamicTensorFromElementsOp -//===----------------------------------------------------------------------===// - -static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser, - OperationState &result) { - // Parse operands. - SmallVector<OpAsmParser::OperandType, 4> dynamicExtents; - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.parseOperandList(dynamicExtents) || - parser.resolveOperands(dynamicExtents, indexTy, result.operands)) - return failure(); - - // Parse body. - Region *body = result.addRegion(); - if (parser.parseRegion(*body, {}, {})) - return failure(); - - // Parse result type. - Type resultType; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType)) - return failure(); - result.addTypes(resultType); - - return success(); -} - -static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) { - p << "dynamic_tensor_from_elements " << op.dynamicExtents(); - p.printRegion(op.body()); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getType(); -} - -static LogicalResult verify(DynamicTensorFromElementsOp op) { - // Ensure that the tensor type has as many dynamic dimensions as are specified - // by the operands. - RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); - if (op.getNumOperands() != resultTy.getNumDynamicDims()) - return op.emitError("must have as many index operands as dynamic extents " - "in the result type"); - - // Ensure that region arguments span the index space. - if (!llvm::all_of(op.body().getArgumentTypes(), - [](Type ty) { return ty.isIndex(); })) - return op.emitError("all body arguments must be index"); - if (op.body().getNumArguments() != resultTy.getRank()) - return op.emitError("must have one body argument per input dimension"); - - // Ensure that the region yields an element of the right type. - auto yieldOp = - llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator()); - if (yieldOp.value().getType() != resultTy.getElementType()) - return op.emitOpError( - "body must be terminated with a `yield` operation of the tensor " - "element type"); - - return success(); -} - -void DynamicTensorFromElementsOp::build( - OpBuilder &b, OperationState &result, Type resultTy, - ValueRange dynamicExtents, - function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { - build(b, result, resultTy, dynamicExtents); - - // Build and populate body. - OpBuilder::InsertionGuard guard(b); - Region *bodyRegion = result.regions.front().get(); - auto rank = resultTy.cast<RankedTensorType>().getRank(); - SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); - Block *bodyBlock = - b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); - bodyBuilder(b, result.location, bodyBlock->getArguments()); -} - -namespace { - -/// Canonicalizes dynamic_tensor_from_elements operations with a constant -/// operand into the equivalent operation with the operand expressed in the -/// result type, instead. We also insert a type cast to make sure that the -/// resulting IR is still well-typed. -struct StaticDynamicTensorFromElements - : public OpRewritePattern<DynamicTensorFromElementsOp> { - using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements, - PatternRewriter &rewriter) const final { - auto resultType = - tensorFromElements.getResult().getType().cast<RankedTensorType>(); - - if (resultType.hasStaticShape()) - return failure(); - - SmallVector<Value, 4> newOperands; - SmallVector<int64_t, 4> newShape; - auto operandsIt = tensorFromElements.dynamicExtents().begin(); - - for (int64_t dim : resultType.getShape()) { - if (dim != RankedTensorType::kDynamicSize) { - newShape.push_back(dim); - continue; - } - APInt index; - if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { - newShape.push_back(RankedTensorType::kDynamicSize); - newOperands.push_back(*operandsIt++); - continue; - } - newShape.push_back(index.getSExtValue()); - operandsIt++; - } - - if (newOperands.size() == tensorFromElements.dynamicExtents().size()) - return failure(); - - auto loc = tensorFromElements.getLoc(); - auto newOp = rewriter.create<DynamicTensorFromElementsOp>( - loc, RankedTensorType::get(newShape, resultType.getElementType()), - newOperands); - rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), - newOp.body().begin()); - rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType, - newOp); - return success(); - } -}; - -/// Canonicalizes the pattern of the form -/// -/// %tensor = dynamic_tensor_from_elements %x { -/// ^bb0(%arg0: index): // no predecessors -/// <computation> -/// yield %1 : index -/// } : tensor<?xindex> -/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> -/// -/// to just <computation> with %arg0 replaced by %c0. We only do this if the -/// dynamic_tensor_from_elements operation has no side-effects. -struct ExtractFromDynamicTensorFromElements - : public OpRewritePattern<tensor::ExtractOp> { - using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - auto tensorFromElements = - extract.tensor().getDefiningOp<DynamicTensorFromElementsOp>(); - if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) - return failure(); - - BlockAndValueMapping mapping; - Block *body = tensorFromElements.getBody(); - mapping.map(body->getArguments(), extract.indices()); - for (auto &op : body->without_terminator()) - rewriter.clone(op, mapping); - - auto yield = cast<YieldOp>(body->getTerminator()); - - rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); - return success(); - } -}; - -/// Canonicalizes the pattern of the form -/// -/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> -/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> -/// -/// to -/// -/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> -struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { - using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>(); - if (!tensorCast) - return failure(); - - rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(), - extract.indices()); - return success(); - } -}; - -} // namespace - -void DynamicTensorFromElementsOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - // TODO: Move extract patterns to tensor::ExtractOp. - results.insert<ExtractFromDynamicTensorFromElements, ExtractFromTensorCast, - StaticDynamicTensorFromElements>(context); -} - -//===----------------------------------------------------------------------===// -// TensorFromElementsOp -//===----------------------------------------------------------------------===// - -void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, - Type elementType, ValueRange elements) { - Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())}, - elementType); - result.addOperands(elements); - result.addTypes(resultTy); -} - -void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, - ValueRange elements) { - assert(!elements.empty() && "expected at least one element"); - build(builder, result, elements.front().getType(), elements); -} - -namespace { - -// Canonicalizes the pattern of the form -// -// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32> -// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> -// -// to just %element. -struct ExtractElementFromTensorFromElements - : public OpRewritePattern<tensor::ExtractOp> { - using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - if (extract.indices().size() != 1) - return failure(); - - auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>( - extract.tensor().getDefiningOp()); - if (tensorFromElements == nullptr) - return failure(); - - APInt index; - if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) - return failure(); - rewriter.replaceOp(extract, - tensorFromElements.getOperand(index.getZExtValue())); - return success(); - } -}; - -} // namespace - -void TensorFromElementsOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<ExtractElementFromTensorFromElements>(context); -} - //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index 98792838deff..2a3a464cd0a8 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -35,70 +35,6 @@ class BufferizeDimOp : public OpConversionPattern<DimOp> { }; } // namespace -namespace { -class BufferizeDynamicTensorFromElementsOp - : public OpConversionPattern<DynamicTensorFromElementsOp> { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(DynamicTensorFromElementsOp op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const final { - // Allocate memory. - Location loc = op.getLoc(); - DynamicTensorFromElementsOp::Adaptor transformed(operands); - RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); - MemRefType memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value result = - rewriter.create<AllocOp>(loc, memrefType, transformed.dynamicExtents()); - - // Collect loop bounds. - int64_t rank = tensorType.getRank(); - Value zero = rewriter.create<ConstantIndexOp>(loc, 0); - Value one = rewriter.create<ConstantIndexOp>(loc, 1); - SmallVector<Value, 4> lowerBounds(rank, zero); - SmallVector<Value, 4> steps(rank, one); - SmallVector<Value, 4> upperBounds; - int nextDynamicIndex = 0; - for (int i = 0; i < rank; i++) { - Value upperBound = - tensorType.isDynamicDim(i) - ? transformed.dynamicExtents()[nextDynamicIndex++] - : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i)); - upperBounds.push_back(upperBound); - } - - // Generate tensor elements with a parallel loop that stores into - // each element of the resulting memref. - // - // This is a bit tricky. We cannot simply clone the ops because when an op - // is cloned, it must be legalized. However, we want to allow arbitrary ops - // in the body that we don't necessarily have legalization patterns for as - // part of this dialect conversion invocation. - // - // To accomplish this, we use mergeBlockBefore to "move" this op's body - // into the scf.parallel's body. - auto parallel = - rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); - Block *parallelBody = parallel.getBody(); - rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), - parallelBody->getArguments()); - // Replace the inlined yield op with a store op. The scf.parallel's builder - // already populated an scf.yield at the end, so we don't need to worry - // about creating that. - Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); - rewriter.setInsertionPointAfter(elementYield); - rewriter.replaceOpWithNewOp<StoreOp>(elementYield, - elementYield->getOperands()[0], result, - parallelBody->getArguments()); - - rewriter.replaceOp(op, {result}); - return success(); - } -}; -} // namespace - namespace { class BufferizeSelectOp : public OpConversionPattern<SelectOp> { public: @@ -117,40 +53,10 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> { }; } // namespace -namespace { -class BufferizeTensorFromElementsOp - : public OpConversionPattern<TensorFromElementsOp> { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(TensorFromElementsOp op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const override { - int numberOfElements = op.elements().size(); - auto resultType = MemRefType::get( - {numberOfElements}, op.getType().cast<TensorType>().getElementType()); - Value result = rewriter.create<AllocOp>(op.getLoc(), resultType); - for (auto element : llvm::enumerate(op.elements())) { - Value index = - rewriter.create<ConstantIndexOp>(op.getLoc(), element.index()); - rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index); - } - rewriter.replaceOp(op, {result}); - return success(); - } -}; -} // namespace - void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert< - // clang-format off - BufferizeDimOp, - BufferizeDynamicTensorFromElementsOp, - BufferizeSelectOp, - BufferizeTensorFromElementsOp - // clang-format on - >(typeConverter, context); + patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter, context); } namespace { @@ -165,7 +71,6 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> { target.addLegalDialect<scf::SCFDialect>(); populateStdBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp<DynamicTensorFromElementsOp, TensorFromElementsOp>(); // We only bufferize the case of tensor selected type and scalar condition, // as that boils down to a select over memref descriptors (don't need to // touch the data). diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt index 2d5e2fbd6a31..b8fb44a9f4cb 100644 --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -13,5 +13,6 @@ add_mlir_dialect_library(MLIRTensor LINK_LIBS PUBLIC MLIRIR + MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index aaae7fbf807c..e231a3a3b56e 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" @@ -205,6 +207,223 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { return {}; } +//===----------------------------------------------------------------------===// +// FromElementsOp +//===----------------------------------------------------------------------===// + +void FromElementsOp::build(OpBuilder &builder, OperationState &result, + Type elementType, ValueRange elements) { + Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())}, + elementType); + result.addOperands(elements); + result.addTypes(resultTy); +} + +void FromElementsOp::build(OpBuilder &builder, OperationState &result, + ValueRange elements) { + assert(!elements.empty() && "expected at least one element"); + build(builder, result, elements.front().getType(), elements); +} + +namespace { + +// Canonicalizes the pattern of the form +// +// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> +// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> +// +// to just %element. +struct ExtractElementFromTensorFromElements + : public OpRewritePattern<tensor::ExtractOp> { + using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + if (extract.indices().size() != 1) + return failure(); + + auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>(); + if (tensorFromElements == nullptr) + return failure(); + + APInt index; + if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) + return failure(); + rewriter.replaceOp(extract, + tensorFromElements.getOperand(index.getZExtValue())); + return success(); + } +}; + +} // namespace + +void FromElementsOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<ExtractElementFromTensorFromElements>(context); +} + +//===----------------------------------------------------------------------===// +// GenerateOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(GenerateOp op) { + // Ensure that the tensor type has as many dynamic dimensions as are specified + // by the operands. + RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); + if (op.getNumOperands() != resultTy.getNumDynamicDims()) + return op.emitError("must have as many index operands as dynamic extents " + "in the result type"); + + // Ensure that region arguments span the index space. + if (!llvm::all_of(op.body().getArgumentTypes(), + [](Type ty) { return ty.isIndex(); })) + return op.emitError("all body arguments must be index"); + if (op.body().getNumArguments() != resultTy.getRank()) + return op.emitError("must have one body argument per input dimension"); + + // Ensure that the region yields an element of the right type. + auto yieldOp = + llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator()); + if (yieldOp.value().getType() != resultTy.getElementType()) + return op.emitOpError( + "body must be terminated with a `yield` operation of the tensor " + "element type"); + + return success(); +} + +void GenerateOp::build( + OpBuilder &b, OperationState &result, Type resultTy, + ValueRange dynamicExtents, + function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { + build(b, result, resultTy, dynamicExtents); + + // Build and populate body. + OpBuilder::InsertionGuard guard(b); + Region *bodyRegion = result.regions.front().get(); + auto rank = resultTy.cast<RankedTensorType>().getRank(); + SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); + Block *bodyBlock = + b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); + bodyBuilder(b, result.location, bodyBlock->getArguments()); +} + +namespace { + +/// Canonicalizes tensor.generate operations with a constant +/// operand into the equivalent operation with the operand expressed in the +/// result type, instead. We also insert a type cast to make sure that the +/// resulting IR is still well-typed. +struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { + using OpRewritePattern<GenerateOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp tensorFromElements, + PatternRewriter &rewriter) const final { + auto resultType = + tensorFromElements.getResult().getType().cast<RankedTensorType>(); + + if (resultType.hasStaticShape()) + return failure(); + + SmallVector<Value, 4> newOperands; + SmallVector<int64_t, 4> newShape; + auto operandsIt = tensorFromElements.dynamicExtents().begin(); + + for (int64_t dim : resultType.getShape()) { + if (dim != RankedTensorType::kDynamicSize) { + newShape.push_back(dim); + continue; + } + APInt index; + if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { + newShape.push_back(RankedTensorType::kDynamicSize); + newOperands.push_back(*operandsIt++); + continue; + } + newShape.push_back(index.getSExtValue()); + operandsIt++; + } + + if (newOperands.size() == tensorFromElements.dynamicExtents().size()) + return failure(); + + auto loc = tensorFromElements.getLoc(); + auto newOp = rewriter.create<GenerateOp>( + loc, RankedTensorType::get(newShape, resultType.getElementType()), + newOperands); + rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), + newOp.body().begin()); + rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType, + newOp); + return success(); + } +}; + +/// Canonicalizes the pattern of the form +/// +/// %tensor = tensor.generate %x { +/// ^bb0(%arg0: index): // no predecessors +/// <computation> +/// yield %1 : index +/// } : tensor<?xindex> +/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> +/// +/// to just <computation> with %arg0 replaced by %c0. We only do this if the +/// tensor.generate operation has no side-effects. +struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> { + using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>(); + if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) + return failure(); + + BlockAndValueMapping mapping; + Block *body = tensorFromElements.getBody(); + mapping.map(body->getArguments(), extract.indices()); + for (auto &op : body->without_terminator()) + rewriter.clone(op, mapping); + + auto yield = cast<YieldOp>(body->getTerminator()); + + rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); + return success(); + } +}; + +/// Canonicalizes the pattern of the form +/// +/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> +/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> +/// +/// to +/// +/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> +struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { + using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>(); + if (!tensorCast) + return failure(); + + rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(), + extract.indices()); + return success(); + } +}; + +} // namespace + +void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + // TODO: Move extract patterns to tensor::ExtractOp. + results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast, + StaticTensorGenerate>(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp index 05ff96fb8d69..66de78758692 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" @@ -48,10 +49,97 @@ class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { }; } // namespace +namespace { +class BufferizeFromElementsOp + : public OpConversionPattern<tensor::FromElementsOp> { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + int numberOfElements = op.elements().size(); + auto resultType = MemRefType::get( + {numberOfElements}, op.getType().cast<TensorType>().getElementType()); + Value result = rewriter.create<AllocOp>(op.getLoc(), resultType); + for (auto element : llvm::enumerate(op.elements())) { + Value index = + rewriter.create<ConstantIndexOp>(op.getLoc(), element.index()); + rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; +} // namespace + +namespace { +class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // Allocate memory. + Location loc = op.getLoc(); + tensor::GenerateOp::Adaptor transformed(operands); + RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + Value result = + rewriter.create<AllocOp>(loc, memrefType, transformed.dynamicExtents()); + + // Collect loop bounds. + int64_t rank = tensorType.getRank(); + Value zero = rewriter.create<ConstantIndexOp>(loc, 0); + Value one = rewriter.create<ConstantIndexOp>(loc, 1); + SmallVector<Value, 4> lowerBounds(rank, zero); + SmallVector<Value, 4> steps(rank, one); + SmallVector<Value, 4> upperBounds; + int nextDynamicIndex = 0; + for (int i = 0; i < rank; i++) { + Value upperBound = + tensorType.isDynamicDim(i) + ? transformed.dynamicExtents()[nextDynamicIndex++] + : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i)); + upperBounds.push_back(upperBound); + } + + // Generate tensor elements with a parallel loop that stores into + // each element of the resulting memref. + // + // This is a bit tricky. We cannot simply clone the ops because when an op + // is cloned, it must be legalized. However, we want to allow arbitrary ops + // in the body that we don't necessarily have legalization patterns for as + // part of this dialect conversion invocation. + // + // To accomplish this, we use mergeBlockBefore to "move" this op's body + // into the scf.parallel's body. + auto parallel = + rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); + Block *parallelBody = parallel.getBody(); + rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), + parallelBody->getArguments()); + // Replace the inlined yield op with a store op. The scf.parallel's builder + // already populated an scf.yield at the end, so we don't need to worry + // about creating that. + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); + rewriter.setInsertionPointAfter(elementYield); + rewriter.replaceOpWithNewOp<StoreOp>(elementYield, + elementYield->getOperands()[0], result, + parallelBody->getArguments()); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; +} // namespace + void mlir::populateTensorBufferizePatterns( MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert<BufferizeCastOp, BufferizeExtractOp>(typeConverter, context); + patterns.insert<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp, + BufferizeGenerateOp>(typeConverter, context); } namespace { @@ -62,9 +150,13 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { OwningRewritePatternList patterns; ConversionTarget target(*context); + populateBufferizeMaterializationLegality(target); + populateTensorBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp<tensor::CastOp, tensor::ExtractOp>(); + target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, + tensor::FromElementsOp, tensor::GenerateOp>(); target.addLegalDialect<StandardOpsDialect>(); + target.addLegalDialect<scf::SCFDialect>(); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt index 141f8caebb57..6d29bd56dca6 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRTensorTransforms LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRSCF MLIRTensor MLIRTransforms ) diff --git a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h index fd1f1cf22bd6..bd4a61e6b7ee 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h @@ -13,6 +13,10 @@ namespace mlir { +namespace scf { +class SCFDialect; +} // end namespace scf + #define GEN_PASS_CLASSES #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 9f7a20ab9de6..2bd4a1d34901 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -87,14 +87,14 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index) // ----- -// Lower `const_shape` to `tensor_from_elements`. +// Lower `const_shape` to `tensor.from_elements`. // CHECK-LABEL: @const_shape // CHECK-SAME: () -> tensor<?xindex> func @const_shape() -> tensor<?xindex> { // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C3:.*]] = constant 3 : index - // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] + // CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex> // CHECK: return %[[RESULT]] : tensor<?xindex> %shape = shape.const_shape [1, 2, 3] : tensor<?xindex> @@ -107,7 +107,7 @@ func @const_shape() -> tensor<?xindex> { // CHECK-LABEL: func @const_shape_zero_elements // CHECK-SAME: () -> tensor<?xindex> func @const_shape_zero_elements() -> tensor<?xindex> { - // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex> + // CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex> // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex> // CHECK: return %[[RESULT]] : tensor<?xindex> %shape = shape.const_shape [] : tensor<?xindex> @@ -204,7 +204,7 @@ func @shape_of(%arg : tensor<*xf32>) { // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @shape_of_unranked(%arg : tensor<*xf32>) { // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> - // CHECK: %[[SHAPE:.*]] = dynamic_tensor_from_elements %[[RANK]] { + // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] { // CHECK: ^bb0(%[[I:.*]]: index): // CHECK: %[[EXTENT:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> // CHECK: yield %[[EXTENT]] : index @@ -233,7 +233,7 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) { // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex> %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex> return } @@ -244,7 +244,7 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) { // CHECK-LABEL: @shape_of_zero_d // CHECK-SAME: (%[[ARG:.*]]: tensor<f32>) func @shape_of_zero_d(%arg : tensor<f32>) { - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements : tensor<0xindex> %shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex> return } @@ -259,7 +259,7 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { // CHECK-DAG: %[[C5:.*]] = constant 5 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex> %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex> return } @@ -321,7 +321,7 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) { // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex> // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex> // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { + // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] { // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex> @@ -361,7 +361,7 @@ func @broadcast_known_ diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xinde // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex> // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex> // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { + // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] { // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex> diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir index 4e8f1282c36b..10310542f138 100644 --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -11,56 +11,6 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index { return %0 : index } -// CHECK-LABEL: func @dynamic_tensor_from_elements( -// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { -// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex> -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { -// CHECK: %[[ARG_MEMREF:.*]] = tensor_to_memref %[[ARG]] : memref<*xf32> -// CHECK: %[[ELEM:.*]] = dim %[[ARG_MEMREF]], %[[I]] : memref<*xf32> -// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> -// CHECK: scf.yield -// CHECK: } -// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<?xindex> -// CHECK: return %[[RET]] : tensor<?xindex> -// CHECK: } -func @dynamic_tensor_from_elements(%arg: tensor<*xf32>, %rank: index) -> tensor<?xindex> { - %result = dynamic_tensor_from_elements %rank { - ^bb0(%i : index): - %elem = dim %arg, %i : tensor<*xf32> - yield %elem : index - } : tensor<?xindex> - return %result : tensor<?xindex> -} - -// Additional test that checks the logic for intermixed static and dynamic -// extents. -// -// CHECK-LABEL: func @dynamic_tensor_from_elements_static_and_dynamic( -// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { -// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[C16:.*]] = constant 16 : index -// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { -// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index -// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> -// CHECK: scf.yield -// CHECK: } -// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex> -// CHECK: return %[[RET]] : tensor<16x?xindex> -// CHECK: } -func @dynamic_tensor_from_elements_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { - %result = dynamic_tensor_from_elements %arg0 { - ^bb0(%i: index, %j: index): - %sum = addi %i, %j : index - yield %sum : index - } : tensor<16x?xindex> - return %result : tensor<16x?xindex> -} - // CHECK-LABEL: func @select( // CHECK-SAME: %[[PRED:.*]]: i1, // CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>, @@ -74,36 +24,3 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> { %0 = select %arg0, %arg1, %arg2 : tensor<f32> return %0 : tensor<f32> } - -// CHECK-LABEL: func @tensor_from_elements( -// CHECK-SAME: %[[ELEM0:.*]]: index, -// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { -// CHECK: %[[MEMREF:.*]] = alloc() -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] -// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] -// CHECK: return %[[RET]] : tensor<2xindex> -func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { - %0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex> - return %0 : tensor<2xindex> -} - -// The dynamic_tensor_from_elements op needs to put its body into the -// resulting scf.parallel. To handle unknown ops in the body, it cannot clone -// the body because that would require the cloned ops to be legalized -// immediately, which is usually not possible since they might be from various -// other dialects. -// -// CHECK-LABEL: func @unknown_ops_in_body -func @unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { - // CHECK-NOT: dynamic_tensor_from_elements - %tensor = dynamic_tensor_from_elements %arg0 { - ^bb0(%iv: index): - // CHECK: test.source - %0 = "test.source"() : () -> index - yield %0 : index - } : tensor<?xindex> - return %tensor : tensor<?xindex> -} diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index e7e4d4f49222..8187c2f3215d 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -59,16 +59,16 @@ func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf return %1 : f32 } -// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx -// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements( +// Test case: Folding of dim(tensor.generate %idx) -> %idx +// CHECK-LABEL: func @dim_of_tensor.generate( // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index // CHECK-NOT: dim // CHECK: return %[[IDX1]] : index -func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index { +func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index { %c3 = constant 3 : index - %0 = dynamic_tensor_from_elements %arg0, %arg1 { + %0 = tensor.generate %arg0, %arg1 { ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): - yield %c3 : index + tensor.yield %c3 : index } : tensor<2x?x4x?x5xindex> %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> return %1 : index diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir index 48d2ae23466c..2d6e0342786c 100644 --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -16,72 +16,6 @@ func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 { // ----- -func @dynamic_tensor_from_elements(%m : index) - -> tensor<?x3x?xf32> { - // expected-error @+1 {{must have as many index operands as dynamic extents in the result type}} - %tnsr = dynamic_tensor_from_elements %m { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor<?x3x?xf32> - return %tnsr : tensor<?x3x?xf32> -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor<?x3x?xf32> { - // expected-error @+1 {{must have one body argument per input dimension}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor<?x3x?xf32> - return %tnsr : tensor<?x3x?xf32> -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor<?x3x?xf32> { - // expected-error @+1 {{all body arguments must be index}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : i64): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor<?x3x?xf32> - return %tnsr : tensor<?x3x?xf32> -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor<?x3x?xf32> { - // expected-error @+2 {{op expects regions to end with 'std.yield', found 'std.return'}} - // expected-note @+1 {{in custom textual format, the absence of terminator implies 'std.yield'}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8.0 : f32 - return %elem : f32 - } : tensor<?x3x?xf32> - return %tnsr : tensor<?x3x?xf32> -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor<?x3x?xf32> { - // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8 : i32 - yield %elem : i32 - } : tensor<?x3x?xf32> - return %tnsr : tensor<?x3x?xf32> -} - -// ----- - func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) { // expected-error @+1 {{expected a permutation map}} transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir index cd173670ae54..e81d0fa03b7d 100644 --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -32,17 +32,6 @@ func @assert(%arg : i1) { return } -// CHECK-LABEL: @dynamic_tensor_from_elements -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor<?x3x?xf32> { - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor<?x3x?xf32> - return %tnsr : tensor<?x3x?xf32> -} - // CHECK-LABEL: @atan func @atan(%arg : f32) -> f32 { %result = atan %arg : f32 @@ -107,4 +96,3 @@ func @read_global_memref() { %1 = tensor_load %0 : memref<2xf32> return } - diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 0e55040ec116..abc7d2af5676 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -33,14 +33,96 @@ func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// CHECK-LABEL: func @extract( +// CHECK-LABEL: func @tensor.extract( // CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> f32 { // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32> // CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32> // CHECK: return %[[RET]] : f32 // CHECK: } -func @extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 { +func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 { %0 = tensor.extract %arg0[%arg1] : tensor<?xf32> return %0 : f32 } + +// CHECK-LABEL: func @tensor.from_elements( +// CHECK-SAME: %[[ELEM0:.*]]: index, +// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { +// CHECK: %[[MEMREF:.*]] = alloc() +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] +// CHECK: return %[[RET]] : tensor<2xindex> +func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { + %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> + return %0 : tensor<2xindex> +} + +// CHECK-LABEL: func @tensor.generate( +// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { +// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex> +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { +// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> +// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<?xindex> +// CHECK: return %[[RET]] : tensor<?xindex> +// CHECK: } +func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> { + %result = tensor.generate %dynamic_extent { + ^bb0(%i : index): + %elem = dim %arg, %i : tensor<*xf32> + tensor.yield %elem : index + } : tensor<?xindex> + return %result : tensor<?xindex> +} + +// Additional test that checks the logic for intermixed static and dynamic +// extents. +// +// CHECK-LABEL: func @tensor.generate_static_and_dynamic( +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { +// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { +// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index +// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex> +// CHECK: return %[[RET]] : tensor<16x?xindex> +// CHECK: } +func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { + %result = tensor.generate %arg0 { + ^bb0(%i: index, %j: index): + %sum = addi %i, %j : index + tensor.yield %sum : index + } : tensor<16x?xindex> + return %result : tensor<16x?xindex> +} + +// The tensor.generate op needs to put its body into the +// resulting scf.parallel. To handle unknown ops in the body, it cannot clone +// the body because that would require the cloned ops to be legalized +// immediately, which is usually not possible since they might be from various +// other dialects. +// +// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body +func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { + // CHECK-NOT: tensor.generate + %tensor = tensor.generate %arg0 { + ^bb0(%iv: index): + // CHECK: test.source + %0 = "test.source"() : () -> index + tensor.yield %0 : index + } : tensor<?xindex> + return %tensor : tensor<?xindex> +} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 9dcd4da13cc5..ae145934ef4d 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -107,3 +107,90 @@ func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 { %result = tensor.extract %casted[%c0] : tensor<?xf32> return %result : f32 } + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.from_elements +func @extract_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c0 = constant 0 : index + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex> + // CHECK: [[ARG]] : index + return %extracted_element : index +} + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.generate +// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> +func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index { + %size = rank %tensor : tensor<*xf32> + // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]] + %0 = tensor.generate %size { + ^bb0(%arg0: index): + %1 = dim %tensor, %arg0 : tensor<*xf32> + tensor.yield %1 : index + } : tensor<?xindex> + %1 = tensor.extract %0[%idx] : tensor<?xindex> + // CHECK-NEXT: return %[[RES]] + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.generate_2d +// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> +func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { + %size = rank %tensor : tensor<*xf32> + // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]] + // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]] + // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]] + %0 = tensor.generate %size, %size { + ^bb0(%arg0: index, %arg1: index): + %1 = dim %tensor, %arg0 : tensor<*xf32> + %2 = dim %tensor, %arg1 : tensor<*xf32> + %3 = addi %1, %2 : index + tensor.yield %3 : index + } : tensor<?x?xindex> + %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex> + // CHECK-NEXT: return %[[RES]] + return %4 : index +} + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.generate_sideeffects +// CHECK-SAME: %[[IDX:.*]]: index +func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index { + %size = rank %tensor : tensor<*xf32> + %mem = alloc(%size) : memref<?xindex> + // CHECK: %[[DTENSOR:.*]] = tensor.generate + %0 = tensor.generate %size { + ^bb0(%arg0: index): + %1 = dim %tensor, %arg0 : tensor<*xf32> + store %1, %mem[%arg0] : memref<?xindex> + tensor.yield %1 : index + } : tensor<?xindex> + // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]] + %1 = tensor.extract %0[%idx] : tensor<?xindex> + // CHECK-NEXT: return %[[RES]] + return %1 : index +} + +// ----- + +// CHECK-LABEL: @static_tensor.generate +// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index) +func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> { + %c5 = constant 5 : index + // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]] + %0 = tensor.generate %size1, %c5, %size4 { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index): + %1 = constant 32 : index + tensor.yield %1 : index + // CHECK: : tensor<3x?x5x7x?xindex> + } : tensor<3x?x?x7x?xindex> + // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> + return %0 : tensor<3x?x?x7x?xindex> +} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index cb38ac884bc3..11866990c885 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -13,3 +13,87 @@ func @extract_too_many_indices(%arg0: tensor<?xf32>) { %0 = tensor.extract %arg0[] : tensor<?xf32> return } + +// ----- + +func @tensor.from_elements_wrong_result_type() { + // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} + %c0 = constant 0 : i32 + %0 = tensor.from_elements %c0 : tensor<*xi32> + return +} + +// ----- + +func @tensor.from_elements_wrong_elements_count() { + // expected-error@+2 {{1 operands present, but expected 2}} + %c0 = constant 0 : index + %0 = tensor.from_elements %c0 : tensor<2xindex> + return +} + +// ----- + +func @tensor.generate(%m : index) + -> tensor<?x3x?xf32> { + // expected-error @+1 {{must have as many index operands as dynamic extents in the result type}} + %tnsr = tensor.generate %m { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor<?x3x?xf32> + return %tnsr : tensor<?x3x?xf32> +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor<?x3x?xf32> { + // expected-error @+1 {{must have one body argument per input dimension}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor<?x3x?xf32> + return %tnsr : tensor<?x3x?xf32> +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor<?x3x?xf32> { + // expected-error @+1 {{all body arguments must be index}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : i64): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor<?x3x?xf32> + return %tnsr : tensor<?x3x?xf32> +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor<?x3x?xf32> { + // expected-error @+2 {{op expects regions to end with 'tensor.yield', found 'std.return'}} + // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tensor.yield'}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + return %elem : f32 + } : tensor<?x3x?xf32> + return %tnsr : tensor<?x3x?xf32> +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor<?x3x?xf32> { + // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8 : i32 + tensor.yield %elem : i32 + } : tensor<?x3x?xf32> + return %tnsr : tensor<?x3x?xf32> +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 06db2bb237cd..9b15712058a2 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -21,3 +21,35 @@ func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) { %0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor<?x?x?xf32> return } + +// CHECK-LABEL: func @tensor.from_elements() { +func @tensor.from_elements() { + %c0 = "std.constant"() {value = 0: index} : () -> index + // CHECK: %0 = tensor.from_elements %c0 : tensor<1xindex> + %0 = tensor.from_elements %c0 : tensor<1xindex> + + %c1 = "std.constant"() {value = 1: index} : () -> index + // CHECK: %1 = tensor.from_elements %c0, %c1 : tensor<2xindex> + %1 = tensor.from_elements %c0, %c1 : tensor<2xindex> + + %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32 + // CHECK: [[C0_F32:%.*]] = constant + // CHECK: %2 = tensor.from_elements [[C0_F32]] : tensor<1xf32> + %2 = tensor.from_elements %c0_f32 : tensor<1xf32> + + // CHECK: tensor.from_elements : tensor<0xindex> + %3 = tensor.from_elements : tensor<0xindex> + + return +} + +// CHECK-LABEL: @tensor.generate +func @tensor.generate(%m : index, %n : index) + -> tensor<?x3x?xf32> { + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor<?x3x?xf32> + return %tnsr : tensor<?x3x?xf32> +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 1deeb3ec49d0..0e86050870ff 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -675,27 +675,6 @@ func @calls(%arg0: i32) { return } -// CHECK-LABEL: func @tensor_from_elements() { -func @tensor_from_elements() { - %c0 = "std.constant"() {value = 0: index} : () -> index - // CHECK: %0 = tensor_from_elements %c0 : tensor<1xindex> - %0 = tensor_from_elements %c0 : tensor<1xindex> - - %c1 = "std.constant"() {value = 1: index} : () -> index - // CHECK: %1 = tensor_from_elements %c0, %c1 : tensor<2xindex> - %1 = tensor_from_elements %c0, %c1 : tensor<2xindex> - - %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32 - // CHECK: [[C0_F32:%.*]] = constant - // CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32> - %2 = tensor_from_elements %c0_f32 : tensor<1xf32> - - // CHECK: tensor_from_elements : tensor<0xindex> - %3 = tensor_from_elements : tensor<0xindex> - - return -} - // CHECK-LABEL: func @memref_cast(%arg0 func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) { // CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 45ebfff34d57..364c9155e2da 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -541,24 +541,6 @@ func @cmpf_canonical_type_mismatch(%a : f32, %b : f64) { // expected-note {{prio // ----- -func @tensor_from_elements_wrong_result_type() { - // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} - %c0 = constant 0 : i32 - %0 = tensor_from_elements %c0 : tensor<*xi32> - return -} - -// ----- - -func @tensor_from_elements_wrong_elements_count() { - // expected-error@+2 {{1 operands present, but expected 2}} - %c0 = constant 0 : index - %0 = tensor_from_elements %c0 : tensor<2xindex> - return -} - -// ----- - func @index_cast_index_to_index(%arg0: index) { // expected-error@+1 {{are cast incompatible}} %0 = index_cast %arg0: index to index diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 5b6f8cde9fec..62c07dd8a063 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1032,93 +1032,6 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i // ----- -// CHECK-LABEL: func @extract_from_tensor_from_elements -func @extract_from_tensor_from_elements(%element : index) -> index { - // CHECK-SAME: ([[ARG:%.*]]: index) - %c0 = constant 0 : index - %tensor = tensor_from_elements %element : tensor<1xindex> - %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex> - // CHECK: [[ARG]] : index - return %extracted_element : index -} - -// ----- - -// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements -// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> -func @extract_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> - // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]] - %0 = dynamic_tensor_from_elements %size { - ^bb0(%arg0: index): - %1 = dim %tensor, %arg0 : tensor<*xf32> - yield %1 : index - } : tensor<?xindex> - %1 = tensor.extract %0[%idx] : tensor<?xindex> - // CHECK-NEXT: return %[[RES]] - return %1 : index -} - -// ----- - -// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_2d -// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> -func @extract_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> - // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]] - // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]] - // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]] - %0 = dynamic_tensor_from_elements %size, %size { - ^bb0(%arg0: index, %arg1: index): - %1 = dim %tensor, %arg0 : tensor<*xf32> - %2 = dim %tensor, %arg1 : tensor<*xf32> - %3 = addi %1, %2 : index - yield %3 : index - } : tensor<?x?xindex> - %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex> - // CHECK-NEXT: return %[[RES]] - return %4 : index -} - -// ----- - -// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_sideeffects -// CHECK-SAME: %[[IDX:.*]]: index -func @extract_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> - %mem = alloc(%size) : memref<?xindex> - // CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements - %0 = dynamic_tensor_from_elements %size { - ^bb0(%arg0: index): - %1 = dim %tensor, %arg0 : tensor<*xf32> - store %1, %mem[%arg0] : memref<?xindex> - yield %1 : index - } : tensor<?xindex> - // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]] - %1 = tensor.extract %0[%idx] : tensor<?xindex> - // CHECK-NEXT: return %[[RES]] - return %1 : index -} - -// ----- - -// CHECK-LABEL: @static_dynamic_tensor_from_elements -// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index) -func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> { - %c5 = constant 5 : index - // CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]] - %0 = dynamic_tensor_from_elements %size1, %c5, %size4 { - ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index): - %1 = constant 32 : index - yield %1 : index - // CHECK: : tensor<3x?x5x7x?xindex> - } : tensor<3x?x?x7x?xindex> - // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> - return %0 : tensor<3x?x?x7x?xindex> -} - -// ----- - // CHECK-LABEL: func @subtensor // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits