Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:20-05:00
New Revision: a1e3fec79420164b7cd398872d525f03c4436e96

URL: 
https://github.com/llvm/llvm-project/commit/a1e3fec79420164b7cd398872d525f03c4436e96
DIFF: 
https://github.com/llvm/llvm-project/commit/a1e3fec79420164b7cd398872d525f03c4436e96.diff

LOG: Generalized op transformation logic for output tensor.

Add more op lowering test cases.

Added: 
    mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir
    mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir
    mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir

Modified: 
    mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp 
b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
index 46083be58a35..f1d8e914c3ec 100644
--- a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
@@ -66,7 +66,7 @@ struct Conv2DOpRewritePattern : public 
OpRewritePattern<miopen::Conv2DOp> {
 
     llvm::SmallVector<NamedAttribute, 3> transformedFilterAttrs;
 
-    // TBD: set layout attribute.
+    // set layout attribute.
     // Weight tensor transformation:
     // - Part 1: Merge non-K dimensions to dimension 0, name it as gemmK.
     // - Part 2: PassThrough K dimension to dimension 1, name it as gemmM.
@@ -414,7 +414,7 @@ struct Conv2DOpRewritePattern : public 
OpRewritePattern<miopen::Conv2DOp> {
                                                             
StringAttr::get("wo", op.getContext())
                                                         }, op.getContext()));
     transformedInputAttrs.push_back(transformedInputImmLayoutAttr);
-    // TBD: set output_layout attribute.
+    // set output_layout attribute.
     auto transformedInputOutputLayoutAttr = 
rewriter.getNamedAttr("output_layout",
                                                         ArrayAttr::get({
                                                             
StringAttr::get("gemmK", op.getContext()),
@@ -442,49 +442,59 @@ struct Conv2DOpRewritePattern : public 
OpRewritePattern<miopen::Conv2DOp> {
 
     llvm::SmallVector<NamedAttribute, 3> transformedOutputAttrs;
 
-    // TBD: set layout attribute.
-    // TBD: Part 1: Passthrough.
-    llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs;
-    
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", 
ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, 
op.getContext())));
-    transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", 
ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
-    
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation",
 StringAttr::get("PassThrough", op.getContext())));
-    
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
-                                                ArrayAttr::get({
-                                                    
IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
-                                                }, op.getContext())));
-    
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
-                                                ArrayAttr::get({
-                                                    StringAttr::get("ko", 
op.getContext())
-                                                }, op.getContext())));
+    // set layout attribute.
+    // Weight tensor transformation:
+    // - Part 1: PassThrough K dimension to dimension 0, name it as gemmM.
+    // - Part 2: Merge non-K dimensions to dimension 1, name it as gemmN.
+    {
+      llvm::SmallVector<IntegerAttr, 3> nonKDims;
+      IntegerAttr kDim;
+      llvm::SmallVector<StringAttr, 3> nonKDimNames;
+      StringAttr kDimName;
+      for (unsigned i = 0; i < outputLayoutAttr.size(); ++i) {
+        if (auto strAttr = 
outputLayoutAttr.getValue()[i].dyn_cast<StringAttr>()) {
+          if (strAttr.getValue() == "ko") {
+            kDim = IntegerAttr::get(IntegerType::get(32, op.getContext()), i);
+            kDimName = StringAttr::get(strAttr.getValue(), op.getContext());
+          } else {
+            nonKDims.push_back(IntegerAttr::get(IntegerType::get(32, 
op.getContext()), i));
+            nonKDimNames.push_back(StringAttr::get(strAttr.getValue(), 
op.getContext()));
+          }
+        }
+      }
+ 
+      // Part 1: Passthrough.
+      llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs;
+      
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", 
ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, 
op.getContext())));
+      
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", 
ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
+      
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation",
 StringAttr::get("PassThrough", op.getContext())));
+      
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                  ArrayAttr::get({kDim}, 
op.getContext())));
+      
transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                  ArrayAttr::get({kDimName}, 
op.getContext())));
 
-    // TBD: Part 2: Merge.
-    llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs;
-    
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", 
ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, 
op.getContext())));
-    transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", 
ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
-    
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation",
 StringAttr::get("Merge", op.getContext())));
-    
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
-                                                ArrayAttr::get({
-                                                    
IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
-                                                    
IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
-                                                    
IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
-                                                }, op.getContext())));
-    
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
-                                                ArrayAttr::get({               
                                                                              
StringAttr::get("no", op.getContext()),
-                                                    StringAttr::get("ho", 
op.getContext()),
-                                                    StringAttr::get("wo", 
op.getContext())
-                                                }, op.getContext())));
+      // Part 2: Merge.
+      llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs;
+      
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", 
ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, 
op.getContext())));
+      
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", 
ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
+      
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation",
 StringAttr::get("Merge", op.getContext())));
