https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/67917
>From b5d134c88a04c524b1d9120a1c1a5dae3722904c Mon Sep 17 00:00:00 2001 From: Yingwei Zheng <dtcxzyw2...@gmail.com> Date: Sun, 1 Oct 2023 22:17:35 +0800 Subject: [PATCH 1/2] [ConstantRange] Handle `Intrinsic::cttz` and `Intrinsic::ctpop` --- llvm/include/llvm/IR/ConstantRange.h | 7 + llvm/lib/IR/ConstantRange.cpp | 127 ++++++++++++++++++ .../CorrelatedValuePropagation/range.ll | 54 ++++++++ llvm/unittests/IR/ConstantRangeTest.cpp | 20 +++ 4 files changed, 208 insertions(+) diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h index ca36732e4e2e8c2..e718e6e7e3403de 100644 --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -530,6 +530,13 @@ class [[nodiscard]] ConstantRange { /// ignoring a possible zero value contained in the input range. ConstantRange ctlz(bool ZeroIsPoison = false) const; + /// Calculate cttz range. If \p ZeroIsPoison is set, the range is computed + /// ignoring a possible zero value contained in the input range. + ConstantRange cttz(bool ZeroIsPoison = false) const; + + /// Calculate ctpop range. + ConstantRange ctpop() const; + /// Represents whether an operation on the given constant range is known to /// always or never overflow. enum class OverflowResult { diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index 3d71b20f7e853e0..f34a2749543c321 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -949,6 +949,8 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) { case Intrinsic::smax: case Intrinsic::abs: case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::ctpop: return true; default: return false; @@ -986,6 +988,15 @@ ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID, assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean"); return Ops[0].ctlz(ZeroIsPoison->getBoolValue()); } + case Intrinsic::cttz: { + const APInt *ZeroIsPoison = Ops[1].getSingleElement(); + assert(ZeroIsPoison && "Must be known (immarg)"); + assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean"); + return Ops[0].cttz(ZeroIsPoison->getBoolValue()); + } + case Intrinsic::ctpop: { + return Ops[0].ctpop(); + } default: assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported"); llvm_unreachable("Unsupported intrinsic"); @@ -1735,6 +1746,122 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const { return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()), APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1)); } +static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower, + const APInt &Upper) { + assert(Lower.ule(Upper)); + unsigned BitWidth = Lower.getBitWidth(); + if (Lower == Upper) + return ConstantRange::getEmpty(BitWidth); + if (Lower + 1 == Upper) + return ConstantRange(APInt(BitWidth, Lower.countr_zero())); + if (Lower.isZero()) + return ConstantRange(APInt::getZero(BitWidth), + APInt(BitWidth, BitWidth + 1)); + + // Calculate longest common prefix. + unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero(); + // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero(). + // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}). + return ConstantRange( + APInt::getZero(BitWidth), + APInt(BitWidth, std::max(BitWidth - LCPLength, Lower.countr_zero() + 1))); +} + +ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const { + if (isEmptySet()) + return getEmpty(); + + APInt Zero = APInt::getZero(getBitWidth()); + + if (ZeroIsPoison && contains(Zero)) { + // ZeroIsPoison is set, and zero is contained. We discern three cases, in + // which a zero can appear: + // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc. + // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc. + // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc. + + if (getLower().isZero()) { + if ((getUpper() - 1).isZero()) { + // We have in input interval of kind [0, 1). In this case we cannot + // really help but return empty-set. + return getEmpty(); + } + + // Compute the resulting range by excluding zero from Lower. + return getUnsignedCountTrailingZerosRange(getLower() + 1, getUpper()); + } else if ((getUpper() - 1).isZero()) { + // Compute the resulting range by excluding zero from Upper. + return ConstantRange( + Zero, APInt(getBitWidth(), + (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + } else { + ConstantRange CR1( + Zero, APInt(getBitWidth(), + (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + ConstantRange CR2 = getUnsignedCountTrailingZerosRange( + APInt(getBitWidth(), 1), getUpper()); + return CR1.unionWith(CR2); + } + } + + if (isFullSet()) { + return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1)); + } + if (!isUpperWrapped()) { + return getUnsignedCountTrailingZerosRange(getLower(), getUpper()); + } + ConstantRange CR1( + Zero, + APInt(getBitWidth(), (getUnsignedMax() - getLower() + 1).logBase2() + 1)); + ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper()); + return CR1.unionWith(CR2); +} + +static ConstantRange getUnsignedPopCountRange(const APInt &Lower, + const APInt &Upper) { + assert(Lower.ule(Upper)); + unsigned BitWidth = Lower.getBitWidth(); + if (Lower == Upper) + return ConstantRange::getEmpty(BitWidth); + if (Lower + 1 == Upper) + return ConstantRange(APInt(BitWidth, Lower.popcount())); + + APInt Max = Upper - 1; + // Calculate longest common prefix. + unsigned LCPLength = (Lower ^ Max).countl_zero(); + unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount(); + // If Lower is {LCP, 000...}, the minimum is the popcount of LCP. + // Otherwise, the minimum is the popcount of LCP + 1. + unsigned MinBits = + LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0); + // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth - + // length of LCP). + // Otherwise, the minimum is the popcount of LCP + (BitWidth - + // length of LCP - 1). + unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) + + (Max.countr_one() >= BitWidth - LCPLength ? 1 : 0); + return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits)); +} + +ConstantRange ConstantRange::ctpop() const { + if (isEmptySet()) + return getEmpty(); + + unsigned BitWidth = getBitWidth(); + APInt Zero = APInt::getZero(BitWidth); + if (isFullSet()) { + return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1)); + } + if (!isUpperWrapped()) { + return getUnsignedPopCountRange(getLower(), getUpper()); + } + ConstantRange CR1 = ConstantRange( + APInt(BitWidth, + BitWidth - (getUnsignedMax() - getLower() + 1).logBase2()), + APInt(BitWidth, BitWidth + 1)); // [lower, intmax] + ConstantRange CR2 = getUnsignedPopCountRange(Zero, getUpper()); // [0, upper) + return CR1.unionWith(CR2); +} ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow( const ConstantRange &Other) const { diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll index 7e89f864c8110ee..182a0bbef255de8 100644 --- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll @@ -1010,6 +1010,60 @@ else: ret i1 %res2 } +define i1 @cttz_fold(i16 %x) { +; CHECK-LABEL: @cttz_fold( +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 +; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] +; CHECK: if: +; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) +; CHECK-NEXT: ret i1 false +; CHECK: else: +; CHECK-NEXT: [[CTTZ2:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) +; CHECK-NEXT: [[RES2:%.*]] = icmp ult i16 [[CTTZ2]], 8 +; CHECK-NEXT: ret i1 [[RES2]] +; + %cmp = icmp ult i16 %x, 256 + br i1 %cmp, label %if, label %else + +if: + %cttz = call i16 @llvm.cttz.i16(i16 %x, i1 true) + %res = icmp uge i16 %cttz, 8 + ret i1 %res + +else: + %cttz2 = call i16 @llvm.cttz.i16(i16 %x, i1 true) + %res2 = icmp ult i16 %cttz2, 8 + ret i1 %res2 +} + +define i1 @ctpop_fold(i16 %x) { +; CHECK-LABEL: @ctpop_fold( +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 +; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] +; CHECK: if: +; CHECK-NEXT: [[CTPOP:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) +; CHECK-NEXT: ret i1 true +; CHECK: else: +; CHECK-NEXT: [[CTPOP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) +; CHECK-NEXT: [[RES2:%.*]] = icmp ugt i16 [[CTPOP2]], 8 +; CHECK-NEXT: ret i1 [[RES2]] +; + %cmp = icmp ult i16 %x, 256 + br i1 %cmp, label %if, label %else + +if: + %ctpop = call i16 @llvm.ctpop.i16(i16 %x) + %res = icmp ule i16 %ctpop, 8 + ret i1 %res + +else: + %ctpop2 = call i16 @llvm.ctpop.i16(i16 %x) + %res2 = icmp ugt i16 %ctpop2, 8 + ret i1 %res2 +} + declare i16 @llvm.ctlz.i16(i16, i1) +declare i16 @llvm.cttz.i16(i16, i1) +declare i16 @llvm.ctpop.i16(i16) declare i16 @llvm.abs.i16(i16, i1) declare void @llvm.assume(i1) diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp index 1cb358a26062ca5..e505af5d3275ef2 100644 --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -2438,6 +2438,26 @@ TEST_F(ConstantRangeTest, Ctlz) { }); } +TEST_F(ConstantRangeTest, Cttz) { + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.cttz(); }, + [](const APInt &N) { return APInt(N.getBitWidth(), N.countr_zero()); }); + + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.cttz(/*ZeroIsPoison=*/true); }, + [](const APInt &N) -> std::optional<APInt> { + if (N.isZero()) + return std::nullopt; + return APInt(N.getBitWidth(), N.countr_zero()); + }); +} + +TEST_F(ConstantRangeTest, Ctpop) { + TestUnaryOpExhaustive( + [](const ConstantRange &CR) { return CR.ctpop(); }, + [](const APInt &N) { return APInt(N.getBitWidth(), N.popcount()); }); +} + TEST_F(ConstantRangeTest, castOps) { ConstantRange A(APInt(16, 66), APInt(16, 128)); ConstantRange FpToI8 = A.castOp(Instruction::FPToSI, 8); >From 96f416fa255e78c3536ad10d1a004ebdb160c964 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng <dtcxzyw2...@gmail.com> Date: Thu, 5 Oct 2023 20:46:50 +0800 Subject: [PATCH 2/2] [ConstantRange] Handle `Intrinsic::cttz` --- llvm/include/llvm/IR/ConstantRange.h | 3 - llvm/lib/IR/ConstantRange.cpp | 50 ---------------- .../CorrelatedValuePropagation/range.ll | 60 +------------------ llvm/unittests/IR/ConstantRangeTest.cpp | 6 -- 4 files changed, 3 insertions(+), 116 deletions(-) diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h index e718e6e7e3403de..efbbb6c99bff792 100644 --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -534,9 +534,6 @@ class [[nodiscard]] ConstantRange { /// ignoring a possible zero value contained in the input range. ConstantRange cttz(bool ZeroIsPoison = false) const; - /// Calculate ctpop range. - ConstantRange ctpop() const; - /// Represents whether an operation on the given constant range is known to /// always or never overflow. enum class OverflowResult { diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index f34a2749543c321..8586dea7fc324ce 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -950,7 +950,6 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) { case Intrinsic::abs: case Intrinsic::ctlz: case Intrinsic::cttz: - case Intrinsic::ctpop: return true; default: return false; @@ -994,9 +993,6 @@ ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID, assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean"); return Ops[0].cttz(ZeroIsPoison->getBoolValue()); } - case Intrinsic::ctpop: { - return Ops[0].ctpop(); - } default: assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported"); llvm_unreachable("Unsupported intrinsic"); @@ -1817,52 +1813,6 @@ ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const { return CR1.unionWith(CR2); } -static ConstantRange getUnsignedPopCountRange(const APInt &Lower, - const APInt &Upper) { - assert(Lower.ule(Upper)); - unsigned BitWidth = Lower.getBitWidth(); - if (Lower == Upper) - return ConstantRange::getEmpty(BitWidth); - if (Lower + 1 == Upper) - return ConstantRange(APInt(BitWidth, Lower.popcount())); - - APInt Max = Upper - 1; - // Calculate longest common prefix. - unsigned LCPLength = (Lower ^ Max).countl_zero(); - unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount(); - // If Lower is {LCP, 000...}, the minimum is the popcount of LCP. - // Otherwise, the minimum is the popcount of LCP + 1. - unsigned MinBits = - LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0); - // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth - - // length of LCP). - // Otherwise, the minimum is the popcount of LCP + (BitWidth - - // length of LCP - 1). - unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) + - (Max.countr_one() >= BitWidth - LCPLength ? 1 : 0); - return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits)); -} - -ConstantRange ConstantRange::ctpop() const { - if (isEmptySet()) - return getEmpty(); - - unsigned BitWidth = getBitWidth(); - APInt Zero = APInt::getZero(BitWidth); - if (isFullSet()) { - return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1)); - } - if (!isUpperWrapped()) { - return getUnsignedPopCountRange(getLower(), getUpper()); - } - ConstantRange CR1 = ConstantRange( - APInt(BitWidth, - BitWidth - (getUnsignedMax() - getLower() + 1).logBase2()), - APInt(BitWidth, BitWidth + 1)); // [lower, intmax] - ConstantRange CR2 = getUnsignedPopCountRange(Zero, getUpper()); // [0, upper) - return CR1.unionWith(CR2); -} - ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow( const ConstantRange &Other) const { if (isEmptySet() || Other.isEmptySet()) diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll index 9b65136fdef748e..44cf914c981e3fb 100644 --- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll @@ -1016,8 +1016,7 @@ define i1 @cttz_fold(i16 %x) { ; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] ; CHECK: if: ; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) -; CHECK-NEXT: [[RES:%.*]] = icmp uge i16 [[CTTZ]], 8 -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; CHECK: else: ; CHECK-NEXT: ret i1 false ; @@ -1033,14 +1032,13 @@ else: ret i1 false } -define i1 @cttz_nofold(i16 %x) { +define i1 @cttz_nofold1(i16 %x) { ; CHECK-LABEL: @cttz_nofold( ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 ; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] ; CHECK: if: ; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) -; CHECK-NEXT: [[RES:%.*]] = icmp uge i16 [[CTTZ]], 9 -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; CHECK: else: ; CHECK-NEXT: ret i1 false ; @@ -1102,58 +1100,6 @@ else: ret i1 true } -define i1 @cttz_fold(i16 %x) { -; CHECK-LABEL: @cttz_fold( -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 -; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] -; CHECK: if: -; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) -; CHECK-NEXT: ret i1 false -; CHECK: else: -; CHECK-NEXT: [[CTTZ2:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true) -; CHECK-NEXT: [[RES2:%.*]] = icmp ult i16 [[CTTZ2]], 8 -; CHECK-NEXT: ret i1 [[RES2]] -; - %cmp = icmp ult i16 %x, 256 - br i1 %cmp, label %if, label %else - -if: - %cttz = call i16 @llvm.cttz.i16(i16 %x, i1 true) - %res = icmp uge i16 %cttz, 8 - ret i1 %res - -else: - %cttz2 = call i16 @llvm.cttz.i16(i16 %x, i1 true) - %res2 = icmp ult i16 %cttz2, 8 - ret i1 %res2 -} - -define i1 @ctpop_fold(i16 %x) { -; CHECK-LABEL: @ctpop_fold( -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256 -; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]] -; CHECK: if: -; CHECK-NEXT: [[CTPOP:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) -; CHECK-NEXT: ret i1 true -; CHECK: else: -; CHECK-NEXT: [[CTPOP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]]) -; CHECK-NEXT: [[RES2:%.*]] = icmp ugt i16 [[CTPOP2]], 8 -; CHECK-NEXT: ret i1 [[RES2]] -; - %cmp = icmp ult i16 %x, 256 - br i1 %cmp, label %if, label %else - -if: - %ctpop = call i16 @llvm.ctpop.i16(i16 %x) - %res = icmp ule i16 %ctpop, 8 - ret i1 %res - -else: - %ctpop2 = call i16 @llvm.ctpop.i16(i16 %x) - %res2 = icmp ugt i16 %ctpop2, 8 - ret i1 %res2 -} - declare i16 @llvm.ctlz.i16(i16, i1) declare i16 @llvm.cttz.i16(i16, i1) declare i16 @llvm.ctpop.i16(i16) diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp index e505af5d3275ef2..8cdcd989f24e4f9 100644 --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -2452,12 +2452,6 @@ TEST_F(ConstantRangeTest, Cttz) { }); } -TEST_F(ConstantRangeTest, Ctpop) { - TestUnaryOpExhaustive( - [](const ConstantRange &CR) { return CR.ctpop(); }, - [](const APInt &N) { return APInt(N.getBitWidth(), N.popcount()); }); -} - TEST_F(ConstantRangeTest, castOps) { ConstantRange A(APInt(16, 66), APInt(16, 128)); ConstantRange FpToI8 = A.castOp(Instruction::FPToSI, 8); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits