summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--net/sunrpc/Makefile2
-rw-r--r--net/sunrpc/auth.c2
-rw-r--r--net/sunrpc/auth_tls.c175
-rw-r--r--net/sunrpc/clnt.c22
-rw-r--r--net/sunrpc/rpcb_clnt.c39
-rw-r--r--net/sunrpc/sysfs.c1
-rw-r--r--net/sunrpc/sysfs.h7
-rw-r--r--net/sunrpc/xprtsock.c434
8 files changed, 658 insertions, 24 deletions
diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile
index 1c8de397d6ad..f89c10fe7e6a 100644
--- a/net/sunrpc/Makefile
+++ b/net/sunrpc/Makefile
@@ -9,7 +9,7 @@ obj-$(CONFIG_SUNRPC_GSS) += auth_gss/
obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/
sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \
- auth.o auth_null.o auth_unix.o \
+ auth.o auth_null.o auth_tls.o auth_unix.o \
svc.o svcsock.o svcauth.o svcauth_unix.o \
addr.o rpcb_clnt.o timer.o xdr.o \
sunrpc_syms.o cache.o rpc_pipe.o sysfs.o \
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c
index fb75a883503f..2f16f9d17966 100644
--- a/net/sunrpc/auth.c
+++ b/net/sunrpc/auth.c
@@ -32,7 +32,7 @@ static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
[RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
[RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
- NULL, /* others can be loadable modules */
+ [RPC_AUTH_TLS] = (const struct rpc_authops __force __rcu *)&authtls_ops,
};
static LIST_HEAD(cred_unused);
diff --git a/net/sunrpc/auth_tls.c b/net/sunrpc/auth_tls.c
new file mode 100644
index 000000000000..de7678f8a23d
--- /dev/null
+++ b/net/sunrpc/auth_tls.c
@@ -0,0 +1,175 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (c) 2021, 2022 Oracle. All rights reserved.
+ *
+ * The AUTH_TLS credential is used only to probe a remote peer
+ * for RPC-over-TLS support.
+ */
+
+#include <linux/types.h>
+#include <linux/module.h>
+#include <linux/sunrpc/clnt.h>
+
+static const char *starttls_token = "STARTTLS";
+static const size_t starttls_len = 8;
+
+static struct rpc_auth tls_auth;
+static struct rpc_cred tls_cred;
+
+static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
+ const void *obj)
+{
+}
+
+static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
+ void *obj)
+{
+ return 0;
+}
+
+static const struct rpc_procinfo rpcproc_tls_probe = {
+ .p_encode = tls_encode_probe,
+ .p_decode = tls_decode_probe,
+};
+
+static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
+{
+ task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
+ rpc_call_start(task);
+}
+
+static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
+{
+}
+
+static const struct rpc_call_ops rpc_tls_probe_ops = {
+ .rpc_call_prepare = rpc_tls_probe_call_prepare,
+ .rpc_call_done = rpc_tls_probe_call_done,
+};
+
+static int tls_probe(struct rpc_clnt *clnt)
+{
+ struct rpc_message msg = {
+ .rpc_proc = &rpcproc_tls_probe,
+ };
+ struct rpc_task_setup task_setup_data = {
+ .rpc_client = clnt,
+ .rpc_message = &msg,
+ .rpc_op_cred = &tls_cred,
+ .callback_ops = &rpc_tls_probe_ops,
+ .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
+ };
+ struct rpc_task *task;
+ int status;
+
+ task = rpc_run_task(&task_setup_data);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+ status = task->tk_status;
+ rpc_put_task(task);
+ return status;
+}
+
+static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
+ struct rpc_clnt *clnt)
+{
+ refcount_inc(&tls_auth.au_count);
+ return &tls_auth;
+}
+
+static void tls_destroy(struct rpc_auth *auth)
+{
+}
+
+static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
+ struct auth_cred *acred, int flags)
+{
+ return get_rpccred(&tls_cred);
+}
+
+static void tls_destroy_cred(struct rpc_cred *cred)
+{
+}
+
+static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
+{
+ return 1;
+}
+
+static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ __be32 *p;
+
+ p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
+ if (!p)
+ return -EMSGSIZE;
+ /* Credential */
+ *p++ = rpc_auth_tls;
+ *p++ = xdr_zero;
+ /* Verifier */
+ *p++ = rpc_auth_null;
+ *p = xdr_zero;
+ return 0;
+}
+
+static int tls_refresh(struct rpc_task *task)
+{
+ set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
+ return 0;
+}
+
+static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
+{
+ __be32 *p;
+ void *str;
+
+ p = xdr_inline_decode(xdr, XDR_UNIT);
+ if (!p)
+ return -EIO;
+ if (*p != rpc_auth_null)
+ return -EIO;
+ if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
+ return -EIO;
+ if (memcmp(str, starttls_token, starttls_len))
+ return -EIO;
+ return 0;
+}
+
+const struct rpc_authops authtls_ops = {
+ .owner = THIS_MODULE,
+ .au_flavor = RPC_AUTH_TLS,
+ .au_name = "NULL",
+ .create = tls_create,
+ .destroy = tls_destroy,
+ .lookup_cred = tls_lookup_cred,
+ .ping = tls_probe,
+};
+
+static struct rpc_auth tls_auth = {
+ .au_cslack = NUL_CALLSLACK,
+ .au_rslack = NUL_REPLYSLACK,
+ .au_verfsize = NUL_REPLYSLACK,
+ .au_ralign = NUL_REPLYSLACK,
+ .au_ops = &authtls_ops,
+ .au_flavor = RPC_AUTH_TLS,
+ .au_count = REFCOUNT_INIT(1),
+};
+
+static const struct rpc_credops tls_credops = {
+ .cr_name = "AUTH_TLS",
+ .crdestroy = tls_destroy_cred,
+ .crmatch = tls_match,
+ .crmarshal = tls_marshal,
+ .crwrap_req = rpcauth_wrap_req_encode,
+ .crrefresh = tls_refresh,
+ .crvalidate = tls_validate,
+ .crunwrap_resp = rpcauth_unwrap_resp_decode,
+};
+
+static struct rpc_cred tls_cred = {
+ .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru),
+ .cr_auth = &tls_auth,
+ .cr_ops = &tls_credops,
+ .cr_count = REFCOUNT_INIT(2),
+ .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
+};
diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c
index d2ee56634308..d7c697af3762 100644
--- a/net/sunrpc/clnt.c
+++ b/net/sunrpc/clnt.c
@@ -385,6 +385,7 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args,
if (!clnt)
goto out_err;
clnt->cl_parent = parent ? : clnt;
+ clnt->cl_xprtsec = args->xprtsec;
err = rpc_alloc_clid(clnt);
if (err)
@@ -434,7 +435,7 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args,
if (parent)
refcount_inc(&parent->cl_count);
- trace_rpc_clnt_new(clnt, xprt, program->name, args->servername);
+ trace_rpc_clnt_new(clnt, xprt, args);
return clnt;
out_no_path:
@@ -532,6 +533,7 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args)
.addrlen = args->addrsize,
.servername = args->servername,
.bc_xprt = args->bc_xprt,
+ .xprtsec = args->xprtsec,
};
char servername[48];
struct rpc_clnt *clnt;
@@ -565,8 +567,12 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args)
servername[0] = '\0';
switch (args->address->sa_family) {
case AF_LOCAL:
- snprintf(servername, sizeof(servername), "%s",
- sun->sun_path);
+ if (sun->sun_path[0])
+ snprintf(servername, sizeof(servername), "%s",
+ sun->sun_path);
+ else
+ snprintf(servername, sizeof(servername), "@%s",
+ sun->sun_path+1);
break;
case AF_INET:
snprintf(servername, sizeof(servername), "%pI4",
@@ -727,6 +733,7 @@ int rpc_switch_client_transport(struct rpc_clnt *clnt,
struct rpc_clnt *parent;
int err;
+ args->xprtsec = clnt->cl_xprtsec;
xprt = xprt_create_transport(args);
if (IS_ERR(xprt))
return PTR_ERR(xprt);
@@ -1717,6 +1724,11 @@ call_start(struct rpc_task *task)
trace_rpc_request(task);
+ if (task->tk_client->cl_shutdown) {
+ rpc_call_rpcerror(task, -EIO);
+ return;
+ }
+
/* Increment call count (version might not be valid for ping) */
if (clnt->cl_program->version[clnt->cl_vers])
clnt->cl_program->version[clnt->cl_vers]->counts[idx]++;
@@ -2826,6 +2838,9 @@ static int rpc_ping(struct rpc_clnt *clnt)
struct rpc_task *task;
int status;
+ if (clnt->cl_auth->au_ops->ping)
+ return clnt->cl_auth->au_ops->ping(clnt);
+
task = rpc_call_null_helper(clnt, NULL, NULL, 0, NULL, NULL);
if (IS_ERR(task))
return PTR_ERR(task);
@@ -3046,6 +3061,7 @@ int rpc_clnt_add_xprt(struct rpc_clnt *clnt,
if (!xprtargs->ident)
xprtargs->ident = ident;
+ xprtargs->xprtsec = clnt->cl_xprtsec;
xprt = xprt_create_transport(xprtargs);
if (IS_ERR(xprt)) {
ret = PTR_ERR(xprt);
diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c
index 5a8e6d46809a..5988a5c5ff3f 100644
--- a/net/sunrpc/rpcb_clnt.c
+++ b/net/sunrpc/rpcb_clnt.c
@@ -36,6 +36,7 @@
#include "netns.h"
#define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock"
+#define RPCBIND_SOCK_ABSTRACT_NAME "\0/run/rpcbind.sock"
#define RPCBIND_PROGRAM (100000u)
#define RPCBIND_PORT (111u)
@@ -216,21 +217,22 @@ static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt,
sn->rpcb_users = 1;
}
+/* Evaluate to actual length of the `sockaddr_un' structure. */
+# define SUN_LEN(ptr) (offsetof(struct sockaddr_un, sun_path) \
+ + 1 + strlen((ptr)->sun_path + 1))
+
/*
* Returns zero on success, otherwise a negative errno value
* is returned.
*/
-static int rpcb_create_local_unix(struct net *net)
+static int rpcb_create_af_local(struct net *net,
+ const struct sockaddr_un *addr)
{
- static const struct sockaddr_un rpcb_localaddr_rpcbind = {
- .sun_family = AF_LOCAL,
- .sun_path = RPCBIND_SOCK_PATHNAME,
- };
struct rpc_create_args args = {
.net = net,
.protocol = XPRT_TRANSPORT_LOCAL,
- .address = (struct sockaddr *)&rpcb_localaddr_rpcbind,
- .addrsize = sizeof(rpcb_localaddr_rpcbind),
+ .address = (struct sockaddr *)addr,
+ .addrsize = SUN_LEN(addr),
.servername = "localhost",
.program = &rpcb_program,
.version = RPCBVERS_2,
@@ -269,6 +271,26 @@ out:
return result;
}
+static int rpcb_create_local_abstract(struct net *net)
+{
+ static const struct sockaddr_un rpcb_localaddr_abstract = {
+ .sun_family = AF_LOCAL,
+ .sun_path = RPCBIND_SOCK_ABSTRACT_NAME,
+ };
+
+ return rpcb_create_af_local(net, &rpcb_localaddr_abstract);
+}
+
+static int rpcb_create_local_unix(struct net *net)
+{
+ static const struct sockaddr_un rpcb_localaddr_unix = {
+ .sun_family = AF_LOCAL,
+ .sun_path = RPCBIND_SOCK_PATHNAME,
+ };
+
+ return rpcb_create_af_local(net, &rpcb_localaddr_unix);
+}
+
/*
* Returns zero on success, otherwise a negative errno value
* is returned.
@@ -332,7 +354,8 @@ int rpcb_create_local(struct net *net)
if (rpcb_get_local(net))
goto out;
- if (rpcb_create_local_unix(net) != 0)
+ if (rpcb_create_local_abstract(net) != 0 &&
+ rpcb_create_local_unix(net) != 0)
result = rpcb_create_local_net(net);
out:
diff --git a/net/sunrpc/sysfs.c b/net/sunrpc/sysfs.c
index 0d0db4e1064e..5c8ecdaaa985 100644
--- a/net/sunrpc/sysfs.c
+++ b/net/sunrpc/sysfs.c
@@ -239,6 +239,7 @@ static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj,
if (!xprt)
return 0;
if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP ||
+ xprt->xprt_class->ident == XPRT_TRANSPORT_TCP_TLS ||
xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) {
xprt_put(xprt);
return -EOPNOTSUPP;
diff --git a/net/sunrpc/sysfs.h b/net/sunrpc/sysfs.h
index 6620cebd1037..d2dd77a0a0e9 100644
--- a/net/sunrpc/sysfs.h
+++ b/net/sunrpc/sysfs.h
@@ -5,13 +5,6 @@
#ifndef __SUNRPC_SYSFS_H
#define __SUNRPC_SYSFS_H
-struct rpc_sysfs_client {
- struct kobject kobject;
- struct net *net;
- struct rpc_clnt *clnt;
- struct rpc_xprt_switch *xprt_switch;
-};
-
struct rpc_sysfs_xprt_switch {
struct kobject kobject;
struct net *net;
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
index 5f9030b81c9e..9f010369100a 100644
--- a/net/sunrpc/xprtsock.c
+++ b/net/sunrpc/xprtsock.c
@@ -47,6 +47,9 @@
#include <net/checksum.h>
#include <net/udp.h>
#include <net/tcp.h>
+#include <net/tls.h>
+#include <net/handshake.h>
+
#include <linux/bvec.h>
#include <linux/highmem.h>
#include <linux/uio.h>
@@ -96,6 +99,7 @@ static struct ctl_table_header *sunrpc_table_header;
static struct xprt_class xs_local_transport;
static struct xprt_class xs_udp_transport;
static struct xprt_class xs_tcp_transport;
+static struct xprt_class xs_tcp_tls_transport;
static struct xprt_class xs_bc_tcp_transport;
/*
@@ -187,6 +191,11 @@ static struct ctl_table xs_tunables_table[] = {
*/
#define XS_IDLE_DISC_TO (5U * 60 * HZ)
+/*
+ * TLS handshake timeout.
+ */
+#define XS_TLS_HANDSHAKE_TO (10U * HZ)
+
#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
# undef RPC_DEBUG_DATA
# define RPCDBG_FACILITY RPCDBG_TRANS
@@ -253,7 +262,12 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt)
switch (sap->sa_family) {
case AF_LOCAL:
sun = xs_addr_un(xprt);
- strscpy(buf, sun->sun_path, sizeof(buf));
+ if (sun->sun_path[0]) {
+ strscpy(buf, sun->sun_path, sizeof(buf));
+ } else {
+ buf[0] = '@';
+ strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1);
+ }
xprt->address_strings[RPC_DISPLAY_ADDR] =
kstrdup(buf, GFP_KERNEL);
break;
@@ -342,13 +356,56 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp)
return want;
}
+static int
+xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
+ struct cmsghdr *cmsg, int ret)
+{
+ if (cmsg->cmsg_level == SOL_TLS &&
+ cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
+ u8 content_type = *((u8 *)CMSG_DATA(cmsg));
+
+ switch (content_type) {
+ case TLS_RECORD_TYPE_DATA:
+ /* TLS sets EOR at the end of each application data
+ * record, even though there might be more frames
+ * waiting to be decrypted.
+ */
+ msg->msg_flags &= ~MSG_EOR;
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ ret = -ENOTCONN;
+ break;
+ default:
+ ret = -EAGAIN;
+ }
+ }
+ return ret;
+}
+
+static int
+xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags)
+{
+ union {
+ struct cmsghdr cmsg;
+ u8 buf[CMSG_SPACE(sizeof(u8))];
+ } u;
+ int ret;
+
+ msg->msg_control = &u;
+ msg->msg_controllen = sizeof(u);
+ ret = sock_recvmsg(sock, msg, flags);
+ if (msg->msg_controllen != sizeof(u))
+ ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret);
+ return ret;
+}
+
static ssize_t
xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek)
{
ssize_t ret;
if (seek != 0)
iov_iter_advance(&msg->msg_iter, seek);
- ret = sock_recvmsg(sock, msg, flags);
+ ret = xs_sock_recv_cmsg(sock, msg, flags);
return ret > 0 ? ret + seek : ret;
}
@@ -374,7 +431,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags,
size_t count)
{
iov_iter_discard(&msg->msg_iter, ITER_DEST, count);
- return sock_recvmsg(sock, msg, flags);
+ return xs_sock_recv_cmsg(sock, msg, flags);
}
#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
@@ -695,6 +752,8 @@ static void xs_poll_check_readable(struct sock_xprt *transport)
{
clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state);
+ if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
+ return;
if (!xs_poll_socket_readable(transport))
return;
if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state))
@@ -1191,6 +1250,8 @@ static void xs_reset_transport(struct sock_xprt *transport)
if (atomic_read(&transport->xprt.swapper))
sk_clear_memalloc(sk);
+ tls_handshake_cancel(sk);
+
kernel_sock_shutdown(sock, SHUT_RDWR);
mutex_lock(&transport->recv_mutex);
@@ -1380,6 +1441,10 @@ static void xs_data_ready(struct sock *sk)
trace_xs_data_ready(xprt);
transport->old_data_ready(sk);
+
+ if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
+ return;
+
/* Any data means we had a useful conversation, so
* then we don't need to delay the next reconnect
*/
@@ -2360,6 +2425,267 @@ out_unlock:
current_restore_flags(pflags, PF_MEMALLOC);
}
+/*
+ * Transfer the connected socket to @upper_transport, then mark that
+ * xprt CONNECTED.
+ */
+static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt,
+ struct sock_xprt *upper_transport)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+
+ if (!upper_transport->inet) {
+ struct socket *sock = lower_transport->sock;
+ struct sock *sk = sock->sk;
+
+ /* Avoid temporary address, they are bad for long-lived
+ * connections such as NFS mounts.
+ * RFC4941, section 3.6 suggests that:
+ * Individual applications, which have specific
+ * knowledge about the normal duration of connections,
+ * MAY override this as appropriate.
+ */
+ if (xs_addr(upper_xprt)->sa_family == PF_INET6)
+ ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC);
+
+ xs_tcp_set_socket_timeouts(upper_xprt, sock);
+ tcp_sock_set_nodelay(sk);
+
+ lock_sock(sk);
+
+ /* @sk is already connected, so it now has the RPC callbacks.
+ * Reach into @lower_transport to save the original ones.
+ */
+ upper_transport->old_data_ready = lower_transport->old_data_ready;
+ upper_transport->old_state_change = lower_transport->old_state_change;
+ upper_transport->old_write_space = lower_transport->old_write_space;
+ upper_transport->old_error_report = lower_transport->old_error_report;
+ sk->sk_user_data = upper_xprt;
+
+ /* socket options */
+ sock_reset_flag(sk, SOCK_LINGER);
+
+ xprt_clear_connected(upper_xprt);
+
+ upper_transport->sock = sock;
+ upper_transport->inet = sk;
+ upper_transport->file = lower_transport->file;
+
+ release_sock(sk);
+
+ /* Reset lower_transport before shutting down its clnt */
+ mutex_lock(&lower_transport->recv_mutex);
+ lower_transport->inet = NULL;
+ lower_transport->sock = NULL;
+ lower_transport->file = NULL;
+
+ xprt_clear_connected(lower_xprt);
+ xs_sock_reset_connection_flags(lower_xprt);
+ xs_stream_reset_connect(lower_transport);
+ mutex_unlock(&lower_transport->recv_mutex);
+ }
+
+ if (!xprt_bound(upper_xprt))
+ return -ENOTCONN;
+
+ xs_set_memalloc(upper_xprt);
+
+ if (!xprt_test_and_set_connected(upper_xprt)) {
+ upper_xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+
+ upper_xprt->stat.connect_count++;
+ upper_xprt->stat.connect_time += (long)jiffies -
+ upper_xprt->stat.connect_start;
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ return 0;
+}
+
+/**
+ * xs_tls_handshake_done - TLS handshake completion handler
+ * @data: address of xprt to wake
+ * @status: status of handshake
+ * @peerid: serial number of key containing the remote's identity
+ *
+ */
+static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
+{
+ struct rpc_xprt *lower_xprt = data;
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+
+ lower_transport->xprt_err = status ? -EACCES : 0;
+ complete(&lower_transport->handshake_done);
+ xprt_put(lower_xprt);
+}
+
+static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct tls_handshake_args args = {
+ .ta_sock = lower_transport->sock,
+ .ta_done = xs_tls_handshake_done,
+ .ta_data = xprt_get(lower_xprt),
+ .ta_peername = lower_xprt->servername,
+ };
+ struct sock *sk = lower_transport->inet;
+ int rc;
+
+ init_completion(&lower_transport->handshake_done);
+ set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+ lower_transport->xprt_err = -ETIMEDOUT;
+ switch (xprtsec->policy) {
+ case RPC_XPRTSEC_TLS_ANON:
+ rc = tls_client_hello_anon(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ case RPC_XPRTSEC_TLS_X509:
+ args.ta_my_cert = xprtsec->cert_serial;
+ args.ta_my_privkey = xprtsec->privkey_serial;
+ rc = tls_client_hello_x509(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ default:
+ rc = -EACCES;
+ goto out_put_xprt;
+ }
+
+ rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
+ XS_TLS_HANDSHAKE_TO);
+ if (rc <= 0) {
+ if (!tls_handshake_cancel(sk)) {
+ if (rc == 0)
+ rc = -ETIMEDOUT;
+ goto out_put_xprt;
+ }
+ }
+
+ rc = lower_transport->xprt_err;
+
+out:
+ xs_stream_reset_connect(lower_transport);
+ clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+ return rc;
+
+out_put_xprt:
+ xprt_put(lower_xprt);
+ goto out;
+}
+
+/**
+ * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket
+ * @work: queued work item
+ *
+ * Invoked by a work queue tasklet.
+ *
+ * For RPC-with-TLS, there is a two-stage connection process.
+ *
+ * The "upper-layer xprt" is visible to the RPC consumer. Once it has
+ * been marked connected, the consumer knows that a TCP connection and
+ * a TLS session have been established.
+ *
+ * A "lower-layer xprt", created in this function, handles the mechanics
+ * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and
+ * then driving the TLS handshake. Once all that is complete, the upper
+ * layer xprt is marked connected.
+ */
+static void xs_tcp_tls_setup_socket(struct work_struct *work)
+{
+ struct sock_xprt *upper_transport =
+ container_of(work, struct sock_xprt, connect_worker.work);
+ struct rpc_clnt *upper_clnt = upper_transport->clnt;
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+ struct rpc_create_args args = {
+ .net = upper_xprt->xprt_net,
+ .protocol = upper_xprt->prot,
+ .address = (struct sockaddr *)&upper_xprt->addr,
+ .addrsize = upper_xprt->addrlen,
+ .timeout = upper_clnt->cl_timeout,
+ .servername = upper_xprt->servername,
+ .program = upper_clnt->cl_program,
+ .prognumber = upper_clnt->cl_prog,
+ .version = upper_clnt->cl_vers,
+ .authflavor = RPC_AUTH_TLS,
+ .cred = upper_clnt->cl_cred,
+ .xprtsec = {
+ .policy = RPC_XPRTSEC_NONE,
+ },
+ };
+ unsigned int pflags = current->flags;
+ struct rpc_clnt *lower_clnt;
+ struct rpc_xprt *lower_xprt;
+ int status;
+
+ if (atomic_read(&upper_xprt->swapper))
+ current->flags |= PF_MEMALLOC;
+
+ xs_stream_start_connect(upper_transport);
+
+ /* This implicitly sends an RPC_AUTH_TLS probe */
+ lower_clnt = rpc_create(&args);
+ if (IS_ERR(lower_clnt)) {
+ trace_rpc_tls_unavailable(upper_clnt, upper_xprt);
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+ xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt));
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ goto out_unlock;
+ }
+
+ /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on
+ * the lower xprt.
+ */
+ rcu_read_lock();
+ lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
+ rcu_read_unlock();
+ status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
+ if (status) {
+ trace_rpc_tls_not_started(upper_clnt, upper_xprt);
+ goto out_close;
+ }
+
+ status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport);
+ if (status)
+ goto out_close;
+
+ trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
+ if (!xprt_test_and_set_connected(upper_xprt)) {
+ upper_xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+
+ upper_xprt->stat.connect_count++;
+ upper_xprt->stat.connect_time += (long)jiffies -
+ upper_xprt->stat.connect_start;
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ rpc_shutdown_client(lower_clnt);
+
+out_unlock:
+ current_restore_flags(pflags, PF_MEMALLOC);
+ upper_transport->clnt = NULL;
+ xprt_unlock_connect(upper_xprt, upper_transport);
+ return;
+
+out_close:
+ rpc_shutdown_client(lower_clnt);
+
+ /* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
+ * Wake them first here to ensure they get our tk_status code.
+ */
+ xprt_wake_pending_tasks(upper_xprt, status);
+ xs_tcp_force_close(upper_xprt);
+ xprt_clear_connecting(upper_xprt);
+ goto out_unlock;
+}
+
/**
* xs_connect - connect a socket to a remote endpoint
* @xprt: pointer to transport structure
@@ -2391,6 +2717,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task)
} else
dprintk("RPC: xs_connect scheduled xprt %p\n", xprt);
+ transport->clnt = task->tk_client;
queue_delayed_work(xprtiod_workqueue,
&transport->connect_worker,
delay);
@@ -2858,7 +3185,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args)
switch (sun->sun_family) {
case AF_LOCAL:
- if (sun->sun_path[0] != '/') {
+ if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') {
dprintk("RPC: bad AF_LOCAL address: %s\n",
sun->sun_path);
ret = ERR_PTR(-EINVAL);
@@ -3045,6 +3372,94 @@ out_err:
}
/**
+ * xs_setup_tcp_tls - Set up transport to use a TCP with TLS
+ * @args: rpc transport creation arguments
+ *
+ */
+static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args)
+{
+ struct sockaddr *addr = args->dstaddr;
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+ struct rpc_xprt *ret;
+ unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries;
+
+ if (args->flags & XPRT_CREATE_INFINITE_SLOTS)
+ max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT;
+
+ xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
+ max_slot_table_size);
+ if (IS_ERR(xprt))
+ return xprt;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+
+ xprt->prot = IPPROTO_TCP;
+ xprt->xprt_class = &xs_tcp_transport;
+ xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
+
+ xprt->bind_timeout = XS_BIND_TO;
+ xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+ xprt->idle_timeout = XS_IDLE_DISC_TO;
+
+ xprt->ops = &xs_tcp_ops;
+ xprt->timeout = &xs_tcp_default_timeout;
+
+ xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+ xprt->connect_timeout = xprt->timeout->to_initval *
+ (xprt->timeout->to_retries + 1);
+
+ INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
+ INIT_WORK(&transport->error_worker, xs_error_handle);
+
+ switch (args->xprtsec.policy) {
+ case RPC_XPRTSEC_TLS_ANON:
+ case RPC_XPRTSEC_TLS_X509:
+ xprt->xprtsec = args->xprtsec;
+ INIT_DELAYED_WORK(&transport->connect_worker,
+ xs_tcp_tls_setup_socket);
+ break;
+ default:
+ ret = ERR_PTR(-EACCES);
+ goto out_err;
+ }
+
+ switch (addr->sa_family) {
+ case AF_INET:
+ if (((struct sockaddr_in *)addr)->sin_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
+ break;
+ case AF_INET6:
+ if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
+ break;
+ default:
+ ret = ERR_PTR(-EAFNOSUPPORT);
+ goto out_err;
+ }
+
+ if (xprt_bound(xprt))
+ dprintk("RPC: set up xprt to %s (port %s) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+ else
+ dprintk("RPC: set up xprt to %s (autobind) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+ if (try_module_get(THIS_MODULE))
+ return xprt;
+ ret = ERR_PTR(-EINVAL);
+out_err:
+ xs_xprt_free(xprt);
+ return ret;
+}
+
+/**
* xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket
* @args: rpc transport creation arguments
*
@@ -3153,6 +3568,15 @@ static struct xprt_class xs_tcp_transport = {
.netid = { "tcp", "tcp6", "" },
};
+static struct xprt_class xs_tcp_tls_transport = {
+ .list = LIST_HEAD_INIT(xs_tcp_tls_transport.list),
+ .name = "tcp-with-tls",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_TCP_TLS,
+ .setup = xs_setup_tcp_tls,
+ .netid = { "tcp", "tcp6", "" },
+};
+
static struct xprt_class xs_bc_tcp_transport = {
.list = LIST_HEAD_INIT(xs_bc_tcp_transport.list),
.name = "tcp NFSv4.1 backchannel",
@@ -3174,6 +3598,7 @@ int init_socket_xprt(void)
xprt_register_transport(&xs_local_transport);
xprt_register_transport(&xs_udp_transport);
xprt_register_transport(&xs_tcp_transport);
+ xprt_register_transport(&xs_tcp_tls_transport);
xprt_register_transport(&xs_bc_tcp_transport);
return 0;
@@ -3193,6 +3618,7 @@ void cleanup_socket_xprt(void)
xprt_unregister_transport(&xs_local_transport);
xprt_unregister_transport(&xs_udp_transport);
xprt_unregister_transport(&xs_tcp_transport);
+ xprt_unregister_transport(&xs_tcp_tls_transport);
xprt_unregister_transport(&xs_bc_tcp_transport);
}