Author: Lei Zhang Date: 2021-01-19T09:14:21-05:00 New Revision: 3a56a96664de955888d63c49a33808e3a1a294d9
URL: https://github.com/llvm/llvm-project/commit/3a56a96664de955888d63c49a33808e3a1a294d9 DIFF: https://github.com/llvm/llvm-project/commit/3a56a96664de955888d63c49a33808e3a1a294d9.diff LOG: [mlir][spirv] Define spv.GLSL.Fma and add lowerings Also changes some rewriter.create + rewriter.replaceOp calls into rewriter.replaceOpWithNewOp calls. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D94965 Added: Modified: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp mlir/test/Conversion/VectorToSPIRV/simple.mlir mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir mlir/test/Target/SPIRV/glsl-ops.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td index a566b7503a15..c34cd98dbb39 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -972,4 +972,44 @@ def SPV_GLSLSClampOp : SPV_GLSLTernaryArithmeticOp<"SClamp", 45, SPV_SignedInt> }]; } +// ----- + +def SPV_GLSLFmaOp : SPV_GLSLTernaryArithmeticOp<"Fma", 50, SPV_Float> { + let summary = "Computes a * b + c."; + + let description = [{ + In uses where this operation is decorated with NoContraction: + + - fma is considered a single operation, whereas the expression a * b + c + is considered two operations. + - The precision of fma can diff er from the precision of the expression + a * b + c. + - fma will be computed with the same precision as any other fma decorated + with NoContraction, giving invariant results for the same input values + of a, b, and c. + + Otherwise, in the absence of a NoContraction decoration, there are no + special constraints on the number of operations or diff erence in precision + between fma and the expression a * b +c. + + The operands must all be a scalar or vector whose component type is + floating-point. + + Result Type and the type of all operands must be the same type. Results + are computed per component. + + <!-- End of AutoGen section --> + ``` + fma-op ::= ssa-id `=` `spv.GLSL.Fma` ssa-use, ssa-use, ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %0 = spv.GLSL.Fma %a, %b, %c : f32 + %1 = spv.GLSL.Fma %a, %b, %c : vector<3xf16> + ``` + }]; +} + #endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 1509836ef2e2..52a35a17869f 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -36,9 +36,8 @@ struct VectorBroadcastConvert final vector::BroadcastOp::Adaptor adaptor(operands); SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), adaptor.source()); - Value construct = rewriter.create<spirv::CompositeConstructOp>( - broadcastOp.getLoc(), broadcastOp.getVectorType(), source); - rewriter.replaceOp(broadcastOp, construct); + rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( + broadcastOp, broadcastOp.getVectorType(), source); return success(); } }; @@ -55,9 +54,23 @@ struct VectorExtractOpConvert final return failure(); vector::ExtractOp::Adaptor adaptor(operands); int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt(); - Value newExtract = rewriter.create<spirv::CompositeExtractOp>( - extractOp.getLoc(), adaptor.vector(), id); - rewriter.replaceOp(extractOp, newExtract); + rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( + extractOp, adaptor.vector(), id); + return success(); + } +}; + +struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) + return failure(); + vector::FMAOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( + fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); return success(); } }; @@ -74,9 +87,8 @@ struct VectorInsertOpConvert final return failure(); vector::InsertOp::Adaptor adaptor(operands); int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt(); - Value newInsert = rewriter.create<spirv::CompositeInsertOp>( - insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); - rewriter.replaceOp(insertOp, newInsert); + rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( + insertOp, adaptor.source(), adaptor.dest(), id); return success(); } }; @@ -92,10 +104,9 @@ struct VectorExtractElementOpConvert final if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) return failure(); vector::ExtractElementOp::Adaptor adaptor(operands); - Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>( - extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), + rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( + extractElementOp, extractElementOp.getType(), adaptor.vector(), extractElementOp.position()); - rewriter.replaceOp(extractElementOp, newExtractElement); return success(); } }; @@ -111,10 +122,9 @@ struct VectorInsertElementOpConvert final if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) return failure(); vector::InsertElementOp::Adaptor adaptor(operands); - Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>( - insertElementOp.getLoc(), insertElementOp.getType(), - insertElementOp.dest(), adaptor.source(), insertElementOp.position()); - rewriter.replaceOp(insertElementOp, newInsertElement); + rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( + insertElementOp, insertElementOp.getType(), insertElementOp.dest(), + adaptor.source(), insertElementOp.position()); return success(); } }; @@ -124,7 +134,8 @@ struct VectorInsertElementOpConvert final void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert, - VectorInsertOpConvert, VectorExtractElementOpConvert, - VectorInsertElementOpConvert>(typeConverter, context); + patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert, + VectorExtractOpConvert, VectorFmaOpConvert, + VectorInsertOpConvert, VectorInsertElementOpConvert>( + typeConverter, context); } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir index 3594a6db805e..fddfd911fb19 100644 --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -57,3 +57,13 @@ func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) { %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> spv.Return } + +// ----- + +// CHECK-LABEL: func @fma +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32> +func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) { + %0 = vector.fma %a, %b, %c: vector<4xf32> + spv.Return +} diff --git a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir index 42377c2277a7..0533396406f7 100644 --- a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir @@ -345,3 +345,23 @@ func @fclamp(%arg0 : i32, %min : i32, %max : i32) -> () { %2 = spv.GLSL.SClamp %arg0, %min, %max : i32 return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GLSL.Fma +//===----------------------------------------------------------------------===// + +func @fma(%a : f32, %b : f32, %c : f32) -> () { + // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %2 = spv.GLSL.Fma %a, %b, %c : f32 + return +} + +// ----- + +func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () { + // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32> + %2 = spv.GLSL.Fma %a, %b, %c : vector<3xf32> + return +} diff --git a/mlir/test/Target/SPIRV/glsl-ops.mlir b/mlir/test/Target/SPIRV/glsl-ops.mlir index d635bde9cbf1..4dfd249288b0 100644 --- a/mlir/test/Target/SPIRV/glsl-ops.mlir +++ b/mlir/test/Target/SPIRV/glsl-ops.mlir @@ -48,4 +48,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> { %13 = spv.GLSL.SClamp %arg0, %arg1, %arg2 : si32 spv.Return } + + spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { + // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32 + spv.Return + } } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits