Add functions to allow vhost buffers to be placed in kernel space and
for the vhost driver to be controlled from a kernel driver after initial
setup by userspace.

The kernel control is only possible on new /dev/vhost-*-kernel devices,
and on these devices userspace cannot write to the iotlb, nor can it
control the placement and attributes of the virtqueues, nor start/stop
the virtqueue handling after the kernel starts using it.

Signed-off-by: Vincent Whitchurch <[email protected]>
---
 drivers/vhost/common.c | 201 +++++++++++++++++++++++++++++++++++++++++
 drivers/vhost/vhost.c  |  92 +++++++++++++++++--
 drivers/vhost/vhost.h  |   3 +
 include/linux/vhost.h  |  23 +++++
 4 files changed, 310 insertions(+), 9 deletions(-)
 create mode 100644 include/linux/vhost.h

diff --git a/drivers/vhost/common.c b/drivers/vhost/common.c
index a5722ad65e24..f9758920a33a 100644
--- a/drivers/vhost/common.c
+++ b/drivers/vhost/common.c
@@ -25,7 +25,9 @@
 struct vhost_ops;
 
 struct vhost {
+       char kernelname[128];
        struct miscdevice misc;
+       struct miscdevice kernelmisc;
        const struct vhost_ops *ops;
 };
 
@@ -46,6 +48,24 @@ static int vhost_open(struct inode *inode, struct file *file)
        return 0;
 }
 
