================ @@ -1145,31 +1169,71 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute { auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache()); ChangeStatus Change = ChangeStatus::UNCHANGED; + Function *F = getAssociatedFunction(); + + const auto *AAFlatWorkGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>( + *this, IRPosition::function(*F), DepClassTy::REQUIRED); + if (!AAFlatWorkGroupSize || !AAFlatWorkGroupSize->isValidState()) { + LLVM_DEBUG( + dbgs() << '[' << getName() + << "] AAAMDFlatWorkGroupSize is unavailable or invalid.\n"); + return ChangeStatus::UNCHANGED; + } + + if (AAFlatWorkGroupSize->isAtInitialState()) { + LLVM_DEBUG(dbgs() << '[' << getName() + << "] AAAMDFlatWorkGroupSize is still at initial " + "state. Skip the update.\n"); + return ChangeStatus::UNCHANGED; + } + + auto CurrentWorkGroupSize = std::make_pair( + AAFlatWorkGroupSize->getAssumed().getLower().getZExtValue(), + AAFlatWorkGroupSize->getAssumed().getUpper().getZExtValue() - 1); + + auto DoUpdate = [&](std::pair<unsigned, unsigned> WavesPerEU, + std::pair<unsigned, unsigned> FlatWorkGroupSize) { + auto [Min, Max] = + InfoCache.getEffectiveWavesPerEU(*F, WavesPerEU, FlatWorkGroupSize); + ConstantRange CR(APInt(32, Min), APInt(32, Max + 1)); + IntegerRangeState IRS(CR); + Change |= clampStateAndIndicateChange(this->getState(), IRS); + }; + + // // We need to clamp once if we are not at initial state, because + // // AAAMDFlatWorkGroupSize could be updated in last iteration. + if (!isAtInitialState()) { + auto CurrentWavesPerEU = + std::make_pair(getAssumed().getLower().getZExtValue(), + getAssumed().getUpper().getZExtValue() - 1); + DoUpdate(CurrentWavesPerEU, CurrentWorkGroupSize); + } + auto CheckCallSite = [&](AbstractCallSite CS) { Function *Caller = CS.getInstruction()->getFunction(); - Function *Func = getAssociatedFunction(); + LLVM_DEBUG(dbgs() << '[' << getName() << "] Call " << Caller->getName() - << "->" << Func->getName() << '\n'); + << "->" << F->getName() << '\n'); - const auto *CallerInfo = A.getAAFor<AAAMDWavesPerEU>( + const auto *AAWavesPerEU = A.getAAFor<AAAMDWavesPerEU>( *this, IRPosition::function(*Caller), DepClassTy::REQUIRED); - const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>( - *this, IRPosition::function(*Func), DepClassTy::REQUIRED); - if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState() || - !AssumedGroupSize->isValidState()) + if (!AAWavesPerEU || !AAWavesPerEU->isValidState()) { + LLVM_DEBUG(dbgs() << '[' << getName() << "] Caller " + << Caller->getName() + << " is unavailable or invalid.\n"); return false; + } + if (AAWavesPerEU->isAtInitialState()) { ---------------- arsenm wrote:
Same as above https://github.com/llvm/llvm-project/pull/114726 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits