diff options
Diffstat (limited to 'drivers/vhost')
-rw-r--r-- | drivers/vhost/net.c | 370 | ||||
-rw-r--r-- | drivers/vhost/scsi.c | 16 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 71 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 11 |
4 files changed, 337 insertions, 131 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 29756d88799b..4e656f89cb22 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -78,6 +78,10 @@ enum { }; enum { + VHOST_NET_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) +}; + +enum { VHOST_NET_VQ_RX = 0, VHOST_NET_VQ_TX = 1, VHOST_NET_VQ_MAX = 2, @@ -94,7 +98,7 @@ struct vhost_net_ubuf_ref { struct vhost_virtqueue *vq; }; -#define VHOST_RX_BATCH 64 +#define VHOST_NET_BATCH 64 struct vhost_net_buf { void **queue; int tail; @@ -168,7 +172,7 @@ static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq) rxq->head = 0; rxq->tail = ptr_ring_consume_batched(nvq->rx_ring, rxq->queue, - VHOST_RX_BATCH); + VHOST_NET_BATCH); return rxq->tail; } @@ -396,13 +400,10 @@ static inline unsigned long busy_clock(void) return local_clock() >> 10; } -static bool vhost_can_busy_poll(struct vhost_dev *dev, - unsigned long endtime) +static bool vhost_can_busy_poll(unsigned long endtime) { - return likely(!need_resched()) && - likely(!time_after(busy_clock(), endtime)) && - likely(!signal_pending(current)) && - !vhost_has_work(dev); + return likely(!need_resched() && !time_after(busy_clock(), endtime) && + !signal_pending(current)); } static void vhost_net_disable_vq(struct vhost_net *n, @@ -431,21 +432,42 @@ static int vhost_net_enable_vq(struct vhost_net *n, return vhost_poll_start(poll, sock->file); } +static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq) +{ + struct vhost_virtqueue *vq = &nvq->vq; + struct vhost_dev *dev = vq->dev; + + if (!nvq->done_idx) + return; + + vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx); + nvq->done_idx = 0; +} + static int vhost_net_tx_get_vq_desc(struct vhost_net *net, - struct vhost_virtqueue *vq, - struct iovec iov[], unsigned int iov_size, - unsigned int *out_num, unsigned int *in_num) + struct vhost_net_virtqueue *nvq, + unsigned int *out_num, unsigned int *in_num, + bool *busyloop_intr) { + struct vhost_virtqueue *vq = &nvq->vq; unsigned long uninitialized_var(endtime); int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), out_num, in_num, NULL, NULL); if (r == vq->num && vq->busyloop_timeout) { + if (!vhost_sock_zcopy(vq->private_data)) + vhost_net_signal_used(nvq); preempt_disable(); endtime = busy_clock() + vq->busyloop_timeout; - while (vhost_can_busy_poll(vq->dev, endtime) && - vhost_vq_avail_empty(vq->dev, vq)) + while (vhost_can_busy_poll(endtime)) { + if (vhost_has_work(vq->dev)) { + *busyloop_intr = true; + break; + } + if (!vhost_vq_avail_empty(vq->dev, vq)) + break; cpu_relax(); + } preempt_enable(); r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), out_num, in_num, NULL, NULL); @@ -463,9 +485,62 @@ static bool vhost_exceeds_maxpend(struct vhost_net *net) min_t(unsigned int, VHOST_MAX_PEND, vq->num >> 2); } -/* Expects to be always run from workqueue - which acts as - * read-size critical section for our kind of RCU. */ -static void handle_tx(struct vhost_net *net) +static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter, + size_t hdr_size, int out) +{ + /* Skip header. TODO: support TSO. */ + size_t len = iov_length(vq->iov, out); + + iov_iter_init(iter, WRITE, vq->iov, out, len); + iov_iter_advance(iter, hdr_size); + + return iov_iter_count(iter); +} + +static bool vhost_exceeds_weight(int pkts, int total_len) +{ + return total_len >= VHOST_NET_WEIGHT || + pkts >= VHOST_NET_PKT_WEIGHT; +} + +static int get_tx_bufs(struct vhost_net *net, + struct vhost_net_virtqueue *nvq, + struct msghdr *msg, + unsigned int *out, unsigned int *in, + size_t *len, bool *busyloop_intr) +{ + struct vhost_virtqueue *vq = &nvq->vq; + int ret; + + ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr); + + if (ret < 0 || ret == vq->num) + return ret; + + if (*in) { + vq_err(vq, "Unexpected descriptor format for TX: out %d, int %d\n", + *out, *in); + return -EFAULT; + } + + /* Sanity check */ + *len = init_iov_iter(vq, &msg->msg_iter, nvq->vhost_hlen, *out); + if (*len == 0) { + vq_err(vq, "Unexpected header len for TX: %zd expected %zd\n", + *len, nvq->vhost_hlen); + return -EFAULT; + } + + return ret; +} + +static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len) +{ + return total_len < VHOST_NET_WEIGHT && + !vhost_vq_avail_empty(vq->dev, vq); +} + +static void handle_tx_copy(struct vhost_net *net, struct socket *sock) { struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; struct vhost_virtqueue *vq = &nvq->vq; @@ -480,67 +555,103 @@ static void handle_tx(struct vhost_net *net) }; size_t len, total_len = 0; int err; - size_t hdr_size; - struct socket *sock; - struct vhost_net_ubuf_ref *uninitialized_var(ubufs); - bool zcopy, zcopy_used; int sent_pkts = 0; - mutex_lock(&vq->mutex); - sock = vq->private_data; - if (!sock) - goto out; + for (;;) { + bool busyloop_intr = false; - if (!vq_iotlb_prefetch(vq)) - goto out; + head = get_tx_bufs(net, nvq, &msg, &out, &in, &len, + &busyloop_intr); + /* On error, stop handling until the next kick. */ + if (unlikely(head < 0)) + break; + /* Nothing new? Wait for eventfd to tell us they refilled. */ + if (head == vq->num) { + if (unlikely(busyloop_intr)) { + vhost_poll_queue(&vq->poll); + } else if (unlikely(vhost_enable_notify(&net->dev, + vq))) { + vhost_disable_notify(&net->dev, vq); + continue; + } + break; + } - vhost_disable_notify(&net->dev, vq); - vhost_net_disable_vq(net, vq); + vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); + vq->heads[nvq->done_idx].len = 0; - hdr_size = nvq->vhost_hlen; - zcopy = nvq->ubufs; + total_len += len; + if (tx_can_batch(vq, total_len)) + msg.msg_flags |= MSG_MORE; + else + msg.msg_flags &= ~MSG_MORE; + + /* TODO: Check specific error and bomb out unless ENOBUFS? */ + err = sock->ops->sendmsg(sock, &msg, len); + if (unlikely(err < 0)) { + vhost_discard_vq_desc(vq, 1); + vhost_net_enable_vq(net, vq); + break; + } + if (err != len) + pr_debug("Truncated TX packet: len %d != %zd\n", + err, len); + if (++nvq->done_idx >= VHOST_NET_BATCH) + vhost_net_signal_used(nvq); + if (vhost_exceeds_weight(++sent_pkts, total_len)) { + vhost_poll_queue(&vq->poll); + break; + } + } + + vhost_net_signal_used(nvq); +} + +static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) +{ + struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; + struct vhost_virtqueue *vq = &nvq->vq; + unsigned out, in; + int head; + struct msghdr msg = { + .msg_name = NULL, + .msg_namelen = 0, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = MSG_DONTWAIT, + }; + size_t len, total_len = 0; + int err; + struct vhost_net_ubuf_ref *uninitialized_var(ubufs); + bool zcopy_used; + int sent_pkts = 0; for (;;) { - /* Release DMAs done buffers first */ - if (zcopy) - vhost_zerocopy_signal_used(net, vq); + bool busyloop_intr; + /* Release DMAs done buffers first */ + vhost_zerocopy_signal_used(net, vq); - head = vhost_net_tx_get_vq_desc(net, vq, vq->iov, - ARRAY_SIZE(vq->iov), - &out, &in); + busyloop_intr = false; + head = get_tx_bufs(net, nvq, &msg, &out, &in, &len, + &busyloop_intr); /* On error, stop handling until the next kick. */ if (unlikely(head < 0)) break; /* Nothing new? Wait for eventfd to tell us they refilled. */ if (head == vq->num) { - if (unlikely(vhost_enable_notify(&net->dev, vq))) { + if (unlikely(busyloop_intr)) { + vhost_poll_queue(&vq->poll); + } else if (unlikely(vhost_enable_notify(&net->dev, vq))) { vhost_disable_notify(&net->dev, vq); continue; } break; } - if (in) { - vq_err(vq, "Unexpected descriptor format for TX: " - "out %d, int %d\n", out, in); - break; - } - /* Skip header. TODO: support TSO. */ - len = iov_length(vq->iov, out); - iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len); - iov_iter_advance(&msg.msg_iter, hdr_size); - /* Sanity check */ - if (!msg_data_left(&msg)) { - vq_err(vq, "Unexpected header len for TX: " - "%zd expected %zd\n", - len, hdr_size); - break; - } - len = msg_data_left(&msg); - zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN - && !vhost_exceeds_maxpend(net) - && vhost_net_tx_select_zcopy(net); + zcopy_used = len >= VHOST_GOODCOPY_LEN + && !vhost_exceeds_maxpend(net) + && vhost_net_tx_select_zcopy(net); /* use msg_control to pass vhost zerocopy ubuf info to skb */ if (zcopy_used) { @@ -562,10 +673,8 @@ static void handle_tx(struct vhost_net *net) msg.msg_control = NULL; ubufs = NULL; } - total_len += len; - if (total_len < VHOST_NET_WEIGHT && - !vhost_vq_avail_empty(&net->dev, vq) && + if (tx_can_batch(vq, total_len) && likely(!vhost_exceeds_maxpend(net))) { msg.msg_flags |= MSG_MORE; } else { @@ -592,12 +701,37 @@ static void handle_tx(struct vhost_net *net) else vhost_zerocopy_signal_used(net, vq); vhost_net_tx_packet(net); - if (unlikely(total_len >= VHOST_NET_WEIGHT) || - unlikely(++sent_pkts >= VHOST_NET_PKT_WEIGHT)) { + if (unlikely(vhost_exceeds_weight(++sent_pkts, total_len))) { vhost_poll_queue(&vq->poll); break; } } +} + +/* Expects to be always run from workqueue - which acts as + * read-size critical section for our kind of RCU. */ +static void handle_tx(struct vhost_net *net) +{ + struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; + struct vhost_virtqueue *vq = &nvq->vq; + struct socket *sock; + + mutex_lock(&vq->mutex); + sock = vq->private_data; + if (!sock) + goto out; + + if (!vq_iotlb_prefetch(vq)) + goto out; + + vhost_disable_notify(&net->dev, vq); + vhost_net_disable_vq(net, vq); + + if (vhost_sock_zcopy(sock)) + handle_tx_zerocopy(net, sock); + else + handle_tx_copy(net, sock); + out: mutex_unlock(&vq->mutex); } @@ -633,53 +767,50 @@ static int sk_has_rx_data(struct sock *sk) return skb_queue_empty(&sk->sk_receive_queue); } -static void vhost_rx_signal_used(struct vhost_net_virtqueue *nvq) +static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, + bool *busyloop_intr) { - struct vhost_virtqueue *vq = &nvq->vq; - struct vhost_dev *dev = vq->dev; - - if (!nvq->done_idx) - return; - - vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx); - nvq->done_idx = 0; -} - -static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk) -{ - struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX]; - struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; - struct vhost_virtqueue *vq = &nvq->vq; + struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX]; + struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX]; + struct vhost_virtqueue *rvq = &rnvq->vq; + struct vhost_virtqueue *tvq = &tnvq->vq; unsigned long uninitialized_var(endtime); - int len = peek_head_len(rvq, sk); + int len = peek_head_len(rnvq, sk); - if (!len && vq->busyloop_timeout) { + if (!len && tvq->busyloop_timeout) { /* Flush batched heads first */ - vhost_rx_signal_used(rvq); + vhost_net_signal_used(rnvq); /* Both tx vq and rx socket were polled here */ - mutex_lock_nested(&vq->mutex, 1); - vhost_disable_notify(&net->dev, vq); + mutex_lock_nested(&tvq->mutex, 1); + vhost_disable_notify(&net->dev, tvq); preempt_disable(); - endtime = busy_clock() + vq->busyloop_timeout; + endtime = busy_clock() + tvq->busyloop_timeout; - while (vhost_can_busy_poll(&net->dev, endtime) && - !sk_has_rx_data(sk) && - vhost_vq_avail_empty(&net->dev, vq)) + while (vhost_can_busy_poll(endtime)) { + if (vhost_has_work(&net->dev)) { + *busyloop_intr = true; + break; + } + if ((sk_has_rx_data(sk) && + !vhost_vq_avail_empty(&net->dev, rvq)) || + !vhost_vq_avail_empty(&net->dev, tvq)) + break; cpu_relax(); + } preempt_enable(); - if (!vhost_vq_avail_empty(&net->dev, vq)) - vhost_poll_queue(&vq->poll); - else if (unlikely(vhost_enable_notify(&net->dev, vq))) { - vhost_disable_notify(&net->dev, vq); - vhost_poll_queue(&vq->poll); + if (!vhost_vq_avail_empty(&net->dev, tvq)) { + vhost_poll_queue(&tvq->poll); + } else if (unlikely(vhost_enable_notify(&net->dev, tvq))) { + vhost_disable_notify(&net->dev, tvq); + vhost_poll_queue(&tvq->poll); } - mutex_unlock(&vq->mutex); + mutex_unlock(&tvq->mutex); - len = peek_head_len(rvq, sk); + len = peek_head_len(rnvq, sk); } return len; @@ -786,6 +917,7 @@ static void handle_rx(struct vhost_net *net) s16 headcount; size_t vhost_hlen, sock_hlen; size_t vhost_len, sock_len; + bool busyloop_intr = false; struct socket *sock; struct iov_iter fixup; __virtio16 num_buffers; @@ -809,7 +941,8 @@ static void handle_rx(struct vhost_net *net) vq->log : NULL; mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF); - while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) { + while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk, + &busyloop_intr))) { sock_len += sock_hlen; vhost_len = sock_len + vhost_hlen; headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx, @@ -820,7 +953,9 @@ static void handle_rx(struct vhost_net *net) goto out; /* OK, now we need to know about added descriptors. */ if (!headcount) { - if (unlikely(vhost_enable_notify(&net->dev, vq))) { + if (unlikely(busyloop_intr)) { + vhost_poll_queue(&vq->poll); + } else if (unlikely(vhost_enable_notify(&net->dev, vq))) { /* They have slipped one in as we were * doing that: check again. */ vhost_disable_notify(&net->dev, vq); @@ -830,6 +965,7 @@ static void handle_rx(struct vhost_net *net) * they refilled. */ goto out; } + busyloop_intr = false; if (nvq->rx_ring) msg.msg_control = vhost_net_buf_consume(&nvq->rxq); /* On overrun, truncate and discard */ @@ -885,20 +1021,22 @@ static void handle_rx(struct vhost_net *net) goto out; } nvq->done_idx += headcount; - if (nvq->done_idx > VHOST_RX_BATCH) - vhost_rx_signal_used(nvq); + if (nvq->done_idx > VHOST_NET_BATCH) + vhost_net_signal_used(nvq); if (unlikely(vq_log)) vhost_log_write(vq, vq_log, log, vhost_len); total_len += vhost_len; - if (unlikely(total_len >= VHOST_NET_WEIGHT) || - unlikely(++recv_pkts >= VHOST_NET_PKT_WEIGHT)) { + if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) { vhost_poll_queue(&vq->poll); goto out; } } - vhost_net_enable_vq(net, vq); + if (unlikely(busyloop_intr)) + vhost_poll_queue(&vq->poll); + else + vhost_net_enable_vq(net, vq); out: - vhost_rx_signal_used(nvq); + vhost_net_signal_used(nvq); mutex_unlock(&vq->mutex); } @@ -951,7 +1089,7 @@ static int vhost_net_open(struct inode *inode, struct file *f) return -ENOMEM; } - queue = kmalloc_array(VHOST_RX_BATCH, sizeof(void *), + queue = kmalloc_array(VHOST_NET_BATCH, sizeof(void *), GFP_KERNEL); if (!queue) { kfree(vqs); @@ -1265,6 +1403,21 @@ done: return err; } +static int vhost_net_set_backend_features(struct vhost_net *n, u64 features) +{ + int i; + + mutex_lock(&n->dev.mutex); + for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { + mutex_lock(&n->vqs[i].vq.mutex); + n->vqs[i].vq.acked_backend_features = features; + mutex_unlock(&n->vqs[i].vq.mutex); + } + mutex_unlock(&n->dev.mutex); + + return 0; +} + static int vhost_net_set_features(struct vhost_net *n, u64 features) { size_t vhost_hlen, sock_hlen, hdr_len; @@ -1355,6 +1508,17 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl, if (features & ~VHOST_NET_FEATURES) return -EOPNOTSUPP; return vhost_net_set_features(n, features); + case VHOST_GET_BACKEND_FEATURES: + features = VHOST_NET_BACKEND_FEATURES; + if (copy_to_user(featurep, &features, sizeof(features))) + return -EFAULT; + return 0; + case VHOST_SET_BACKEND_FEATURES: + if (copy_from_user(&features, featurep, sizeof(features))) + return -EFAULT; + if (features & ~VHOST_NET_BACKEND_FEATURES) + return -EOPNOTSUPP; + return vhost_net_set_backend_features(n, features); case VHOST_RESET_OWNER: return vhost_net_reset_owner(n); case VHOST_SET_OWNER: diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c index 8c32cf58d6fa..c24bb690680b 100644 --- a/drivers/vhost/scsi.c +++ b/drivers/vhost/scsi.c @@ -46,7 +46,6 @@ #include <linux/virtio_scsi.h> #include <linux/llist.h> #include <linux/bitmap.h> -#include <linux/percpu_ida.h> #include "vhost.h" @@ -324,7 +323,7 @@ static void vhost_scsi_release_cmd(struct se_cmd *se_cmd) } vhost_scsi_put_inflight(tv_cmd->inflight); - percpu_ida_free(&se_sess->sess_tag_pool, se_cmd->map_tag); + target_free_tag(se_sess, se_cmd); } static u32 vhost_scsi_sess_get_index(struct se_session *se_sess) @@ -567,7 +566,7 @@ vhost_scsi_get_tag(struct vhost_virtqueue *vq, struct vhost_scsi_tpg *tpg, struct se_session *se_sess; struct scatterlist *sg, *prot_sg; struct page **pages; - int tag; + int tag, cpu; tv_nexus = tpg->tpg_nexus; if (!tv_nexus) { @@ -576,7 +575,7 @@ vhost_scsi_get_tag(struct vhost_virtqueue *vq, struct vhost_scsi_tpg *tpg, } se_sess = tv_nexus->tvn_se_sess; - tag = percpu_ida_alloc(&se_sess->sess_tag_pool, TASK_RUNNING); + tag = sbitmap_queue_get(&se_sess->sess_tag_pool, &cpu); if (tag < 0) { pr_err("Unable to obtain tag for vhost_scsi_cmd\n"); return ERR_PTR(-ENOMEM); @@ -591,6 +590,7 @@ vhost_scsi_get_tag(struct vhost_virtqueue *vq, struct vhost_scsi_tpg *tpg, cmd->tvc_prot_sgl = prot_sg; cmd->tvc_upages = pages; cmd->tvc_se_cmd.map_tag = tag; + cmd->tvc_se_cmd.map_cpu = cpu; cmd->tvc_tag = scsi_tag; cmd->tvc_lun = lun; cmd->tvc_task_attr = task_attr; @@ -1738,7 +1738,7 @@ static int vhost_scsi_make_nexus(struct vhost_scsi_tpg *tpg, * struct se_node_acl for the vhost_scsi struct se_portal_group with * the SCSI Initiator port name of the passed configfs group 'name'. */ - tv_nexus->tvn_se_sess = target_alloc_session(&tpg->se_tpg, + tv_nexus->tvn_se_sess = target_setup_session(&tpg->se_tpg, VHOST_SCSI_DEFAULT_TAGS, sizeof(struct vhost_scsi_cmd), TARGET_PROT_DIN_PASS | TARGET_PROT_DOUT_PASS, @@ -1797,7 +1797,7 @@ static int vhost_scsi_drop_nexus(struct vhost_scsi_tpg *tpg) /* * Release the SCSI I_T Nexus to the emulated vhost Target Port */ - transport_deregister_session(tv_nexus->tvn_se_sess); + target_remove_session(se_sess); tpg->tpg_nexus = NULL; mutex_unlock(&tpg->tv_tpg_mutex); @@ -1912,9 +1912,7 @@ static struct configfs_attribute *vhost_scsi_tpg_attrs[] = { }; static struct se_portal_group * -vhost_scsi_make_tpg(struct se_wwn *wwn, - struct config_group *group, - const char *name) +vhost_scsi_make_tpg(struct se_wwn *wwn, const char *name) { struct vhost_scsi_tport *tport = container_of(wwn, struct vhost_scsi_tport, tport_wwn); diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index ed3114556fda..96c1d8400822 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -315,6 +315,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_addr = -1ull; vq->private_data = NULL; vq->acked_features = 0; + vq->acked_backend_features = 0; vq->log_base = NULL; vq->error_ctx = NULL; vq->kick = NULL; @@ -1027,28 +1028,40 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, ssize_t vhost_chr_write_iter(struct vhost_dev *dev, struct iov_iter *from) { - struct vhost_msg_node node; - unsigned size = sizeof(struct vhost_msg); - size_t ret; - int err; + struct vhost_iotlb_msg msg; + size_t offset; + int type, ret; - if (iov_iter_count(from) < size) - return 0; - ret = copy_from_iter(&node.msg, size, from); - if (ret != size) + ret = copy_from_iter(&type, sizeof(type), from); + if (ret != sizeof(type)) goto done; - switch (node.msg.type) { + switch (type) { case VHOST_IOTLB_MSG: - err = vhost_process_iotlb_msg(dev, &node.msg.iotlb); - if (err) - ret = err; + /* There maybe a hole after type for V1 message type, + * so skip it here. + */ + offset = offsetof(struct vhost_msg, iotlb) - sizeof(int); + break; + case VHOST_IOTLB_MSG_V2: + offset = sizeof(__u32); break; default: ret = -EINVAL; - break; + goto done; + } + + iov_iter_advance(from, offset); + ret = copy_from_iter(&msg, sizeof(msg), from); + if (ret != sizeof(msg)) + goto done; + if (vhost_process_iotlb_msg(dev, &msg)) { + ret = -EFAULT; + goto done; } + ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) : + sizeof(struct vhost_msg_v2); done: return ret; } @@ -1107,13 +1120,28 @@ ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, finish_wait(&dev->wait, &wait); if (node) { - ret = copy_to_iter(&node->msg, size, to); + struct vhost_iotlb_msg *msg; + void *start = &node->msg; + + switch (node->msg.type) { + case VHOST_IOTLB_MSG: + size = sizeof(node->msg); + msg = &node->msg.iotlb; + break; + case VHOST_IOTLB_MSG_V2: + size = sizeof(node->msg_v2); + msg = &node->msg_v2.iotlb; + break; + default: + BUG(); + break; + } - if (ret != size || node->msg.type != VHOST_IOTLB_MISS) { + ret = copy_to_iter(start, size, to); + if (ret != size || msg->type != VHOST_IOTLB_MISS) { kfree(node); return ret; } - vhost_enqueue_msg(dev, &dev->pending_list, node); } @@ -1126,12 +1154,19 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) struct vhost_dev *dev = vq->dev; struct vhost_msg_node *node; struct vhost_iotlb_msg *msg; + bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2); - node = vhost_new_msg(vq, VHOST_IOTLB_MISS); + node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG); if (!node) return -ENOMEM; - msg = &node->msg.iotlb; + if (v2) { + node->msg_v2.type = VHOST_IOTLB_MSG_V2; + msg = &node->msg_v2.iotlb; + } else { + msg = &node->msg.iotlb; + } + msg->type = VHOST_IOTLB_MISS; msg->iova = iova; msg->perm = access; diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 6c844b90a168..466ef7542291 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -132,6 +132,7 @@ struct vhost_virtqueue { struct vhost_umem *iotlb; void *private_data; u64 acked_features; + u64 acked_backend_features; /* Log write descriptors */ void __user *log_base; struct vhost_log *log; @@ -147,7 +148,10 @@ struct vhost_virtqueue { }; struct vhost_msg_node { - struct vhost_msg msg; + union { + struct vhost_msg msg; + struct vhost_msg_v2 msg_v2; + }; struct vhost_virtqueue *vq; struct list_head node; }; @@ -238,6 +242,11 @@ static inline bool vhost_has_feature(struct vhost_virtqueue *vq, int bit) return vq->acked_features & (1ULL << bit); } +static inline bool vhost_backend_has_feature(struct vhost_virtqueue *vq, int bit) +{ + return vq->acked_backend_features & (1ULL << bit); +} + #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY static inline bool vhost_is_little_endian(struct vhost_virtqueue *vq) { |