+      
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                  
ArrayAttr::get(ArrayRef<Attribute>(nonKDims.begin(), nonKDims.end()), 
op.getContext())));
+      
transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                  
ArrayAttr::get(ArrayRef<Attribute>(nonKDimNames.begin(), nonKDimNames.end()), 
op.getContext())));
 
-    auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout",
-                                                             ArrayAttr::get({
-                                                                 
DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()),
-                                                                 
DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext())
-                                                             }, 
op.getContext()));
-    transformedOutputAttrs.push_back(transformedOutputLayoutAttr);
+      auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout",
+                                                               ArrayAttr::get({
+                                                                   
DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()),
+                                                                   
DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext())
+                                                               }, 
op.getContext()));
+      transformedOutputAttrs.push_back(transformedOutputLayoutAttr);
+    }
 
     // set source_layout attribute.
     auto outputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", 
outputLayoutAttr);
     transformedOutputAttrs.push_back(outputSrcLayoutAttr);
-    // TBD: set output_layout attribute.
+    // set output_layout attribute.
     auto transformedOutputOutputLayoutAttr = 
rewriter.getNamedAttr("output_layout",
                                                         ArrayAttr::get({
                                                             
StringAttr::get("gemmM", op.getContext()),
@@ -492,7 +502,7 @@ struct Conv2DOpRewritePattern : public 
OpRewritePattern<miopen::Conv2DOp> {
                                                         }, op.getContext()));
     transformedOutputAttrs.push_back(transformedOutputOutputLayoutAttr);
 
-    // TBD: set gridwise_gemm_argument_pos attribute.
+    // set gridwise_gemm_argument_pos attribute.
     auto outputGridwiseGemmArgPosAttr = 
rewriter.getNamedAttr("gridwise_gemm_argument_position", 
                                                              
IntegerAttr::get(IntegerType::get(32, op.getContext()), 2));
     transformedOutputAttrs.push_back(outputGridwiseGemmArgPosAttr);

diff  --git a/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir 
b/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir
new file mode 100644
index 000000000000..4f6222237d59
--- /dev/null
+++ b/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -miopen-lowering %s | FileCheck %s
+
+func @miopen_conv2d_ckyx_cnhw_knhw(%filter : memref<?x?x?x?xf32>, %input : 
memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
+  miopen.conv2d(%filter, %input, %output) {
+    filter_layout = ["c", "k", "y", "x"],
+    input_layout = ["ci", "ni", "hi", "wi"],
+    output_layout = ["ko", "no", "ho", "wo"],
+    dilations = [1, 1],
+    strides = [1, 1],
+    padding = [0, 0]
+  } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @miopen_conv2d
+//  CHECK-NOT: miopen.conv2d
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.gridwise_gemm

diff  --git a/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir 
b/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir
new file mode 100644
index 000000000000..d5d0d9836bbd
--- /dev/null
+++ b/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -miopen-lowering %s | FileCheck %s
+
+func @miopen_conv2d_cyxk_chwn_khwn(%filter : memref<?x?x?x?xf32>, %input : 
memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
+  miopen.conv2d(%filter, %input, %output) {
+    filter_layout = ["c", "y", "x", "k"],
+    input_layout = ["ci", "hi", "wi", "ni"],
+    output_layout = ["ko", "ho", "wo", "no"],
+    dilations = [1, 1],
+    strides = [1, 1],
+    padding = [0, 0]
+  } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @miopen_conv2d
+//  CHECK-NOT: miopen.conv2d
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.gridwise_gemm

diff  --git a/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir 
b/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir
new file mode 100644
index 000000000000..cec6463d783c
--- /dev/null
+++ b/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -miopen-lowering %s | FileCheck %s
+
+func @miopen_conv2d_cyxk_cnhw_knhw(%filter : memref<?x?x?x?xf32>, %input : 
memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
+  miopen.conv2d(%filter, %input, %output) {
+    filter_layout = ["c", "y", "x", "k"],
+    input_layout = ["ci", "ni", "hi", "wi"],
+    output_layout = ["ko", "no", "ho", "wo"],
+    dilations = [1, 1],
+    strides = [1, 1],
+    padding = [0, 0]
+  } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @miopen_conv2d
+//  CHECK-NOT: miopen.conv2d
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.transform
+//  CHECK-NEXT: miopen.gridwise_gemm


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to