diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 97 |
1 files changed, 50 insertions, 47 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 78987e481bc6..c90f4374442a 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -18,7 +18,6 @@ #include <linux/mmu_context.h> #include <linux/miscdevice.h> #include <linux/mutex.h> -#include <linux/rcupdate.h> #include <linux/poll.h> #include <linux/file.h> #include <linux/highmem.h> @@ -191,6 +190,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_used = false; vq->log_addr = -1ull; vq->private_data = NULL; + vq->acked_features = 0; vq->log_base = NULL; vq->error_ctx = NULL; vq->error = NULL; @@ -198,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->call_ctx = NULL; vq->call = NULL; vq->log_ctx = NULL; + vq->memory = NULL; } static int vhost_worker(void *data) @@ -415,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); /* Caller should have device mutex */ void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) { + int i; + vhost_dev_cleanup(dev, true); /* Restore memory to default empty mapping. */ memory->nregions = 0; - RCU_INIT_POINTER(dev->memory, memory); + dev->memory = memory; + /* We don't need VQ locks below since vhost_dev_cleanup makes sure + * VQs aren't running. + */ + for (i = 0; i < dev->nvqs; ++i) + dev->vqs[i]->memory = memory; } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); @@ -462,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) fput(dev->log_file); dev->log_file = NULL; /* No one will access memory at this point */ - kfree(rcu_dereference_protected(dev->memory, - locked == - lockdep_is_held(&dev->mutex))); - RCU_INIT_POINTER(dev->memory, NULL); + kfree(dev->memory); + dev->memory = NULL; WARN_ON(!list_empty(&dev->work_list)); if (dev->worker) { kthread_stop(dev->worker); @@ -524,11 +530,13 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, for (i = 0; i < d->nvqs; ++i) { int ok; + bool log; + mutex_lock(&d->vqs[i]->mutex); + log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); /* If ring is inactive, will check when it's enabled. */ if (d->vqs[i]->private_data) - ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, - log_all); + ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log); else ok = 1; mutex_unlock(&d->vqs[i]->mutex); @@ -538,12 +546,12 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, return 1; } -static int vq_access_ok(struct vhost_dev *d, unsigned int num, +static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, struct vring_desc __user *desc, struct vring_avail __user *avail, struct vring_used __user *used) { - size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; + size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; return access_ok(VERIFY_READ, desc, num * sizeof *desc) && access_ok(VERIFY_READ, avail, sizeof *avail + num * sizeof *avail->ring + s) && @@ -555,26 +563,19 @@ static int vq_access_ok(struct vhost_dev *d, unsigned int num, /* Caller should have device mutex but not vq mutex */ int vhost_log_access_ok(struct vhost_dev *dev) { - struct vhost_memory *mp; - - mp = rcu_dereference_protected(dev->memory, - lockdep_is_held(&dev->mutex)); - return memory_access_ok(dev, mp, 1); + return memory_access_ok(dev, dev->memory, 1); } EXPORT_SYMBOL_GPL(vhost_log_access_ok); /* Verify access for write logging. */ /* Caller should have vq mutex and device mutex */ -static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq, +static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base) { - struct vhost_memory *mp; - size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; + size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - mp = rcu_dereference_protected(vq->dev->memory, - lockdep_is_held(&vq->mutex)); - return vq_memory_access_ok(log_base, mp, - vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) && + return vq_memory_access_ok(log_base, vq->memory, + vhost_has_feature(vq, VHOST_F_LOG_ALL)) && (!vq->log_used || log_access_ok(log_base, vq->log_addr, sizeof *vq->used + vq->num * sizeof *vq->used->ring + s)); @@ -584,8 +585,8 @@ static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq, /* Caller should have vq mutex and device mutex */ int vhost_vq_access_ok(struct vhost_virtqueue *vq) { - return vq_access_ok(vq->dev, vq->num, vq->desc, vq->avail, vq->used) && - vq_log_access_ok(vq->dev, vq, vq->log_base); + return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && + vq_log_access_ok(vq, vq->log_base); } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); @@ -593,6 +594,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) { struct vhost_memory mem, *newmem, *oldmem; unsigned long size = offsetof(struct vhost_memory, regions); + int i; if (copy_from_user(&mem, m, size)) return -EFAULT; @@ -611,15 +613,19 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) return -EFAULT; } - if (!memory_access_ok(d, newmem, - vhost_has_feature(d, VHOST_F_LOG_ALL))) { + if (!memory_access_ok(d, newmem, 0)) { kfree(newmem); return -EFAULT; } - oldmem = rcu_dereference_protected(d->memory, - lockdep_is_held(&d->mutex)); - rcu_assign_pointer(d->memory, newmem); - synchronize_rcu(); + oldmem = d->memory; + d->memory = newmem; + + /* All memory accesses are done under some VQ mutex. */ + for (i = 0; i < d->nvqs; ++i) { + mutex_lock(&d->vqs[i]->mutex); + d->vqs[i]->memory = newmem; + mutex_unlock(&d->vqs[i]->mutex); + } kfree(oldmem); return 0; } @@ -718,7 +724,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) * If it is not, we don't as size might not have been setup. * We will verify when backend is configured. */ if (vq->private_data) { - if (!vq_access_ok(d, vq->num, + if (!vq_access_ok(vq, vq->num, (void __user *)(unsigned long)a.desc_user_addr, (void __user *)(unsigned long)a.avail_user_addr, (void __user *)(unsigned long)a.used_user_addr)) { @@ -858,7 +864,7 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) vq = d->vqs[i]; mutex_lock(&vq->mutex); /* If ring is inactive, will check when it's enabled. */ - if (vq->private_data && !vq_log_access_ok(d, vq, base)) + if (vq->private_data && !vq_log_access_ok(vq, base)) r = -EFAULT; else vq->log_base = base; @@ -1044,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq) } EXPORT_SYMBOL_GPL(vhost_init_used); -static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, +static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, struct iovec iov[], int iov_size) { const struct vhost_memory_region *reg; @@ -1053,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, u64 s = 0; int ret = 0; - rcu_read_lock(); - - mem = rcu_dereference(dev->memory); + mem = vq->memory; while ((u64)len > s) { u64 size; if (unlikely(ret >= iov_size)) { @@ -1077,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, ++ret; } - rcu_read_unlock(); return ret; } @@ -1102,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc) return next; } -static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, +static int get_indirect(struct vhost_virtqueue *vq, struct iovec iov[], unsigned int iov_size, unsigned int *out_num, unsigned int *in_num, struct vhost_log *log, unsigned int *log_num, @@ -1121,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, return -EINVAL; } - ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, + ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect, UIO_MAXIOV); if (unlikely(ret < 0)) { vq_err(vq, "Translation failure %d in indirect.\n", ret); @@ -1161,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, return -EINVAL; } - ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, + ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, iov_size - iov_count); if (unlikely(ret < 0)) { vq_err(vq, "Translation failure %d indirect idx %d\n", @@ -1198,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, * This function returns the descriptor number found, or vq->num (which is * never a valid descriptor number) if none was found. A negative code is * returned on error. */ -int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, +int vhost_get_vq_desc(struct vhost_virtqueue *vq, struct iovec iov[], unsigned int iov_size, unsigned int *out_num, unsigned int *in_num, struct vhost_log *log, unsigned int *log_num) @@ -1272,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, return -EFAULT; } if (desc.flags & VRING_DESC_F_INDIRECT) { - ret = get_indirect(dev, vq, iov, iov_size, + ret = get_indirect(vq, iov, iov_size, out_num, in_num, log, log_num, &desc); if (unlikely(ret < 0)) { @@ -1283,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, continue; } - ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, + ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, iov_size - iov_count); if (unlikely(ret < 0)) { vq_err(vq, "Translation failure %d descriptor idx %d\n", @@ -1426,11 +1429,11 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) * interrupts. */ smp_mb(); - if (vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) && + if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) && unlikely(vq->avail_idx == vq->last_avail_idx)) return true; - if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { + if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { __u16 flags; if (__get_user(flags, &vq->avail->flags)) { vq_err(vq, "Failed to get flags"); @@ -1491,7 +1494,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) return false; vq->used_flags &= ~VRING_USED_F_NO_NOTIFY; - if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { + if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { r = vhost_update_used_flags(vq); if (r) { vq_err(vq, "Failed to enable notification at %p: %d\n", @@ -1528,7 +1531,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (vq->used_flags & VRING_USED_F_NO_NOTIFY) return; vq->used_flags |= VRING_USED_F_NO_NOTIFY; - if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { + if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { r = vhost_update_used_flags(vq); if (r) vq_err(vq, "Failed to enable notification at %p: %d\n", |