+static int vhost_kernel_open(struct inode *inode, struct file *file)
+{
+       struct miscdevice *misc = file->private_data;
+       struct vhost *vhost = container_of(misc, struct vhost, kernelmisc);
+       struct vhost_dev *dev;
+
+       dev = vhost->ops->open(vhost);
+       if (IS_ERR(dev))
+               return PTR_ERR(dev);
+
+       dev->vhost = vhost;
+       dev->file = file;
+       dev->kernel = true;
+       file->private_data = dev;
+
+       return 0;
+}
+
 static int vhost_release(struct inode *inode, struct file *file)
 {
        struct vhost_dev *dev = file->private_data;
@@ -69,6 +89,46 @@ static long vhost_ioctl(struct file *file, unsigned int 
ioctl, unsigned long arg
        return ret;
 }
 
+static long vhost_kernel_ioctl(struct file *file, unsigned int ioctl, unsigned 
long arg)
+{
+       struct vhost_dev *dev = file->private_data;
+       struct vhost *vhost = dev->vhost;
+       long ret;
+
+       /* Only the kernel is allowed to control virtqueue attributes */
+       switch (ioctl) {
+       case VHOST_SET_VRING_NUM:
+       case VHOST_SET_VRING_ADDR:
+       case VHOST_SET_VRING_BASE:
+       case VHOST_SET_VRING_ENDIAN:
+       case VHOST_SET_MEM_TABLE:
+       case VHOST_SET_LOG_BASE:
+       case VHOST_SET_LOG_FD:
+               return -EPERM;
+       }
+
+       mutex_lock(&dev->mutex);
+
+       /*
+        * Userspace should perform all reqired setup on the vhost device
+        * _before_ asking the kernel to start using it.
+        *
+        * Note that ->kernel_attached is never reset, if userspace wants to
+        * attach again it should open the device again.
+        */
+       if (dev->kernel_attached) {
+               ret = -EPERM;
+               goto out_unlock;
+       }
+
+       ret = vhost->ops->ioctl(dev, ioctl, arg);
+
+out_unlock:
+       mutex_unlock(&dev->mutex);
+
+       return ret;
+}
+
 static ssize_t vhost_read_iter(struct kiocb *iocb, struct iov_iter *to)
 {
        struct file *file = iocb->ki_filp;
@@ -105,6 +165,129 @@ static const struct file_operations vhost_fops = {
        .poll           = vhost_poll,
 };
 
+static const struct file_operations vhost_kernel_fops = {
+       .owner          = THIS_MODULE,
+       .open           = vhost_kernel_open,
+       .release        = vhost_release,
+       .llseek         = noop_llseek,
+       .unlocked_ioctl = vhost_kernel_ioctl,
+       .compat_ioctl   = compat_ptr_ioctl,
+};
+
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
+{
+       int i;
+
+       for (i = 0; i < d->nvqs; ++i)
+               mutex_lock_nested(&d->vqs[i]->mutex, i);
+}
+
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
+{
+       int i;
+
+       for (i = 0; i < d->nvqs; ++i)
+               mutex_unlock(&d->vqs[i]->mutex);
+}
+
+struct vhost_dev *vhost_dev_get(int fd)
+{
+       struct file *file;
+       struct vhost_dev *dev;
+       struct vhost_dev *ret;
+       int err;
+       int i;
+
+       file = fget(fd);
+       if (!file)
+               return ERR_PTR(-EBADF);
+
+       if (file->f_op != &vhost_kernel_fops) {
+               ret = ERR_PTR(-EINVAL);
+               goto err_fput;
+       }
+
+       dev = file->private_data;
+
+       mutex_lock(&dev->mutex);
+       vhost_dev_lock_vqs(dev);
+
+       err = vhost_dev_check_owner(dev);
+       if (err) {
+               ret = ERR_PTR(err);
+               goto err_unlock;
+       }
+
+       if (dev->kernel_attached) {
+               ret = ERR_PTR(-EBUSY);
+               goto err_unlock;
+       }
+
+       if (!dev->iotlb) {
+               ret = ERR_PTR(-EINVAL);
+               goto err_unlock;
+       }
+
+       for (i = 0; i < dev->nvqs; i++) {
+               struct vhost_virtqueue *vq = dev->vqs[i];
+
+               if (vq->private_data) {
+                       ret = ERR_PTR(-EBUSY);
+                       goto err_unlock;
+               }
+       }
+
+       dev->kernel_attached = true;
+
+       vhost_dev_unlock_vqs(dev);
+       mutex_unlock(&dev->mutex);
+
+       return dev;
+
+err_unlock:
+       vhost_dev_unlock_vqs(dev);
+       mutex_unlock(&dev->mutex);
+err_fput:
+       fput(file);
+       return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_dev_get);
+
+void vhost_dev_start_vq(struct vhost_dev *dev, u16 idx)
+{
+       struct vhost *vhost = dev->vhost;
+
+       mutex_lock(&dev->mutex);
+       vhost->ops->start_vq(dev, idx);
+       mutex_unlock(&dev->mutex);
+}
+EXPORT_SYMBOL_GPL(vhost_dev_start_vq);
+
+void vhost_dev_stop_vq(struct vhost_dev *dev, u16 idx)
+{
+       struct vhost *vhost = dev->vhost;
+
+       mutex_lock(&dev->mutex);
+       vhost->ops->stop_vq(dev, idx);
+       mutex_unlock(&dev->mutex);
+}
+EXPORT_SYMBOL_GPL(vhost_dev_stop_vq);
+
+void vhost_dev_put(struct vhost_dev *dev)
+{
+       /* The virtqueues should already be stopped. */
+       fput(dev->file);
+}
+EXPORT_SYMBOL_GPL(vhost_dev_put);
+
+static bool vhost_kernel_supported(const struct vhost_ops *ops)
+{
+       if (!IS_ENABLED(CONFIG_VHOST_KERNEL))
+               return false;
+
+       return ops->start_vq && ops->stop_vq;
+}
+
 struct vhost *vhost_register(const struct vhost_ops *ops)
 {
        struct vhost *vhost;
@@ -125,12 +308,30 @@ struct vhost *vhost_register(const struct vhost_ops *ops)
                return ERR_PTR(ret);
        }
 
+       if (vhost_kernel_supported(ops)) {
+               snprintf(vhost->kernelname, sizeof(vhost->kernelname),
+                        "%s-kernel", ops->name);
+
+               vhost->kernelmisc.minor = MISC_DYNAMIC_MINOR;
+               vhost->kernelmisc.name = vhost->kernelname;
+               vhost->kernelmisc.fops = &vhost_kernel_fops;
+
+               ret = misc_register(&vhost->kernelmisc);
+               if (ret) {
+                       misc_deregister(&vhost->misc);
+                       kfree(vhost);
+                       return ERR_PTR(ret);
+               }
+       }
+
        return vhost;
 }
 EXPORT_SYMBOL_GPL(vhost_register);
 
 void vhost_unregister(struct vhost *vhost)
 {
+       if (vhost_kernel_supported(vhost->ops))
+               misc_deregister(&vhost->kernelmisc);
        misc_deregister(&vhost->misc);
        kfree(vhost);
 }
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 9d6496b7ad85..56a69ecfd910 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -486,6 +486,7 @@ void vhost_dev_init(struct vhost_dev *dev,
        dev->mm = NULL;
        dev->worker = NULL;
        dev->kernel = false;
+       dev->kernel_attached = false;
        dev->iov_limit = iov_limit;
        dev->weight = weight;
        dev->byte_weight = byte_weight;
@@ -1329,6 +1330,30 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
 
        return ret;
 }
+
+int vhost_dev_iotlb_update(struct vhost_dev *dev, u64 iova, u64 size, u64 
kaddr, unsigned int perm)
+{
+       int ret = 0;
+
+       mutex_lock(&dev->mutex);
+       vhost_dev_lock_vqs(dev);
+
+       if (!dev->iotlb) {
+               ret = -EINVAL;
+               goto out_unlock;
+       }
+
+       if (vhost_iotlb_add_range(dev->iotlb, iova, iova + size - 1, kaddr, 
perm))
+               ret = -ENOMEM;
+
+out_unlock:
+       vhost_dev_unlock_vqs(dev);
+       mutex_unlock(&dev->mutex);
+
+       return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_dev_iotlb_update);
+
 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
                             struct iov_iter *from)
 {
@@ -1677,27 +1702,35 @@ static long vhost_set_memory(struct vhost_dev *d, 
struct vhost_memory __user *m)
        return -EFAULT;
 }
 
