Author: Wen-Heng (Jack) Chung Date: 2020-06-05T22:18:20-05:00 New Revision: 01659f382d0ab554d59bfd793a0f9153613c7e20
URL: https://github.com/llvm/llvm-project/commit/01659f382d0ab554d59bfd793a0f9153613c7e20 DIFF: https://github.com/llvm/llvm-project/commit/01659f382d0ab554d59bfd793a0f9153613c7e20.diff LOG: Generalized op transformation logic for weight tensor. Add test cases. Added: mlir/test/Dialect/MIOpen/lowering_kcyx_nchw_nkhw.mlir mlir/test/Dialect/MIOpen/lowering_kyxc_nhwc_nhwk.mlir Modified: mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp Removed: ################################################################################ diff --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp index 0378a4c113bf..46083be58a35 100644 --- a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp +++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp @@ -67,44 +67,53 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> { llvm::SmallVector<NamedAttribute, 3> transformedFilterAttrs; // TBD: set layout attribute. - // TBD: Merge part. - llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart1Specs; - transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); - transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext()))); - transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); - transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", - ArrayAttr::get({ - IntegerAttr::get(IntegerType::get(32, op.getContext()), 1), - IntegerAttr::get(IntegerType::get(32, op.getContext()), 2), - IntegerAttr::get(IntegerType::get(32, op.getContext()), 3), - }, op.getContext()))); - transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", - ArrayAttr::get({ - StringAttr::get("c", op.getContext()), - StringAttr::get("y", op.getContext()), - StringAttr::get("x", op.getContext()) - }, op.getContext()))); - - // TBD: Passthrough part. - llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart2Specs; - transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); - transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext()))); - transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); - transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", - ArrayAttr::get({ - IntegerAttr::get(IntegerType::get(32, op.getContext()), 0), - }, op.getContext()))); - transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", - ArrayAttr::get({ - StringAttr::get("k", op.getContext()) - }, op.getContext()))); - - auto transformedFilterLayoutAttr = rewriter.getNamedAttr("layout", - ArrayAttr::get({ - DictionaryAttr::get(transformedFilterLayoutPart1Specs, op.getContext()), - DictionaryAttr::get(transformedFilterLayoutPart2Specs, op.getContext()) - }, op.getContext())); - transformedFilterAttrs.push_back(transformedFilterLayoutAttr); + // Weight tensor transformation: + // - Part 1: Merge non-K dimensions to dimension 0, name it as gemmK. + // - Part 2: PassThrough K dimension to dimension 1, name it as gemmM. + { + llvm::SmallVector<IntegerAttr, 3> nonKDims; + IntegerAttr kDim; + llvm::SmallVector<StringAttr, 3> nonKDimNames; + StringAttr kDimName; + for (unsigned i = 0; i < filterLayoutAttr.size(); ++i) { + if (auto strAttr = filterLayoutAttr.getValue()[i].dyn_cast<StringAttr>()) { + if (strAttr.getValue() == "k") { + kDim = IntegerAttr::get(IntegerType::get(32, op.getContext()), i); + kDimName = StringAttr::get(strAttr.getValue(), op.getContext()); + } else { + nonKDims.push_back(IntegerAttr::get(IntegerType::get(32, op.getContext()), i)); + nonKDimNames.push_back(StringAttr::get(strAttr.getValue(), op.getContext())); + } + } + } + + // Part 1: Merge part. + llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart1Specs; + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get(ArrayRef<Attribute>(nonKDims.begin(), nonKDims.end()), op.getContext()))); + transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get(ArrayRef<Attribute>(nonKDimNames.begin(), nonKDimNames.end()), op.getContext()))); + + // Part 2: Passthrough part. + llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart2Specs; + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions", + ArrayAttr::get({kDim}, op.getContext()))); + transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names", + ArrayAttr::get({kDimName}, op.getContext()))); + + auto transformedFilterLayoutAttr = rewriter.getNamedAttr("layout", + ArrayAttr::get({ + DictionaryAttr::get(transformedFilterLayoutPart1Specs, op.getContext()), + DictionaryAttr::get(transformedFilterLayoutPart2Specs, op.getContext()) + }, op.getContext())); + transformedFilterAttrs.push_back(transformedFilterLayoutAttr); + } // set source_layout attribute. auto filterSrcLayoutAttr = rewriter.getNamedAttr("source_layout", filterLayoutAttr); diff --git a/mlir/test/Dialect/MIOpen/lowering_kcyx_nchw_nkhw.mlir b/mlir/test/Dialect/MIOpen/lowering_kcyx_nchw_nkhw.mlir new file mode 100644 index 000000000000..1f880ee37ef1 --- /dev/null +++ b/mlir/test/Dialect/MIOpen/lowering_kcyx_nchw_nkhw.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -miopen-lowering %s | FileCheck %s + +func @miopen_conv2d_kcyx_nchw_nkhw(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { + miopen.conv2d(%filter, %input, %output) { + filter_layout = ["k", "c", "y", "x"], + input_layout = ["ni", "ci", "hi", "wi"], + output_layout = ["no", "ko", "ho", "wo"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0] + } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32> + return +} +// CHECK-LABEL: func @miopen_conv2d +// CHECK-NOT: miopen.conv2d +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.gridwise_gemm diff --git a/mlir/test/Dialect/MIOpen/lowering_kyxc_nhwc_nhwk.mlir b/mlir/test/Dialect/MIOpen/lowering_kyxc_nhwc_nhwk.mlir new file mode 100644 index 000000000000..8450b72ec78a --- /dev/null +++ b/mlir/test/Dialect/MIOpen/lowering_kyxc_nhwc_nhwk.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -miopen-lowering -split-input-file %s | FileCheck %s + +func @miopen_conv2d_kyxc_nhwc_nhwk(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { + miopen.conv2d(%filter, %input, %output) { + filter_layout = ["k", "y", "x", "c"], + input_layout = ["ni", "hi", "wi", "ci"], + output_layout = ["no", "ho", "wo", "ko"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0] + } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32> + return +} +// CHECK-LABEL: func @miopen_conv2d +// CHECK-NOT: miopen.conv2d +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.gridwise_gemm _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits