diff options
Diffstat (limited to 'net/handshake')
-rw-r--r-- | net/handshake/Makefile | 2 | ||||
-rw-r--r-- | net/handshake/alert.c | 110 | ||||
-rw-r--r-- | net/handshake/handshake.h | 6 | ||||
-rw-r--r-- | net/handshake/tlshd.c | 23 | ||||
-rw-r--r-- | net/handshake/trace.c | 2 |
5 files changed, 142 insertions, 1 deletions
diff --git a/net/handshake/Makefile b/net/handshake/Makefile index 247d73c6ff6e..ef4d9a2112bd 100644 --- a/net/handshake/Makefile +++ b/net/handshake/Makefile @@ -8,6 +8,6 @@ # obj-y += handshake.o -handshake-y := genl.o netlink.o request.o tlshd.o trace.o +handshake-y := alert.o genl.o netlink.o request.o tlshd.o trace.o obj-$(CONFIG_NET_HANDSHAKE_KUNIT_TEST) += handshake-test.o diff --git a/net/handshake/alert.c b/net/handshake/alert.c new file mode 100644 index 000000000000..329d91984683 --- /dev/null +++ b/net/handshake/alert.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Handle the TLS Alert protocol + * + * Author: Chuck Lever <chuck.lever@oracle.com> + * + * Copyright (c) 2023, Oracle and/or its affiliates. + */ + +#include <linux/types.h> +#include <linux/socket.h> +#include <linux/kernel.h> +#include <linux/module.h> +#include <linux/skbuff.h> +#include <linux/inet.h> + +#include <net/sock.h> +#include <net/handshake.h> +#include <net/tls.h> +#include <net/tls_prot.h> + +#include "handshake.h" + +#include <trace/events/handshake.h> + +/** + * tls_alert_send - send a TLS Alert on a kTLS socket + * @sock: open kTLS socket to send on + * @level: TLS Alert level + * @description: TLS Alert description + * + * Returns zero on success or a negative errno. + */ +int tls_alert_send(struct socket *sock, u8 level, u8 description) +{ + u8 record_type = TLS_RECORD_TYPE_ALERT; + u8 buf[CMSG_SPACE(sizeof(record_type))]; + struct msghdr msg = { 0 }; + struct cmsghdr *cmsg; + struct kvec iov; + u8 alert[2]; + int ret; + + trace_tls_alert_send(sock->sk, level, description); + + alert[0] = level; + alert[1] = description; + iov.iov_base = alert; + iov.iov_len = sizeof(alert); + + memset(buf, 0, sizeof(buf)); + msg.msg_control = buf; + msg.msg_controllen = sizeof(buf); + msg.msg_flags = MSG_DONTWAIT; + + cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_TLS; + cmsg->cmsg_type = TLS_SET_RECORD_TYPE; + cmsg->cmsg_len = CMSG_LEN(sizeof(record_type)); + memcpy(CMSG_DATA(cmsg), &record_type, sizeof(record_type)); + + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, iov.iov_len); + ret = sock_sendmsg(sock, &msg); + return ret < 0 ? ret : 0; +} + +/** + * tls_get_record_type - Look for TLS RECORD_TYPE information + * @sk: socket (for IP address information) + * @cmsg: incoming message to be parsed + * + * Returns zero or a TLS_RECORD_TYPE value. + */ +u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *cmsg) +{ + u8 record_type; + + if (cmsg->cmsg_level != SOL_TLS) + return 0; + if (cmsg->cmsg_type != TLS_GET_RECORD_TYPE) + return 0; + + record_type = *((u8 *)CMSG_DATA(cmsg)); + trace_tls_contenttype(sk, record_type); + return record_type; +} +EXPORT_SYMBOL(tls_get_record_type); + +/** + * tls_alert_recv - Parse TLS Alert messages + * @sk: socket (for IP address information) + * @msg: incoming message to be parsed + * @level: OUT - TLS AlertLevel value + * @description: OUT - TLS AlertDescription value + * + */ +void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, + u8 *level, u8 *description) +{ + const struct kvec *iov; + u8 *data; + + iov = msg->msg_iter.kvec; + data = iov->iov_base; + *level = data[0]; + *description = data[1]; + + trace_tls_alert_recv(sk, *level, *description); +} +EXPORT_SYMBOL(tls_alert_recv); diff --git a/net/handshake/handshake.h b/net/handshake/handshake.h index 4dac965c99df..a48163765a7a 100644 --- a/net/handshake/handshake.h +++ b/net/handshake/handshake.h @@ -41,8 +41,11 @@ struct handshake_req { enum hr_flags_bits { HANDSHAKE_F_REQ_COMPLETED, + HANDSHAKE_F_REQ_SESSION, }; +struct genl_info; + /* Invariants for all handshake requests for one transport layer * security protocol */ @@ -63,6 +66,9 @@ enum hp_flags_bits { HANDSHAKE_F_PROTO_NOTIFY, }; +/* alert.c */ +int tls_alert_send(struct socket *sock, u8 level, u8 description); + /* netlink.c */ int handshake_genl_notify(struct net *net, const struct handshake_proto *proto, gfp_t flags); diff --git a/net/handshake/tlshd.c b/net/handshake/tlshd.c index b735f5cced2f..bbfb4095ddd6 100644 --- a/net/handshake/tlshd.c +++ b/net/handshake/tlshd.c @@ -18,6 +18,7 @@ #include <net/sock.h> #include <net/handshake.h> #include <net/genetlink.h> +#include <net/tls_prot.h> #include <uapi/linux/keyctl.h> #include <uapi/linux/handshake.h> @@ -100,6 +101,9 @@ static void tls_handshake_done(struct handshake_req *req, if (info) tls_handshake_remote_peerids(treq, info); + if (!status) + set_bit(HANDSHAKE_F_REQ_SESSION, &req->hr_flags); + treq->th_consumer_done(treq->th_consumer_data, -status, treq->th_peerid[0]); } @@ -424,3 +428,22 @@ bool tls_handshake_cancel(struct sock *sk) return handshake_req_cancel(sk); } EXPORT_SYMBOL(tls_handshake_cancel); + +/** + * tls_handshake_close - send a Closure alert + * @sock: an open socket + * + */ +void tls_handshake_close(struct socket *sock) +{ + struct handshake_req *req; + + req = handshake_req_hash_lookup(sock->sk); + if (!req) + return; + if (!test_and_clear_bit(HANDSHAKE_F_REQ_SESSION, &req->hr_flags)) + return; + tls_alert_send(sock, TLS_ALERT_LEVEL_WARNING, + TLS_ALERT_DESC_CLOSE_NOTIFY); +} +EXPORT_SYMBOL(tls_handshake_close); diff --git a/net/handshake/trace.c b/net/handshake/trace.c index 1c4d8e27e17a..44432d0857b9 100644 --- a/net/handshake/trace.c +++ b/net/handshake/trace.c @@ -8,8 +8,10 @@ */ #include <linux/types.h> +#include <linux/ipv6.h> #include <net/sock.h> +#include <net/inet_sock.h> #include <net/netlink.h> #include <net/genetlink.h> |