https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/116172
The following functionality is duplicated in multiple places: trying to parse an APFloat from a floating point literal or an integer in hexadecimal representation (bit pattern). Move it to a common helper function. NFC apart from the slightly changed error messages. Depends on #116171. >From 51530aeea8c18804034881c87236d1ab5ceb274f Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Thu, 14 Nov 2024 07:43:08 +0100 Subject: [PATCH] [mlir][Parser] Deduplicate fp parsing functionality --- mlir/lib/AsmParser/AsmParserImpl.h | 33 ++------- mlir/lib/AsmParser/AttributeParser.cpp | 71 ++++---------------- mlir/lib/AsmParser/Parser.cpp | 23 +++++++ mlir/lib/AsmParser/Parser.h | 6 ++ mlir/test/IR/invalid-builtin-attributes.mlir | 10 +-- 5 files changed, 56 insertions(+), 87 deletions(-) diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 1e6cbc0ec51beb..bbd70d5980f8fe 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -288,32 +288,13 @@ class AsmParserImpl : public BaseT { bool isNegative = parser.consumeIf(Token::minus); Token curTok = parser.getToken(); auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); }; - - // Check for a floating point value. - if (curTok.is(Token::floatliteral)) { - auto val = curTok.getFloatingPointValue(); - if (!val) - return emitErrorAtTok() << "floating point value too large"; - parser.consumeToken(Token::floatliteral); - result = APFloat(isNegative ? -*val : *val); - bool losesInfo; - result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo); - return success(); - } - - // Check for a hexadecimal float value. - if (curTok.is(Token::integer)) { - FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral( - emitErrorAtTok, curTok, isNegative, semantics); - if (failed(apResult)) - return failure(); - - result = *apResult; - parser.consumeToken(Token::integer); - return success(); - } - - return emitErrorAtTok() << "expected floating point literal"; + FailureOr<APFloat> apResult = + parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics); + if (failed(apResult)) + return failure(); + parser.consumeToken(); + result = *apResult; + return success(); } /// Parse a floating point value from the stream. diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index ba9be3b030453a..9ebada076cd042 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -658,36 +658,12 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; const Token &token = signAndToken.second; - - // Handle hexadecimal float literals. - if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) { - auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; - FailureOr<APFloat> result = parseFloatFromIntegerLiteral( - emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics()); - if (failed(result)) - return failure(); - - floatValues.push_back(*result); - continue; - } - - // Check to see if any decimal integers or booleans were parsed. - if (!token.is(Token::floatliteral)) - return p.emitError() - << "expected floating-point elements, but parsed integer"; - - // Build the float values from tokens. - auto val = token.getFloatingPointValue(); - if (!val) - return p.emitError("floating point value too large for attribute"); - - APFloat apVal(isNegative ? -*val : *val); - if (!eltTy.isF64()) { - bool unused; - apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - } - floatValues.push_back(apVal); + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr<APFloat> result = parseFloatFromLiteral( + emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics()); + if (failed(result)) + return failure(); + floatValues.push_back(*result); } return success(); } @@ -905,34 +881,15 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { bool isNegative = p.consumeIf(Token::minus); - Token token = p.getToken(); - std::optional<APFloat> result; - auto floatType = cast<FloatType>(type); - if (p.consumeIf(Token::integer)) { - // Parse an integer literal as a float. - auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; - FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral( - emitErrorAtTok, token, isNegative, floatType.getFloatSemantics()); - if (failed(fromIntLit)) - return failure(); - result = *fromIntLit; - } else if (p.consumeIf(Token::floatliteral)) { - // Parse a floating point literal. - std::optional<double> val = token.getFloatingPointValue(); - if (!val) - return failure(); - result = APFloat(isNegative ? -*val : *val); - if (!type.isF64()) { - bool unused; - result->convert(floatType.getFloatSemantics(), - APFloat::rmNearestTiesToEven, &unused); - } - } else { - return p.emitError("expected integer or floating point literal"); - } - - append(result->bitcastToAPInt()); + auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); }; + FailureOr<APFloat> fromIntLit = + parseFloatFromLiteral(emitErrorAtTok, token, isNegative, + cast<FloatType>(type).getFloatSemantics()); + if (failed(fromIntLit)) + return failure(); + p.consumeToken(); + append(fromIntLit->bitcastToAPInt()); return success(); } diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index ac7eec931b1250..15f3dd7a66c358 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -99,6 +99,29 @@ FailureOr<APFloat> detail::parseFloatFromIntegerLiteral( return APFloat(semantics, truncatedValue); } +FailureOr<APFloat> +detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics) { + // Check for a floating point value. + if (tok.is(Token::floatliteral)) { + auto val = tok.getFloatingPointValue(); + if (!val) + return emitError() << "floating point value too large"; + + APFloat result(isNegative ? -*val : *val); + bool unused; + result.convert(semantics, APFloat::rmNearestTiesToEven, &unused); + return result; + } + + // Check for a hexadecimal float value. + if (tok.is(Token::integer)) + return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics); + + return emitError() << "expected floating point literal"; +} + //===----------------------------------------------------------------------===// // CodeComplete //===----------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index fa29264ffe506a..ab445476a91923 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -22,6 +22,12 @@ parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics); +/// Parse a floating point value from a literal. +FailureOr<APFloat> +parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError, + const Token &tok, bool isNegative, + const llvm::fltSemantics &semantics); + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir index 431c7b12b8f5fe..5098fe751fd01f 100644 --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () { // ----- func.func @elementsattr_floattype2() -> () { - // expected-error@+1 {{expected floating-point elements, but parsed integer}} + // expected-error@below {{unexpected decimal integer literal for a floating point value}} + // expected-note@below {{add a trailing dot to make the literal a float}} "foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> () } @@ -138,21 +139,22 @@ func.func @float_in_int_tensor() { // ----- func.func @float_in_bool_tensor() { - // expected-error @+1 {{expected integer elements, but parsed floating-point}} + // expected-error@below {{expected integer elements, but parsed floating-point}} "foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> () } // ----- func.func @decimal_int_in_float_tensor() { - // expected-error @+1 {{expected floating-point elements, but parsed integer}} + // expected-error@below {{unexpected decimal integer literal for a floating point value}} + // expected-note@below {{add a trailing dot to make the literal a float}} "foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> () } // ----- func.func @bool_in_float_tensor() { - // expected-error @+1 {{expected floating-point elements, but parsed integer}} + // expected-error @+1 {{expected floating point literal}} "foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> () } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits