diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 114 |
1 files changed, 82 insertions, 32 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index e05557d52999..8b5a1b33d0fe 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -60,22 +60,25 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, return 0; } +static void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) +{ + INIT_LIST_HEAD(&work->node); + work->fn = fn; + init_waitqueue_head(&work->done); + work->flushing = 0; + work->queue_seq = work->done_seq = 0; +} + /* Init poll structure */ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, unsigned long mask, struct vhost_dev *dev) { - struct vhost_work *work = &poll->work; - init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); init_poll_funcptr(&poll->table, vhost_poll_func); poll->mask = mask; poll->dev = dev; - INIT_LIST_HEAD(&work->node); - work->fn = fn; - init_waitqueue_head(&work->done); - work->flushing = 0; - work->queue_seq = work->done_seq = 0; + vhost_work_init(&poll->work, fn); } /* Start polling a file. We add ourselves to file's wait queue. The caller must @@ -95,35 +98,38 @@ void vhost_poll_stop(struct vhost_poll *poll) remove_wait_queue(poll->wqh, &poll->wait); } -/* Flush any work that has been scheduled. When calling this, don't hold any - * locks that are also used by the callback. */ -void vhost_poll_flush(struct vhost_poll *poll) +static void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) { - struct vhost_work *work = &poll->work; unsigned seq; int left; int flushing; - spin_lock_irq(&poll->dev->work_lock); + spin_lock_irq(&dev->work_lock); seq = work->queue_seq; work->flushing++; - spin_unlock_irq(&poll->dev->work_lock); + spin_unlock_irq(&dev->work_lock); wait_event(work->done, ({ - spin_lock_irq(&poll->dev->work_lock); + spin_lock_irq(&dev->work_lock); left = seq - work->done_seq <= 0; - spin_unlock_irq(&poll->dev->work_lock); + spin_unlock_irq(&dev->work_lock); left; })); - spin_lock_irq(&poll->dev->work_lock); + spin_lock_irq(&dev->work_lock); flushing = --work->flushing; - spin_unlock_irq(&poll->dev->work_lock); + spin_unlock_irq(&dev->work_lock); BUG_ON(flushing < 0); } -void vhost_poll_queue(struct vhost_poll *poll) +/* Flush any work that has been scheduled. When calling this, don't hold any + * locks that are also used by the callback. */ +void vhost_poll_flush(struct vhost_poll *poll) +{ + vhost_work_flush(poll->dev, &poll->work); +} + +static inline void vhost_work_queue(struct vhost_dev *dev, + struct vhost_work *work) { - struct vhost_dev *dev = poll->dev; - struct vhost_work *work = &poll->work; unsigned long flags; spin_lock_irqsave(&dev->work_lock, flags); @@ -135,6 +141,11 @@ void vhost_poll_queue(struct vhost_poll *poll) spin_unlock_irqrestore(&dev->work_lock, flags); } +void vhost_poll_queue(struct vhost_poll *poll) +{ + vhost_work_queue(poll->dev, &poll->work); +} + static void vhost_vq_reset(struct vhost_dev *dev, struct vhost_virtqueue *vq) { @@ -236,6 +247,29 @@ long vhost_dev_check_owner(struct vhost_dev *dev) return dev->mm == current->mm ? 0 : -EPERM; } +struct vhost_attach_cgroups_struct { + struct vhost_work work; + struct task_struct *owner; + int ret; +}; + +static void vhost_attach_cgroups_work(struct vhost_work *work) +{ + struct vhost_attach_cgroups_struct *s; + s = container_of(work, struct vhost_attach_cgroups_struct, work); + s->ret = cgroup_attach_task_all(s->owner, current); +} + +static int vhost_attach_cgroups(struct vhost_dev *dev) +{ + struct vhost_attach_cgroups_struct attach; + attach.owner = current; + vhost_work_init(&attach.work, vhost_attach_cgroups_work); + vhost_work_queue(dev, &attach.work); + vhost_work_flush(dev, &attach.work); + return attach.ret; +} + /* Caller should have device mutex */ static long vhost_dev_set_owner(struct vhost_dev *dev) { @@ -255,14 +289,16 @@ static long vhost_dev_set_owner(struct vhost_dev *dev) } dev->worker = worker; - err = cgroup_attach_task_current_cg(worker); + wake_up_process(worker); /* avoid contributing to loadavg */ + + err = vhost_attach_cgroups(dev); if (err) goto err_cgroup; - wake_up_process(worker); /* avoid contributing to loadavg */ return 0; err_cgroup: kthread_stop(worker); + dev->worker = NULL; err_worker: if (dev->mm) mmput(dev->mm); @@ -284,7 +320,7 @@ long vhost_dev_reset_owner(struct vhost_dev *dev) vhost_dev_cleanup(dev); memory->nregions = 0; - dev->memory = memory; + RCU_INIT_POINTER(dev->memory, memory); return 0; } @@ -316,14 +352,18 @@ void vhost_dev_cleanup(struct vhost_dev *dev) fput(dev->log_file); dev->log_file = NULL; /* No one will access memory at this point */ - kfree(dev->memory); - dev->memory = NULL; + kfree(rcu_dereference_protected(dev->memory, + lockdep_is_held(&dev->mutex))); + RCU_INIT_POINTER(dev->memory, NULL); if (dev->mm) mmput(dev->mm); dev->mm = NULL; WARN_ON(!list_empty(&dev->work_list)); - kthread_stop(dev->worker); + if (dev->worker) { + kthread_stop(dev->worker); + dev->worker = NULL; + } } static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) @@ -401,14 +441,22 @@ static int vq_access_ok(unsigned int num, /* Caller should have device mutex but not vq mutex */ int vhost_log_access_ok(struct vhost_dev *dev) { - return memory_access_ok(dev, dev->memory, 1); + struct vhost_memory *mp; + + mp = rcu_dereference_protected(dev->memory, + lockdep_is_held(&dev->mutex)); + return memory_access_ok(dev, mp, 1); } /* Verify access for write logging. */ /* Caller should have vq mutex and device mutex */ static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base) { - return vq_memory_access_ok(log_base, vq->dev->memory, + struct vhost_memory *mp; + + mp = rcu_dereference_protected(vq->dev->memory, + lockdep_is_held(&vq->mutex)); + return vq_memory_access_ok(log_base, mp, vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) && (!vq->log_used || log_access_ok(log_base, vq->log_addr, sizeof *vq->used + @@ -448,7 +496,8 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) kfree(newmem); return -EFAULT; } - oldmem = d->memory; + oldmem = rcu_dereference_protected(d->memory, + lockdep_is_held(&d->mutex)); rcu_assign_pointer(d->memory, newmem); synchronize_rcu(); kfree(oldmem); @@ -819,11 +868,12 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, if (r < 0) return r; len -= l; - if (!len) + if (!len) { + if (vq->log_ctx) + eventfd_signal(vq->log_ctx, 1); return 0; + } } - if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); /* Length written exceeds what we have stored. This is a bug. */ BUG(); return 0; |