================ @@ -541,25 +531,83 @@ class DevelopmentModeEvictionAdvisorAnalysis final Log = std::make_unique<Logger>(std::move(OS), LFS, Reward, /*IncludeReward*/ true); - return false; + return; + } + + // support for isa<> and dyn_cast. + static bool classof(const RegAllocEvictionAdvisorProvider *R) { + return R->getAdvisorMode() == AdvisorMode::Development; + } + + void logRewardIfNeeded(const MachineFunction &MF, + llvm::function_ref<float()> GetReward) override { + if (!Log || !Log->hasAnyObservationForContext(MF.getName())) + return; + // The function pass manager would run all the function passes for a + // function, so we assume the last context belongs to this function. If + // this invariant ever changes, we can implement at that time switching + // contexts. At this point, it'd be an error + if (Log->currentContext() != MF.getName()) { + MF.getFunction().getContext().emitError( + "The training log context shouldn't have had changed."); + } + if (Log->hasObservationInProgress()) + Log->logReward<float>(GetReward()); } std::unique_ptr<RegAllocEvictionAdvisor> - getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { + getAdvisor(const MachineFunction &MF, const RAGreedy &RA, + MachineBlockFrequencyInfo *MBFI, MachineLoopInfo *Loops) override { if (!Runner) return nullptr; if (Log) Log->switchContext(MF.getName()); + assert(MBFI && Loops && + "Invalid provider state: must have analysis available"); return std::make_unique<DevelopmentModeEvictAdvisor>( - MF, RA, Runner.get(), - getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI(), - getAnalysis<MachineLoopInfoWrapperPass>().getLI(), Log.get()); + MF, RA, Runner.get(), *MBFI, *Loops, Log.get()); } +private: + std::vector<TensorSpec> InputFeatures; + std::vector<TensorSpec> TrainingInputFeatures; + std::unique_ptr<MLModelRunner> Runner; std::unique_ptr<Logger> Log; }; +class DevelopmentModeEvictionAdvisorAnalysisLegacy final + : public RegAllocEvictionAdvisorAnalysisLegacy { +public: + DevelopmentModeEvictionAdvisorAnalysisLegacy() + : RegAllocEvictionAdvisorAnalysisLegacy(AdvisorMode::Development) {} + + bool doInitialization(Module &M) override { + Provider = std::make_unique<DevelopmentModeEvictionAdvisorProvider>( + M.getContext()); + return false; + } + + void logRewardIfNeeded(const MachineFunction &MF, + llvm::function_ref<float()> GetReward) override { + Provider->logRewardIfNeeded(MF, GetReward); + } + + // support for isa<> and dyn_cast. + static bool classof(const RegAllocEvictionAdvisorAnalysisLegacy *R) { + return R->getAdvisorMode() == AdvisorMode::Development; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<MachineBlockFrequencyInfoWrapperPass>(); + AU.addRequired<MachineLoopInfoWrapperPass>(); + RegAllocEvictionAdvisorAnalysisLegacy::getAnalysisUsage(AU); + } + +private: + // std::unique_ptr<DevelopmentModeEvictionAdvisorProvider> Provider; ---------------- arsenm wrote:
```suggestion ``` Commented out code https://github.com/llvm/llvm-project/pull/117309 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits