diff options
Diffstat (limited to 'drivers/vhost')
-rw-r--r-- | drivers/vhost/net.c | 326 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 26 |
2 files changed, 259 insertions, 93 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 4e656f89cb22..ab11b2bee273 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -116,6 +116,8 @@ struct vhost_net_virtqueue { * For RX, number of batched heads */ int done_idx; + /* Number of XDP frames batched */ + int batched_xdp; /* an array of userspace buffers info */ struct ubuf_info *ubuf_info; /* Reference counting for outstanding ubufs. @@ -123,6 +125,8 @@ struct vhost_net_virtqueue { struct vhost_net_ubuf_ref *ubufs; struct ptr_ring *rx_ring; struct vhost_net_buf rxq; + /* Batched XDP buffs */ + struct xdp_buff *xdp; }; struct vhost_net { @@ -338,6 +342,11 @@ static bool vhost_sock_zcopy(struct socket *sock) sock_flag(sock->sk, SOCK_ZEROCOPY); } +static bool vhost_sock_xdp(struct socket *sock) +{ + return sock_flag(sock->sk, SOCK_XDP); +} + /* In case of DMA done not in order in lower device driver for some reason. * upend_idx is used to track end of used idx, done_idx is used to track head * of used idx. Once lower device DMA done contiguously, we will signal KVM @@ -444,32 +453,120 @@ static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq) nvq->done_idx = 0; } +static void vhost_tx_batch(struct vhost_net *net, + struct vhost_net_virtqueue *nvq, + struct socket *sock, + struct msghdr *msghdr) +{ + struct tun_msg_ctl ctl = { + .type = TUN_MSG_PTR, + .num = nvq->batched_xdp, + .ptr = nvq->xdp, + }; + int err; + + if (nvq->batched_xdp == 0) + goto signal_used; + + msghdr->msg_control = &ctl; + err = sock->ops->sendmsg(sock, msghdr, 0); + if (unlikely(err < 0)) { + vq_err(&nvq->vq, "Fail to batch sending packets\n"); + return; + } + +signal_used: + vhost_net_signal_used(nvq); + nvq->batched_xdp = 0; +} + +static int sock_has_rx_data(struct socket *sock) +{ + if (unlikely(!sock)) + return 0; + + if (sock->ops->peek_len) + return sock->ops->peek_len(sock); + + return skb_queue_empty(&sock->sk->sk_receive_queue); +} + +static void vhost_net_busy_poll_try_queue(struct vhost_net *net, + struct vhost_virtqueue *vq) +{ + if (!vhost_vq_avail_empty(&net->dev, vq)) { + vhost_poll_queue(&vq->poll); + } else if (unlikely(vhost_enable_notify(&net->dev, vq))) { + vhost_disable_notify(&net->dev, vq); + vhost_poll_queue(&vq->poll); + } +} + +static void vhost_net_busy_poll(struct vhost_net *net, + struct vhost_virtqueue *rvq, + struct vhost_virtqueue *tvq, + bool *busyloop_intr, + bool poll_rx) +{ + unsigned long busyloop_timeout; + unsigned long endtime; + struct socket *sock; + struct vhost_virtqueue *vq = poll_rx ? tvq : rvq; + + mutex_lock_nested(&vq->mutex, poll_rx ? VHOST_NET_VQ_TX: VHOST_NET_VQ_RX); + vhost_disable_notify(&net->dev, vq); + sock = rvq->private_data; + + busyloop_timeout = poll_rx ? rvq->busyloop_timeout: + tvq->busyloop_timeout; + + preempt_disable(); + endtime = busy_clock() + busyloop_timeout; + + while (vhost_can_busy_poll(endtime)) { + if (vhost_has_work(&net->dev)) { + *busyloop_intr = true; + break; + } + + if ((sock_has_rx_data(sock) && + !vhost_vq_avail_empty(&net->dev, rvq)) || + !vhost_vq_avail_empty(&net->dev, tvq)) + break; + + cpu_relax(); + } + + preempt_enable(); + + if (poll_rx || sock_has_rx_data(sock)) + vhost_net_busy_poll_try_queue(net, vq); + else if (!poll_rx) /* On tx here, sock has no rx data. */ + vhost_enable_notify(&net->dev, rvq); + + mutex_unlock(&vq->mutex); +} + static int vhost_net_tx_get_vq_desc(struct vhost_net *net, - struct vhost_net_virtqueue *nvq, + struct vhost_net_virtqueue *tnvq, unsigned int *out_num, unsigned int *in_num, - bool *busyloop_intr) + struct msghdr *msghdr, bool *busyloop_intr) { - struct vhost_virtqueue *vq = &nvq->vq; - unsigned long uninitialized_var(endtime); - int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), + struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX]; + struct vhost_virtqueue *rvq = &rnvq->vq; + struct vhost_virtqueue *tvq = &tnvq->vq; + + int r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov), out_num, in_num, NULL, NULL); - if (r == vq->num && vq->busyloop_timeout) { - if (!vhost_sock_zcopy(vq->private_data)) - vhost_net_signal_used(nvq); - preempt_disable(); - endtime = busy_clock() + vq->busyloop_timeout; - while (vhost_can_busy_poll(endtime)) { - if (vhost_has_work(vq->dev)) { - *busyloop_intr = true; - break; - } - if (!vhost_vq_avail_empty(vq->dev, vq)) - break; - cpu_relax(); - } - preempt_enable(); - r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), + if (r == tvq->num && tvq->busyloop_timeout) { + /* Flush batched packets first */ + if (!vhost_sock_zcopy(tvq->private_data)) + vhost_tx_batch(net, tnvq, tvq->private_data, msghdr); + + vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, false); + + r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov), out_num, in_num, NULL, NULL); } @@ -512,7 +609,7 @@ static int get_tx_bufs(struct vhost_net *net, struct vhost_virtqueue *vq = &nvq->vq; int ret; - ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr); + ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg, busyloop_intr); if (ret < 0 || ret == vq->num) return ret; @@ -540,6 +637,80 @@ static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len) !vhost_vq_avail_empty(vq->dev, vq); } +#define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD) + +static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, + struct iov_iter *from) +{ + struct vhost_virtqueue *vq = &nvq->vq; + struct socket *sock = vq->private_data; + struct page_frag *alloc_frag = ¤t->task_frag; + struct virtio_net_hdr *gso; + struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp]; + struct tun_xdp_hdr *hdr; + size_t len = iov_iter_count(from); + int headroom = vhost_sock_xdp(sock) ? XDP_PACKET_HEADROOM : 0; + int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + int pad = SKB_DATA_ALIGN(VHOST_NET_RX_PAD + headroom + nvq->sock_hlen); + int sock_hlen = nvq->sock_hlen; + void *buf; + int copied; + + if (unlikely(len < nvq->sock_hlen)) + return -EFAULT; + + if (SKB_DATA_ALIGN(len + pad) + + SKB_DATA_ALIGN(sizeof(struct skb_shared_info)) > PAGE_SIZE) + return -ENOSPC; + + buflen += SKB_DATA_ALIGN(len + pad); + alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES); + if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL))) + return -ENOMEM; + + buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset; + copied = copy_page_from_iter(alloc_frag->page, + alloc_frag->offset + + offsetof(struct tun_xdp_hdr, gso), + sock_hlen, from); + if (copied != sock_hlen) + return -EFAULT; + + hdr = buf; + gso = &hdr->gso; + + if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) && + vhost16_to_cpu(vq, gso->csum_start) + + vhost16_to_cpu(vq, gso->csum_offset) + 2 > + vhost16_to_cpu(vq, gso->hdr_len)) { + gso->hdr_len = cpu_to_vhost16(vq, + vhost16_to_cpu(vq, gso->csum_start) + + vhost16_to_cpu(vq, gso->csum_offset) + 2); + + if (vhost16_to_cpu(vq, gso->hdr_len) > len) + return -EINVAL; + } + + len -= sock_hlen; + copied = copy_page_from_iter(alloc_frag->page, + alloc_frag->offset + pad, + len, from); + if (copied != len) + return -EFAULT; + + xdp->data_hard_start = buf; + xdp->data = buf + pad; + xdp->data_end = xdp->data + len; + hdr->buflen = buflen; + + get_page(alloc_frag->page); + alloc_frag->offset += buflen; + + ++nvq->batched_xdp; + + return 0; +} + static void handle_tx_copy(struct vhost_net *net, struct socket *sock) { struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; @@ -556,10 +727,14 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) size_t len, total_len = 0; int err; int sent_pkts = 0; + bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX); for (;;) { bool busyloop_intr = false; + if (nvq->done_idx == VHOST_NET_BATCH) + vhost_tx_batch(net, nvq, sock, &msg); + head = get_tx_bufs(net, nvq, &msg, &out, &in, &len, &busyloop_intr); /* On error, stop handling until the next kick. */ @@ -577,14 +752,34 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) break; } - vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); - vq->heads[nvq->done_idx].len = 0; - total_len += len; - if (tx_can_batch(vq, total_len)) - msg.msg_flags |= MSG_MORE; - else - msg.msg_flags &= ~MSG_MORE; + + /* For simplicity, TX batching is only enabled if + * sndbuf is unlimited. + */ + if (sock_can_batch) { + err = vhost_net_build_xdp(nvq, &msg.msg_iter); + if (!err) { + goto done; + } else if (unlikely(err != -ENOSPC)) { + vhost_tx_batch(net, nvq, sock, &msg); + vhost_discard_vq_desc(vq, 1); + vhost_net_enable_vq(net, vq); + break; + } + + /* We can't build XDP buff, go for single + * packet path but let's flush batched + * packets. + */ + vhost_tx_batch(net, nvq, sock, &msg); + msg.msg_control = NULL; + } else { + if (tx_can_batch(vq, total_len)) + msg.msg_flags |= MSG_MORE; + else + msg.msg_flags &= ~MSG_MORE; + } /* TODO: Check specific error and bomb out unless ENOBUFS? */ err = sock->ops->sendmsg(sock, &msg, len); @@ -596,15 +791,17 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) if (err != len) pr_debug("Truncated TX packet: len %d != %zd\n", err, len); - if (++nvq->done_idx >= VHOST_NET_BATCH) - vhost_net_signal_used(nvq); +done: + vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); + vq->heads[nvq->done_idx].len = 0; + ++nvq->done_idx; if (vhost_exceeds_weight(++sent_pkts, total_len)) { vhost_poll_queue(&vq->poll); break; } } - vhost_net_signal_used(nvq); + vhost_tx_batch(net, nvq, sock, &msg); } static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) @@ -620,6 +817,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) .msg_controllen = 0, .msg_flags = MSG_DONTWAIT, }; + struct tun_msg_ctl ctl; size_t len, total_len = 0; int err; struct vhost_net_ubuf_ref *uninitialized_var(ubufs); @@ -664,8 +862,10 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) ubuf->ctx = nvq->ubufs; ubuf->desc = nvq->upend_idx; refcount_set(&ubuf->refcnt, 1); - msg.msg_control = ubuf; - msg.msg_controllen = sizeof(ubuf); + msg.msg_control = &ctl; + ctl.type = TUN_MSG_UBUF; + ctl.ptr = ubuf; + msg.msg_controllen = sizeof(ctl); ubufs = nvq->ubufs; atomic_inc(&ubufs->refcount); nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV; @@ -716,7 +916,7 @@ static void handle_tx(struct vhost_net *net) struct vhost_virtqueue *vq = &nvq->vq; struct socket *sock; - mutex_lock(&vq->mutex); + mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_TX); sock = vq->private_data; if (!sock) goto out; @@ -757,16 +957,6 @@ static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk) return len; } -static int sk_has_rx_data(struct sock *sk) -{ - struct socket *sock = sk->sk_socket; - - if (sock->ops->peek_len) - return sock->ops->peek_len(sock); - - return skb_queue_empty(&sk->sk_receive_queue); -} - static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, bool *busyloop_intr) { @@ -774,41 +964,13 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX]; struct vhost_virtqueue *rvq = &rnvq->vq; struct vhost_virtqueue *tvq = &tnvq->vq; - unsigned long uninitialized_var(endtime); int len = peek_head_len(rnvq, sk); - if (!len && tvq->busyloop_timeout) { + if (!len && rvq->busyloop_timeout) { /* Flush batched heads first */ vhost_net_signal_used(rnvq); /* Both tx vq and rx socket were polled here */ - mutex_lock_nested(&tvq->mutex, 1); - vhost_disable_notify(&net->dev, tvq); - - preempt_disable(); - endtime = busy_clock() + tvq->busyloop_timeout; - - while (vhost_can_busy_poll(endtime)) { - if (vhost_has_work(&net->dev)) { - *busyloop_intr = true; - break; - } - if ((sk_has_rx_data(sk) && - !vhost_vq_avail_empty(&net->dev, rvq)) || - !vhost_vq_avail_empty(&net->dev, tvq)) - break; - cpu_relax(); - } - - preempt_enable(); - - if (!vhost_vq_avail_empty(&net->dev, tvq)) { - vhost_poll_queue(&tvq->poll); - } else if (unlikely(vhost_enable_notify(&net->dev, tvq))) { - vhost_disable_notify(&net->dev, tvq); - vhost_poll_queue(&tvq->poll); - } - - mutex_unlock(&tvq->mutex); + vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, true); len = peek_head_len(rnvq, sk); } @@ -923,7 +1085,7 @@ static void handle_rx(struct vhost_net *net) __virtio16 num_buffers; int recv_pkts = 0; - mutex_lock_nested(&vq->mutex, 0); + mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_RX); sock = vq->private_data; if (!sock) goto out; @@ -1078,6 +1240,7 @@ static int vhost_net_open(struct inode *inode, struct file *f) struct vhost_dev *dev; struct vhost_virtqueue **vqs; void **queue; + struct xdp_buff *xdp; int i; n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL); @@ -1098,6 +1261,15 @@ static int vhost_net_open(struct inode *inode, struct file *f) } n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue; + xdp = kmalloc_array(VHOST_NET_BATCH, sizeof(*xdp), GFP_KERNEL); + if (!xdp) { + kfree(vqs); + kvfree(n); + kfree(queue); + return -ENOMEM; + } + n->vqs[VHOST_NET_VQ_TX].xdp = xdp; + dev = &n->dev; vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq; vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq; @@ -1108,6 +1280,7 @@ static int vhost_net_open(struct inode *inode, struct file *f) n->vqs[i].ubuf_info = NULL; n->vqs[i].upend_idx = 0; n->vqs[i].done_idx = 0; + n->vqs[i].batched_xdp = 0; n->vqs[i].vhost_hlen = 0; n->vqs[i].sock_hlen = 0; n->vqs[i].rx_ring = NULL; @@ -1191,6 +1364,7 @@ static int vhost_net_release(struct inode *inode, struct file *f) * since jobs can re-queue themselves. */ vhost_net_flush(n); kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue); + kfree(n->vqs[VHOST_NET_VQ_TX].xdp); kfree(n->dev.vqs); kvfree(n); return 0; diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index b13c6b4b2c66..3a5f81a66d34 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -30,6 +30,7 @@ #include <linux/sched/mm.h> #include <linux/sched/signal.h> #include <linux/interval_tree_generic.h> +#include <linux/nospec.h> #include "vhost.h" @@ -294,8 +295,11 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) { int i; - for (i = 0; i < d->nvqs; ++i) + for (i = 0; i < d->nvqs; ++i) { + mutex_lock(&d->vqs[i]->mutex); __vhost_vq_meta_reset(d->vqs[i]); + mutex_unlock(&d->vqs[i]->mutex); + } } static void vhost_vq_reset(struct vhost_dev *dev, @@ -891,20 +895,6 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, #define vhost_get_used(vq, x, ptr) \ vhost_get_user(vq, x, ptr, VHOST_ADDR_USED) -static void vhost_dev_lock_vqs(struct vhost_dev *d) -{ - int i = 0; - for (i = 0; i < d->nvqs; ++i) - mutex_lock_nested(&d->vqs[i]->mutex, i); -} - -static void vhost_dev_unlock_vqs(struct vhost_dev *d) -{ - int i = 0; - for (i = 0; i < d->nvqs; ++i) - mutex_unlock(&d->vqs[i]->mutex); -} - static int vhost_new_umem_range(struct vhost_umem *umem, u64 start, u64 size, u64 end, u64 userspace_addr, int perm) @@ -954,7 +944,10 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d, if (msg->iova <= vq_msg->iova && msg->iova + msg->size - 1 >= vq_msg->iova && vq_msg->type == VHOST_IOTLB_MISS) { + mutex_lock(&node->vq->mutex); vhost_poll_queue(&node->vq->poll); + mutex_unlock(&node->vq->mutex); + list_del(&node->node); kfree(node); } @@ -986,7 +979,6 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, int ret = 0; mutex_lock(&dev->mutex); - vhost_dev_lock_vqs(dev); switch (msg->type) { case VHOST_IOTLB_UPDATE: if (!dev->iotlb) { @@ -1020,7 +1012,6 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, break; } - vhost_dev_unlock_vqs(dev); mutex_unlock(&dev->mutex); return ret; @@ -1397,6 +1388,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg if (idx >= d->nvqs) return -ENOBUFS; + idx = array_index_nospec(idx, d->nvqs); vq = d->vqs[idx]; mutex_lock(&vq->mutex); |