-static long vhost_vring_set_num(struct vhost_dev *d,
+static int __vhost_vring_set_num(struct vhost_dev *d,
                                struct vhost_virtqueue *vq,
-                               void __user *argp)
+                               unsigned int num)
 {
-       struct vhost_vring_state s;
-
        /* Resizing ring with an active backend?
         * You don't want to do that. */
        if (vq->private_data)
                return -EBUSY;
 
-       if (copy_from_user(&s, argp, sizeof s))
-               return -EFAULT;
-
-       if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
+       if (!num || num > 0xffff || (num & (num - 1)))
                return -EINVAL;
-       vq->num = s.num;
+
+       vq->num = num;
 
        return 0;
 }
 
+static long vhost_vring_set_num(struct vhost_dev *d,
+                               struct vhost_virtqueue *vq,
+                               void __user *argp)
+{
+       struct vhost_vring_state s;
+
+       if (copy_from_user(&s, argp, sizeof(s)))
+               return -EFAULT;
+
+       return __vhost_vring_set_num(d, vq, s.num);
+}
+
 static long vhost_vring_set_addr(struct vhost_dev *d,
                                 struct vhost_virtqueue *vq,
                                 void __user *argp)
@@ -1750,6 +1783,47 @@ static long vhost_vring_set_addr(struct vhost_dev *d,
        return 0;
 }
 
+int vhost_dev_set_vring_num(struct vhost_dev *dev, unsigned int idx, unsigned 
int num)
+{
+       struct vhost_virtqueue *vq;
+       int ret;
+
+       if (idx >= dev->nvqs)
+               return -ENOBUFS;
+
+       vq = dev->vqs[idx];
+
+       mutex_lock(&vq->mutex);
+       ret = __vhost_vring_set_num(dev, vq, num);
+       mutex_unlock(&vq->mutex);
+
+       return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_dev_set_vring_num);
+
+int vhost_dev_set_num_addr(struct vhost_dev *dev, unsigned int idx, void *desc,
+                          void *avail, void *used)
+{
+       struct vhost_virtqueue *vq;
+       int ret = 0;
+
+       if (idx >= dev->nvqs)
+               return -ENOBUFS;
+
+       vq = dev->vqs[idx];
+
+       mutex_lock(&vq->mutex);
+       vq->kern.desc = desc;
+       vq->kern.avail = avail;
+       vq->kern.used = used;
+       vq->last_avail_idx = 0;
+       vq->avail_idx = vq->last_avail_idx;
+       mutex_unlock(&vq->mutex);
+
+       return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_dev_set_num_addr);
+
 static long vhost_vring_set_num_addr(struct vhost_dev *d,
                                     struct vhost_virtqueue *vq,
                                     unsigned int ioctl,
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 408ff243ed31..6cd5d6b0d644 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -23,6 +23,8 @@ struct vhost_ops {
        struct vhost_dev * (*open)(struct vhost *vhost);
        long (*ioctl)(struct vhost_dev *dev, unsigned int ioctl, unsigned long 
arg);
        void (*release)(struct vhost_dev *dev);
+       void (*start_vq)(struct vhost_dev *dev, u16 idx);
+       void (*stop_vq)(struct vhost_dev *dev, u16 idx);
 };
 
 struct vhost *vhost_register(const struct vhost_ops *ops);
@@ -191,6 +193,7 @@ struct vhost_dev {
        u64 kcov_handle;
        bool use_worker;
        bool kernel;
+       bool kernel_attached;
        int (*msg_handler)(struct vhost_dev *dev,
                           struct vhost_iotlb_msg *msg);
 };
diff --git a/include/linux/vhost.h b/include/linux/vhost.h
new file mode 100644
index 000000000000..cdfe244c776b
--- /dev/null
+++ b/include/linux/vhost.h
@@ -0,0 +1,23 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+#ifndef _INCLUDE_LINUX_VHOST_H
+#define _INCLUDE_LINUX_VHOST_H
+
+#include <uapi/linux/vhost.h>
+
+struct vhost_dev;
+
+struct vhost_dev *vhost_dev_get(int fd);
+void vhost_dev_put(struct vhost_dev *dev);
+
+int vhost_dev_set_vring_num(struct vhost_dev *dev, unsigned int idx,
+                           unsigned int num);
+int vhost_dev_set_num_addr(struct vhost_dev *dev, unsigned int idx, void *desc,
+                          void *avail, void *used);
+
+void vhost_dev_start_vq(struct vhost_dev *dev, u16 idx);
+void vhost_dev_stop_vq(struct vhost_dev *dev, u16 idx);
+
+int vhost_dev_iotlb_update(struct vhost_dev *dev, u64 iova, u64 size,
+                          u64 kaddr, unsigned int perm);
+
+#endif
-- 
2.28.0

_______________________________________________
Virtualization mailing list
[email protected]
https://lists.linuxfoundation.org/mailman/listinfo/virtualization

Reply via email to