From 11a328e9d2e420f24609f42f78eefa1e707aabd3 Mon Sep 17 00:00:00 2001
From: perflab1 <perflab1@hyperscale-dp04.nvidia.com>
Date: Wed, 12 Dec 2018 15:35:14 +0300
Subject: [PATCH] Adding HEVC YUV444P decoding support:

---
 libavcodec/cuviddec.c      | 137 ++++++++++++++++++++++++++++++++++++++++++---
 libavutil/hwcontext_cuda.c |   2 +
 2 files changed, 132 insertions(+), 7 deletions(-)

diff --git a/libavcodec/cuviddec.c b/libavcodec/cuviddec.c
index 03589367ce..5e65161924 100644
--- a/libavcodec/cuviddec.c
+++ b/libavcodec/cuviddec.c
@@ -106,6 +106,7 @@ static int CUDAAPI cuvid_handle_video_sequence(void *opaque, CUVIDEOFORMAT* form
     CUVIDDECODECAPS *caps = NULL;
     CUVIDDECODECREATEINFO cuinfo;
     int surface_fmt;
+    int is_hevc_yuv444p;
 
     int old_width = avctx->width;
     int old_height = avctx->height;
@@ -148,17 +149,31 @@ static int CUDAAPI cuvid_handle_video_sequence(void *opaque, CUVIDEOFORMAT* form
     cuinfo.target_rect.right = cuinfo.ulTargetWidth;
     cuinfo.target_rect.bottom = cuinfo.ulTargetHeight;
 
+    is_hevc_yuv444p  = (format->codec  == cudaVideoCodec_HEVC);
+    is_hevc_yuv444p &= (avctx->pix_fmt == AV_PIX_FMT_YUV444P) || 
+                       (avctx->pix_fmt == AV_PIX_FMT_YUV444P10LE) ||
+                       (avctx->pix_fmt == AV_PIX_FMT_YUV444P12LE);
+
     switch (format->bit_depth_luma_minus8) {
     case 0: // 8-bit
-        pix_fmts[1] = AV_PIX_FMT_NV12;
+	if(is_hevc_yuv444p)
+            pix_fmts[1] = AV_PIX_FMT_YUV444P;
+        else
+            pix_fmts[1] = AV_PIX_FMT_NV12;
         caps = &ctx->caps8;
         break;
     case 2: // 10-bit
-        pix_fmts[1] = AV_PIX_FMT_P010;
+        if(is_hevc_yuv444p)
+            pix_fmts[1] = AV_PIX_FMT_YUV444P10LE;
+        else
+            pix_fmts[1] = AV_PIX_FMT_P010;
         caps = &ctx->caps10;
         break;
     case 4: // 12-bit
-        pix_fmts[1] = AV_PIX_FMT_P016;
+         if(is_hevc_yuv444p)
+             pix_fmts[1] = AV_PIX_FMT_YUV444P12LE;
+         else
+             pix_fmts[1] = AV_PIX_FMT_P016;
         caps = &ctx->caps12;
         break;
     default:
@@ -261,8 +276,15 @@ static int CUDAAPI cuvid_handle_video_sequence(void *opaque, CUVIDEOFORMAT* form
         return 0;
     }
 
-    if (format->chroma_format != cudaVideoChromaFormat_420) {
-        av_log(avctx, AV_LOG_ERROR, "Chroma formats other than 420 are not supported\n");
+    /* H264, VP9, VP8, VC1, MPEG1, MPEG2 and MPEG4 support YUV 4:2:0 chroma format only.
+     * HEVC supports YUV 4:2:0 and YUV 4:4:4.
+     */
+    const int supported_chroma_format = (format->codec == cudaVideoCodec_HEVC) ?
+        (format->chroma_format == cudaVideoChromaFormat_420) || (format->chroma_format == cudaVideoChromaFormat_444):
+        (format->chroma_format == cudaVideoChromaFormat_420);
+
+    if(!supported_chroma_format) {
+        av_log(avctx, AV_LOG_ERROR, "Unsopported chroma format\n");
         ctx->internal_error = AVERROR(EINVAL);
         return 0;
     }
@@ -276,12 +298,19 @@ static int CUDAAPI cuvid_handle_video_sequence(void *opaque, CUVIDEOFORMAT* form
     case AV_PIX_FMT_NV12:
         cuinfo.OutputFormat = cudaVideoSurfaceFormat_NV12;
         break;
+    case AV_PIX_FMT_YUV444P:
+        cuinfo.OutputFormat = cudaVideoSurfaceFormat_YUV444;
+        break;
+    case AV_PIX_FMT_YUV444P10LE:
+    case AV_PIX_FMT_YUV444P12LE:
+        cuinfo.OutputFormat = cudaVideoSurfaceFormat_YUV444_16Bit;
+        break;
     case AV_PIX_FMT_P010:
     case AV_PIX_FMT_P016:
         cuinfo.OutputFormat = cudaVideoSurfaceFormat_P016;
         break;
     default:
-        av_log(avctx, AV_LOG_ERROR, "Output formats other than NV12, P010 or P016 are not supported\n");
+        av_log(avctx, AV_LOG_ERROR, "Unsupported output format\n");
         ctx->internal_error = AVERROR(EINVAL);
         return 0;
     }
@@ -576,7 +605,44 @@ static int cuvid_output_frame(AVCodecContext *avctx, AVFrame *frame)
                 goto error;
             }
             av_frame_free(&tmp_frame);
-        } else {
+        }  else if (avctx->pix_fmt == AV_PIX_FMT_YUV444P     ||
+                    avctx->pix_fmt == AV_PIX_FMT_YUV444P10LE ||
+                    avctx->pix_fmt == AV_PIX_FMT_YUV444P12LE   ) 
+        {
+            AVFrame *tmp_frame = av_frame_alloc();
+            if (!tmp_frame) {
+                av_log(avctx, AV_LOG_ERROR, "av_frame_alloc failed\n");
+                ret = AVERROR(ENOMEM);
+                goto error;
+            }
+
+            tmp_frame->format        = AV_PIX_FMT_CUDA;
+            tmp_frame->hw_frames_ctx = av_buffer_ref(ctx->hwframe);
+            tmp_frame->data[0]       = (uint8_t*)mapped_frame;
+            tmp_frame->linesize[0]   = pitch;
+            tmp_frame->data[1]       = (uint8_t*)(mapped_frame + avctx->height * pitch);
+            tmp_frame->linesize[1]   = pitch;
+            tmp_frame->data[2]       = (uint8_t*)(mapped_frame + avctx->height * pitch * 2);
+            tmp_frame->linesize[2]   = pitch;
+            tmp_frame->width         = avctx->width;
+            tmp_frame->height        = avctx->height;
+
+            ret = ff_get_buffer(avctx, frame, 0);
+            if (ret < 0) {
+                av_log(avctx, AV_LOG_ERROR, "ff_get_buffer failed\n");
+                av_frame_free(&tmp_frame);
+                goto error;
+            }
+
+            ret = av_hwframe_transfer_data(frame, tmp_frame, 0);
+            if (ret) {
+                av_log(avctx, AV_LOG_ERROR, "av_hwframe_transfer_data failed\n");
+                av_frame_free(&tmp_frame);
+                goto error;
+            }
+            av_frame_free(&tmp_frame);
+        } 
+        else {
             ret = AVERROR_BUG;
             goto error;
         }
@@ -722,6 +788,20 @@ static int cuvid_test_capabilities(AVCodecContext *avctx,
     ctx->caps8.eChromaFormat = ctx->caps10.eChromaFormat = ctx->caps12.eChromaFormat
         = cudaVideoChromaFormat_420;
 
+#if NVDECAPI_CHECK_VERSION(9, 0)
+    int is_hevc_yuv444p;
+    is_hevc_yuv444p  = (cuparseinfo->CodecType == cudaVideoCodec_HEVC);
+    is_hevc_yuv444p &= (avctx->pix_fmt == AV_PIX_FMT_YUV444P) ||
+                       (avctx->pix_fmt == AV_PIX_FMT_YUV444P10LE) ||
+                       (avctx->pix_fmt == AV_PIX_FMT_YUV444P12LE);
+    
+    if(is_hevc_yuv444p) {
+        ctx->caps8.eChromaFormat  = cudaVideoChromaFormat_444; 
+        ctx->caps10.eChromaFormat = cudaVideoChromaFormat_444; 
+        ctx->caps12.eChromaFormat = cudaVideoChromaFormat_444;
+    }
+#endif
+
     ctx->caps8.nBitDepthMinus8 = 0;
     ctx->caps10.nBitDepthMinus8 = 2;
     ctx->caps12.nBitDepthMinus8 = 4;
@@ -791,11 +871,20 @@ static av_cold int cuvid_decode_init(AVCodecContext *avctx)
     CUcontext dummy;
     const AVBitStreamFilter *bsf;
     int ret = 0;
+    int is_hevc_yuv444p;
 
     enum AVPixelFormat pix_fmts[3] = { AV_PIX_FMT_CUDA,
                                        AV_PIX_FMT_NV12,
                                        AV_PIX_FMT_NONE };
 
+    is_hevc_yuv444p  = (avctx->codec->id == AV_CODEC_ID_HEVC);
+    is_hevc_yuv444p &= (avctx->pix_fmt == AV_PIX_FMT_YUV444P) ||
+                       (avctx->pix_fmt == AV_PIX_FMT_YUV444P10LE) ||
+                       (avctx->pix_fmt == AV_PIX_FMT_YUV444P12LE);
+
+    if(is_hevc_yuv444p)
+        pix_fmts[1] = avctx->pix_fmt;
+
     int probed_width = avctx->coded_width ? avctx->coded_width : 1280;
     int probed_height = avctx->coded_height ? avctx->coded_height : 720;
     int probed_bit_depth = 8;
@@ -1138,8 +1227,42 @@ static const AVCodecHWConfigInternal *cuvid_hw_configs[] = {
     };
 
 #if CONFIG_HEVC_CUVID_DECODER
+#if NVDECAPI_CHECK_VERSION(9, 0)
+static const AVClass hevc_cuvid_class = { 
+    .class_name = "hevc_cuvid", 
+    .item_name = av_default_item_name, 
+    .option = options, 
+    .version = LIBAVUTIL_VERSION_INT, 
+}; 
+
+AVCodec ff_hevc_cuvid_decoder = { 
+    .name           = "hevc_cuvid", 
+    .long_name      = NULL_IF_CONFIG_SMALL("Nvidia CUVID HEVC decoder"), 
+    .type           = AVMEDIA_TYPE_VIDEO, 
+    .id             = AV_CODEC_ID_HEVC, 
+    .priv_data_size = sizeof(CuvidContext), 
+    .priv_class     = &hevc_cuvid_class, 
+    .init           = cuvid_decode_init, 
+    .close          = cuvid_decode_end, 
+    .decode         = cuvid_decode_frame, 
+    .receive_frame  = cuvid_output_frame, 
+    .flush          = cuvid_flush, \
+    .capabilities   = AV_CODEC_CAP_DELAY | AV_CODEC_CAP_AVOID_PROBING | AV_CODEC_CAP_HARDWARE, 
+    .pix_fmts       = (const enum AVPixelFormat[]){ AV_PIX_FMT_CUDA, 
+                                                    AV_PIX_FMT_NV12, 
+                                                    AV_PIX_FMT_P010, 
+                                                    AV_PIX_FMT_P016,
+                                                    AV_PIX_FMT_YUV444P,
+                                                    AV_PIX_FMT_YUV444P10LE,
+                                                    AV_PIX_FMT_YUV444P12LE,
+                                                    AV_PIX_FMT_NONE }, 
+    .hw_configs     = cuvid_hw_configs, 
+    .wrapper_name   = "cuvid", 
+};
+#else
 DEFINE_CUVID_CODEC(hevc, HEVC)
 #endif
+#endif
 
 #if CONFIG_H264_CUVID_DECODER
 DEFINE_CUVID_CODEC(h264, H264)
diff --git a/libavutil/hwcontext_cuda.c b/libavutil/hwcontext_cuda.c
index 540a7610ef..a34d46992c 100644
--- a/libavutil/hwcontext_cuda.c
+++ b/libavutil/hwcontext_cuda.c
@@ -39,6 +39,8 @@ static const enum AVPixelFormat supported_formats[] = {
     AV_PIX_FMT_YUV444P,
     AV_PIX_FMT_P010,
     AV_PIX_FMT_P016,
+    AV_PIX_FMT_YUV444P10LE,
+    AV_PIX_FMT_YUV444P12LE,
     AV_PIX_FMT_YUV444P16,
     AV_PIX_FMT_0RGB32,
     AV_PIX_FMT_0BGR32,
-- 
2.15.1.windows.2

