Stefan =?utf-8?q?Gränitz?= <[email protected]>,
Stefan =?utf-8?q?Gränitz?= <[email protected]>,
Stefan =?utf-8?q?Gränitz?= <[email protected]>,
Stefan =?utf-8?q?Gränitz?= <[email protected]>,
Stefan =?utf-8?q?Gränitz?= <[email protected]>,
Stefan =?utf-8?q?Gränitz?= <[email protected]>,
Stefan =?utf-8?q?Gränitz?= <[email protected]>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/[email protected]>


================
@@ -0,0 +1,501 @@
+//===- tools/plugins-shlib/pypass.cpp 
-------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Passes/PassPlugin.h"
+#include "llvm/Support/DynamicLibrary.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <algorithm>
+#include <cstdlib>
+#include <filesystem>
+#include <memory>
+#include <optional>
+#include <string>
+
+using namespace llvm;
+
+static cl::opt<std::string>
+    DylibPath("pypass-dylib", cl::desc("Path to the Python shared library"),
+              cl::init(""));
+
+static cl::opt<std::string>
+    ScriptPath("pypass-script", cl::desc("Path to the Python script to run"),
+               cl::init(""));
+
+static std::string findPython() {
+  if (!DylibPath.empty())
+    return DylibPath;
+  if (const char *Path = std::getenv("LLVM_PYPASS_DYLIB"))
+    return std::string(Path);
+  // TODO: Run Python from PATH and use a script to query the shared lib
+  return std::string{};
+}
+
+static std::string findScript() {
+  if (!ScriptPath.empty())
+    return ScriptPath;
+  if (const char *Path = std::getenv("LLVM_PYPASS_SCRIPT"))
+    return std::string(Path);
+  return std::string{};
+}
+
+struct PythonAPI {
+  using Py_InitializeEx_t = void(int);
+  using Py_FinalizeEx_t = int();
+  using Py_DecRef_t = void(void *);
+  using Py_IncRef_t = void(void *);
+  using PyDict_GetItemString_t = void *(void *, const char *);
+  using PyDict_New_t = void *();
+  using PyDict_SetItemString_t = int(void *, const char *, void *);
+  using PyErr_Print_t = void();
+  using PyGILStateEnsure_t = int();
+  using PyGILStateRelease_t = void(int);
+  using PyImport_AddModule_t = void *(const char *);
+  using PyImport_ImportModule_t = void *(const char *);
+  using PyLong_FromVoidPtr_t = void *(void *);
+  using PyUnicode_FromString_t = void *(const char *);
+  using PyModule_GetDict_t = void *(void *);
+  using PyObject_CallObject_t = void *(void *, void *);
+  using PyObject_GetAttrString_t = void *(void *, const char *);
+  using PyObject_IsTrue_t = int(void *);
+  using PyTuple_SetItem_t = int(void *, long, void *);
+  using PyTuple_New_t = void *(long);
+  using PyTypeObject_t = void *;
+
+  // pylifecycle.h
+  Py_InitializeEx_t *Py_InitializeEx;
+  Py_FinalizeEx_t *Py_FinalizeEx;
+
+  // pystate.h
+  PyGILStateEnsure_t *PyGILState_Ensure;
+  PyGILStateRelease_t *PyGILState_Release;
+
+  // pythonrun.h
+  PyErr_Print_t *PyErr_Print;
+
+  // import.h
+  PyImport_AddModule_t *PyImport_AddModule;
+  PyImport_ImportModule_t *PyImport_ImportModule;
+
+  // object.h
+  PyObject_IsTrue_t *PyObject_IsTrue;
+  PyObject_GetAttrString_t *PyObject_GetAttrString;
+  Py_IncRef_t *Py_IncRef;
+  Py_DecRef_t *Py_DecRef;
+
+  // moduleobject.h
+  PyModule_GetDict_t *PyModule_GetDict;
+
+  // dictobject.h
+  PyDict_GetItemString_t *PyDict_GetItemString;
+  PyDict_SetItemString_t *PyDict_SetItemString;
+  PyDict_New_t *PyDict_New;
+
+  // abstract.h
+  PyObject_CallObject_t *PyObject_CallObject;
+
+  // longobject.h
+  PyLong_FromVoidPtr_t *PyLong_FromVoidPtr;
+
+  // unicodeobject.h
+  PyUnicode_FromString_t *PyUnicode_FromString;
+
+  // tupleobject.h
+  PyTuple_SetItem_t *PyTuple_SetItem;
+  PyTuple_New_t *PyTuple_New;
+
+  void *PyGlobals;
+  void *PyBuiltins;
+
+private:
+  PythonAPI() : Valid(false) {
+    if (!loadDylib(findPython()))
+      return;
+    if (!resolveSymbols())
+      return;
+    Py_InitializeEx(0);
+    PyBuiltins = PyImport_ImportModule("builtins");
+    PyGlobals = PyModule_GetDict(PyImport_AddModule("__main__"));
+    Valid = true;
+  }
+
+  ~PythonAPI() {
+    if (std::atomic_exchange(&Valid, false)) {
+      Py_DecRef(PyBuiltins);
+      Py_DecRef(PyGlobals);
+      Py_FinalizeEx();
+    }
+  }
+
+  bool loadDylib(std::string Path) {
+    std::string Err;
+    Dylib = sys::DynamicLibrary::getPermanentLibrary(Path.c_str(), &Err);
+    if (!Dylib.isValid()) {
+      errs() << "dlopen for '" << Path << "' failed: " << Err << "\n";
+      return false;
+    }
+
+    return true;
+  }
+
+  bool resolveSymbols() {
+    bool Success = true;
+    Success &= resolve("_Py_IncRef", &Py_IncRef);
+    Success &= resolve("_Py_DecRef", &Py_DecRef);
+    Success &= resolve("Py_InitializeEx", &Py_InitializeEx);
+    Success &= resolve("Py_FinalizeEx", &Py_FinalizeEx);
+    Success &= resolve("PyErr_Print", &PyErr_Print);
+    Success &= resolve("PyGILState_Ensure", &PyGILState_Ensure);
+    Success &= resolve("PyGILState_Release", &PyGILState_Release);
+    Success &= resolve("PyImport_AddModule", &PyImport_AddModule);
+    Success &= resolve("PyImport_ImportModule", &PyImport_ImportModule);
+    Success &= resolve("PyModule_GetDict", &PyModule_GetDict);
+    Success &= resolve("PyDict_GetItemString", &PyDict_GetItemString);
+    Success &= resolve("PyDict_SetItemString", &PyDict_SetItemString);
+    Success &= resolve("PyDict_New", &PyDict_New);
+    Success &= resolve("PyObject_CallObject", &PyObject_CallObject);
+    Success &= resolve("PyObject_GetAttrString", &PyObject_GetAttrString);
+    Success &= resolve("PyObject_IsTrue", &PyObject_IsTrue);
+    Success &= resolve("PyLong_FromVoidPtr", &PyLong_FromVoidPtr);
+    Success &= resolve("PyUnicode_FromString", &PyUnicode_FromString);
+    Success &= resolve("PyTuple_SetItem", &PyTuple_SetItem);
+    Success &= resolve("PyTuple_New", &PyTuple_New);
+    return Success;
+  }
+
+  bool importModule(const char *Name) const {
+    void *Mod = PyImport_ImportModule(Name);
+    if (!Mod) {
+      PyErr_Print();
+      return false;
+    }
+    PyDict_SetItemString(PyGlobals, Name, Mod);
+    Py_DecRef(Mod);
+    return true;
+  }
+
+  bool evaluate(std::string Code, void *Globals) const {
+    void *Exec = PyObject_GetAttrString(PyBuiltins, "exec");
+    void *Args = PyTuple_New(2);
+    if (!Args)
+      return false;
+    if (PyTuple_SetItem(Args, 0, PyUnicode_FromString(Code.c_str())))
+      return false;
+    Py_IncRef(Globals);
+    if (PyTuple_SetItem(Args, 1, Globals))
+      return false;
+
+    // Interpreter is not thread-safe
+    auto GIL = make_scope_exit(
+        [this, Lock = PyGILState_Ensure()]() { PyGILState_Release(Lock); });
+    void *Result = PyObject_CallObject(Exec, Args);
+    Py_DecRef(Args);
+
+    if (Result == nullptr) {
+      PyErr_Print();
+      return false;
+    }
+
+    return true;
+  }
+
+public:
+  static const PythonAPI &instance() {
+    static const PythonAPI PyAPI;
+    return PyAPI;
+  }
+
+  bool isValid() const { return Valid; }
+
+  bool loadScript(const std::string &Path) const {
+    if (!importModule("runpy"))
+      return false;
+    if (!evaluate("globals().update(runpy.run_path('" + Path + "'))",
+                  PyGlobals))
+      return false;
+    if (!PyDict_GetItemString(PyGlobals, "run")) {
+      errs() << "Script defines no run() function: " << Path << "\n";
+      return false;
+    }
+    return true;
+  }
+
+  bool addImportSearchPath(const std::string &Path) const {
+    if (!importModule("sys"))
+      return false;
+    return evaluate("sys.path.append('" + Path + "')", PyGlobals);
+  }
+
+  void *getFunction(std::string Name) const {
+    return PyDict_GetItemString(PyGlobals, Name.c_str());
+  }
+
+  // Run Python function with boolean result
+  bool invoke(void *Fn, void *Args = nullptr) const {
+    // Interpreter is not thread-safe
+    auto GIL = make_scope_exit(
+        [this, Lock = PyGILState_Ensure()]() { PyGILState_Release(Lock); });
+    // If we get no result, there was an error in Python
+    void *Result = PyObject_CallObject(Fn, Args);
+    if (Args)
+      Py_DecRef(Args);
+    if (!Result) {
+      errs() << "PyPassContext error: invoke failed\n";
+      return false;
+    }
+    // If the result is truthy, then it's a yes
+    return PyObject_IsTrue(Result);
+  }
+
+private:
+  sys::DynamicLibrary Dylib;
+  std::atomic<bool> Valid;
+
+  template <typename FnTy> bool resolve(const char *Name, FnTy **Var) {
+    assert(Dylib.isValid() && "dlopen shared library first");
+    assert(*Var == nullptr && "Resolve symbols once");
+    if (void *FnPtr = Dylib.getAddressOfSymbol(Name)) {
+      *Var = reinterpret_cast<FnTy *>(FnPtr);
+      return true;
+    }
+    errs() << "Missing required CPython API symbol '" << Name
+           << "' in: " << DylibPath << "\n";
+    return false;
+  };
+};
+
+class PyPassContext {
+public:
+  PyPassContext(const PythonAPI &PyAPI) : PyAPI(PyAPI) {}
+
+  bool loadScript(const std::string &Path) {
+    // Make relative paths resolve naturally in import statements
+    std::string Dir = std::filesystem::path(Path).parent_path().u8string();
+    if (!PyAPI.addImportSearchPath(Dir)) {
+      errs() << "Failed to add import search path: " << Dir << "\n";
+      return false;
+    }
+
+    return PyAPI.loadScript(Path);
+  }
+
+  bool registerEP(std::string Name) {
+    // Default is no, if the function is not defined
+    if (void *Fn = PyAPI.getFunction("register" + Name))
+      return PyAPI.invoke(Fn);
+    return false;
+  }
+
+  bool run(void *Entity, void *Ctx, const char *Stage) {
+    void *Args = PyAPI.PyTuple_New(3);
+    if (!Args)
+      return false;
+    if (PyAPI.PyTuple_SetItem(Args, 0, PyAPI.PyLong_FromVoidPtr(Entity)) != 0)
----------------
serge-sans-paille wrote:

In each error case here, you're leaking the reference to `Args` and to the 
contained elements, when those are Set.

https://github.com/llvm/llvm-project/pull/171111
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to