Author: Lei Zhang Date: 2021-01-14T09:57:56-05:00 New Revision: 6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218
URL: https://github.com/llvm/llvm-project/commit/6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218 DIFF: https://github.com/llvm/llvm-project/commit/6b9fa8a50d0f9e1e54f238b1c50fee8ff7011218.diff LOG: [mlir][linalg] Add docstring support for named op spec Depends on D94335 Reviewed By: nicolasvasilache, hanchung Differential Revision: https://reviews.llvm.org/D94548 Added: Modified: mlir/docs/Dialects/Linalg.md mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp Removed: ################################################################################ diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md index 1f8ef3c4021b..a5caabd212b4 100644 --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -608,10 +608,18 @@ semantics: perform multiple updates. 2. Each tensor may only be used with a single indexing expression. +A `"""`-wrapped doc string can be attached to the named op. It should contain a +oneliner for summary first, followed by lengthy description. + The following specification may be used to define a named `batchmatmul` op: ``` -def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { +def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) +"""Batch matrix-multiply operation. + +This operation performs batch matrix-multiply over ... +""" +{ C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n))); } ``` diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc index 1ce2d2ac9418..226a09669b1c 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -125,3 +125,22 @@ def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) O(n, h, w, f) = std_addf<kh, kw>(std_mulf( I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c))); } + +// ODS-LABEL: def Test6Op +// ODS: let summary = [{ My magic op. }]; +// ODS-NEXT: let description = [{ +// ODS-NEXT: It has two inputs. +// ODS-NEXT: It has one output. +// ODS-NEXT: }]; +// +ods_def<Test6Op>: +def test6(A: f32(M, K), B: f32(K)) -> (C: f32(M)) +""" +My magic op. + +It has two inputs. +It has one output. +""" +{ + C(m) = std_addf<k>(std_mulf(A(m, k), B(k))); +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index cb7bfd2c9c4d..f4b7f9f9323a 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" @@ -85,6 +86,7 @@ class Token { // Tokens with no info. colon, comma, + doc_str, equal, gt, l_brace, @@ -183,6 +185,9 @@ class Lexer { // Lex an integer. Token lexInteger(const char *tokStart); + // Lex a string. + Token lexString(const char *tokStart); + // Skip a comment line, starting with a '//'. void skipComment(); @@ -287,6 +292,8 @@ Token Lexer::lexToken() { return formToken(Token::Kind::star, tokStart); case '?': return formToken(Token::Kind::question, tokStart); + case '"': + return lexString(tokStart); case '/': if (*curPtr == '/') { skipComment(); @@ -333,6 +340,36 @@ Token Lexer::lexInteger(const char *tokStart) { return Token(Token::Kind::integer, str); } +Token Lexer::lexString(const char *tokStart) { + assert(curPtr[-1] == '"'); + + if (*curPtr == '"' && *(curPtr + 1) == '"') { + curPtr += 2; + while (true) { + switch (*curPtr++) { + case '"': + if (*curPtr == '"' && *(curPtr + 1) == '"') { + Token token(Token::Kind::doc_str, + StringRef(tokStart + 3, curPtr - tokStart - 4)); + curPtr += 2; + return token; + } + continue; + case 0: + // If this is a random nul character in the middle of the doc string, + // just include it. If it is the end of file, then it is an error. + if (curPtr - 1 != curBuffer.end()) + continue; + return emitError(curPtr - 1, "expected '\"\"\"' to end doc string"); + default: + continue; + } + } + } + + return emitError(curPtr - 1, "expected '\"\"\"' to start doc string"); +} + /// Skip a comment line, starting with a '//'. void Lexer::skipComment() { // Advance over the second '/' in a '//' comment. @@ -1134,6 +1171,8 @@ class TCParser { /// Attributes are per TC def. std::map<std::string, RegisteredAttr> registeredAttrs; + StringRef docString; + Parser &parser; }; } // namespace @@ -1655,6 +1694,14 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { return failure(); } + // Parse optional doc string + if (parser.curToken.is(Token::Kind::doc_str)) { + docString = parser.curToken.getSpelling(); + parser.consumeToken(); + LLVM_DEBUG(llvm::dbgs() + << "parsed doc string: '''" << docString << "'''\n"); + } + // Since we don't declare symbols separately, we discover them eagerly: each // newly encountered id in a tensor shape expression is treated as a new // symbolic. At this point, all tensors have been parsed and all the symbols @@ -1755,9 +1802,10 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, AttrSizedOperandSegments, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, SingleBlockImplicitTerminator<"YieldOp">]> { + {2} let arguments = (ins Variadic<AnyShaped>:$inputs, - Variadic<AnyShaped>:$outputs{4} + Variadic<AnyShaped>:$outputs{3} ); let results = (outs Variadic<AnyRankedTensor>:$result_tensors); let regions = (region AnyRegion:$region); @@ -1818,23 +1866,30 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; } // Generic methods. - static unsigned getNumRegionArgs() {{ return {5}; } + static unsigned getNumRegionArgs() {{ return {4}; } std::string getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } }]; })FMT"; - unsigned nInputs = 0, nOutputs = 0; - for (auto &t : registeredTensors) { - if (t.getValue().isOutput) - nOutputs++; - else - nInputs++; + std::string doc; + + if (!docString.empty()) { + const char *docFmt = R"FMT( + let summary = [{ {0} }]; + let description = [{ + {1} + }]; + )FMT"; + + StringRef summary, description; + std::tie(summary, description) = docString.trim().split('\n'); + doc = llvm::formatv(docFmt, summary.trim(), description.trim()); } - os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs, - attrList, state.orderedTensorArgs.size()); + os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList, + state.orderedTensorArgs.size()); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`. _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits