Author: Wen-Heng (Jack) Chung Date: 2020-06-05T22:18:20-05:00 New Revision: a1e3fec79420164b7cd398872d525f03c4436e96
URL: https://github.com/llvm/llvm-project/commit/a1e3fec79420164b7cd398872d525f03c4436e96 DIFF: https://github.com/llvm/llvm-project/commit/a1e3fec79420164b7cd398872d525f03c4436e96.diff LOG: Generalized op transformation logic for output tensor. Add more op lowering test cases. Added: mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir Modified: mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp Removed: ################################################################################ diff --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp index 46083be58a35..f1d8e914c3ec 100644 --- a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp +++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp @@ -66,7 +66,7 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { llvm::SmallVector<NamedAttribute, 3> transformedFilterAttrs; - // TBD: set layout attribute. + // set layout attribute. // Weight tensor transformation: // - Part 1: Merge non-K dimensions to dimension 0, name it as gemmK. // - Part 2: PassThrough K dimension to dimension 1, name it as gemmM. @@ -414,7 +414,7 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { StringAttr::get("wo", op.getContext()) }, op.getContext())); transformedInputAttrs.push_back(transformedInputImmLayoutAttr); - // TBD: set output_layout attribute. + // set output_layout attribute. auto transformedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout", ArrayAttr::get({ StringAttr::get("gemmK", op.getContext()), @@ -442,49 +442,59 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { llvm::SmallVector<NamedAttribute, 3> transformedOutputAttrs; - // TBD: set layout attribute. - // TBD: Part 1: Passthrough. - llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs; - transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); - transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext()))); - transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); - transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", - ArrayAttr::get({ - IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), - }, op.getContext()))); - transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", - ArrayAttr::get({ - StringAttr::get("ko", op.getContext()) - }, op.getContext()))); + // set layout attribute. + // Weight tensor transformation: + // - Part 1: PassThrough K dimension to dimension 0, name it as gemmM. + // - Part 2: Merge non-K dimensions to dimension 1, name it as gemmN. + { + llvm::SmallVector<IntegerAttr, 3> nonKDims; + IntegerAttr kDim; + llvm::SmallVector<StringAttr, 3> nonKDimNames; + StringAttr kDimName; + for (unsigned i = 0; i < outputLayoutAttr.size(); ++i) { + if (auto strAttr = outputLayoutAttr.getValue()[i].dyn_cast<StringAttr>()) { + if (strAttr.getValue() == "ko") { + kDim = IntegerAttr::get(IntegerType::get(32, op.getContext()), i); + kDimName = StringAttr::get(strAttr.getValue(), op.getContext()); + } else { + nonKDims.push_back(IntegerAttr::get(IntegerType::get(32, op.getContext()), i)); + nonKDimNames.push_back(StringAttr::get(strAttr.getValue(), op.getContext())); + } + } + } + + // Part 1: Passthrough. + llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs; + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({kDim}, op.getContext()))); + transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({kDimName}, op.getContext()))); - // TBD: Part 2: Merge. - llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs; - transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); - transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext()))); - transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); - transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", - ArrayAttr::get({ - IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), - IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), - IntegerAttr::get(IntegerType::get(32, op.getContext()), 3), - }, op.getContext()))); - transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", - ArrayAttr::get({ StringAttr::get("no", op.getContext()), - StringAttr::get("ho", op.getContext()), - StringAttr::get("wo", op.getContext()) - }, op.getContext()))); + // Part 2: Merge. + llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs; + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get(ArrayRef<Attribute>(nonKDims.begin(), nonKDims.end()), op.getContext()))); + transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get(ArrayRef<Attribute>(nonKDimNames.begin(), nonKDimNames.end()), op.getContext()))); - auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout", - ArrayAttr::get({ - DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()), - DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext()) - }, op.getContext())); - transformedOutputAttrs.push_back(transformedOutputLayoutAttr); + auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout", + ArrayAttr::get({ + DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()), + DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext()) + }, op.getContext())); + transformedOutputAttrs.push_back(transformedOutputLayoutAttr); + } // set source_layout attribute. auto outputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", outputLayoutAttr); transformedOutputAttrs.push_back(outputSrcLayoutAttr); - // TBD: set output_layout attribute. + // set output_layout attribute. auto transformedOutputOutputLayoutAttr = rewriter.getNamedAttr("output_layout", ArrayAttr::get({ StringAttr::get("gemmM", op.getContext()), @@ -492,7 +502,7 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { }, op.getContext())); transformedOutputAttrs.push_back(transformedOutputOutputLayoutAttr); - // TBD: set gridwise_gemm_argument_pos attribute. + // set gridwise_gemm_argument_pos attribute. auto outputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position", IntegerAttr::get(IntegerType::get(32, op.getContext()), 2)); transformedOutputAttrs.push_back(outputGridwiseGemmArgPosAttr); diff --git a/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir new file mode 100644 index 000000000000..4f6222237d59 --- /dev/null +++ b/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -miopen-lowering %s | FileCheck %s + +func @miopen_conv2d_ckyx_cnhw_knhw(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { + miopen.conv2d(%filter, %input, %output) { + filter_layout = ["c", "k", "y", "x"], + input_layout = ["ci", "ni", "hi", "wi"], + output_layout = ["ko", "no", "ho", "wo"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0] + } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32> + return +} +// CHECK-LABEL: func @miopen_conv2d +// CHECK-NOT: miopen.conv2d +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.gridwise_gemm diff --git a/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir b/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir new file mode 100644 index 000000000000..d5d0d9836bbd --- /dev/null +++ b/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -miopen-lowering %s | FileCheck %s + +func @miopen_conv2d_cyxk_chwn_khwn(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { + miopen.conv2d(%filter, %input, %output) { + filter_layout = ["c", "y", "x", "k"], + input_layout = ["ci", "hi", "wi", "ni"], + output_layout = ["ko", "ho", "wo", "no"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0] + } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32> + return +} +// CHECK-LABEL: func @miopen_conv2d +// CHECK-NOT: miopen.conv2d +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.gridwise_gemm diff --git a/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir new file mode 100644 index 000000000000..cec6463d783c --- /dev/null +++ b/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -miopen-lowering %s | FileCheck %s + +func @miopen_conv2d_cyxk_cnhw_knhw(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { + miopen.conv2d(%filter, %input, %output) { + filter_layout = ["c", "y", "x", "k"], + input_layout = ["ci", "ni", "hi", "wi"], + output_layout = ["ko", "no", "ho", "wo"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0] + } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32> + return +} +// CHECK-LABEL: func @miopen_conv2d +// CHECK-NOT: miopen.conv2d +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.gridwise_gemm _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits