From a2912c0fe75a35683ea7c78acddd53328d1a7b77 Mon Sep 17 00:00:00 2001
From: Roman Arzumanyan <r.arzumanyan@visionlabs.ai>
Date: Wed, 13 Sep 2023 13:25:19 +0300
Subject: [PATCH] Add cuCtxGetCurrent function

---
 README                             | 3 +++
 include/ffnvcodec/dynlink_cuda.h   | 9 +++++++++
 include/ffnvcodec/dynlink_loader.h | 2 ++
 3 files changed, 14 insertions(+)

diff --git a/README b/README
index 0dc1e99..a5d41f5 100644
--- a/README
+++ b/README
@@ -5,3 +5,6 @@ Corresponds to Video Codec SDK version 12.0.16.
 Minimum required driver versions:
 Linux: 530.41.03 or newer
 Windows: 531.61 or newer
+
+Minimum required cuda version:
+CUDA 11.4
\ No newline at end of file
diff --git a/include/ffnvcodec/dynlink_cuda.h b/include/ffnvcodec/dynlink_cuda.h
index ca474c6..690766c 100644
--- a/include/ffnvcodec/dynlink_cuda.h
+++ b/include/ffnvcodec/dynlink_cuda.h
@@ -33,6 +33,14 @@
 
 #define CUDA_VERSION 7050
 
+#define CUDAAPI_MAJOR_VERSION 11
+#define CUDAAPI_MINOR_VERSION 4
+
+#define CUDAAPI_VERSION (CUDAAPI_MAJOR_VERSION | (CUDAAPI_MINOR_VERSION << 24))
+
+#define CUDAAPI_CHECK_VERSION(major, minor) \
+    ((major) < CUDAAPI_MAJOR_VERSION || ((major) == CUDAAPI_MAJOR_VERSION && (minor) <= CUDAAPI_MINOR_VERSION))
+
 #if defined(_WIN32) || defined(__CYGWIN__)
 #define CUDAAPI __stdcall
 #else
@@ -428,6 +436,7 @@ typedef CUresult CUDAAPI tcuDeviceGetByPCIBusId(CUdevice* dev, const char* pciBu
 typedef CUresult CUDAAPI tcuDeviceGetPCIBusId(char* pciBusId, int len, CUdevice dev);
 typedef CUresult CUDAAPI tcuDeviceComputeCapability(int *major, int *minor, CUdevice dev);
 typedef CUresult CUDAAPI tcuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
+typedef CUresult CUDAAPI tcuCtxGetCurrent(CUcontext *pctx);
 typedef CUresult CUDAAPI tcuCtxSetLimit(CUlimit limit, size_t value);
 typedef CUresult CUDAAPI tcuCtxPushCurrent_v2(CUcontext pctx);
 typedef CUresult CUDAAPI tcuCtxPopCurrent_v2(CUcontext *pctx);
diff --git a/include/ffnvcodec/dynlink_loader.h b/include/ffnvcodec/dynlink_loader.h
index 2f94c07..1d3f8c5 100644
--- a/include/ffnvcodec/dynlink_loader.h
+++ b/include/ffnvcodec/dynlink_loader.h
@@ -149,6 +149,7 @@ typedef struct CudaFunctions {
     tcuDeviceGetPCIBusId *cuDeviceGetPCIBusId;
     tcuDeviceComputeCapability *cuDeviceComputeCapability;
     tcuCtxCreate_v2 *cuCtxCreate;
+    tcuCtxGetCurrent *cuCtxGetCurrent;
     tcuCtxSetLimit *cuCtxSetLimit;
     tcuCtxPushCurrent_v2 *cuCtxPushCurrent;
     tcuCtxPopCurrent_v2 *cuCtxPopCurrent;
@@ -315,6 +316,7 @@ static inline int cuda_load_functions(CudaFunctions **functions, void *logctx)
     LOAD_SYMBOL(cuDeviceGetName, tcuDeviceGetName, "cuDeviceGetName");
     LOAD_SYMBOL(cuDeviceComputeCapability, tcuDeviceComputeCapability, "cuDeviceComputeCapability");
     LOAD_SYMBOL(cuCtxCreate, tcuCtxCreate_v2, "cuCtxCreate_v2");
+    LOAD_SYMBOL(cuCtxGetCurrent, tcuCtxGetCurrent, "cuCtxGetCurrent");
     LOAD_SYMBOL(cuCtxSetLimit, tcuCtxSetLimit, "cuCtxSetLimit");
     LOAD_SYMBOL(cuCtxPushCurrent, tcuCtxPushCurrent_v2, "cuCtxPushCurrent_v2");
     LOAD_SYMBOL(cuCtxPopCurrent, tcuCtxPopCurrent_v2, "cuCtxPopCurrent_v2");
-- 
2.39.2 (Apple Git-143)

