Author: Lei Zhang Date: 2021-01-22T13:09:33-05:00 New Revision: e27197f3605450c372ddc71922d0e9982b30e115
URL: https://github.com/llvm/llvm-project/commit/e27197f3605450c372ddc71922d0e9982b30e115 DIFF: https://github.com/llvm/llvm-project/commit/e27197f3605450c372ddc71922d0e9982b30e115.diff LOG: [mlir][spirv] Define spv.IsNan/spv.IsInf and add lowerings spv.Ordered/spv.Unordered are meant for OpenCL Kernel capability. For Vulkan Shader capability, we should use spv.IsNan to check whether a number is NaN. Add a new pattern for converting `std.cmpf ord|uno` to spv.IsNan and bumped the pattern converting to spv.Ordered/spv.Unordered to a higher benefit. The SPIR-V target environment will properly select between these two patterns. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D95237 Added: Modified: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir mlir/test/Dialect/SPIRV/IR/logical-ops.mlir mlir/test/Target/SPIRV/logical-ops.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index c369304cf18b..347b65a7739e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3216,6 +3216,8 @@ def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; +def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; +def SPV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; def SPV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; def SPV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>; def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; @@ -3332,15 +3334,15 @@ def SPV_OpcodeAttr : SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, - SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpOrdered, SPV_OC_OpUnordered, - SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, - SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, + SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index 0516e70f87c4..019b63f3a582 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -41,6 +41,11 @@ class SPV_LogicalUnaryOp<string mnemonic, Type operandType, SameOperandsAndResultShape])> { let parser = [{ return ::parseLogicalUnaryOp(parser, result); }]; let printer = [{ return ::printLogicalOp(getOperation(), p); }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$value), + [{::buildLogicalUnaryOp($_builder, $_state, value);}]> + ]; } // ----- @@ -507,6 +512,70 @@ def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual", // ----- +def SPV_IsInfOp : SPV_LogicalUnaryOp<"IsInf", SPV_Float, []> { + let summary = "Result is true if x is an IEEE Inf, otherwise result is false"; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + Results are computed per component. + + <!-- End of AutoGen section --> + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + isinf-op ::= ssa-id `=` `spv.IsInf` ssa-use + `:` float-scalar-vector-type + ``` + + #### Example: + + ```mlir + %2 = spv.IsInf %0: f32 + %3 = spv.IsInf %1: vector<4xi32> + ``` + }]; +} + +// ----- + +def SPV_IsNanOp : SPV_LogicalUnaryOp<"IsNan", SPV_Float, []> { + let summary = [{ + Result is true if x is an IEEE NaN, otherwise result is false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + Results are computed per component. + + <!-- End of AutoGen section --> + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + isnan-op ::= ssa-id `=` `spv.IsNan` ssa-use + `:` float-scalar-vector-type + ``` + + #### Example: + + ```mlir + %2 = spv.IsNan %0: f32 + %3 = spv.IsNan %1: vector<4xi32> + ``` + }]; +} + +// ----- + def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 95bb0eca4496..041495e2b7cb 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -386,6 +386,28 @@ class CmpFOpPattern final : public OpConversionPattern<CmpFOp> { ConversionPatternRewriter &rewriter) const override; }; +/// Converts floating point NaN check to SPIR-V ops. This pattern requires +/// Kernel capability. +class CmpFOpNanKernelPattern final : public OpConversionPattern<CmpFOp> { +public: + using OpConversionPattern<CmpFOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern does not +/// require additional capability. +class CmpFOpNanNonePattern final : public OpConversionPattern<CmpFOp> { +public: + using OpConversionPattern<CmpFOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts integer compare operation on i1 type operands to SPIR-V ops. class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> { public: @@ -730,7 +752,6 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands, DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); - DISPATCH(CmpFPredicate::ORD, spirv::OrderedOp); // Unordered. DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); @@ -738,7 +759,6 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands, DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); - DISPATCH(CmpFPredicate::UNO, spirv::UnorderedOp); #undef DISPATCH @@ -748,6 +768,47 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands, return failure(); } +LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( + CmpFOp cmpFOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + CmpFOpAdaptor cmpFOpOperands(operands); + + if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { + rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, cmpFOpOperands.lhs(), + cmpFOpOperands.rhs()); + return success(); + } + + if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { + rewriter.replaceOpWithNewOp<spirv::UnorderedOp>( + cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs()); + return success(); + } + + return failure(); +} + +LogicalResult CmpFOpNanNonePattern::matchAndRewrite( + CmpFOp cmpFOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (cmpFOp.getPredicate() != CmpFPredicate::ORD && + cmpFOp.getPredicate() != CmpFPredicate::UNO) + return failure(); + + CmpFOpAdaptor cmpFOpOperands(operands); + Location loc = cmpFOp.getLoc(); + + Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.lhs()); + Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.rhs()); + + Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); + if (cmpFOp.getPredicate() == CmpFPredicate::ORD) + replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); + + rewriter.replaceOp(cmpFOp, replace); + return success(); +} + //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// @@ -1102,7 +1163,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, SignedRemIOpPattern, XOrOpPattern, // Comparison patterns - BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern, + BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, @@ -1124,5 +1185,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>, TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter, context); + + // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel + // capability is available. + patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context, + /*benefit=*/2); } } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 3d99696d6882..4506447b0503 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -900,6 +900,16 @@ static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, state.addOperands({lhs, rhs}); } +static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, + Value value) { + Type boolType = builder.getI1Type(); + if (auto vecType = value.getType().dyn_cast<VectorType>()) + boolType = VectorType::get(vecType.getShape(), boolType); + state.addTypes(boolType); + + state.addOperands(value); +} + //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index a33db1dd42cf..8ae93c2e4b9b 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -301,6 +301,7 @@ func @cmpf(%arg0 : f32, %arg1 : f32) { // ----- +// With Kernel capability, we can convert NaN check to spv.Ordered/spv.Unordered. module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, {}> } { @@ -318,6 +319,31 @@ func @cmpf(%arg0 : f32, %arg1 : f32) { // ----- +// Without Kernel capability, we need to convert NaN check to spv.IsNan. +module attributes { + spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}> +} { + +// CHECK-LABEL: @cmpf +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func @cmpf(%arg0 : f32, %arg1 : f32) { + // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 + // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 + // CHECK-NEXT: %[[OR:.+]] = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1 + // CHECK-NEXT: %{{.+}} = spv.LogicalNot %[[OR]] : i1 + %0 = cmpf ord, %arg0, %arg1 : f32 + + // CHECK-NEXT: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 + // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 + // CHECK-NEXT: %{{.+}} = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1 + %1 = cmpf uno, %arg0, %arg1 : f32 + return +} + +} // end module + +// ----- + //===----------------------------------------------------------------------===// // std.cmpi //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index baf8b45d7eaf..b2c34b85f194 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -32,6 +32,40 @@ func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi // ----- +//===----------------------------------------------------------------------===// +// spv.IsInf +//===----------------------------------------------------------------------===// + +func @isinf_scalar(%arg0: f32) -> i1 { + // CHECK: spv.IsInf {{.*}} : f32 + %0 = spv.IsInf %arg0 : f32 + return %0 : i1 +} + +func @isinf_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spv.IsInf {{.*}} : vector<2xf32> + %0 = spv.IsInf %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.IsNan +//===----------------------------------------------------------------------===// + +func @isnan_scalar(%arg0: f32) -> i1 { + // CHECK: spv.IsNan {{.*}} : f32 + %0 = spv.IsNan %arg0 : f32 + return %0 : i1 +} + +func @isnan_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spv.IsNan {{.*}} : vector<2xf32> + %0 = spv.IsNan %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + //===----------------------------------------------------------------------===// // spv.LogicalAnd //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir index 000cf49d733a..bd92074de39f 100644 --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -80,6 +80,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> { %13 = spv.Ordered %arg0, %arg1 : f32 // CHECK: spv.Unordered %14 = spv.Unordered %arg0, %arg1 : f32 + // CHCK: spv.IsNan + %15 = spv.IsNan %arg0 : f32 + // CHCK: spv.IsInf + %16 = spv.IsInf %arg1 : f32 spv.Return } } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits