Author: Alexander Pivovarov Date: 2024-07-17T23:33:52-07:00 New Revision: f36331770267501e157ac34afc3ca7d7a0bfb52c
URL: https://github.com/llvm/llvm-project/commit/f36331770267501e157ac34afc3ca7d7a0bfb52c DIFF: https://github.com/llvm/llvm-project/commit/f36331770267501e157ac34afc3ca7d7a0bfb52c.diff LOG: [APFloat] Add support for f8E4M3 IEEE 754 type (#97179) This PR adds `f8E4M3` type to APFloat. `f8E4M3` type follows IEEE 754 convention ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` Related PRs: - [PR-97118](https://github.com/llvm/llvm-project/pull/97118) Add f8E4M3 IEEE 754 type to mlir Added: Modified: clang/include/clang/AST/Stmt.h clang/lib/AST/MicrosoftMangle.cpp llvm/include/llvm/ADT/APFloat.h llvm/lib/Support/APFloat.cpp llvm/unittests/ADT/APFloatTest.cpp Removed: ################################################################################ diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h index e91e89d728ca0..bbd7634bcc3bf 100644 --- a/clang/include/clang/AST/Stmt.h +++ b/clang/include/clang/AST/Stmt.h @@ -460,10 +460,10 @@ class alignas(void *) Stmt { unsigned : NumExprBits; static_assert( - llvm::APFloat::S_MaxSemantics < 16, - "Too many Semantics enum values to fit in bitfield of size 4"); + llvm::APFloat::S_MaxSemantics < 32, + "Too many Semantics enum values to fit in bitfield of size 5"); LLVM_PREFERRED_TYPE(llvm::APFloat::Semantics) - unsigned Semantics : 4; // Provides semantics for APFloat construction + unsigned Semantics : 5; // Provides semantics for APFloat construction LLVM_PREFERRED_TYPE(bool) unsigned IsExact : 1; }; diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp index fac14ce1dce8c..4016043df62ed 100644 --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -981,6 +981,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) { case APFloat::S_IEEEquad: Out << 'Y'; break; case APFloat::S_PPCDoubleDouble: Out << 'Z'; break; case APFloat::S_Float8E5M2: + case APFloat::S_Float8E4M3: case APFloat::S_Float8E4M3FN: case APFloat::S_Float8E5M2FNUZ: case APFloat::S_Float8E4M3FNUZ: diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h index db2fa480655c6..bff8e6490d1de 100644 --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -166,6 +166,9 @@ struct APFloatBase { // This format's exponent bias is 16, instead of the 15 (2 ** (5 - 1) - 1) // that IEEE precedent would imply. S_Float8E5M2FNUZ, + // 8-bit floating point number following IEEE-754 conventions with bit + // layout S1E4M3. + S_Float8E4M3, // 8-bit floating point number mostly following IEEE-754 conventions with // bit layout S1E4M3 as described in https://arxiv.org/abs/2209.05433. // Unlike IEEE-754 types, there are no infinity values, and NaN is @@ -217,6 +220,7 @@ struct APFloatBase { static const fltSemantics &PPCDoubleDouble() LLVM_READNONE; static const fltSemantics &Float8E5M2() LLVM_READNONE; static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE; + static const fltSemantics &Float8E4M3() LLVM_READNONE; static const fltSemantics &Float8E4M3FN() LLVM_READNONE; static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE; static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE; @@ -638,6 +642,7 @@ class IEEEFloat final : public APFloatBase { APInt convertPPCDoubleDoubleAPFloatToAPInt() const; APInt convertFloat8E5M2APFloatToAPInt() const; APInt convertFloat8E5M2FNUZAPFloatToAPInt() const; + APInt convertFloat8E4M3APFloatToAPInt() const; APInt convertFloat8E4M3FNAPFloatToAPInt() const; APInt convertFloat8E4M3FNUZAPFloatToAPInt() const; APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const; @@ -656,6 +661,7 @@ class IEEEFloat final : public APFloatBase { void initFromPPCDoubleDoubleAPInt(const APInt &api); void initFromFloat8E5M2APInt(const APInt &api); void initFromFloat8E5M2FNUZAPInt(const APInt &api); + void initFromFloat8E4M3APInt(const APInt &api); void initFromFloat8E4M3FNAPInt(const APInt &api); void initFromFloat8E4M3FNUZAPInt(const APInt &api); void initFromFloat8E4M3B11FNUZAPInt(const APInt &api); diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 3664de71d06df..26b4f8e55448f 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -136,6 +136,7 @@ static constexpr fltSemantics semIEEEquad = {16383, -16382, 113, 128}; static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8}; static constexpr fltSemantics semFloat8E5M2FNUZ = { 15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; +static constexpr fltSemantics semFloat8E4M3 = {7, -6, 4, 8}; static constexpr fltSemantics semFloat8E4M3FN = { 8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes}; static constexpr fltSemantics semFloat8E4M3FNUZ = { @@ -208,6 +209,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) { return Float8E5M2(); case S_Float8E5M2FNUZ: return Float8E5M2FNUZ(); + case S_Float8E4M3: + return Float8E4M3(); case S_Float8E4M3FN: return Float8E4M3FN(); case S_Float8E4M3FNUZ: @@ -246,6 +249,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { return S_Float8E5M2; else if (&Sem == &llvm::APFloat::Float8E5M2FNUZ()) return S_Float8E5M2FNUZ; + else if (&Sem == &llvm::APFloat::Float8E4M3()) + return S_Float8E4M3; else if (&Sem == &llvm::APFloat::Float8E4M3FN()) return S_Float8E4M3FN; else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ()) @@ -276,6 +281,7 @@ const fltSemantics &APFloatBase::PPCDoubleDouble() { } const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; } const fltSemantics &APFloatBase::Float8E5M2FNUZ() { return semFloat8E5M2FNUZ; } +const fltSemantics &APFloatBase::Float8E4M3() { return semFloat8E4M3; } const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; } const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; } const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { @@ -3617,6 +3623,11 @@ APInt IEEEFloat::convertFloat8E5M2FNUZAPFloatToAPInt() const { return convertIEEEFloatToAPInt<semFloat8E5M2FNUZ>(); } +APInt IEEEFloat::convertFloat8E4M3APFloatToAPInt() const { + assert(partCount() == 1); + return convertIEEEFloatToAPInt<semFloat8E4M3>(); +} + APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const { assert(partCount() == 1); return convertIEEEFloatToAPInt<semFloat8E4M3FN>(); @@ -3681,6 +3692,9 @@ APInt IEEEFloat::bitcastToAPInt() const { if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FNUZ) return convertFloat8E5M2FNUZAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3) + return convertFloat8E4M3APFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN) return convertFloat8E4M3FNAPFloatToAPInt(); @@ -3902,6 +3916,10 @@ void IEEEFloat::initFromFloat8E5M2FNUZAPInt(const APInt &api) { initFromIEEEAPInt<semFloat8E5M2FNUZ>(api); } +void IEEEFloat::initFromFloat8E4M3APInt(const APInt &api) { + initFromIEEEAPInt<semFloat8E4M3>(api); +} + void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) { initFromIEEEAPInt<semFloat8E4M3FN>(api); } @@ -3951,6 +3969,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { return initFromFloat8E5M2APInt(api); if (Sem == &semFloat8E5M2FNUZ) return initFromFloat8E5M2FNUZAPInt(api); + if (Sem == &semFloat8E4M3) + return initFromFloat8E4M3APInt(api); if (Sem == &semFloat8E4M3FN) return initFromFloat8E4M3FNAPInt(api); if (Sem == &semFloat8E4M3FNUZ) diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp index 86a25f4394e19..d50bdf4a65dcb 100644 --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -2133,6 +2133,8 @@ TEST(APFloatTest, getZero) { {&APFloat::Float8E5M2(), true, true, {0x80ULL, 0}, 1}, {&APFloat::Float8E5M2FNUZ(), false, false, {0, 0}, 1}, {&APFloat::Float8E5M2FNUZ(), true, false, {0, 0}, 1}, + {&APFloat::Float8E4M3(), false, true, {0, 0}, 1}, + {&APFloat::Float8E4M3(), true, true, {0x80ULL, 0}, 1}, {&APFloat::Float8E4M3FN(), false, true, {0, 0}, 1}, {&APFloat::Float8E4M3FN(), true, true, {0x80ULL, 0}, 1}, {&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1}, @@ -6532,6 +6534,34 @@ TEST(APFloatTest, Float8E5M2ToDouble) { EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); } +TEST(APFloatTest, Float8E4M3ToDouble) { + APFloat One(APFloat::Float8E4M3(), "1.0"); + EXPECT_EQ(1.0, One.convertToDouble()); + APFloat Two(APFloat::Float8E4M3(), "2.0"); + EXPECT_EQ(2.0, Two.convertToDouble()); + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3(), false); + EXPECT_EQ(240.0F, PosLargest.convertToDouble()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3(), true); + EXPECT_EQ(-240.0F, NegLargest.convertToDouble()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3(), false); + EXPECT_EQ(0x1.p-6, PosSmallest.convertToDouble()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3(), true); + EXPECT_EQ(-0x1.p-6, NegSmallest.convertToDouble()); + + APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToDouble()); + + APFloat PosInf = APFloat::getInf(APFloat::Float8E4M3()); + EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble()); + APFloat NegInf = APFloat::getInf(APFloat::Float8E4M3(), true); + EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble()); + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3()); + EXPECT_TRUE(std::isnan(QNaN.convertToDouble())); +} + TEST(APFloatTest, Float8E4M3FNToDouble) { APFloat One(APFloat::Float8E4M3FN(), "1.0"); EXPECT_EQ(1.0, One.convertToDouble()); @@ -6846,6 +6876,42 @@ TEST(APFloatTest, Float8E5M2ToFloat) { EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); } +TEST(APFloatTest, Float8E4M3ToFloat) { + APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3()); + APFloat PosZeroToFloat(PosZero.convertToFloat()); + EXPECT_TRUE(PosZeroToFloat.isPosZero()); + APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3(), true); + APFloat NegZeroToFloat(NegZero.convertToFloat()); + EXPECT_TRUE(NegZeroToFloat.isNegZero()); + + APFloat One(APFloat::Float8E4M3(), "1.0"); + EXPECT_EQ(1.0F, One.convertToFloat()); + APFloat Two(APFloat::Float8E4M3(), "2.0"); + EXPECT_EQ(2.0F, Two.convertToFloat()); + + APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3(), false); + EXPECT_EQ(240.0F, PosLargest.convertToFloat()); + APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3(), true); + EXPECT_EQ(-240.0F, NegLargest.convertToFloat()); + APFloat PosSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3(), false); + EXPECT_EQ(0x1.p-6, PosSmallest.convertToFloat()); + APFloat NegSmallest = + APFloat::getSmallestNormalized(APFloat::Float8E4M3(), true); + EXPECT_EQ(-0x1.p-6, NegSmallest.convertToFloat()); + + APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3(), false); + EXPECT_TRUE(SmallestDenorm.isDenormal()); + EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToFloat()); + + APFloat PosInf = APFloat::getInf(APFloat::Float8E4M3()); + EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat()); + APFloat NegInf = APFloat::getInf(APFloat::Float8E4M3(), true); + EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat()); + APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3()); + EXPECT_TRUE(std::isnan(QNaN.convertToFloat())); +} + TEST(APFloatTest, Float8E4M3FNToFloat) { APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FN()); APFloat PosZeroToFloat(PosZero.convertToFloat()); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits