diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 232 |
1 files changed, 198 insertions, 34 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 0b99783083f6..e05557d52999 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -17,12 +17,13 @@ #include <linux/mm.h> #include <linux/miscdevice.h> #include <linux/mutex.h> -#include <linux/workqueue.h> #include <linux/rcupdate.h> #include <linux/poll.h> #include <linux/file.h> #include <linux/highmem.h> #include <linux/slab.h> +#include <linux/kthread.h> +#include <linux/cgroup.h> #include <linux/net.h> #include <linux/if_packet.h> @@ -37,8 +38,6 @@ enum { VHOST_MEMORY_F_LOG = 0x1, }; -static struct workqueue_struct *vhost_workqueue; - static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, poll_table *pt) { @@ -52,23 +51,31 @@ static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, void *key) { - struct vhost_poll *poll; - poll = container_of(wait, struct vhost_poll, wait); + struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); + if (!((unsigned long)key & poll->mask)) return 0; - queue_work(vhost_workqueue, &poll->work); + vhost_poll_queue(poll); return 0; } /* Init poll structure */ -void vhost_poll_init(struct vhost_poll *poll, work_func_t func, - unsigned long mask) +void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, + unsigned long mask, struct vhost_dev *dev) { - INIT_WORK(&poll->work, func); + struct vhost_work *work = &poll->work; + init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); init_poll_funcptr(&poll->table, vhost_poll_func); poll->mask = mask; + poll->dev = dev; + + INIT_LIST_HEAD(&work->node); + work->fn = fn; + init_waitqueue_head(&work->done); + work->flushing = 0; + work->queue_seq = work->done_seq = 0; } /* Start polling a file. We add ourselves to file's wait queue. The caller must @@ -92,12 +99,40 @@ void vhost_poll_stop(struct vhost_poll *poll) * locks that are also used by the callback. */ void vhost_poll_flush(struct vhost_poll *poll) { - flush_work(&poll->work); + struct vhost_work *work = &poll->work; + unsigned seq; + int left; + int flushing; + + spin_lock_irq(&poll->dev->work_lock); + seq = work->queue_seq; + work->flushing++; + spin_unlock_irq(&poll->dev->work_lock); + wait_event(work->done, ({ + spin_lock_irq(&poll->dev->work_lock); + left = seq - work->done_seq <= 0; + spin_unlock_irq(&poll->dev->work_lock); + left; + })); + spin_lock_irq(&poll->dev->work_lock); + flushing = --work->flushing; + spin_unlock_irq(&poll->dev->work_lock); + BUG_ON(flushing < 0); } void vhost_poll_queue(struct vhost_poll *poll) { - queue_work(vhost_workqueue, &poll->work); + struct vhost_dev *dev = poll->dev; + struct vhost_work *work = &poll->work; + unsigned long flags; + + spin_lock_irqsave(&dev->work_lock, flags); + if (list_empty(&work->node)) { + list_add_tail(&work->node, &dev->work_list); + work->queue_seq++; + wake_up_process(dev->worker); + } + spin_unlock_irqrestore(&dev->work_lock, flags); } static void vhost_vq_reset(struct vhost_dev *dev, @@ -114,7 +149,8 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->used_flags = 0; vq->log_used = false; vq->log_addr = -1ull; - vq->hdr_size = 0; + vq->vhost_hlen = 0; + vq->sock_hlen = 0; vq->private_data = NULL; vq->log_base = NULL; vq->error_ctx = NULL; @@ -125,10 +161,51 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_ctx = NULL; } +static int vhost_worker(void *data) +{ + struct vhost_dev *dev = data; + struct vhost_work *work = NULL; + unsigned uninitialized_var(seq); + + for (;;) { + /* mb paired w/ kthread_stop */ + set_current_state(TASK_INTERRUPTIBLE); + + spin_lock_irq(&dev->work_lock); + if (work) { + work->done_seq = seq; + if (work->flushing) + wake_up_all(&work->done); + } + + if (kthread_should_stop()) { + spin_unlock_irq(&dev->work_lock); + __set_current_state(TASK_RUNNING); + return 0; + } + if (!list_empty(&dev->work_list)) { + work = list_first_entry(&dev->work_list, + struct vhost_work, node); + list_del_init(&work->node); + seq = work->queue_seq; + } else + work = NULL; + spin_unlock_irq(&dev->work_lock); + + if (work) { + __set_current_state(TASK_RUNNING); + work->fn(work); + } else + schedule(); + + } +} + long vhost_dev_init(struct vhost_dev *dev, struct vhost_virtqueue *vqs, int nvqs) { int i; + dev->vqs = vqs; dev->nvqs = nvqs; mutex_init(&dev->mutex); @@ -136,6 +213,9 @@ long vhost_dev_init(struct vhost_dev *dev, dev->log_file = NULL; dev->memory = NULL; dev->mm = NULL; + spin_lock_init(&dev->work_lock); + INIT_LIST_HEAD(&dev->work_list); + dev->worker = NULL; for (i = 0; i < dev->nvqs; ++i) { dev->vqs[i].dev = dev; @@ -143,9 +223,9 @@ long vhost_dev_init(struct vhost_dev *dev, vhost_vq_reset(dev, dev->vqs + i); if (dev->vqs[i].handle_kick) vhost_poll_init(&dev->vqs[i].poll, - dev->vqs[i].handle_kick, - POLLIN); + dev->vqs[i].handle_kick, POLLIN, dev); } + return 0; } @@ -159,12 +239,36 @@ long vhost_dev_check_owner(struct vhost_dev *dev) /* Caller should have device mutex */ static long vhost_dev_set_owner(struct vhost_dev *dev) { + struct task_struct *worker; + int err; /* Is there an owner already? */ - if (dev->mm) - return -EBUSY; + if (dev->mm) { + err = -EBUSY; + goto err_mm; + } /* No owner, become one */ dev->mm = get_task_mm(current); + worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); + if (IS_ERR(worker)) { + err = PTR_ERR(worker); + goto err_worker; + } + + dev->worker = worker; + err = cgroup_attach_task_current_cg(worker); + if (err) + goto err_cgroup; + wake_up_process(worker); /* avoid contributing to loadavg */ + return 0; +err_cgroup: + kthread_stop(worker); +err_worker: + if (dev->mm) + mmput(dev->mm); + dev->mm = NULL; +err_mm: + return err; } /* Caller should have device mutex */ @@ -217,6 +321,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev) if (dev->mm) mmput(dev->mm); dev->mm = NULL; + + WARN_ON(!list_empty(&dev->work_list)); + kthread_stop(dev->worker); } static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) @@ -237,8 +344,8 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, { int i; - if (!mem) - return 0; + if (!mem) + return 0; for (i = 0; i < mem->nregions; ++i) { struct vhost_memory_region *m = mem->regions + i; @@ -995,9 +1102,9 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, } /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ -void vhost_discard_vq_desc(struct vhost_virtqueue *vq) +void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) { - vq->last_avail_idx--; + vq->last_avail_idx -= n; } /* After we've used one of their buffers, we tell them about it. We'll then @@ -1042,6 +1149,67 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) return 0; } +static int __vhost_add_used_n(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + unsigned count) +{ + struct vring_used_elem __user *used; + int start; + + start = vq->last_used_idx % vq->num; + used = vq->used->ring + start; + if (copy_to_user(used, heads, count * sizeof *used)) { + vq_err(vq, "Failed to write used"); + return -EFAULT; + } + if (unlikely(vq->log_used)) { + /* Make sure data is seen before log. */ + smp_wmb(); + /* Log used ring entry write. */ + log_write(vq->log_base, + vq->log_addr + + ((void __user *)used - (void __user *)vq->used), + count * sizeof *used); + } + vq->last_used_idx += count; + return 0; +} + +/* After we've used one of their buffers, we tell them about it. We'll then + * want to notify the guest, using eventfd. */ +int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, + unsigned count) +{ + int start, n, r; + + start = vq->last_used_idx % vq->num; + n = vq->num - start; + if (n < count) { + r = __vhost_add_used_n(vq, heads, n); + if (r < 0) + return r; + heads += n; + count -= n; + } + r = __vhost_add_used_n(vq, heads, count); + + /* Make sure buffer is written before we update index. */ + smp_wmb(); + if (put_user(vq->last_used_idx, &vq->used->idx)) { + vq_err(vq, "Failed to increment used idx"); + return -EFAULT; + } + if (unlikely(vq->log_used)) { + /* Log used index update. */ + log_write(vq->log_base, + vq->log_addr + offsetof(struct vring_used, idx), + sizeof vq->used->idx); + if (vq->log_ctx) + eventfd_signal(vq->log_ctx, 1); + } + return r; +} + /* This actually signals the guest, using eventfd. */ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) { @@ -1076,6 +1244,15 @@ void vhost_add_used_and_signal(struct vhost_dev *dev, vhost_signal(dev, vq); } +/* multi-buffer version of vhost_add_used_and_signal */ +void vhost_add_used_and_signal_n(struct vhost_dev *dev, + struct vhost_virtqueue *vq, + struct vring_used_elem *heads, unsigned count) +{ + vhost_add_used_n(vq, heads, count); + vhost_signal(dev, vq); +} + /* OK, now we need to know about added descriptors. */ bool vhost_enable_notify(struct vhost_virtqueue *vq) { @@ -1100,7 +1277,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq) return false; } - return avail_idx != vq->last_avail_idx; + return avail_idx != vq->avail_idx; } /* We don't need to be notified again. */ @@ -1115,16 +1292,3 @@ void vhost_disable_notify(struct vhost_virtqueue *vq) vq_err(vq, "Failed to enable notification at %p: %d\n", &vq->used->flags, r); } - -int vhost_init(void) -{ - vhost_workqueue = create_singlethread_workqueue("vhost"); - if (!vhost_workqueue) - return -ENOMEM; - return 0; -} - -void vhost_cleanup(void) -{ - destroy_workqueue(vhost_workqueue); -} |