diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 57 |
1 files changed, 25 insertions, 32 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index a23870cbbf91..c90f4374442a 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -18,7 +18,6 @@ #include <linux/mmu_context.h> #include <linux/miscdevice.h> #include <linux/mutex.h> -#include <linux/rcupdate.h> #include <linux/poll.h> #include <linux/file.h> #include <linux/highmem.h> @@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->call_ctx = NULL; vq->call = NULL; vq->log_ctx = NULL; + vq->memory = NULL; } static int vhost_worker(void *data) @@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); /* Caller should have device mutex */ void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) { + int i; + vhost_dev_cleanup(dev, true); /* Restore memory to default empty mapping. */ memory->nregions = 0; - RCU_INIT_POINTER(dev->memory, memory); + dev->memory = memory; + /* We don't need VQ locks below since vhost_dev_cleanup makes sure + * VQs aren't running. + */ + for (i = 0; i < dev->nvqs; ++i) + dev->vqs[i]->memory = memory; } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); @@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) fput(dev->log_file); dev->log_file = NULL; /* No one will access memory at this point */ - kfree(rcu_dereference_protected(dev->memory, - locked == - lockdep_is_held(&dev->mutex))); - RCU_INIT_POINTER(dev->memory, NULL); + kfree(dev->memory); + dev->memory = NULL; WARN_ON(!list_empty(&dev->work_list)); if (dev->worker) { kthread_stop(dev->worker); @@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, /* Caller should have device mutex but not vq mutex */ int vhost_log_access_ok(struct vhost_dev *dev) { - struct vhost_memory *mp; - - mp = rcu_dereference_protected(dev->memory, - lockdep_is_held(&dev->mutex)); - return memory_access_ok(dev, mp, 1); + return memory_access_ok(dev, dev->memory, 1); } EXPORT_SYMBOL_GPL(vhost_log_access_ok); @@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok); static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base) { - struct vhost_memory *mp; size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - mp = rcu_dereference_protected(vq->dev->memory, - lockdep_is_held(&vq->mutex)); - return vq_memory_access_ok(log_base, mp, + return vq_memory_access_ok(log_base, vq->memory, vhost_has_feature(vq, VHOST_F_LOG_ALL)) && (!vq->log_used || log_access_ok(log_base, vq->log_addr, sizeof *vq->used + @@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) kfree(newmem); return -EFAULT; } - oldmem = rcu_dereference_protected(d->memory, - lockdep_is_held(&d->mutex)); - rcu_assign_pointer(d->memory, newmem); + oldmem = d->memory; + d->memory = newmem; - /* All memory accesses are done under some VQ mutex. - * So below is a faster equivalent of synchronize_rcu() - */ + /* All memory accesses are done under some VQ mutex. */ for (i = 0; i < d->nvqs; ++i) { mutex_lock(&d->vqs[i]->mutex); + d->vqs[i]->memory = newmem; mutex_unlock(&d->vqs[i]->mutex); } kfree(oldmem); @@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq) } EXPORT_SYMBOL_GPL(vhost_init_used); -static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, +static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, struct iovec iov[], int iov_size) { const struct vhost_memory_region *reg; @@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, u64 s = 0; int ret = 0; - rcu_read_lock(); - - mem = rcu_dereference(dev->memory); + mem = vq->memory; while ((u64)len > s) { u64 size; if (unlikely(ret >= iov_size)) { @@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, ++ret; } - rcu_read_unlock(); return ret; } @@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc) return next; } -static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, +static int get_indirect(struct vhost_virtqueue *vq, struct iovec iov[], unsigned int iov_size, unsigned int *out_num, unsigned int *in_num, struct vhost_log *log, unsigned int *log_num, @@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, return -EINVAL; } - ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, + ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect, UIO_MAXIOV); if (unlikely(ret < 0)) { vq_err(vq, "Translation failure %d in indirect.\n", ret); @@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, return -EINVAL; } - ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, + ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, iov_size - iov_count); if (unlikely(ret < 0)) { vq_err(vq, "Translation failure %d indirect idx %d\n", @@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, * This function returns the descriptor number found, or vq->num (which is * never a valid descriptor number) if none was found. A negative code is * returned on error. */ -int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, +int vhost_get_vq_desc(struct vhost_virtqueue *vq, struct iovec iov[], unsigned int iov_size, unsigned int *out_num, unsigned int *in_num, struct vhost_log *log, unsigned int *log_num) @@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, return -EFAULT; } if (desc.flags & VRING_DESC_F_INDIRECT) { - ret = get_indirect(dev, vq, iov, iov_size, + ret = get_indirect(vq, iov, iov_size, out_num, in_num, log, log_num, &desc); if (unlikely(ret < 0)) { @@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, continue; } - ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, + ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count, iov_size - iov_count); if (unlikely(ret < 0)) { vq_err(vq, "Translation failure %d descriptor idx %d\n", |