Author: Wen-Heng (Jack) Chung Date: 2020-06-05T22:18:19-05:00 New Revision: 05d19a7eb4ea27d4d0e7989b145c2f7458fff54a
URL: https://github.com/llvm/llvm-project/commit/05d19a7eb4ea27d4d0e7989b145c2f7458fff54a DIFF: https://github.com/llvm/llvm-project/commit/05d19a7eb4ea27d4d0e7989b145c2f7458fff54a.diff LOG: Add Op traversing logic into MIOpen dialect -> C++ translator. Added: mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir Modified: mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp Removed: mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir ################################################################################ diff --git a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h index d3e9b8ee09a2..09d2d1166caf 100644 --- a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h +++ b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h @@ -33,7 +33,17 @@ class ModuleOp; /// Convert the given MLIR module into MIOpen C++ . In case of error, report it /// to the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr<llvm::StringRef> translateModuleToMIOpenCPP(ModuleOp m); +std::unique_ptr<llvm::StringRef> translateModuleToMIOpenCpp(ModuleOp m); + +/// Convert the given MLIR module into MIOpen C++ Header. In case of error, report it +/// to the error handler registered with the MLIR context, if any (obtained from +/// the MLIR module), and return `nullptr`. +std::unique_ptr<llvm::StringRef> translateModuleToMIOpenHeader(ModuleOp m); + +/// Convert the given MLIR module into MIOpen C++ Solver. In case of error, report it +/// to the error handler registered with the MLIR context, if any (obtained from +/// the MLIR module), and return `nullptr`. +std::unique_ptr<llvm::StringRef> translateModuleToMIOpenSolver(ModuleOp m); } // namespace mlir diff --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt b/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt index 855985b4b945..3d37305c60e7 100644 --- a/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt +++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt @@ -5,6 +5,7 @@ add_llvm_library(MLIRMIOpenCpp ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MIOpenOps ) target_link_libraries(MLIRMIOpenCpp + LLVMSupport MLIRIR MLIRMIOpenOps MLIRStandardOps diff --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp index 5fe33d695cb3..b071565c2b51 100644 --- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp +++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp @@ -13,34 +13,396 @@ #include "mlir/Dialect/MIOpenOps/MIOpenCPP.h" #include "mlir/Dialect/MIOpenOps/MIOpenOps.h" #include "mlir/Dialect/StandardOps/Ops.h" - +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" #include "mlir/Translation.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; -std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCPP(ModuleOp m) { - // Check constraints: +namespace { + +static constexpr StringLiteral kVarName[3] = {"weight", "input", "output"}; + +static constexpr int kConv2DTensorDimension = 4; + +static constexpr StringLiteral kCppPreamblePart1 = R"( +#include "common_header.hpp" +)"; + +static constexpr StringLiteral kCppPreamblePart2 = R"( +#include "float_types.h" + +extern "C" __global__ +)"; + +static constexpr StringLiteral kCppPreamblePart3 = R"( + (const FLOAT* const __restrict__ p_in_global, + const FLOAT* const __restrict__ p_wei_global, + FLOAT* const __restrict__ p_out_global) +{ + using namespace ck; + + constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H; + constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W; + + constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H; + constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W; + + constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H; + constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W; + + constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H; + constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W; + + constexpr index_t BlockSize = CK_PARAM_TUNABLE_BLOCK_SIZE; + constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE; + + constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK; + constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK; + constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK; + +)"; + +static constexpr StringLiteral kCppInterlude = R"( + using ConvStrides = Sequence<ConvStrideH, ConvStrideW>; + using ConvDilations = Sequence<ConvDilationH, ConvDilationW>; + + using InLeftPads = Sequence<InLeftPadH, InLeftPadW>; + using InRightPads = Sequence<InRightPadH, InRightPadW>; + + // read and calculate tuning parameter + constexpr index_t GemmMPerThreadSubC = CK_PARAM_TUNABLE_GEMM_M_PER_THREAD_SUB_C; + constexpr index_t GemmNPerThreadSubC = CK_PARAM_TUNABLE_GEMM_N_PER_THREAD_SUB_C; + constexpr index_t GemmMLevel0Cluster = CK_PARAM_TUNABLE_GEMM_M_LEVEL0_CLUSTER; + constexpr index_t GemmNLevel0Cluster = CK_PARAM_TUNABLE_GEMM_N_LEVEL0_CLUSTER; + constexpr index_t GemmMLevel1Cluster = CK_PARAM_TUNABLE_GEMM_M_LEVEL1_CLUSTER; + constexpr index_t GemmNLevel1Cluster = CK_PARAM_TUNABLE_GEMM_N_LEVEL1_CLUSTER; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t GemmThreadGemmDataPerReadM = GemmMPerThreadSubC; + constexpr index_t GemmThreadGemmDataPerReadN = GemmNPerThreadSubC; + + // A matrix + constexpr index_t GemmABlockCopyClusterLengths_GemmK = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K; + + constexpr index_t GemmABlockCopyClusterLengths_GemmM = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M; + + constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK = + GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK; + + constexpr index_t GemmABlockCopyThreadSliceLengths_GemmM = + GemmMPerBlock / GemmABlockCopyClusterLengths_GemmM; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = + Sequence<GemmABlockCopyThreadSliceLengths_GemmK, GemmABlockCopyThreadSliceLengths_GemmM>; + + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = + Sequence<GemmABlockCopyClusterLengths_GemmK, GemmABlockCopyClusterLengths_GemmM>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K; + + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M; + + // B matrix + constexpr index_t GemmBBlockCopyClusterLengths_GemmK = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K; + + constexpr index_t GemmBBlockCopyClusterLengths_GemmN = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N; + + constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK = + GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK; + + constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmN = + GemmNPerBlock / GemmBBlockCopyClusterLengths_GemmN; + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = + Sequence<GemmBBlockCopyThreadSliceLengths_GemmK, GemmBBlockCopyThreadSliceLengths_GemmN>; + + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = + Sequence<GemmBBlockCopyClusterLengths_GemmK, GemmBBlockCopyClusterLengths_GemmN>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N; + + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N; + + // C matrix + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = + CK_PARAM_TUNABLE_GEMM_C_THREAD_COPY_DST_DATA_PER_WRITE_GEMM_N1; +)"; + +static constexpr StringLiteral kCppEpiloguePart1 = R"( + <GridSize, + BlockSize, + FLOAT, + FLOAT_ACCUM, +)"; + +static constexpr StringLiteral kCppEpiloguePart2 =R"( + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmK, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + + gridwise_conv.Run(p_in_global, p_wei_global, p_out_global); +} +)"; + +void EmitCppPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr) { + output << kCppPreamblePart1; + +// Between Preamble Part 1 and Part 2: +// #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" + output << R"(#include "gridwise_convolution_implicit_gemm_v4r4_)"; + output << layoutStr << ".hpp"; + + output << kCppPreamblePart2; + +// Between Preamble Part 2 and Par 3: +// __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw( + output << R"( + __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_)"; + output << layoutStr; + + output << kCppPreamblePart3; +} + +void EmitCppInterlude(llvm::raw_ostream &output) { + output << kCppInterlude; +} + +void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm::SmallVector<std::string, 3> tensorDescs) { +// Before Part1: +// constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw + output << R"( + constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_)"; + output << layoutStr; + + output << kCppEpiloguePart1; + +// Between Part1 and Part2: +// decltype(in_nchw_desc), +// decltype(wei_kcyx_desc), +// decltype(out_nkhw_desc), + for (auto desc : tensorDescs) { + output << " decltype(" << desc << "),\n"; + } + + output << kCppEpiloguePart2; +} + +void EmitLayoutString(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr, llvm::StringRef prefix, llvm::StringRef suffix, llvm::StringRef delimiter = "") { + for (int i = 0; i < kConv2DTensorDimension; ++i) { + auto attr = layoutArrayAttr[i]; + if (auto strAttr = attr.dyn_cast<StringAttr>()) { + output << prefix << strAttr.getValue() << suffix; + } + if (i < kConv2DTensorDimension - 1) { + output << delimiter; + } + } +} + +void EmitDimensionVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr) { + for (int i = 0; i < kConv2DTensorDimension; ++i) { + auto attr = layoutArrayAttr[i]; + if (auto strAttr = attr.dyn_cast<StringAttr>()) { + output << " const index_t " << strAttr.getValue() << " = CK_PARAM_PROBLEM_"; + + switch (llvm::toUpper(strAttr.getValue()[0])) { + case 'H': + case 'W': + output << llvm::toUpper(strAttr.getValue()[0]); + output << llvm::toUpper(strAttr.getValue()[1]); + break; + default: + output << llvm::toUpper(strAttr.getValue()[0]); + } + output << ";\n"; + } + } +} + +void EmitStrideVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr) { + for (int i = 0; i < kConv2DTensorDimension; ++i) { + auto attr = layoutArrayAttr[i]; + if (auto strAttr = attr.dyn_cast<StringAttr>()) { + output << " const index_t stride_" << strAttr.getValue() << " = "; + + if (i == 0) { + output << "1;\n"; + } else { + auto prevAttr = layoutArrayAttr[i - 1]; + if (auto strPrevAttr = prevAttr.dyn_cast<StringAttr>()) { + output << strPrevAttr.getValue() << " * stride_" << strPrevAttr.getValue() << ";\n"; + } + } + } + } +} + +void ObtainModuleInfo(ModuleOp &m, std::string &layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) { + // (TBD verifiying logic) The Module could contain multiple FuncOp, and inside each FuncOp there + // should be exactly: + // - 3 input arguments + // - 1 result. // - // The Module should only contain 1 function. - // The Function should only contain exactly: // - 0 conv2d op. // - 5 transform ops (1 for filter, 3 for input, 1 for output). // - 1 gridwise gemm op. - m.dump(); - return std::make_unique<llvm::StringRef>("Hello World"); + // Enumerate FuncOp instances inside the ModuleOp. + for (auto f : m.getOps<FuncOp>()) { + int srcLayoutAttrCtr = 0; + llvm::raw_string_ostream los(layoutStr); + + // First iteration. Construct tensor descriptor names. + f.walk([&srcLayoutAttrCtr, &tensorDescs, &los](miopen::TransformOp op) { + // get source_layout attribute. + auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout"); + if (srcLayoutAttr) { + auto srcLayout = srcLayoutAttr.getValue(); + + // Prepare tensor descriptor variable name. + std::string desc{}; + llvm::raw_string_ostream os(desc); + os << kVarName[srcLayoutAttrCtr++] << "_"; + EmitLayoutString(os, srcLayout, "", "", "_"); + os << "_desc"; + os.flush(); + tensorDescs.push_back(desc); + + // Prepare layout string. + if (srcLayoutAttrCtr != 1) + los << "_"; + EmitLayoutString(los, srcLayout, "", ""); + } + }); + los.flush(); + } +} + +} + +std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) { + std::string resultStr; + llvm::raw_string_ostream output(resultStr); + + // Enumerate FuncOp instances inside the ModuleOp. + for (auto f : m.getOps<FuncOp>()) { + std::string layoutStr; + llvm::SmallVector<std::string, 3> tensorDescs; + + // Obtain critical information from ModuleOp. + ObtainModuleInfo(m, layoutStr, tensorDescs); + + int srcLayoutAttrCtr = 0; + + // Start emitting. + + EmitCppPreamble(output, layoutStr); + + f.walk([&output, &srcLayoutAttrCtr, &tensorDescs](miopen::TransformOp op) { + + // get source_layout attribute. + auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout"); + if (srcLayoutAttr) { + auto srcLayout = srcLayoutAttr.getValue(); + output << " // "; + EmitLayoutString(output, srcLayout, "", "", ", "); + output << '\n'; + + EmitDimensionVariables(output, srcLayout); + output << '\n'; + EmitStrideVariables(output, srcLayout); + + output << " constexpr auto " << tensorDescs[srcLayoutAttrCtr++]; + output << " = make_native_tensor_descriptor(Sequence<"; + EmitLayoutString(output, srcLayout, "", "", ", "); + output << ">{}, Sequence<"; + EmitLayoutString(output, srcLayout, "stride_", "", ", "); + output << ">{});\n\n"; + } + + //// get layout attribute. + // TBD not used in emitting C++ source wrapper. + // would be used in emitting C++ header. + //auto layoutAttr = op.getAttrOfType<ArrayAttr>("layout"); + //for (auto layoutSpec : layoutAttr) { + // if (auto layoutSpecDict = layoutSpec.dyn_cast<DictionaryAttr>()) { + // //output << "dimensions: " << layoutSpecDict.get("dimensions") << "\n"; + // //output << "names: " << layoutSpecDict.get("names") << "\n"; + // //output << "source_dimensions: " << layoutSpecDict.get("source_dimensions") << "\n"; + // //output << "source_names: " << layoutSpecDict.get("source_names") << "\n"; + // //output << "transformation: " << layoutSpecDict.get("transformation") << "\n"; + // } + //} + }); + + EmitCppInterlude(output); + + // TBD get tuning parameters. + //f.walk([&output](miopen::GridwiseGemmOp op) { + // // get op name. + // //output << "op name: " << op.getOperationName() << "\n"; + // //op.dump(); + //}); + + EmitCppEpilogue(output, layoutStr, tensorDescs); + } + + output.flush(); + return std::make_unique<llvm::StringRef>(resultStr); } static TranslateFromMLIRRegistration - toCPP("mlir-to-miopencpp", [](ModuleOp module, llvm::raw_ostream &output) { - auto sourceCode = mlir::translateModuleToMIOpenCPP(module); + toCpp("mlir-to-miopen-cpp", [](ModuleOp module, llvm::raw_ostream &output) { + auto sourceCode = mlir::translateModuleToMIOpenCpp(module); if (!sourceCode) return failure(); output << *sourceCode; return success(); }); + +//static TranslateFromMLIRRegistration +// toHeader("mlir-to-miopen-h", [](ModuleOp module, llvm::raw_ostream &output) { +// auto sourceCode = mlir::translateModuleToMIOpenHeader(module); +// if (!sourceCode) +// return failure(); +// +// output << *sourceCode; +// return success(); +// }); + diff --git a/mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir b/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir similarity index 70% rename from mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir rename to mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir index 4b4bc0031717..ca9751191859 100644 --- a/mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir +++ b/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-translate -mlir-to-miopencpp %s | FileCheck %s +// RUN: mlir-translate -mlir-to-miopen-cpp %s | FileCheck %s -// CHECK: Hello World +// CHECK: __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_kcyx_niciwihi_nokohowo func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) { // filter tensor %filter_gemmK_gemmM = miopen.transform(%filter) { @@ -19,7 +19,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<? source_dimensions = [0], source_names = ["n"] } - ] + ], + source_layout = ["k", "c", "y", "x"] } : memref<?x?x?x?xf32> to memref<?x?xf32> // input tensor @@ -30,14 +31,14 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<? names = ["n"], transformation = "passthorugh", source_dimensions = [0], - source_names = ["n"] + source_names = ["ni"] }, { dimensions = [1], names = ["c"], transformation = "passthorugh", source_dimensions = [1], - source_names = ["c"] + source_names = ["ci"] }, { dimensions = [2], @@ -55,7 +56,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<? source_dimensions = [3], source_names = ["wi"] } - ] + ], + source_layout = ["ni", "ci", "wi", "hi"] } : memref<?x?x?x?xf32> to memref<?x?x?x?xf32> %input_n_c_y_ho_x_wo = miopen.transform(%input_n_c_hipad_wipad) { @@ -90,7 +92,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<? source_dimensions = [2], source_names = ["wipad"] } - ] + ], + intermediate_layout = ["n", "c", "hipad", "wipad"] } : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?x?xf32> %input_gemmK_gemmN = miopen.transform(%input_n_c_y_ho_x_wo) { @@ -107,9 +110,10 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<? names = ["gemmN"], transformation = "merge", source_dimensions = [0, 3, 5], - source_names = ["n", "ho", "wo"] + source_names = ["n", "hipad", "wipad"] } - ] + ], + intermediate_layout = ["n", "c", "y", "hipad", "x", "wipad"] } : memref<?x?x?x?x?x?x?xf32> to memref<?x?xf32> // output tensor @@ -120,16 +124,17 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<? names = ["gemmM"], transformation = "passthrough", source_dimensions = [1], - source_names = ["k"] + source_names = ["ko"] }, { dimensions = [1], names = ["gemmN"], transformation = "merge", source_dimensions = [0, 2, 3], - source_names = ["n", "ho", "wo"] + source_names = ["no", "ho", "wo"] } - ] + ], + source_layout = ["no", "ko", "ho", "wo"] } : memref<?x?x?x?xf32> to memref<?x?xf32> // apply gridwise GEMM @@ -143,3 +148,10 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<? return } +// CHECK: constexpr auto weight_k_c_y_x_desc = make_native_tensor_descriptor(Sequence<k, c, y, x>{}, Sequence<stride_k, stride_c, stride_y, stride_x>{}); +// CHECK: constexpr auto input_ni_ci_wi_hi_desc = make_native_tensor_descriptor(Sequence<ni, ci, wi, hi>{}, Sequence<stride_ni, stride_ci, stride_wi, stride_hi>{}); +// CHECK: constexpr auto output_no_ko_ho_wo_desc = make_native_tensor_descriptor(Sequence<no, ko, ho, wo>{}, Sequence<stride_no, stride_ko, stride_ho, stride_wo>{}); +// CHECK: constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_kcyx_niciwihi_nokohowo +// CHECK: decltype(weight_k_c_y_x_desc), +// CHECK: decltype(input_ni_ci_wi_hi_desc), +// CHECK: decltype(output_no_ko_ho_wo_desc), _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits