summaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/af_vsock.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
-rw-r--r--net/vmw_vsock/af_vsock.c397
1 files changed, 295 insertions, 102 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 582a3e4dfce2..74db4cd637a7 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -126,19 +126,18 @@ static struct proto vsock_proto = {
*/
#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
-static const struct vsock_transport *transport;
+#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
+
+/* Transport used for host->guest communication */
+static const struct vsock_transport *transport_h2g;
+/* Transport used for guest->host communication */
+static const struct vsock_transport *transport_g2h;
+/* Transport used for DGRAM communication */
+static const struct vsock_transport *transport_dgram;
static DEFINE_MUTEX(vsock_register_mutex);
-/**** EXPORTS ****/
-
-/* Get the ID of the local context. This is transport dependent. */
-
-int vm_sockets_get_local_cid(void)
-{
- return transport->get_local_cid();
-}
-EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
-
/**** UTILS ****/
/* Each bound VSocket is stored in the bind hash table and each connected
@@ -188,7 +187,7 @@ static int vsock_auto_bind(struct vsock_sock *vsk)
return __vsock_bind(sk, &local_addr);
}
-static int __init vsock_init_tables(void)
+static void vsock_init_tables(void)
{
int i;
@@ -197,7 +196,6 @@ static int __init vsock_init_tables(void)
for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
INIT_LIST_HEAD(&vsock_connected_table[i]);
- return 0;
}
static void __vsock_insert_bound(struct list_head *list,
@@ -230,9 +228,15 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
{
struct vsock_sock *vsk;
- list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
- if (addr->svm_port == vsk->local_addr.svm_port)
+ list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
+ if (vsock_addr_equals_addr(addr, &vsk->local_addr))
+ return sk_vsock(vsk);
+
+ if (addr->svm_port == vsk->local_addr.svm_port &&
+ (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
+ addr->svm_cid == VMADDR_CID_ANY))
return sk_vsock(vsk);
+ }
return NULL;
}
@@ -382,6 +386,88 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
}
EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
+static void vsock_deassign_transport(struct vsock_sock *vsk)
+{
+ if (!vsk->transport)
+ return;
+
+ vsk->transport->destruct(vsk);
+ module_put(vsk->transport->module);
+ vsk->transport = NULL;
+}
+
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for stream socket this must be called when vsk->remote_addr is set
+ * (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ * - remote CID <= VMADDR_CID_HOST will use guest->host transport;
+ * - remote CID == local_cid (guest->host transport) will use guest->host
+ * transport for loopback (host->guest transports don't support loopback);
+ * - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+ const struct vsock_transport *new_transport;
+ struct sock *sk = sk_vsock(vsk);
+ unsigned int remote_cid = vsk->remote_addr.svm_cid;
+ int ret;
+
+ switch (sk->sk_type) {
+ case SOCK_DGRAM:
+ new_transport = transport_dgram;
+ break;
+ case SOCK_STREAM:
+ if (remote_cid <= VMADDR_CID_HOST ||
+ (transport_g2h &&
+ remote_cid == transport_g2h->get_local_cid()))
+ new_transport = transport_g2h;
+ else
+ new_transport = transport_h2g;
+ break;
+ default:
+ return -ESOCKTNOSUPPORT;
+ }
+
+ if (vsk->transport) {
+ if (vsk->transport == new_transport)
+ return 0;
+
+ vsk->transport->release(vsk);
+ vsock_deassign_transport(vsk);
+ }
+
+ /* We increase the module refcnt to prevent the transport unloading
+ * while there are open sockets assigned to it.
+ */
+ if (!new_transport || !try_module_get(new_transport->module))
+ return -ENODEV;
+
+ ret = new_transport->init(vsk, psk);
+ if (ret) {
+ module_put(new_transport->module);
+ return ret;
+ }
+
+ vsk->transport = new_transport;
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+bool vsock_find_cid(unsigned int cid)
+{
+ if (transport_g2h && cid == transport_g2h->get_local_cid())
+ return true;
+
+ if (transport_h2g && cid == VMADDR_CID_HOST)
+ return true;
+
+ return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+
static struct sock *vsock_dequeue_accept(struct sock *listener)
{
struct vsock_sock *vlistener;
@@ -418,7 +504,12 @@ static bool vsock_is_pending(struct sock *sk)
static int vsock_send_shutdown(struct sock *sk, int mode)
{
- return transport->shutdown(vsock_sk(sk), mode);
+ struct vsock_sock *vsk = vsock_sk(sk);
+
+ if (!vsk->transport)
+ return -ENODEV;
+
+ return vsk->transport->shutdown(vsk, mode);
}
static void vsock_pending_work(struct work_struct *work)
@@ -439,7 +530,7 @@ static void vsock_pending_work(struct work_struct *work)
if (vsock_is_pending(sk)) {
vsock_remove_pending(listener, sk);
- listener->sk_ack_backlog--;
+ sk_acceptq_removed(listener);
} else if (!vsk->rejected) {
/* We are not on the pending list and accept() did not reject
* us, so we must have been accepted by our user process. We
@@ -528,13 +619,12 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
static int __vsock_bind_dgram(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
- return transport->dgram_bind(vsk, addr);
+ return vsk->transport->dgram_bind(vsk, addr);
}
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
{
struct vsock_sock *vsk = vsock_sk(sk);
- u32 cid;
int retval;
/* First ensure this socket isn't already bound. */
@@ -544,10 +634,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
/* Now bind to the provided address or select appropriate values if
* none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that
* like AF_INET prevents binding to a non-local IP address (in most
- * cases), we only allow binding to the local CID.
+ * cases), we only allow binding to a local CID.
*/
- cid = transport->get_local_cid();
- if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
+ if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
return -EADDRNOTAVAIL;
switch (sk->sk_socket->type) {
@@ -571,12 +660,12 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
static void vsock_connect_timeout(struct work_struct *work);
-struct sock *__vsock_create(struct net *net,
- struct socket *sock,
- struct sock *parent,
- gfp_t priority,
- unsigned short type,
- int kern)
+static struct sock *__vsock_create(struct net *net,
+ struct socket *sock,
+ struct sock *parent,
+ gfp_t priority,
+ unsigned short type,
+ int kern)
{
struct sock *sk;
struct vsock_sock *psk;
@@ -620,28 +709,24 @@ struct sock *__vsock_create(struct net *net,
vsk->trusted = psk->trusted;
vsk->owner = get_cred(psk->owner);
vsk->connect_timeout = psk->connect_timeout;
+ vsk->buffer_size = psk->buffer_size;
+ vsk->buffer_min_size = psk->buffer_min_size;
+ vsk->buffer_max_size = psk->buffer_max_size;
} else {
vsk->trusted = capable(CAP_NET_ADMIN);
vsk->owner = get_current_cred();
vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
+ vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE;
+ vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE;
+ vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
}
- if (transport->init(vsk, psk) < 0) {
- sk_free(sk);
- return NULL;
- }
-
- if (sock)
- vsock_insert_unbound(vsk);
-
return sk;
}
-EXPORT_SYMBOL_GPL(__vsock_create);
static void __vsock_release(struct sock *sk, int level)
{
if (sk) {
- struct sk_buff *skb;
struct sock *pending;
struct vsock_sock *vsk;
@@ -651,7 +736,10 @@ static void __vsock_release(struct sock *sk, int level)
/* The release call is supposed to use lock_sock_nested()
* rather than lock_sock(), if a sock lock should be acquired.
*/
- transport->release(vsk);
+ if (vsk->transport)
+ vsk->transport->release(vsk);
+ else if (sk->sk_type == SOCK_STREAM)
+ vsock_remove_sock(vsk);
/* When "level" is SINGLE_DEPTH_NESTING, use the nested
* version to avoid the warning "possible recursive locking
@@ -662,8 +750,7 @@ static void __vsock_release(struct sock *sk, int level)
sock_orphan(sk);
sk->sk_shutdown = SHUTDOWN_MASK;
- while ((skb = skb_dequeue(&sk->sk_receive_queue)))
- kfree_skb(skb);
+ skb_queue_purge(&sk->sk_receive_queue);
/* Clean up any sockets that never were accepted. */
while ((pending = vsock_dequeue_accept(sk)) != NULL) {
@@ -680,7 +767,7 @@ static void vsock_sk_destruct(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
- transport->destruct(vsk);
+ vsock_deassign_transport(vsk);
/* When clearing these addresses, there's no need to set the family and
* possibly register the address family with the kernel.
@@ -702,15 +789,22 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
return err;
}
+struct sock *vsock_create_connected(struct sock *parent)
+{
+ return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL,
+ parent->sk_type, 0);
+}
+EXPORT_SYMBOL_GPL(vsock_create_connected);
+
s64 vsock_stream_has_data(struct vsock_sock *vsk)
{
- return transport->stream_has_data(vsk);
+ return vsk->transport->stream_has_data(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_data);
s64 vsock_stream_has_space(struct vsock_sock *vsk)
{
- return transport->stream_has_space(vsk);
+ return vsk->transport->stream_has_space(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_space);
@@ -879,6 +973,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
} else if (sock->type == SOCK_STREAM) {
+ const struct vsock_transport *transport = vsk->transport;
lock_sock(sk);
/* Listening sockets that have connections in their accept
@@ -889,7 +984,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
mask |= EPOLLIN | EPOLLRDNORM;
/* If there is something in the queue then we can read. */
- if (transport->stream_is_active(vsk) &&
+ if (transport && transport->stream_is_active(vsk) &&
!(sk->sk_shutdown & RCV_SHUTDOWN)) {
bool data_ready_now = false;
int ret = transport->notify_poll_in(
@@ -954,6 +1049,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
struct sock *sk;
struct vsock_sock *vsk;
struct sockaddr_vm *remote_addr;
+ const struct vsock_transport *transport;
if (msg->msg_flags & MSG_OOB)
return -EOPNOTSUPP;
@@ -962,6 +1058,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
err = 0;
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
lock_sock(sk);
@@ -1046,8 +1143,8 @@ static int vsock_dgram_connect(struct socket *sock,
if (err)
goto out;
- if (!transport->dgram_allow(remote_addr->svm_cid,
- remote_addr->svm_port)) {
+ if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
+ remote_addr->svm_port)) {
err = -EINVAL;
goto out;
}
@@ -1063,7 +1160,9 @@ out:
static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
size_t len, int flags)
{
- return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags);
+ struct vsock_sock *vsk = vsock_sk(sock->sk);
+
+ return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
}
static const struct proto_ops vsock_dgram_ops = {
@@ -1089,6 +1188,8 @@ static const struct proto_ops vsock_dgram_ops = {
static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
{
+ const struct vsock_transport *transport = vsk->transport;
+
if (!transport->cancel_pkt)
return -EOPNOTSUPP;
@@ -1125,6 +1226,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
int err;
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
struct sockaddr_vm *remote_addr;
long timeout;
DEFINE_WAIT(wait);
@@ -1159,19 +1261,26 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
goto out;
}
+ /* Set the remote address that we are connecting to. */
+ memcpy(&vsk->remote_addr, remote_addr,
+ sizeof(vsk->remote_addr));
+
+ err = vsock_assign_transport(vsk, NULL);
+ if (err)
+ goto out;
+
+ transport = vsk->transport;
+
/* The hypervisor and well-known contexts do not have socket
* endpoints.
*/
- if (!transport->stream_allow(remote_addr->svm_cid,
+ if (!transport ||
+ !transport->stream_allow(remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -ENETUNREACH;
goto out;
}
- /* Set the remote address that we are connecting to. */
- memcpy(&vsk->remote_addr, remote_addr,
- sizeof(vsk->remote_addr));
-
err = vsock_auto_bind(vsk);
if (err)
goto out;
@@ -1301,7 +1410,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
err = -listener->sk_err;
if (connected) {
- listener->sk_ack_backlog--;
+ sk_acceptq_removed(listener);
lock_sock_nested(connected, SINGLE_DEPTH_NESTING);
vconnected = vsock_sk(connected);
@@ -1366,6 +1475,23 @@ out:
return err;
}
+static void vsock_update_buffer_size(struct vsock_sock *vsk,
+ const struct vsock_transport *transport,
+ u64 val)
+{
+ if (val > vsk->buffer_max_size)
+ val = vsk->buffer_max_size;
+
+ if (val < vsk->buffer_min_size)
+ val = vsk->buffer_min_size;
+
+ if (val != vsk->buffer_size &&
+ transport && transport->notify_buffer_size)
+ transport->notify_buffer_size(vsk, &val);
+
+ vsk->buffer_size = val;
+}
+
static int vsock_stream_setsockopt(struct socket *sock,
int level,
int optname,
@@ -1375,6 +1501,7 @@ static int vsock_stream_setsockopt(struct socket *sock,
int err;
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
u64 val;
if (level != AF_VSOCK)
@@ -1395,23 +1522,26 @@ static int vsock_stream_setsockopt(struct socket *sock,
err = 0;
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
lock_sock(sk);
switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE:
COPY_IN(val);
- transport->set_buffer_size(vsk, val);
+ vsock_update_buffer_size(vsk, transport, val);
break;
case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
COPY_IN(val);
- transport->set_max_buffer_size(vsk, val);
+ vsk->buffer_max_size = val;
+ vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
break;
case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
COPY_IN(val);
- transport->set_min_buffer_size(vsk, val);
+ vsk->buffer_min_size = val;
+ vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
break;
case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
@@ -1478,17 +1608,17 @@ static int vsock_stream_getsockopt(struct socket *sock,
switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE:
- val = transport->get_buffer_size(vsk);
+ val = vsk->buffer_size;
COPY_OUT(val);
break;
case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
- val = transport->get_max_buffer_size(vsk);
+ val = vsk->buffer_max_size;
COPY_OUT(val);
break;
case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
- val = transport->get_min_buffer_size(vsk);
+ val = vsk->buffer_min_size;
COPY_OUT(val);
break;
@@ -1519,6 +1649,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
{
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
ssize_t total_written;
long timeout;
int err;
@@ -1527,6 +1658,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
total_written = 0;
err = 0;
@@ -1548,7 +1680,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
- if (sk->sk_state != TCP_ESTABLISHED ||
+ if (!transport || sk->sk_state != TCP_ESTABLISHED ||
!vsock_addr_bound(&vsk->local_addr)) {
err = -ENOTCONN;
goto out;
@@ -1658,6 +1790,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
{
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
int err;
size_t target;
ssize_t copied;
@@ -1668,11 +1801,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
sk = sock->sk;
vsk = vsock_sk(sk);
+ transport = vsk->transport;
err = 0;
lock_sock(sk);
- if (sk->sk_state != TCP_ESTABLISHED) {
+ if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
* orderly shutdown. Differentiate between that case and when a
* peer has not connected or a local shutdown occured with the
@@ -1846,6 +1980,10 @@ static const struct proto_ops vsock_stream_ops = {
static int vsock_create(struct net *net, struct socket *sock,
int protocol, int kern)
{
+ struct vsock_sock *vsk;
+ struct sock *sk;
+ int ret;
+
if (!sock)
return -EINVAL;
@@ -1865,7 +2003,23 @@ static int vsock_create(struct net *net, struct socket *sock,
sock->state = SS_UNCONNECTED;
- return __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern) ? 0 : -ENOMEM;
+ sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern);
+ if (!sk)
+ return -ENOMEM;
+
+ vsk = vsock_sk(sk);
+
+ if (sock->type == SOCK_DGRAM) {
+ ret = vsock_assign_transport(vsk, NULL);
+ if (ret < 0) {
+ sock_put(sk);
+ return ret;
+ }
+ }
+
+ vsock_insert_unbound(vsk);
+
+ return 0;
}
static const struct net_proto_family vsock_family_ops = {
@@ -1878,11 +2032,20 @@ static long vsock_dev_do_ioctl(struct file *filp,
unsigned int cmd, void __user *ptr)
{
u32 __user *p = ptr;
+ u32 cid = VMADDR_CID_ANY;
int retval = 0;
switch (cmd) {
case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
- if (put_user(transport->get_local_cid(), p) != 0)
+ /* To be compatible with the VMCI behavior, we prioritize the
+ * guest CID instead of well-know host CID (VMADDR_CID_HOST).
+ */
+ if (transport_g2h)
+ cid = transport_g2h->get_local_cid();
+ else if (transport_h2g)
+ cid = transport_h2g->get_local_cid();
+
+ if (put_user(cid, p) != 0)
retval = -EFAULT;
break;
@@ -1922,24 +2085,13 @@ static struct miscdevice vsock_device = {
.fops = &vsock_device_ops,
};
-int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
+static int __init vsock_init(void)
{
- int err = mutex_lock_interruptible(&vsock_register_mutex);
+ int err = 0;
- if (err)
- return err;
-
- if (transport) {
- err = -EBUSY;
- goto err_busy;
- }
-
- /* Transport must be the owner of the protocol so that it can't
- * unload while there are open sockets.
- */
- vsock_proto.owner = owner;
- transport = t;
+ vsock_init_tables();
+ vsock_proto.owner = THIS_MODULE;
vsock_device.minor = MISC_DYNAMIC_MINOR;
err = misc_register(&vsock_device);
if (err) {
@@ -1960,7 +2112,6 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
goto err_unregister_proto;
}
- mutex_unlock(&vsock_register_mutex);
return 0;
err_unregister_proto:
@@ -1968,44 +2119,86 @@ err_unregister_proto:
err_deregister_misc:
misc_deregister(&vsock_device);
err_reset_transport:
- transport = NULL;
-err_busy:
- mutex_unlock(&vsock_register_mutex);
return err;
}
-EXPORT_SYMBOL_GPL(__vsock_core_init);
-void vsock_core_exit(void)
+static void __exit vsock_exit(void)
{
- mutex_lock(&vsock_register_mutex);
-
misc_deregister(&vsock_device);
sock_unregister(AF_VSOCK);
proto_unregister(&vsock_proto);
-
- /* We do not want the assignment below re-ordered. */
- mb();
- transport = NULL;
-
- mutex_unlock(&vsock_register_mutex);
}
-EXPORT_SYMBOL_GPL(vsock_core_exit);
-const struct vsock_transport *vsock_core_get_transport(void)
+const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
{
- /* vsock_register_mutex not taken since only the transport uses this
- * function and only while registered.
- */
- return transport;
+ return vsk->transport;
}
EXPORT_SYMBOL_GPL(vsock_core_get_transport);
-static void __exit vsock_exit(void)
+int vsock_core_register(const struct vsock_transport *t, int features)
+{
+ const struct vsock_transport *t_h2g, *t_g2h, *t_dgram;
+ int err = mutex_lock_interruptible(&vsock_register_mutex);
+
+ if (err)
+ return err;
+
+ t_h2g = transport_h2g;
+ t_g2h = transport_g2h;
+ t_dgram = transport_dgram;
+
+ if (features & VSOCK_TRANSPORT_F_H2G) {
+ if (t_h2g) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_h2g = t;
+ }
+
+ if (features & VSOCK_TRANSPORT_F_G2H) {
+ if (t_g2h) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_g2h = t;
+ }
+
+ if (features & VSOCK_TRANSPORT_F_DGRAM) {
+ if (t_dgram) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_dgram = t;
+ }
+
+ transport_h2g = t_h2g;
+ transport_g2h = t_g2h;
+ transport_dgram = t_dgram;
+
+err_busy:
+ mutex_unlock(&vsock_register_mutex);
+ return err;
+}
+EXPORT_SYMBOL_GPL(vsock_core_register);
+
+void vsock_core_unregister(const struct vsock_transport *t)
{
- /* Do nothing. This function makes this module removable. */
+ mutex_lock(&vsock_register_mutex);
+
+ if (transport_h2g == t)
+ transport_h2g = NULL;
+
+ if (transport_g2h == t)
+ transport_g2h = NULL;
+
+ if (transport_dgram == t)
+ transport_dgram = NULL;
+
+ mutex_unlock(&vsock_register_mutex);
}
+EXPORT_SYMBOL_GPL(vsock_core_unregister);
-module_init(vsock_init_tables);
+module_init(vsock_init);
module_exit(vsock_exit);
MODULE_AUTHOR("VMware, Inc.");