================ @@ -821,6 +826,152 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP, "AAAMDFlatWorkGroupSize is only valid for function position"); } +struct TupleDecIntegerRangeState : public AbstractState { + DecIntegerState<uint32_t> X, Y, Z; + + bool isValidState() const override { + return X.isValidState() && Y.isValidState() && Z.isValidState(); + } + + bool isAtFixpoint() const override { + return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint(); + } + + ChangeStatus indicateOptimisticFixpoint() override { + return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() | + Z.indicateOptimisticFixpoint(); + } + + ChangeStatus indicatePessimisticFixpoint() override { + return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() | + Z.indicatePessimisticFixpoint(); + } + + TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) { + X ^= Other.X; + Y ^= Other.Y; + Z ^= Other.Z; + return *this; + } + + bool operator==(const TupleDecIntegerRangeState &Other) const { + return X == Other.X && Y == Other.Y && Z == Other.Z; + } + + TupleDecIntegerRangeState &getAssumed() { return *this; } + const TupleDecIntegerRangeState &getAssumed() const { return *this; } +}; + +using AAAMDMaxNumWorkgroupsState = + StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>; + +/// Propagate amdgpu-max-num-workgroups attribute. +struct AAAMDMaxNumWorkgroups + : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> { + using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>; + + AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + void initialize(Attributor &A) override { + Function *F = getAssociatedFunction(); + auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache()); + + SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F); + + // FIXME: What is the interpretation of 0? + for (unsigned &Entry : MaxNumWorkgroups) { + if (Entry == 0) + Entry = std::numeric_limits<uint32_t>::max(); + } + + X.takeKnownMinimum(MaxNumWorkgroups[0]); + Y.takeKnownMinimum(MaxNumWorkgroups[1]); + Z.takeKnownMinimum(MaxNumWorkgroups[2]); + + if (AMDGPU::isEntryFunctionCC(F->getCallingConv())) + indicatePessimisticFixpoint(); + } + + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + auto CheckCallSite = [&](AbstractCallSite CS) { + Function *Caller = CS.getInstruction()->getFunction(); + LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName() + << "->" << getAssociatedFunction()->getName() << '\n'); + + const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>( + *this, IRPosition::function(*Caller), DepClassTy::REQUIRED); + if (!CallerInfo || !CallerInfo->isValidState()) + return false; + + Change |= + clampStateAndIndicateChange(this->getState(), CallerInfo->getState()); + return true; + }; + + bool AllCallSitesKnown = true; + if (!A.checkForAllCallSites(CheckCallSite, *this, + /*RequireAllCallSites=*/true, + AllCallSitesKnown)) + return indicatePessimisticFixpoint(); + + return Change; + } + + /// Create an abstract attribute view for the position \p IRP. + static AAAMDMaxNumWorkgroups &createForPosition(const IRPosition &IRP, + Attributor &A); + + ChangeStatus manifest(Attributor &A) override { + Function *F = getAssociatedFunction(); + // TODO: Skip adding if worst case? ---------------- arsenm wrote:
Yes, uint32_max x 3 https://github.com/llvm/llvm-project/pull/113018 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits