diff options
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 439 |
1 files changed, 375 insertions, 64 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 3b75e0dd51a2..a525fc4c2a4b 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -4,6 +4,7 @@ * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. + * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io * * This software is available to you under a choice of one of two * licenses. You may choose to be licensed under the terms of the GNU @@ -258,21 +259,58 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) return sk_msg_clone(sk, msg_pl, msg_en, skip, len); } -static void tls_free_open_rec(struct sock *sk) +static struct tls_rec *tls_get_rec(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct tls_rec *rec = ctx->open_rec; + struct sk_msg *msg_pl, *msg_en; + struct tls_rec *rec; + int mem_size; - /* Return if there is no open record */ + mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); + + rec = kzalloc(mem_size, sk->sk_allocation); if (!rec) - return; + return NULL; + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + + sk_msg_init(msg_pl); + sk_msg_init(msg_en); + + sg_init_table(rec->sg_aead_in, 2); + sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, + sizeof(rec->aad_space)); + sg_unmark_end(&rec->sg_aead_in[1]); + + sg_init_table(rec->sg_aead_out, 2); + sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, + sizeof(rec->aad_space)); + sg_unmark_end(&rec->sg_aead_out[1]); + + return rec; +} + +static void tls_free_rec(struct sock *sk, struct tls_rec *rec) +{ sk_msg_free(sk, &rec->msg_encrypted); sk_msg_free(sk, &rec->msg_plaintext); kfree(rec); } +static void tls_free_open_rec(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; + + if (rec) { + tls_free_rec(sk, rec); + ctx->open_rec = NULL; + } +} + int tls_tx_records(struct sock *sk, int flags) { struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -439,16 +477,135 @@ static int tls_do_encryption(struct sock *sk, return rc; } +static int tls_split_open_record(struct sock *sk, struct tls_rec *from, + struct tls_rec **to, struct sk_msg *msg_opl, + struct sk_msg *msg_oen, u32 split_point, + u32 tx_overhead_size, u32 *orig_end) +{ + u32 i, j, bytes = 0, apply = msg_opl->apply_bytes; + struct scatterlist *sge, *osge, *nsge; + u32 orig_size = msg_opl->sg.size; + struct scatterlist tmp = { }; + struct sk_msg *msg_npl; + struct tls_rec *new; + int ret; + + new = tls_get_rec(sk); + if (!new) + return -ENOMEM; + ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size + + tx_overhead_size, 0); + if (ret < 0) { + tls_free_rec(sk, new); + return ret; + } + + *orig_end = msg_opl->sg.end; + i = msg_opl->sg.start; + sge = sk_msg_elem(msg_opl, i); + while (apply && sge->length) { + if (sge->length > apply) { + u32 len = sge->length - apply; + + get_page(sg_page(sge)); + sg_set_page(&tmp, sg_page(sge), len, + sge->offset + apply); + sge->length = apply; + bytes += apply; + apply = 0; + } else { + apply -= sge->length; + bytes += sge->length; + } + + sk_msg_iter_var_next(i); + if (i == msg_opl->sg.end) + break; + sge = sk_msg_elem(msg_opl, i); + } + + msg_opl->sg.end = i; + msg_opl->sg.curr = i; + msg_opl->sg.copybreak = 0; + msg_opl->apply_bytes = 0; + msg_opl->sg.size = bytes; + + msg_npl = &new->msg_plaintext; + msg_npl->apply_bytes = apply; + msg_npl->sg.size = orig_size - bytes; + + j = msg_npl->sg.start; + nsge = sk_msg_elem(msg_npl, j); + if (tmp.length) { + memcpy(nsge, &tmp, sizeof(*nsge)); + sk_msg_iter_var_next(j); + nsge = sk_msg_elem(msg_npl, j); + } + + osge = sk_msg_elem(msg_opl, i); + while (osge->length) { + memcpy(nsge, osge, sizeof(*nsge)); + sg_unmark_end(nsge); + sk_msg_iter_var_next(i); + sk_msg_iter_var_next(j); + if (i == *orig_end) + break; + osge = sk_msg_elem(msg_opl, i); + nsge = sk_msg_elem(msg_npl, j); + } + + msg_npl->sg.end = j; + msg_npl->sg.curr = j; + msg_npl->sg.copybreak = 0; + + *to = new; + return 0; +} + +static void tls_merge_open_record(struct sock *sk, struct tls_rec *to, + struct tls_rec *from, u32 orig_end) +{ + struct sk_msg *msg_npl = &from->msg_plaintext; + struct sk_msg *msg_opl = &to->msg_plaintext; + struct scatterlist *osge, *nsge; + u32 i, j; + + i = msg_opl->sg.end; + sk_msg_iter_var_prev(i); + j = msg_npl->sg.start; + + osge = sk_msg_elem(msg_opl, i); + nsge = sk_msg_elem(msg_npl, j); + + if (sg_page(osge) == sg_page(nsge) && + osge->offset + osge->length == nsge->offset) { + osge->length += nsge->length; + put_page(sg_page(nsge)); + } + + msg_opl->sg.end = orig_end; + msg_opl->sg.curr = orig_end; + msg_opl->sg.copybreak = 0; + msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size; + msg_opl->sg.size += msg_npl->sg.size; + + sk_msg_free(sk, &to->msg_encrypted); + sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted); + + kfree(from); +} + static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct tls_rec *rec = ctx->open_rec; + struct tls_rec *rec = ctx->open_rec, *tmp = NULL; + u32 i, split_point, uninitialized_var(orig_end); struct sk_msg *msg_pl, *msg_en; struct aead_request *req; + bool split; int rc; - u32 i; if (!rec) return 0; @@ -456,6 +613,18 @@ static int tls_push_record(struct sock *sk, int flags, msg_pl = &rec->msg_plaintext; msg_en = &rec->msg_encrypted; + split_point = msg_pl->apply_bytes; + split = split_point && split_point < msg_pl->sg.size; + if (split) { + rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, + split_point, tls_ctx->tx.overhead_size, + &orig_end); + if (rc < 0) + return rc; + sk_msg_trim(sk, msg_en, msg_pl->sg.size + + tls_ctx->tx.overhead_size); + } + rec->tx_flags = flags; req = &rec->aead_req; @@ -487,57 +656,139 @@ static int tls_push_record(struct sock *sk, int flags, rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i); if (rc < 0) { - if (rc != -EINPROGRESS) + if (rc != -EINPROGRESS) { tls_err_abort(sk, EBADMSG); + if (split) { + tls_ctx->pending_open_record_frags = true; + tls_merge_open_record(sk, rec, tmp, orig_end); + } + } return rc; + } else if (split) { + msg_pl = &tmp->msg_plaintext; + msg_en = &tmp->msg_encrypted; + sk_msg_trim(sk, msg_en, msg_pl->sg.size + + tls_ctx->tx.overhead_size); + tls_ctx->pending_open_record_frags = true; + ctx->open_rec = tmp; } return tls_tx_records(sk, flags); } -static int tls_sw_push_pending_record(struct sock *sk, int flags) -{ - return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA); -} - -static struct tls_rec *get_rec(struct sock *sk) +static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, + bool full_record, u8 record_type, + size_t *copied, int flags) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct sk_msg *msg_pl, *msg_en; + struct sk_msg msg_redir = { }; + struct sk_psock *psock; + struct sock *sk_redir; struct tls_rec *rec; - int mem_size; + int err = 0, send; + bool enospc; + + psock = sk_psock_get(sk); + if (!psock) + return tls_push_record(sk, flags, record_type); +more_data: + enospc = sk_msg_full(msg); + if (psock->eval == __SK_NONE) + psock->eval = sk_psock_msg_verdict(sk, psock, msg); + if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && + !enospc && !full_record) { + err = -ENOSPC; + goto out_err; + } + msg->cork_bytes = 0; + send = msg->sg.size; + if (msg->apply_bytes && msg->apply_bytes < send) + send = msg->apply_bytes; + + switch (psock->eval) { + case __SK_PASS: + err = tls_push_record(sk, flags, record_type); + if (err < 0) { + *copied -= sk_msg_free(sk, msg); + tls_free_open_rec(sk); + goto out_err; + } + break; + case __SK_REDIRECT: + sk_redir = psock->sk_redir; + memcpy(&msg_redir, msg, sizeof(*msg)); + if (msg->apply_bytes < send) + msg->apply_bytes = 0; + else + msg->apply_bytes -= send; + sk_msg_return_zero(sk, msg, send); + msg->sg.size -= send; + release_sock(sk); + err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags); + lock_sock(sk); + if (err < 0) { + *copied -= sk_msg_free_nocharge(sk, &msg_redir); + msg->sg.size = 0; + } + if (msg->sg.size == 0) + tls_free_open_rec(sk); + break; + case __SK_DROP: + default: + sk_msg_free_partial(sk, msg, send); + if (msg->apply_bytes < send) + msg->apply_bytes = 0; + else + msg->apply_bytes -= send; + if (msg->sg.size == 0) + tls_free_open_rec(sk); + *copied -= send; + err = -EACCES; + } - /* Return if we already have an open record */ - if (ctx->open_rec) - return ctx->open_rec; + if (likely(!err)) { + bool reset_eval = !ctx->open_rec; - mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); + rec = ctx->open_rec; + if (rec) { + msg = &rec->msg_plaintext; + if (!msg->apply_bytes) + reset_eval = true; + } + if (reset_eval) { + psock->eval = __SK_NONE; + if (psock->sk_redir) { + sock_put(psock->sk_redir); + psock->sk_redir = NULL; + } + } + if (rec) + goto more_data; + } + out_err: + sk_psock_put(sk, psock); + return err; +} + +static int tls_sw_push_pending_record(struct sock *sk, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; + struct sk_msg *msg_pl; + size_t copied; - rec = kzalloc(mem_size, sk->sk_allocation); if (!rec) - return NULL; + return 0; msg_pl = &rec->msg_plaintext; - msg_en = &rec->msg_encrypted; - - sk_msg_init(msg_pl); - sk_msg_init(msg_en); - - sg_init_table(rec->sg_aead_in, 2); - sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, - sizeof(rec->aad_space)); - sg_unmark_end(&rec->sg_aead_in[1]); - - sg_init_table(rec->sg_aead_out, 2); - sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, - sizeof(rec->aad_space)); - sg_unmark_end(&rec->sg_aead_out[1]); - - ctx->open_rec = rec; - rec->inplace_crypto = 1; + copied = msg_pl->sg.size; + if (!copied) + return 0; - return rec; + return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA, + &copied, flags); } int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) @@ -589,7 +840,10 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) goto send_end; } - rec = get_rec(sk); + if (ctx->open_rec) + rec = ctx->open_rec; + else + rec = ctx->open_rec = tls_get_rec(sk); if (!rec) { ret = -ENOMEM; goto send_end; @@ -628,6 +882,8 @@ alloc_encrypted: } if (!is_kvec && (full_record || eor) && !async_capable) { + u32 first = msg_pl->sg.end; + ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter, msg_pl, try_to_copy); if (ret) @@ -637,15 +893,27 @@ alloc_encrypted: num_zc++; copied += try_to_copy; - ret = tls_push_record(sk, msg->msg_flags, record_type); + + sk_msg_sg_copy_set(msg_pl, first); + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, + msg->msg_flags); if (ret) { if (ret == -EINPROGRESS) num_async++; + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret == -ENOSPC) + goto rollback_iter; else if (ret != -EAGAIN) goto send_end; } continue; - +rollback_iter: + copied -= try_to_copy; + sk_msg_sg_copy_clear(msg_pl, first); + iov_iter_revert(&msg->msg_iter, + msg_pl->sg.size - orig_size); fallback_to_reg_send: sk_msg_trim(sk, msg_pl, orig_size); } @@ -678,12 +946,19 @@ fallback_to_reg_send: tls_ctx->pending_open_record_frags = true; copied += try_to_copy; if (full_record || eor) { - ret = tls_push_record(sk, msg->msg_flags, record_type); + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, + msg->msg_flags); if (ret) { if (ret == -EINPROGRESS) num_async++; - else if (ret != -EAGAIN) + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret != -EAGAIN) { + if (ret == -ENOSPC) + ret = 0; goto send_end; + } } } @@ -742,10 +1017,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); unsigned char record_type = TLS_RECORD_TYPE_DATA; - size_t orig_size = size; struct sk_msg *msg_pl; struct tls_rec *rec; int num_async = 0; + size_t copied = 0; bool full_record; int record_room; int ret = 0; @@ -778,7 +1053,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, goto sendpage_end; } - rec = get_rec(sk); + if (ctx->open_rec) + rec = ctx->open_rec; + else + rec = ctx->open_rec = tls_get_rec(sk); if (!rec) { ret = -ENOMEM; goto sendpage_end; @@ -788,6 +1066,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, full_record = false; record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; + copied = 0; copy = size; if (copy >= record_room) { copy = record_room; @@ -818,16 +1097,23 @@ alloc_payload: offset += copy; size -= copy; + copied += copy; tls_ctx->pending_open_record_frags = true; if (full_record || eor || sk_msg_full(msg_pl)) { rec->inplace_crypto = 0; - ret = tls_push_record(sk, flags, record_type); + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, flags); if (ret) { if (ret == -EINPROGRESS) num_async++; - else if (ret != -EAGAIN) + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret != -EAGAIN) { + if (ret == -ENOSPC) + ret = 0; goto sendpage_end; + } } } continue; @@ -851,24 +1137,20 @@ wait_for_memory: } } sendpage_end: - if (orig_size > size) - ret = orig_size - size; - else - ret = sk_stream_error(sk, flags, ret); - + ret = sk_stream_error(sk, flags, ret); release_sock(sk); - return ret; + return copied ? copied : ret; } -static struct sk_buff *tls_wait_data(struct sock *sk, int flags, - long timeo, int *err) +static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, + int flags, long timeo, int *err) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct sk_buff *skb; DEFINE_WAIT_FUNC(wait, woken_wake_function); - while (!(skb = ctx->recv_pkt)) { + while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) { if (sk->sk_err) { *err = sock_error(sk); return NULL; @@ -887,7 +1169,10 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, add_wait_queue(sk_sleep(sk), &wait); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); - sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); + sk_wait_event(sk, &timeo, + ctx->recv_pkt != skb || + !sk_psock_queue_empty(psock), + &wait); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); remove_wait_queue(sk_sleep(sk), &wait); @@ -1164,6 +1449,7 @@ int tls_sw_recvmsg(struct sock *sk, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct sk_psock *psock; unsigned char control; struct strp_msg *rxm; struct sk_buff *skb; @@ -1179,6 +1465,7 @@ int tls_sw_recvmsg(struct sock *sk, if (unlikely(flags & MSG_ERRQUEUE)) return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); + psock = sk_psock_get(sk); lock_sock(sk); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); @@ -1188,9 +1475,19 @@ int tls_sw_recvmsg(struct sock *sk, bool async = false; int chunk = 0; - skb = tls_wait_data(sk, flags, timeo, &err); - if (!skb) + skb = tls_wait_data(sk, psock, flags, timeo, &err); + if (!skb) { + if (psock) { + int ret = __tcp_bpf_recvmsg(sk, psock, msg, len); + + if (ret > 0) { + copied += ret; + len -= ret; + continue; + } + } goto recv_end; + } rxm = strp_msg(skb); @@ -1296,6 +1593,8 @@ recv_end: } release_sock(sk); + if (psock) + sk_psock_put(sk, psock); return copied ? : err; } @@ -1318,7 +1617,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); - skb = tls_wait_data(sk, flags, timeo, &err); + skb = tls_wait_data(sk, NULL, flags, timeo, &err); if (!skb) goto splice_read_end; @@ -1356,11 +1655,16 @@ bool tls_sw_stream_read(const struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + bool ingress_empty = true; + struct sk_psock *psock; - if (ctx->recv_pkt) - return true; + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) + ingress_empty = list_empty(&psock->ingress_msg); + rcu_read_unlock(); - return false; + return !ingress_empty || ctx->recv_pkt; } static int tls_read_size(struct strparser *strp, struct sk_buff *skb) @@ -1439,8 +1743,15 @@ static void tls_data_ready(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct sk_psock *psock; strp_data_ready(&ctx->strp); + + psock = sk_psock_get(sk); + if (psock && !list_empty(&psock->ingress_msg)) { + ctx->saved_data_ready(sk); + sk_psock_put(sk, psock); + } } void tls_sw_free_resources_tx(struct sock *sk) |