summaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c42
1 files changed, 27 insertions, 15 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 3a5f81a66d34..9f7942cbcbb2 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -295,11 +295,8 @@ static void vhost_vq_meta_reset(struct vhost_dev *d)
{
int i;
- for (i = 0; i < d->nvqs; ++i) {
- mutex_lock(&d->vqs[i]->mutex);
+ for (i = 0; i < d->nvqs; ++i)
__vhost_vq_meta_reset(d->vqs[i]);
- mutex_unlock(&d->vqs[i]->mutex);
- }
}
static void vhost_vq_reset(struct vhost_dev *dev,
@@ -658,7 +655,7 @@ static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
a + (unsigned long)log_base > ULONG_MAX)
return false;
- return access_ok(VERIFY_WRITE, log_base + a,
+ return access_ok(log_base + a,
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
}
@@ -684,7 +681,7 @@ static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
return false;
- if (!access_ok(VERIFY_WRITE, (void __user *)a,
+ if (!access_ok((void __user *)a,
node->size))
return false;
else if (log_all && !log_access_ok(log_base,
@@ -895,6 +892,20 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
#define vhost_get_used(vq, x, ptr) \
vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_lock_nested(&d->vqs[i]->mutex, i);
+}
+
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_unlock(&d->vqs[i]->mutex);
+}
+
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
@@ -944,10 +955,7 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d,
if (msg->iova <= vq_msg->iova &&
msg->iova + msg->size - 1 >= vq_msg->iova &&
vq_msg->type == VHOST_IOTLB_MISS) {
- mutex_lock(&node->vq->mutex);
vhost_poll_queue(&node->vq->poll);
- mutex_unlock(&node->vq->mutex);
-
list_del(&node->node);
kfree(node);
}
@@ -965,10 +973,10 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access)
return false;
if ((access & VHOST_ACCESS_RO) &&
- !access_ok(VERIFY_READ, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
if ((access & VHOST_ACCESS_WO) &&
- !access_ok(VERIFY_WRITE, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
return true;
}
@@ -979,6 +987,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
int ret = 0;
mutex_lock(&dev->mutex);
+ vhost_dev_lock_vqs(dev);
switch (msg->type) {
case VHOST_IOTLB_UPDATE:
if (!dev->iotlb) {
@@ -1012,6 +1021,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
break;
}
+ vhost_dev_unlock_vqs(dev);
mutex_unlock(&dev->mutex);
return ret;
@@ -1175,10 +1185,10 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
{
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
- return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
- access_ok(VERIFY_READ, avail,
+ return access_ok(desc, num * sizeof *desc) &&
+ access_ok(avail,
sizeof *avail + num * sizeof *avail->ring + s) &&
- access_ok(VERIFY_WRITE, used,
+ access_ok(used,
sizeof *used + num * sizeof *used->ring + s);
}
@@ -1804,7 +1814,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
goto err;
vq->signalled_used_valid = false;
if (!vq->iotlb &&
- !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
+ !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
r = -EFAULT;
goto err;
}
@@ -2223,6 +2233,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
return -EFAULT;
}
if (unlikely(vq->log_used)) {
+ /* Make sure used idx is seen before log. */
+ smp_wmb();
/* Log used index update. */
log_write(vq->log_base,
vq->log_addr + offsetof(struct vring_used, idx),