junaire updated this revision to Diff 435781. junaire added a comment. Restore PTU in Interpreter::Undo, this should fix the broken tests.
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D126682/new/ https://reviews.llvm.org/D126682 Files: clang/include/clang/Interpreter/Interpreter.h clang/lib/Interpreter/IncrementalExecutor.cpp clang/lib/Interpreter/IncrementalExecutor.h clang/lib/Interpreter/IncrementalParser.cpp clang/lib/Interpreter/IncrementalParser.h clang/lib/Interpreter/Interpreter.cpp clang/tools/clang-repl/ClangRepl.cpp clang/unittests/Interpreter/InterpreterTest.cpp
Index: clang/unittests/Interpreter/InterpreterTest.cpp =================================================================== --- clang/unittests/Interpreter/InterpreterTest.cpp +++ clang/unittests/Interpreter/InterpreterTest.cpp @@ -248,4 +248,38 @@ EXPECT_EQ(42, fn(NewA)); } +TEST(InterpreterTest, UndoBasic) { + Args ExtraArgs = {"-Xclang", "-diagnostic-log-file", "-Xclang", "-"}; + + // Create the diagnostic engine with unowned consumer. + std::string DiagnosticOutput; + llvm::raw_string_ostream DiagnosticsOS(DiagnosticOutput); + auto DiagPrinter = std::make_unique<TextDiagnosticPrinter>( + DiagnosticsOS, new DiagnosticOptions()); + + auto Interp = createInterpreter(ExtraArgs, DiagPrinter.get()); + + auto R1 = Interp->Parse("int x = 42;"); + EXPECT_TRUE(!!R1); + + llvm::cantFail(Interp->Undo()); + + auto R2 = Interp->Parse("int x = 24;"); + EXPECT_TRUE(!!R2); + + auto R3 = Interp->Parse("int foo() { return 42;}"); + EXPECT_TRUE(!!R3); + + auto R4 = Interp->Parse("int bar = foo();"); + EXPECT_TRUE(!!R4); + + llvm::cantFail(Interp->Undo()); + + auto R5 = Interp->Parse("int x = 24;"); + EXPECT_TRUE(!!R5); + + auto R6 = Interp->Parse("int baz = foo();"); + EXPECT_TRUE(!!R6); +} + } // end anonymous namespace Index: clang/tools/clang-repl/ClangRepl.cpp =================================================================== --- clang/tools/clang-repl/ClangRepl.cpp +++ clang/tools/clang-repl/ClangRepl.cpp @@ -92,8 +92,13 @@ llvm::LineEditor LE("clang-repl"); // FIXME: Add LE.setListCompleter while (llvm::Optional<std::string> Line = LE.readLine()) { - if (*Line == "quit") + if (*Line == R"(%quit)") break; + if (*Line == R"(%undo)") { + llvm::cantFail(Interp->Undo()); + continue; + } + if (auto Err = Interp->ParseAndExecute(*Line)) llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: "); } Index: clang/lib/Interpreter/Interpreter.cpp =================================================================== --- clang/lib/Interpreter/Interpreter.cpp +++ clang/lib/Interpreter/Interpreter.cpp @@ -218,8 +218,7 @@ if (Err) return Err; } - // FIXME: Add a callback to retain the llvm::Module once the JIT is done. - if (auto Err = IncrExecutor->addModule(std::move(T.TheModule))) + if (auto Err = IncrExecutor->addModule(T)) return Err; if (auto Err = IncrExecutor->runCtors()) @@ -228,6 +227,10 @@ return llvm::Error::success(); } +void Interpreter::Restore(PartialTranslationUnit &PTU) { + IncrParser->Restore(PTU); +} + llvm::Expected<llvm::JITTargetAddress> Interpreter::getSymbolAddress(GlobalDecl GD) const { if (!IncrExecutor) @@ -257,3 +260,18 @@ return IncrExecutor->getSymbolAddress(Name, IncrementalExecutor::LinkerName); } + +llvm::Error Interpreter::Undo(unsigned N) { + auto &PTUs = IncrParser->getPTUs(); + if (N > PTUs.size()) + return llvm::make_error<llvm::StringError>("Operation failed, " + "Can't undo this much", + std::error_code()); + for (unsigned I = 0; I < N; I++) { + if (llvm::Error Err = IncrExecutor->removeModule(PTUs.back())) + return Err; + Restore(PTUs.back()); + PTUs.pop_back(); + } + return llvm::Error::success(); +} Index: clang/lib/Interpreter/IncrementalParser.h =================================================================== --- clang/lib/Interpreter/IncrementalParser.h +++ clang/lib/Interpreter/IncrementalParser.h @@ -72,6 +72,10 @@ ///\returns the mangled name of a \c GD. llvm::StringRef GetMangledName(GlobalDecl GD) const; + void Restore(PartialTranslationUnit &PTU); + + std::list<PartialTranslationUnit> &getPTUs() { return PTUs; } + private: llvm::Expected<PartialTranslationUnit &> ParseOrWrapTopLevelDecl(); }; Index: clang/lib/Interpreter/IncrementalParser.cpp =================================================================== --- clang/lib/Interpreter/IncrementalParser.cpp +++ clang/lib/Interpreter/IncrementalParser.cpp @@ -293,6 +293,24 @@ return PTU; } +void IncrementalParser::Restore(PartialTranslationUnit &PTU) { + TranslationUnitDecl *MostRecentTU = PTU.TUPart; + TranslationUnitDecl *FirstTU = MostRecentTU->getFirstDecl(); + if (StoredDeclsMap *Map = FirstTU->getLookupPtr()) { + for (auto I = Map->begin(); I != Map->end(); ++I) { + StoredDeclsList &List = I->second; + DeclContextLookupResult R = List.getLookupResult(); + for (NamedDecl *D : R) { + if (D->getTranslationUnitDecl() == MostRecentTU) { + List.remove(D); + } + } + if (List.isNull()) + Map->erase(I); + } + } +} + llvm::StringRef IncrementalParser::GetMangledName(GlobalDecl GD) const { CodeGenerator *CG = getCodeGen(Act.get()); assert(CG); Index: clang/lib/Interpreter/IncrementalExecutor.h =================================================================== --- clang/lib/Interpreter/IncrementalExecutor.h +++ clang/lib/Interpreter/IncrementalExecutor.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_LIB_INTERPRETER_INCREMENTALEXECUTOR_H #define LLVM_CLANG_LIB_INTERPRETER_INCREMENTALEXECUTOR_H +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" @@ -29,11 +30,17 @@ } // namespace llvm namespace clang { + +struct PartialTranslationUnit; + class IncrementalExecutor { using CtorDtorIterator = llvm::orc::CtorDtorIterator; std::unique_ptr<llvm::orc::LLJIT> Jit; llvm::orc::ThreadSafeContext &TSCtx; + llvm::DenseMap<const PartialTranslationUnit *, llvm::orc::ResourceTrackerSP> + ResourceTrackers; + public: enum SymbolNameKind { IRName, LinkerName }; @@ -41,7 +48,8 @@ const llvm::Triple &Triple); ~IncrementalExecutor(); - llvm::Error addModule(std::unique_ptr<llvm::Module> M); + llvm::Error addModule(PartialTranslationUnit &PTU); + llvm::Error removeModule(PartialTranslationUnit &PTU); llvm::Error runCtors() const; llvm::Expected<llvm::JITTargetAddress> getSymbolAddress(llvm::StringRef Name, SymbolNameKind NameKind) const; Index: clang/lib/Interpreter/IncrementalExecutor.cpp =================================================================== --- clang/lib/Interpreter/IncrementalExecutor.cpp +++ clang/lib/Interpreter/IncrementalExecutor.cpp @@ -12,6 +12,7 @@ #include "IncrementalExecutor.h" +#include "clang/Interpreter/PartialTranslationUnit.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" @@ -52,8 +53,24 @@ IncrementalExecutor::~IncrementalExecutor() {} -llvm::Error IncrementalExecutor::addModule(std::unique_ptr<llvm::Module> M) { - return Jit->addIRModule(llvm::orc::ThreadSafeModule(std::move(M), TSCtx)); +llvm::Error IncrementalExecutor::addModule(PartialTranslationUnit &PTU) { + llvm::orc::ResourceTrackerSP RT = + Jit->getMainJITDylib().createResourceTracker(); + ResourceTrackers[&PTU] = RT; + + return Jit->addIRModule(RT, {std::move(PTU.TheModule), TSCtx}); +} + +llvm::Error IncrementalExecutor::removeModule(PartialTranslationUnit &PTU) { + + llvm::orc::ResourceTrackerSP RT = std::move(ResourceTrackers[&PTU]); + if (!RT) + return llvm::Error::success(); + + ResourceTrackers.erase(&PTU); + if (llvm::Error Err = RT->remove()) + return Err; + return llvm::Error::success(); } llvm::Error IncrementalExecutor::runCtors() const { Index: clang/include/clang/Interpreter/Interpreter.h =================================================================== --- clang/include/clang/Interpreter/Interpreter.h +++ clang/include/clang/Interpreter/Interpreter.h @@ -69,6 +69,12 @@ return llvm::Error::success(); } + void Restore(PartialTranslationUnit &PTU); + + /// Undo previous parse results for N times. It'll stop and report an error + /// if an error occurs. + llvm::Error Undo(unsigned N = 1); + /// \returns the \c JITTargetAddress of a \c GlobalDecl. This interface uses /// the CodeGenModule's internal mangling cache to avoid recomputing the /// mangled name.
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits