Author: Lei Zhang Date: 2020-12-23T13:27:31-05:00 New Revision: 42980a789d2212f774dbb12c2555452d328089a6
URL: https://github.com/llvm/llvm-project/commit/42980a789d2212f774dbb12c2555452d328089a6 DIFF: https://github.com/llvm/llvm-project/commit/42980a789d2212f774dbb12c2555452d328089a6.diff LOG: [mlir][spirv] Convert functions returning one value Reviewed By: hanchung, ThomasRaoux Differential Revision: https://reviews.llvm.org/D93468 Added: Modified: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir Removed: ################################################################################ diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index d15623568212..470f4143f2c5 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -924,10 +924,14 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands, LogicalResult ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { - if (returnOp.getNumOperands()) { + if (returnOp.getNumOperands() > 1) return failure(); + + if (returnOp.getNumOperands() == 1) { + rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]); + } else { + rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); } - rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index b310d5df7b26..9393f3df6425 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -473,23 +473,27 @@ LogicalResult FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); - // TODO: support converting functions with one result. - if (fnType.getNumResults()) + if (fnType.getNumResults() > 1) return failure(); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); - for (auto argType : enumerate(funcOp.getType().getInputs())) { + for (auto argType : enumerate(fnType.getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); if (!convertedType) return failure(); signatureConverter.addInputs(argType.index(), convertedType); } + Type resultType; + if (fnType.getNumResults() == 1) + resultType = typeConverter.convertType(fnType.getResult(0)); + // Create the converted spv.func op. auto newFuncOp = rewriter.create<spirv::FuncOp>( funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), - llvm::None)); + resultType ? TypeRange(resultType) + : TypeRange())); // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp.getAttrs()) { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 10e43ef4acd7..850e22465d44 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -954,3 +954,29 @@ func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { } } // end module + +// ----- + +//===----------------------------------------------------------------------===// +// std.return +//===----------------------------------------------------------------------===// + +module attributes { + spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}> +} { + +// CHECK-LABEL: spv.func @return_one_val +// CHECK-SAME: (%[[ARG:.+]]: f32) +func @return_one_val(%arg0: f32) -> f32 { + // CHECK: spv.ReturnValue %[[ARG]] : f32 + return %arg0: f32 +} + +// Check that multiple-return functions are not converted. +// CHECK-LABEL: func @return_multi_val +func @return_multi_val(%arg0: f32) -> (f32, f32) { + // CHECK: return + return %arg0, %arg0: f32, f32 +} + +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits