diff options
Diffstat (limited to 'drivers/vhost/vsock.c')
-rw-r--r-- | drivers/vhost/vsock.c | 631 |
1 files changed, 631 insertions, 0 deletions
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c new file mode 100644 index 000000000000..65b1cf8a06cb --- /dev/null +++ b/drivers/vhost/vsock.c @@ -0,0 +1,631 @@ +/* + * vhost transport for vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * + * This work is licensed under the terms of the GNU GPL, version 2. + */ +#include <linux/miscdevice.h> +#include <linux/module.h> +#include <linux/mutex.h> +#include <net/sock.h> +#include <linux/virtio_vsock.h> +#include <linux/vhost.h> + +#include <net/af_vsock.h> +#include "vhost.h" +#include "vsock.h" + +#define VHOST_VSOCK_DEFAULT_HOST_CID 2 + +static int vhost_transport_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk); + +enum { + VHOST_VSOCK_FEATURES = VHOST_FEATURES, +}; + +/* Used to track all the vhost_vsock instances on the system. */ +static LIST_HEAD(vhost_vsock_list); +static DEFINE_MUTEX(vhost_vsock_mutex); + +struct vhost_vsock_virtqueue { + struct vhost_virtqueue vq; +}; + +struct vhost_vsock { + /* Vhost device */ + struct vhost_dev dev; + /* Vhost vsock virtqueue*/ + struct vhost_vsock_virtqueue vqs[VSOCK_VQ_MAX]; + /* Link to global vhost_vsock_list*/ + struct list_head list; + /* Head for pkt from host to guest */ + struct list_head send_pkt_list; + /* Work item to send pkt */ + struct vhost_work send_pkt_work; + /* Wait queue for send pkt */ + wait_queue_head_t queue_wait; + /* Used for global tx buf limitation */ + u32 total_tx_buf; + /* Guest contex id this vhost_vsock instance handles */ + u32 guest_cid; +}; + +static u32 vhost_transport_get_local_cid(void) +{ + u32 cid = VHOST_VSOCK_DEFAULT_HOST_CID; + return cid; +} + +static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) +{ + struct vhost_vsock *vsock; + + mutex_lock(&vhost_vsock_mutex); + list_for_each_entry(vsock, &vhost_vsock_list, list) { + if (vsock->guest_cid == guest_cid) { + mutex_unlock(&vhost_vsock_mutex); + return vsock; + } + } + mutex_unlock(&vhost_vsock_mutex); + + return NULL; +} + +static void +vhost_transport_do_send_pkt(struct vhost_vsock *vsock, + struct vhost_virtqueue *vq) +{ + bool added = false; + + mutex_lock(&vq->mutex); + vhost_disable_notify(&vsock->dev, vq); + for (;;) { + struct virtio_vsock_pkt *pkt; + struct iov_iter iov_iter; + unsigned out, in; + struct sock *sk; + size_t nbytes; + size_t len; + int head; + + if (list_empty(&vsock->send_pkt_list)) { + vhost_enable_notify(&vsock->dev, vq); + break; + } + + head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), + &out, &in, NULL, NULL); + pr_debug("%s: head = %d\n", __func__, head); + if (head < 0) + break; + + if (head == vq->num) { + if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { + vhost_disable_notify(&vsock->dev, vq); + continue; + } + break; + } + + pkt = list_first_entry(&vsock->send_pkt_list, + struct virtio_vsock_pkt, list); + list_del_init(&pkt->list); + + if (out) { + virtio_transport_free_pkt(pkt); + vq_err(vq, "Expected 0 output buffers, got %u\n", out); + break; + } + + len = iov_length(&vq->iov[out], in); + iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len); + + nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); + if (nbytes != sizeof(pkt->hdr)) { + virtio_transport_free_pkt(pkt); + vq_err(vq, "Faulted on copying pkt hdr\n"); + break; + } + + nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter); + if (nbytes != pkt->len) { + virtio_transport_free_pkt(pkt); + vq_err(vq, "Faulted on copying pkt buf\n"); + break; + } + + vhost_add_used(vq, head, pkt->len); /* TODO should this be sizeof(pkt->hdr) + pkt->len? */ + added = true; + + virtio_transport_dec_tx_pkt(pkt); + vsock->total_tx_buf -= pkt->len; + + sk = sk_vsock(pkt->trans->vsk); + /* Release refcnt taken in vhost_transport_send_pkt */ + sock_put(sk); + + virtio_transport_free_pkt(pkt); + } + if (added) + vhost_signal(&vsock->dev, vq); + mutex_unlock(&vq->mutex); + + if (added) + wake_up(&vsock->queue_wait); +} + +static void vhost_transport_send_pkt_work(struct vhost_work *work) +{ + struct vhost_virtqueue *vq; + struct vhost_vsock *vsock; + + vsock = container_of(work, struct vhost_vsock, send_pkt_work); + vq = &vsock->vqs[VSOCK_VQ_RX].vq; + + vhost_transport_do_send_pkt(vsock, vq); +} + +static int +vhost_transport_send_pkt(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info) +{ + u32 src_cid, src_port, dst_cid, dst_port; + struct virtio_transport *trans; + struct virtio_vsock_pkt *pkt; + struct vhost_virtqueue *vq; + struct vhost_vsock *vsock; + u32 pkt_len = info->pkt_len; + DEFINE_WAIT(wait); + + src_cid = vhost_transport_get_local_cid(); + src_port = vsk->local_addr.svm_port; + if (!info->remote_cid) { + dst_cid = vsk->remote_addr.svm_cid; + dst_port = vsk->remote_addr.svm_port; + } else { + dst_cid = info->remote_cid; + dst_port = info->remote_port; + } + + /* Find the vhost_vsock according to guest context id */ + vsock = vhost_vsock_get(dst_cid); + if (!vsock) + return -ENODEV; + + trans = vsk->trans; + vq = &vsock->vqs[VSOCK_VQ_RX].vq; + + /* we can send less than pkt_len bytes */ + if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) + pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; + + /* virtio_transport_get_credit might return less than pkt_len credit */ + pkt_len = virtio_transport_get_credit(trans, pkt_len); + + /* Do not send zero length OP_RW pkt*/ + if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) + return pkt_len; + + /* Respect global tx buf limitation */ + mutex_lock(&vq->mutex); + while (pkt_len + vsock->total_tx_buf > VIRTIO_VSOCK_MAX_TX_BUF_SIZE) { + prepare_to_wait_exclusive(&vsock->queue_wait, &wait, + TASK_UNINTERRUPTIBLE); + mutex_unlock(&vq->mutex); + schedule(); + mutex_lock(&vq->mutex); + finish_wait(&vsock->queue_wait, &wait); + } + vsock->total_tx_buf += pkt_len; + mutex_unlock(&vq->mutex); + + pkt = virtio_transport_alloc_pkt(vsk, info, pkt_len, + src_cid, src_port, + dst_cid, dst_port); + if (!pkt) { + mutex_lock(&vq->mutex); + vsock->total_tx_buf -= pkt_len; + mutex_unlock(&vq->mutex); + virtio_transport_put_credit(trans, pkt_len); + return -ENOMEM; + } + + pr_debug("%s:info->pkt_len= %d\n", __func__, pkt_len); + /* Released in vhost_transport_do_send_pkt */ + sock_hold(&trans->vsk->sk); + virtio_transport_inc_tx_pkt(pkt); + + /* Queue it up in vhost work */ + mutex_lock(&vq->mutex); + list_add_tail(&pkt->list, &vsock->send_pkt_list); + vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); + mutex_unlock(&vq->mutex); + + return pkt_len; +} + +static struct virtio_transport_pkt_ops vhost_ops = { + .send_pkt = vhost_transport_send_pkt, +}; + +static struct virtio_vsock_pkt * +vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq, + unsigned int out, unsigned int in) +{ + struct virtio_vsock_pkt *pkt; + struct iov_iter iov_iter; + size_t nbytes; + size_t len; + + if (in != 0) { + vq_err(vq, "Expected 0 input buffers, got %u\n", in); + return NULL; + } + + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return NULL; + + len = iov_length(vq->iov, out); + iov_iter_init(&iov_iter, WRITE, vq->iov, out, len); + + nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); + if (nbytes != sizeof(pkt->hdr)) { + vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n", + sizeof(pkt->hdr), nbytes); + kfree(pkt); + return NULL; + } + + if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_DGRAM) + pkt->len = le32_to_cpu(pkt->hdr.len) & 0XFFFF; + else if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) + pkt->len = le32_to_cpu(pkt->hdr.len); + + /* No payload */ + if (!pkt->len) + return pkt; + + /* The pkt is too big */ + if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { + kfree(pkt); + return NULL; + } + + pkt->buf = kmalloc(pkt->len, GFP_KERNEL); + if (!pkt->buf) { + kfree(pkt); + return NULL; + } + + nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter); + if (nbytes != pkt->len) { + vq_err(vq, "Expected %u byte payload, got %zu bytes\n", + pkt->len, nbytes); + virtio_transport_free_pkt(pkt); + return NULL; + } + + return pkt; +} + +static void vhost_vsock_handle_ctl_kick(struct vhost_work *work) +{ + struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, + poll.work); + struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, + dev); + + pr_debug("%s vq=%p, vsock=%p\n", __func__, vq, vsock); +} + +static void vhost_vsock_handle_tx_kick(struct vhost_work *work) +{ + struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, + poll.work); + struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, + dev); + struct virtio_vsock_pkt *pkt; + int head; + unsigned int out, in; + bool added = false; + u32 len; + + mutex_lock(&vq->mutex); + vhost_disable_notify(&vsock->dev, vq); + for (;;) { + head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), + &out, &in, NULL, NULL); + if (head < 0) + break; + + if (head == vq->num) { + if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { + vhost_disable_notify(&vsock->dev, vq); + continue; + } + break; + } + + pkt = vhost_vsock_alloc_pkt(vq, out, in); + if (!pkt) { + vq_err(vq, "Faulted on pkt\n"); + continue; + } + + len = pkt->len; + + /* Only accept correctly addressed packets */ + if (le32_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid && + le32_to_cpu(pkt->hdr.dst_cid) == vhost_transport_get_local_cid()) + virtio_transport_recv_pkt(pkt); + else + virtio_transport_free_pkt(pkt); + + vhost_add_used(vq, head, len); + added = true; + } + if (added) + vhost_signal(&vsock->dev, vq); + mutex_unlock(&vq->mutex); +} + +static void vhost_vsock_handle_rx_kick(struct vhost_work *work) +{ + struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, + poll.work); + struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, + dev); + + vhost_transport_do_send_pkt(vsock, vq); +} + +static int vhost_vsock_dev_open(struct inode *inode, struct file *file) +{ + struct vhost_virtqueue **vqs; + struct vhost_vsock *vsock; + int ret; + + vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); + if (!vsock) + return -ENOMEM; + + pr_debug("%s:vsock=%p\n", __func__, vsock); + + vqs = kmalloc(VSOCK_VQ_MAX * sizeof(*vqs), GFP_KERNEL); + if (!vqs) { + ret = -ENOMEM; + goto out; + } + + vqs[VSOCK_VQ_CTRL] = &vsock->vqs[VSOCK_VQ_CTRL].vq; + vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX].vq; + vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX].vq; + vsock->vqs[VSOCK_VQ_CTRL].vq.handle_kick = vhost_vsock_handle_ctl_kick; + vsock->vqs[VSOCK_VQ_TX].vq.handle_kick = vhost_vsock_handle_tx_kick; + vsock->vqs[VSOCK_VQ_RX].vq.handle_kick = vhost_vsock_handle_rx_kick; + + vhost_dev_init(&vsock->dev, vqs, VSOCK_VQ_MAX); + + file->private_data = vsock; + init_waitqueue_head(&vsock->queue_wait); + INIT_LIST_HEAD(&vsock->send_pkt_list); + vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work); + + mutex_lock(&vhost_vsock_mutex); + list_add_tail(&vsock->list, &vhost_vsock_list); + mutex_unlock(&vhost_vsock_mutex); + return 0; + +out: + kfree(vsock); + return ret; +} + +static void vhost_vsock_flush(struct vhost_vsock *vsock) +{ + int i; + + for (i = 0; i < VSOCK_VQ_MAX; i++) + vhost_poll_flush(&vsock->vqs[i].vq.poll); + vhost_work_flush(&vsock->dev, &vsock->send_pkt_work); +} + +static int vhost_vsock_dev_release(struct inode *inode, struct file *file) +{ + struct vhost_vsock *vsock = file->private_data; + + mutex_lock(&vhost_vsock_mutex); + list_del(&vsock->list); + mutex_unlock(&vhost_vsock_mutex); + + vhost_dev_stop(&vsock->dev); + vhost_vsock_flush(vsock); + vhost_dev_cleanup(&vsock->dev, false); + kfree(vsock->dev.vqs); + kfree(vsock); + return 0; +} + +static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u32 guest_cid) +{ + struct vhost_vsock *other; + + /* Refuse reserved CIDs */ + if (guest_cid <= VMADDR_CID_HOST) { + return -EINVAL; + } + + /* Refuse if CID is already in use */ + other = vhost_vsock_get(guest_cid); + if (other && other != vsock) { + return -EADDRINUSE; + } + + mutex_lock(&vhost_vsock_mutex); + vsock->guest_cid = guest_cid; + pr_debug("%s:guest_cid=%d\n", __func__, guest_cid); + mutex_unlock(&vhost_vsock_mutex); + + return 0; +} + +static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features) +{ + struct vhost_virtqueue *vq; + int i; + + if (features & ~VHOST_VSOCK_FEATURES) + return -EOPNOTSUPP; + + mutex_lock(&vsock->dev.mutex); + if ((features & (1 << VHOST_F_LOG_ALL)) && + !vhost_log_access_ok(&vsock->dev)) { + mutex_unlock(&vsock->dev.mutex); + return -EFAULT; + } + + for (i = 0; i < VSOCK_VQ_MAX; i++) { + vq = &vsock->vqs[i].vq; + mutex_lock(&vq->mutex); + vq->acked_features = features; + mutex_unlock(&vq->mutex); + } + mutex_unlock(&vsock->dev.mutex); + return 0; +} + +static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl, + unsigned long arg) +{ + struct vhost_vsock *vsock = f->private_data; + void __user *argp = (void __user *)arg; + u64 __user *featurep = argp; + u32 __user *cidp = argp; + u32 guest_cid; + u64 features; + int r; + + switch (ioctl) { + case VHOST_VSOCK_SET_GUEST_CID: + if (get_user(guest_cid, cidp)) + return -EFAULT; + return vhost_vsock_set_cid(vsock, guest_cid); + case VHOST_GET_FEATURES: + features = VHOST_VSOCK_FEATURES; + if (copy_to_user(featurep, &features, sizeof(features))) + return -EFAULT; + return 0; + case VHOST_SET_FEATURES: + if (copy_from_user(&features, featurep, sizeof(features))) + return -EFAULT; + return vhost_vsock_set_features(vsock, features); + default: + mutex_lock(&vsock->dev.mutex); + r = vhost_dev_ioctl(&vsock->dev, ioctl, argp); + if (r == -ENOIOCTLCMD) + r = vhost_vring_ioctl(&vsock->dev, ioctl, argp); + else + vhost_vsock_flush(vsock); + mutex_unlock(&vsock->dev.mutex); + return r; + } +} + +static const struct file_operations vhost_vsock_fops = { + .owner = THIS_MODULE, + .open = vhost_vsock_dev_open, + .release = vhost_vsock_dev_release, + .llseek = noop_llseek, + .unlocked_ioctl = vhost_vsock_dev_ioctl, +}; + +static struct miscdevice vhost_vsock_misc = { + .minor = MISC_DYNAMIC_MINOR, + .name = "vhost-vsock", + .fops = &vhost_vsock_fops, +}; + +static int +vhost_transport_socket_init(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + struct virtio_transport *trans; + int ret; + + ret = virtio_transport_do_socket_init(vsk, psk); + if (ret) + return ret; + + trans = vsk->trans; + trans->ops = &vhost_ops; + + return ret; +} + +static struct vsock_transport vhost_transport = { + .get_local_cid = vhost_transport_get_local_cid, + + .init = vhost_transport_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_bind = virtio_transport_dgram_bind, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + + .set_buffer_size = virtio_transport_set_buffer_size, + .set_min_buffer_size = virtio_transport_set_min_buffer_size, + .set_max_buffer_size = virtio_transport_set_max_buffer_size, + .get_buffer_size = virtio_transport_get_buffer_size, + .get_min_buffer_size = virtio_transport_get_min_buffer_size, + .get_max_buffer_size = virtio_transport_get_max_buffer_size, +}; + +static int __init vhost_vsock_init(void) +{ + int ret; + + ret = vsock_core_init(&vhost_transport); + if (ret < 0) + return ret; + return misc_register(&vhost_vsock_misc); +}; + +static void __exit vhost_vsock_exit(void) +{ + misc_deregister(&vhost_vsock_misc); + vsock_core_exit(); +}; + +module_init(vhost_vsock_init); +module_exit(vhost_vsock_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("vhost transport for vsock "); |