diff options
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
-rw-r--r-- | net/vmw_vsock/af_vsock.c | 397 |
1 files changed, 295 insertions, 102 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 582a3e4dfce2..74db4cd637a7 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -126,19 +126,18 @@ static struct proto vsock_proto = { */ #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) -static const struct vsock_transport *transport; +#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128 + +/* Transport used for host->guest communication */ +static const struct vsock_transport *transport_h2g; +/* Transport used for guest->host communication */ +static const struct vsock_transport *transport_g2h; +/* Transport used for DGRAM communication */ +static const struct vsock_transport *transport_dgram; static DEFINE_MUTEX(vsock_register_mutex); -/**** EXPORTS ****/ - -/* Get the ID of the local context. This is transport dependent. */ - -int vm_sockets_get_local_cid(void) -{ - return transport->get_local_cid(); -} -EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid); - /**** UTILS ****/ /* Each bound VSocket is stored in the bind hash table and each connected @@ -188,7 +187,7 @@ static int vsock_auto_bind(struct vsock_sock *vsk) return __vsock_bind(sk, &local_addr); } -static int __init vsock_init_tables(void) +static void vsock_init_tables(void) { int i; @@ -197,7 +196,6 @@ static int __init vsock_init_tables(void) for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) INIT_LIST_HEAD(&vsock_connected_table[i]); - return 0; } static void __vsock_insert_bound(struct list_head *list, @@ -230,9 +228,15 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr) { struct vsock_sock *vsk; - list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) - if (addr->svm_port == vsk->local_addr.svm_port) + list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) { + if (vsock_addr_equals_addr(addr, &vsk->local_addr)) + return sk_vsock(vsk); + + if (addr->svm_port == vsk->local_addr.svm_port && + (vsk->local_addr.svm_cid == VMADDR_CID_ANY || + addr->svm_cid == VMADDR_CID_ANY)) return sk_vsock(vsk); + } return NULL; } @@ -382,6 +386,88 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected) } EXPORT_SYMBOL_GPL(vsock_enqueue_accept); +static void vsock_deassign_transport(struct vsock_sock *vsk) +{ + if (!vsk->transport) + return; + + vsk->transport->destruct(vsk); + module_put(vsk->transport->module); + vsk->transport = NULL; +} + +/* Assign a transport to a socket and call the .init transport callback. + * + * Note: for stream socket this must be called when vsk->remote_addr is set + * (e.g. during the connect() or when a connection request on a listener + * socket is received). + * The vsk->remote_addr is used to decide which transport to use: + * - remote CID <= VMADDR_CID_HOST will use guest->host transport; + * - remote CID == local_cid (guest->host transport) will use guest->host + * transport for loopback (host->guest transports don't support loopback); + * - remote CID > VMADDR_CID_HOST will use host->guest transport; + */ +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + const struct vsock_transport *new_transport; + struct sock *sk = sk_vsock(vsk); + unsigned int remote_cid = vsk->remote_addr.svm_cid; + int ret; + + switch (sk->sk_type) { + case SOCK_DGRAM: + new_transport = transport_dgram; + break; + case SOCK_STREAM: + if (remote_cid <= VMADDR_CID_HOST || + (transport_g2h && + remote_cid == transport_g2h->get_local_cid())) + new_transport = transport_g2h; + else + new_transport = transport_h2g; + break; + default: + return -ESOCKTNOSUPPORT; + } + + if (vsk->transport) { + if (vsk->transport == new_transport) + return 0; + + vsk->transport->release(vsk); + vsock_deassign_transport(vsk); + } + + /* We increase the module refcnt to prevent the transport unloading + * while there are open sockets assigned to it. + */ + if (!new_transport || !try_module_get(new_transport->module)) + return -ENODEV; + + ret = new_transport->init(vsk, psk); + if (ret) { + module_put(new_transport->module); + return ret; + } + + vsk->transport = new_transport; + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_assign_transport); + +bool vsock_find_cid(unsigned int cid) +{ + if (transport_g2h && cid == transport_g2h->get_local_cid()) + return true; + + if (transport_h2g && cid == VMADDR_CID_HOST) + return true; + + return false; +} +EXPORT_SYMBOL_GPL(vsock_find_cid); + static struct sock *vsock_dequeue_accept(struct sock *listener) { struct vsock_sock *vlistener; @@ -418,7 +504,12 @@ static bool vsock_is_pending(struct sock *sk) static int vsock_send_shutdown(struct sock *sk, int mode) { - return transport->shutdown(vsock_sk(sk), mode); + struct vsock_sock *vsk = vsock_sk(sk); + + if (!vsk->transport) + return -ENODEV; + + return vsk->transport->shutdown(vsk, mode); } static void vsock_pending_work(struct work_struct *work) @@ -439,7 +530,7 @@ static void vsock_pending_work(struct work_struct *work) if (vsock_is_pending(sk)) { vsock_remove_pending(listener, sk); - listener->sk_ack_backlog--; + sk_acceptq_removed(listener); } else if (!vsk->rejected) { /* We are not on the pending list and accept() did not reject * us, so we must have been accepted by our user process. We @@ -528,13 +619,12 @@ static int __vsock_bind_stream(struct vsock_sock *vsk, static int __vsock_bind_dgram(struct vsock_sock *vsk, struct sockaddr_vm *addr) { - return transport->dgram_bind(vsk, addr); + return vsk->transport->dgram_bind(vsk, addr); } static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) { struct vsock_sock *vsk = vsock_sk(sk); - u32 cid; int retval; /* First ensure this socket isn't already bound. */ @@ -544,10 +634,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) /* Now bind to the provided address or select appropriate values if * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that * like AF_INET prevents binding to a non-local IP address (in most - * cases), we only allow binding to the local CID. + * cases), we only allow binding to a local CID. */ - cid = transport->get_local_cid(); - if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY) + if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid)) return -EADDRNOTAVAIL; switch (sk->sk_socket->type) { @@ -571,12 +660,12 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) static void vsock_connect_timeout(struct work_struct *work); -struct sock *__vsock_create(struct net *net, - struct socket *sock, - struct sock *parent, - gfp_t priority, - unsigned short type, - int kern) +static struct sock *__vsock_create(struct net *net, + struct socket *sock, + struct sock *parent, + gfp_t priority, + unsigned short type, + int kern) { struct sock *sk; struct vsock_sock *psk; @@ -620,28 +709,24 @@ struct sock *__vsock_create(struct net *net, vsk->trusted = psk->trusted; vsk->owner = get_cred(psk->owner); vsk->connect_timeout = psk->connect_timeout; + vsk->buffer_size = psk->buffer_size; + vsk->buffer_min_size = psk->buffer_min_size; + vsk->buffer_max_size = psk->buffer_max_size; } else { vsk->trusted = capable(CAP_NET_ADMIN); vsk->owner = get_current_cred(); vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; + vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE; + vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE; + vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE; } - if (transport->init(vsk, psk) < 0) { - sk_free(sk); - return NULL; - } - - if (sock) - vsock_insert_unbound(vsk); - return sk; } -EXPORT_SYMBOL_GPL(__vsock_create); static void __vsock_release(struct sock *sk, int level) { if (sk) { - struct sk_buff *skb; struct sock *pending; struct vsock_sock *vsk; @@ -651,7 +736,10 @@ static void __vsock_release(struct sock *sk, int level) /* The release call is supposed to use lock_sock_nested() * rather than lock_sock(), if a sock lock should be acquired. */ - transport->release(vsk); + if (vsk->transport) + vsk->transport->release(vsk); + else if (sk->sk_type == SOCK_STREAM) + vsock_remove_sock(vsk); /* When "level" is SINGLE_DEPTH_NESTING, use the nested * version to avoid the warning "possible recursive locking @@ -662,8 +750,7 @@ static void __vsock_release(struct sock *sk, int level) sock_orphan(sk); sk->sk_shutdown = SHUTDOWN_MASK; - while ((skb = skb_dequeue(&sk->sk_receive_queue))) - kfree_skb(skb); + skb_queue_purge(&sk->sk_receive_queue); /* Clean up any sockets that never were accepted. */ while ((pending = vsock_dequeue_accept(sk)) != NULL) { @@ -680,7 +767,7 @@ static void vsock_sk_destruct(struct sock *sk) { struct vsock_sock *vsk = vsock_sk(sk); - transport->destruct(vsk); + vsock_deassign_transport(vsk); /* When clearing these addresses, there's no need to set the family and * possibly register the address family with the kernel. @@ -702,15 +789,22 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) return err; } +struct sock *vsock_create_connected(struct sock *parent) +{ + return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL, + parent->sk_type, 0); +} +EXPORT_SYMBOL_GPL(vsock_create_connected); + s64 vsock_stream_has_data(struct vsock_sock *vsk) { - return transport->stream_has_data(vsk); + return vsk->transport->stream_has_data(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_data); s64 vsock_stream_has_space(struct vsock_sock *vsk) { - return transport->stream_has_space(vsk); + return vsk->transport->stream_has_space(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_space); @@ -879,6 +973,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; } else if (sock->type == SOCK_STREAM) { + const struct vsock_transport *transport = vsk->transport; lock_sock(sk); /* Listening sockets that have connections in their accept @@ -889,7 +984,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, mask |= EPOLLIN | EPOLLRDNORM; /* If there is something in the queue then we can read. */ - if (transport->stream_is_active(vsk) && + if (transport && transport->stream_is_active(vsk) && !(sk->sk_shutdown & RCV_SHUTDOWN)) { bool data_ready_now = false; int ret = transport->notify_poll_in( @@ -954,6 +1049,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *remote_addr; + const struct vsock_transport *transport; if (msg->msg_flags & MSG_OOB) return -EOPNOTSUPP; @@ -962,6 +1058,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, err = 0; sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; lock_sock(sk); @@ -1046,8 +1143,8 @@ static int vsock_dgram_connect(struct socket *sock, if (err) goto out; - if (!transport->dgram_allow(remote_addr->svm_cid, - remote_addr->svm_port)) { + if (!vsk->transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { err = -EINVAL; goto out; } @@ -1063,7 +1160,9 @@ out: static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags) { - return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags); + struct vsock_sock *vsk = vsock_sk(sock->sk); + + return vsk->transport->dgram_dequeue(vsk, msg, len, flags); } static const struct proto_ops vsock_dgram_ops = { @@ -1089,6 +1188,8 @@ static const struct proto_ops vsock_dgram_ops = { static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) { + const struct vsock_transport *transport = vsk->transport; + if (!transport->cancel_pkt) return -EOPNOTSUPP; @@ -1125,6 +1226,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, int err; struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; struct sockaddr_vm *remote_addr; long timeout; DEFINE_WAIT(wait); @@ -1159,19 +1261,26 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, goto out; } + /* Set the remote address that we are connecting to. */ + memcpy(&vsk->remote_addr, remote_addr, + sizeof(vsk->remote_addr)); + + err = vsock_assign_transport(vsk, NULL); + if (err) + goto out; + + transport = vsk->transport; + /* The hypervisor and well-known contexts do not have socket * endpoints. */ - if (!transport->stream_allow(remote_addr->svm_cid, + if (!transport || + !transport->stream_allow(remote_addr->svm_cid, remote_addr->svm_port)) { err = -ENETUNREACH; goto out; } - /* Set the remote address that we are connecting to. */ - memcpy(&vsk->remote_addr, remote_addr, - sizeof(vsk->remote_addr)); - err = vsock_auto_bind(vsk); if (err) goto out; @@ -1301,7 +1410,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags, err = -listener->sk_err; if (connected) { - listener->sk_ack_backlog--; + sk_acceptq_removed(listener); lock_sock_nested(connected, SINGLE_DEPTH_NESTING); vconnected = vsock_sk(connected); @@ -1366,6 +1475,23 @@ out: return err; } +static void vsock_update_buffer_size(struct vsock_sock *vsk, + const struct vsock_transport *transport, + u64 val) +{ + if (val > vsk->buffer_max_size) + val = vsk->buffer_max_size; + + if (val < vsk->buffer_min_size) + val = vsk->buffer_min_size; + + if (val != vsk->buffer_size && + transport && transport->notify_buffer_size) + transport->notify_buffer_size(vsk, &val); + + vsk->buffer_size = val; +} + static int vsock_stream_setsockopt(struct socket *sock, int level, int optname, @@ -1375,6 +1501,7 @@ static int vsock_stream_setsockopt(struct socket *sock, int err; struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; u64 val; if (level != AF_VSOCK) @@ -1395,23 +1522,26 @@ static int vsock_stream_setsockopt(struct socket *sock, err = 0; sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; lock_sock(sk); switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: COPY_IN(val); - transport->set_buffer_size(vsk, val); + vsock_update_buffer_size(vsk, transport, val); break; case SO_VM_SOCKETS_BUFFER_MAX_SIZE: COPY_IN(val); - transport->set_max_buffer_size(vsk, val); + vsk->buffer_max_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); break; case SO_VM_SOCKETS_BUFFER_MIN_SIZE: COPY_IN(val); - transport->set_min_buffer_size(vsk, val); + vsk->buffer_min_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); break; case SO_VM_SOCKETS_CONNECT_TIMEOUT: { @@ -1478,17 +1608,17 @@ static int vsock_stream_getsockopt(struct socket *sock, switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: - val = transport->get_buffer_size(vsk); + val = vsk->buffer_size; COPY_OUT(val); break; case SO_VM_SOCKETS_BUFFER_MAX_SIZE: - val = transport->get_max_buffer_size(vsk); + val = vsk->buffer_max_size; COPY_OUT(val); break; case SO_VM_SOCKETS_BUFFER_MIN_SIZE: - val = transport->get_min_buffer_size(vsk); + val = vsk->buffer_min_size; COPY_OUT(val); break; @@ -1519,6 +1649,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, { struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; ssize_t total_written; long timeout; int err; @@ -1527,6 +1658,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; total_written = 0; err = 0; @@ -1548,7 +1680,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, goto out; } - if (sk->sk_state != TCP_ESTABLISHED || + if (!transport || sk->sk_state != TCP_ESTABLISHED || !vsock_addr_bound(&vsk->local_addr)) { err = -ENOTCONN; goto out; @@ -1658,6 +1790,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, { struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; int err; size_t target; ssize_t copied; @@ -1668,11 +1801,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, sk = sock->sk; vsk = vsock_sk(sk); + transport = vsk->transport; err = 0; lock_sock(sk); - if (sk->sk_state != TCP_ESTABLISHED) { + if (!transport || sk->sk_state != TCP_ESTABLISHED) { /* Recvmsg is supposed to return 0 if a peer performs an * orderly shutdown. Differentiate between that case and when a * peer has not connected or a local shutdown occured with the @@ -1846,6 +1980,10 @@ static const struct proto_ops vsock_stream_ops = { static int vsock_create(struct net *net, struct socket *sock, int protocol, int kern) { + struct vsock_sock *vsk; + struct sock *sk; + int ret; + if (!sock) return -EINVAL; @@ -1865,7 +2003,23 @@ static int vsock_create(struct net *net, struct socket *sock, sock->state = SS_UNCONNECTED; - return __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern) ? 0 : -ENOMEM; + sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern); + if (!sk) + return -ENOMEM; + + vsk = vsock_sk(sk); + + if (sock->type == SOCK_DGRAM) { + ret = vsock_assign_transport(vsk, NULL); + if (ret < 0) { + sock_put(sk); + return ret; + } + } + + vsock_insert_unbound(vsk); + + return 0; } static const struct net_proto_family vsock_family_ops = { @@ -1878,11 +2032,20 @@ static long vsock_dev_do_ioctl(struct file *filp, unsigned int cmd, void __user *ptr) { u32 __user *p = ptr; + u32 cid = VMADDR_CID_ANY; int retval = 0; switch (cmd) { case IOCTL_VM_SOCKETS_GET_LOCAL_CID: - if (put_user(transport->get_local_cid(), p) != 0) + /* To be compatible with the VMCI behavior, we prioritize the + * guest CID instead of well-know host CID (VMADDR_CID_HOST). + */ + if (transport_g2h) + cid = transport_g2h->get_local_cid(); + else if (transport_h2g) + cid = transport_h2g->get_local_cid(); + + if (put_user(cid, p) != 0) retval = -EFAULT; break; @@ -1922,24 +2085,13 @@ static struct miscdevice vsock_device = { .fops = &vsock_device_ops, }; -int __vsock_core_init(const struct vsock_transport *t, struct module *owner) +static int __init vsock_init(void) { - int err = mutex_lock_interruptible(&vsock_register_mutex); + int err = 0; - if (err) - return err; - - if (transport) { - err = -EBUSY; - goto err_busy; - } - - /* Transport must be the owner of the protocol so that it can't - * unload while there are open sockets. - */ - vsock_proto.owner = owner; - transport = t; + vsock_init_tables(); + vsock_proto.owner = THIS_MODULE; vsock_device.minor = MISC_DYNAMIC_MINOR; err = misc_register(&vsock_device); if (err) { @@ -1960,7 +2112,6 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner) goto err_unregister_proto; } - mutex_unlock(&vsock_register_mutex); return 0; err_unregister_proto: @@ -1968,44 +2119,86 @@ err_unregister_proto: err_deregister_misc: misc_deregister(&vsock_device); err_reset_transport: - transport = NULL; -err_busy: - mutex_unlock(&vsock_register_mutex); return err; } -EXPORT_SYMBOL_GPL(__vsock_core_init); -void vsock_core_exit(void) +static void __exit vsock_exit(void) { - mutex_lock(&vsock_register_mutex); - misc_deregister(&vsock_device); sock_unregister(AF_VSOCK); proto_unregister(&vsock_proto); - - /* We do not want the assignment below re-ordered. */ - mb(); - transport = NULL; - - mutex_unlock(&vsock_register_mutex); } -EXPORT_SYMBOL_GPL(vsock_core_exit); -const struct vsock_transport *vsock_core_get_transport(void) +const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk) { - /* vsock_register_mutex not taken since only the transport uses this - * function and only while registered. - */ - return transport; + return vsk->transport; } EXPORT_SYMBOL_GPL(vsock_core_get_transport); -static void __exit vsock_exit(void) +int vsock_core_register(const struct vsock_transport *t, int features) +{ + const struct vsock_transport *t_h2g, *t_g2h, *t_dgram; + int err = mutex_lock_interruptible(&vsock_register_mutex); + + if (err) + return err; + + t_h2g = transport_h2g; + t_g2h = transport_g2h; + t_dgram = transport_dgram; + + if (features & VSOCK_TRANSPORT_F_H2G) { + if (t_h2g) { + err = -EBUSY; + goto err_busy; + } + t_h2g = t; + } + + if (features & VSOCK_TRANSPORT_F_G2H) { + if (t_g2h) { + err = -EBUSY; + goto err_busy; + } + t_g2h = t; + } + + if (features & VSOCK_TRANSPORT_F_DGRAM) { + if (t_dgram) { + err = -EBUSY; + goto err_busy; + } + t_dgram = t; + } + + transport_h2g = t_h2g; + transport_g2h = t_g2h; + transport_dgram = t_dgram; + +err_busy: + mutex_unlock(&vsock_register_mutex); + return err; +} +EXPORT_SYMBOL_GPL(vsock_core_register); + +void vsock_core_unregister(const struct vsock_transport *t) { - /* Do nothing. This function makes this module removable. */ + mutex_lock(&vsock_register_mutex); + + if (transport_h2g == t) + transport_h2g = NULL; + + if (transport_g2h == t) + transport_g2h = NULL; + + if (transport_dgram == t) + transport_dgram = NULL; + + mutex_unlock(&vsock_register_mutex); } +EXPORT_SYMBOL_GPL(vsock_core_unregister); -module_init(vsock_init_tables); +module_init(vsock_init); module_exit(vsock_exit); MODULE_AUTHOR("VMware, Inc."); |