diff options
-rw-r--r-- | drivers/vhost/net.c | 8 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 182 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 27 |
3 files changed, 128 insertions, 89 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index f744eeb3e2b4..a6b270aff9ef 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -1036,20 +1036,20 @@ static long vhost_net_reset_owner(struct vhost_net *n) struct socket *tx_sock = NULL; struct socket *rx_sock = NULL; long err; - struct vhost_memory *memory; + struct vhost_umem *umem; mutex_lock(&n->dev.mutex); err = vhost_dev_check_owner(&n->dev); if (err) goto done; - memory = vhost_dev_reset_owner_prepare(); - if (!memory) { + umem = vhost_dev_reset_owner_prepare(); + if (!umem) { err = -ENOMEM; goto done; } vhost_net_stop(n, &tx_sock, &rx_sock); vhost_net_flush(n); - vhost_dev_reset_owner(&n->dev, memory); + vhost_dev_reset_owner(&n->dev, umem); vhost_net_vq_reset(n); done: mutex_unlock(&n->dev.mutex); diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index e1b5047d8f19..8071f3638db9 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -27,6 +27,7 @@ #include <linux/cgroup.h> #include <linux/module.h> #include <linux/sort.h> +#include <linux/interval_tree_generic.h> #include "vhost.h" @@ -42,6 +43,10 @@ enum { #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) +INTERVAL_TREE_DEFINE(struct vhost_umem_node, + rb, __u64, __subtree_last, + START, LAST, , vhost_umem_interval_tree); + #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) { @@ -297,10 +302,10 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->call_ctx = NULL; vq->call = NULL; vq->log_ctx = NULL; - vq->memory = NULL; vhost_reset_is_le(vq); vhost_disable_cross_endian(vq); vq->busyloop_timeout = 0; + vq->umem = NULL; } static int vhost_worker(void *data) @@ -394,7 +399,7 @@ void vhost_dev_init(struct vhost_dev *dev, mutex_init(&dev->mutex); dev->log_ctx = NULL; dev->log_file = NULL; - dev->memory = NULL; + dev->umem = NULL; dev->mm = NULL; dev->worker = NULL; init_llist_head(&dev->work_list); @@ -499,27 +504,36 @@ err_mm: } EXPORT_SYMBOL_GPL(vhost_dev_set_owner); -struct vhost_memory *vhost_dev_reset_owner_prepare(void) +static void *vhost_kvzalloc(unsigned long size) { - return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL); + void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); + + if (!n) + n = vzalloc(size); + return n; +} + +struct vhost_umem *vhost_dev_reset_owner_prepare(void) +{ + return vhost_kvzalloc(sizeof(struct vhost_umem)); } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); /* Caller should have device mutex */ -void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) +void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem) { int i; vhost_dev_cleanup(dev, true); /* Restore memory to default empty mapping. */ - memory->nregions = 0; - dev->memory = memory; + INIT_LIST_HEAD(&umem->umem_list); + dev->umem = umem; /* We don't need VQ locks below since vhost_dev_cleanup makes sure * VQs aren't running. */ for (i = 0; i < dev->nvqs; ++i) - dev->vqs[i]->memory = memory; + dev->vqs[i]->umem = umem; } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); @@ -536,6 +550,21 @@ void vhost_dev_stop(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_dev_stop); +static void vhost_umem_clean(struct vhost_umem *umem) +{ + struct vhost_umem_node *node, *tmp; + + if (!umem) + return; + + list_for_each_entry_safe(node, tmp, &umem->umem_list, link) { + vhost_umem_interval_tree_remove(node, &umem->umem_tree); + list_del(&node->link); + kvfree(node); + } + kvfree(umem); +} + /* Caller should have device mutex if and only if locked is set */ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) { @@ -562,8 +591,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) fput(dev->log_file); dev->log_file = NULL; /* No one will access memory at this point */ - kvfree(dev->memory); - dev->memory = NULL; + vhost_umem_clean(dev->umem); + dev->umem = NULL; WARN_ON(!llist_empty(&dev->work_list)); if (dev->worker) { kthread_stop(dev->worker); @@ -589,25 +618,25 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) } /* Caller should have vq mutex and device mutex. */ -static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, +static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, int log_all) { - int i; + struct vhost_umem_node *node; - if (!mem) + if (!umem) return 0; - for (i = 0; i < mem->nregions; ++i) { - struct vhost_memory_region *m = mem->regions + i; - unsigned long a = m->userspace_addr; - if (m->memory_size > ULONG_MAX) + list_for_each_entry(node, &umem->umem_list, link) { + unsigned long a = node->userspace_addr; + + if (node->size > ULONG_MAX) return 0; else if (!access_ok(VERIFY_WRITE, (void __user *)a, - m->memory_size)) + node->size)) return 0; else if (log_all && !log_access_ok(log_base, - m->guest_phys_addr, - m->memory_size)) + node->start, + node->size)) return 0; } return 1; @@ -615,7 +644,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, /* Can we switch to this memory table? */ /* Caller should have device mutex but not vq mutex */ -static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, +static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, int log_all) { int i; @@ -628,7 +657,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); /* If ring is inactive, will check when it's enabled. */ if (d->vqs[i]->private_data) - ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log); + ok = vq_memory_access_ok(d->vqs[i]->log_base, + umem, log); else ok = 1; mutex_unlock(&d->vqs[i]->mutex); @@ -671,7 +701,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, /* Caller should have device mutex but not vq mutex */ int vhost_log_access_ok(struct vhost_dev *dev) { - return memory_access_ok(dev, dev->memory, 1); + return memory_access_ok(dev, dev->umem, 1); } EXPORT_SYMBOL_GPL(vhost_log_access_ok); @@ -682,7 +712,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, { size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return vq_memory_access_ok(log_base, vq->memory, + return vq_memory_access_ok(log_base, vq->umem, vhost_has_feature(vq, VHOST_F_LOG_ALL)) && (!vq->log_used || log_access_ok(log_base, vq->log_addr, sizeof *vq->used + @@ -698,28 +728,12 @@ int vhost_vq_access_ok(struct vhost_virtqueue *vq) } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); -static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2) -{ - const struct vhost_memory_region *r1 = p1, *r2 = p2; - if (r1->guest_phys_addr < r2->guest_phys_addr) - return 1; - if (r1->guest_phys_addr > r2->guest_phys_addr) - return -1; - return 0; -} - -static void *vhost_kvzalloc(unsigned long size) -{ - void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); - - if (!n) - n = vzalloc(size); - return n; -} - static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) { - struct vhost_memory mem, *newmem, *oldmem; + struct vhost_memory mem, *newmem; + struct vhost_memory_region *region; + struct vhost_umem_node *node; + struct vhost_umem *newumem, *oldumem; unsigned long size = offsetof(struct vhost_memory, regions); int i; @@ -739,24 +753,52 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) kvfree(newmem); return -EFAULT; } - sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions), - vhost_memory_reg_sort_cmp, NULL); - if (!memory_access_ok(d, newmem, 0)) { + newumem = vhost_kvzalloc(sizeof(*newumem)); + if (!newumem) { kvfree(newmem); - return -EFAULT; + return -ENOMEM; + } + + newumem->umem_tree = RB_ROOT; + INIT_LIST_HEAD(&newumem->umem_list); + + for (region = newmem->regions; + region < newmem->regions + mem.nregions; + region++) { + node = vhost_kvzalloc(sizeof(*node)); + if (!node) + goto err; + node->start = region->guest_phys_addr; + node->size = region->memory_size; + node->last = node->start + node->size - 1; + node->userspace_addr = region->userspace_addr; + INIT_LIST_HEAD(&node->link); + list_add_tail(&node->link, &newumem->umem_list); + vhost_umem_interval_tree_insert(node, &newumem->umem_tree); } - oldmem = d->memory; - d->memory = newmem; + + if (!memory_access_ok(d, newumem, 0)) + goto err; + + oldumem = d->umem; + d->umem = newumem; /* All memory accesses are done under some VQ mutex. */ for (i = 0; i < d->nvqs; ++i) { mutex_lock(&d->vqs[i]->mutex); - d->vqs[i]->memory = newmem; + d->vqs[i]->umem = newumem; mutex_unlock(&d->vqs[i]->mutex); } - kvfree(oldmem); + + kvfree(newmem); + vhost_umem_clean(oldumem); return 0; + +err: + vhost_umem_clean(newumem); + kvfree(newmem); + return -EFAULT; } long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) @@ -1059,28 +1101,6 @@ done: } EXPORT_SYMBOL_GPL(vhost_dev_ioctl); -static const struct vhost_memory_region *find_region(struct vhost_memory *mem, - __u64 addr, __u32 len) -{ - const struct vhost_memory_region *reg; - int start = 0, end = mem->nregions; - - while (start < end) { - int slot = start + (end - start) / 2; - reg = mem->regions + slot; - if (addr >= reg->guest_phys_addr) - end = slot; - else - start = slot + 1; - } - - reg = mem->regions + start; - if (addr >= reg->guest_phys_addr && - reg->guest_phys_addr + reg->memory_size > addr) - return reg; - return NULL; -} - /* TODO: This is really inefficient. We need something like get_user() * (instruction directly accesses the data, with an exception table entry * returning -EFAULT). See Documentation/x86/exception-tables.txt. @@ -1231,29 +1251,29 @@ EXPORT_SYMBOL_GPL(vhost_vq_init_access); static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, struct iovec iov[], int iov_size) { - const struct vhost_memory_region *reg; - struct vhost_memory *mem; + const struct vhost_umem_node *node; + struct vhost_umem *umem = vq->umem; struct iovec *_iov; u64 s = 0; int ret = 0; - mem = vq->memory; while ((u64)len > s) { u64 size; if (unlikely(ret >= iov_size)) { ret = -ENOBUFS; break; } - reg = find_region(mem, addr, len); - if (unlikely(!reg)) { + node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, + addr, addr + len - 1); + if (node == NULL || node->start > addr) { ret = -EFAULT; break; } _iov = iov + ret; - size = reg->memory_size - addr + reg->guest_phys_addr; + size = node->size - addr + node->start; _iov->iov_len = min((u64)len - s, size); _iov->iov_base = (void __user *)(unsigned long) - (reg->userspace_addr + addr - reg->guest_phys_addr); + (node->userspace_addr + addr - node->start); s += size; addr += size; ++ret; diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 6690e645d2f8..eaaf6df72218 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -55,6 +55,25 @@ struct vhost_log { u64 len; }; +#define START(node) ((node)->start) +#define LAST(node) ((node)->last) + +struct vhost_umem_node { + struct rb_node rb; + struct list_head link; + __u64 start; + __u64 last; + __u64 size; + __u64 userspace_addr; + __u64 flags_padding; + __u64 __subtree_last; +}; + +struct vhost_umem { + struct rb_root umem_tree; + struct list_head umem_list; +}; + /* The virtqueue structure describes a queue attached to a device. */ struct vhost_virtqueue { struct vhost_dev *dev; @@ -103,7 +122,7 @@ struct vhost_virtqueue { struct iovec *indirect; struct vring_used_elem *heads; /* Protected by virtqueue mutex. */ - struct vhost_memory *memory; + struct vhost_umem *umem; void *private_data; u64 acked_features; /* Log write descriptors */ @@ -121,7 +140,6 @@ struct vhost_virtqueue { }; struct vhost_dev { - struct vhost_memory *memory; struct mm_struct *mm; struct mutex mutex; struct vhost_virtqueue **vqs; @@ -130,14 +148,15 @@ struct vhost_dev { struct eventfd_ctx *log_ctx; struct llist_head work_list; struct task_struct *worker; + struct vhost_umem *umem; }; void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); long vhost_dev_set_owner(struct vhost_dev *dev); bool vhost_dev_has_owner(struct vhost_dev *dev); long vhost_dev_check_owner(struct vhost_dev *); -struct vhost_memory *vhost_dev_reset_owner_prepare(void); -void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *); +struct vhost_umem *vhost_dev_reset_owner_prepare(void); +void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_umem *); void vhost_dev_cleanup(struct vhost_dev *, bool locked); void vhost_dev_stop(struct vhost_dev *); long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp); |