This op does the copy to/from user for the info and can return back
a cap chain through a vfio_info_cap * result.

Reviewed-by: Kevin Tian <[email protected]>
Reviewed-by: Pranjal Shrivastava <[email protected]>
Signed-off-by: Jason Gunthorpe <[email protected]>
---
 drivers/vfio/vfio_main.c | 56 +++++++++++++++++++++++++++++++++++++---
 include/linux/vfio.h     |  4 +++
 2 files changed, 56 insertions(+), 4 deletions(-)

diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c
index f056e82ba35075..48d034aede46fc 100644
--- a/drivers/vfio/vfio_main.c
+++ b/drivers/vfio/vfio_main.c
@@ -1259,6 +1259,57 @@ static int vfio_ioctl_device_feature(struct vfio_device 
*device,
        }
 }
 
+static long vfio_get_region_info(struct vfio_device *device,
+                                struct vfio_region_info __user *arg)
+{
+       unsigned long minsz = offsetofend(struct vfio_region_info, offset);
+       struct vfio_region_info info = {};
+       struct vfio_info_cap caps = {};
+       int ret;
+
+       if (copy_from_user(&info, arg, minsz))
+               return -EFAULT;
+       if (info.argsz < minsz)
+               return -EINVAL;
+
+       if (device->ops->get_region_info_caps) {
+               ret = device->ops->get_region_info_caps(device, &info, &caps);
+               if (ret)
+                       goto out_free;
+
+               if (caps.size) {
+                       info.flags |= VFIO_REGION_INFO_FLAG_CAPS;
+                       if (info.argsz < sizeof(info) + caps.size) {
+                               info.argsz = sizeof(info) + caps.size;
+                               info.cap_offset = 0;
+                       } else {
+                               vfio_info_cap_shift(&caps, sizeof(info));
+                               if (copy_to_user(arg + 1, caps.buf,
+                                                caps.size)) {
+                                       ret = -EFAULT;
+                                       goto out_free;
+                               }
+                               info.cap_offset = sizeof(info);
+                       }
+               }
+
+               if (copy_to_user(arg, &info, minsz)) {
+                       ret = -EFAULT;
+                       goto out_free;
+               }
+       } else if (device->ops->get_region_info) {
+               ret = device->ops->get_region_info(device, arg);
+               if (ret)
+                       return ret;
+       } else {
+               return -EINVAL;
+       }
+
+out_free:
+       kfree(caps.buf);
+       return ret;
+}
+
 static long vfio_device_fops_unl_ioctl(struct file *filep,
                                       unsigned int cmd, unsigned long arg)
 {
@@ -1297,10 +1348,7 @@ static long vfio_device_fops_unl_ioctl(struct file 
*filep,
                break;
 
        case VFIO_DEVICE_GET_REGION_INFO:
-               if (unlikely(!device->ops->get_region_info))
-                       ret = -EINVAL;
-               else
-                       ret = device->ops->get_region_info(device, uptr);
+               ret = vfio_get_region_info(device, uptr);
                break;
 
        default:
diff --git a/include/linux/vfio.h b/include/linux/vfio.h
index be5fcf8432e8d5..6311ddc837701d 100644
--- a/include/linux/vfio.h
+++ b/include/linux/vfio.h
@@ -21,6 +21,7 @@ struct kvm;
 struct iommufd_ctx;
 struct iommufd_device;
 struct iommufd_access;
+struct vfio_info_cap;
 
 /*
  * VFIO devices can be placed in a set, this allows all devices to share this
@@ -134,6 +135,9 @@ struct vfio_device_ops {
                         unsigned long arg);
        int     (*get_region_info)(struct vfio_device *vdev,
                                   struct vfio_region_info __user *arg);
+       int     (*get_region_info_caps)(struct vfio_device *vdev,
+                                       struct vfio_region_info *info,
+                                       struct vfio_info_cap *caps);
        int     (*mmap)(struct vfio_device *vdev, struct vm_area_struct *vma);
        void    (*request)(struct vfio_device *vdev, unsigned int count);
        int     (*match)(struct vfio_device *vdev, char *buf);
-- 
2.43.0

Reply via email to