summaryrefslogtreecommitdiffstats
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls.h5
-rw-r--r--net/tls/tls_device.c22
-rw-r--r--net/tls/tls_main.c3
-rw-r--r--net/tls/tls_strp.c185
-rw-r--r--net/tls/tls_sw.c4
5 files changed, 168 insertions, 51 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h
index 804c3880d028..0672acab2773 100644
--- a/net/tls/tls.h
+++ b/net/tls/tls.h
@@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx)
return ctx->strp.msg_ready;
}
+static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx)
+{
+ return ctx->strp.mixed_decrypted;
+}
+
#ifdef CONFIG_TLS_DEVICE
int tls_device_init(void);
void tls_device_cleanup(void);
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index a7cc4f9faac2..bf69c9d6d06c 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb = tls_strp_msg(sw_ctx);
struct strp_msg *rxm = strp_msg(skb);
- int is_decrypted = skb->decrypted;
- int is_encrypted = !is_decrypted;
- struct sk_buff *skb_iter;
- int left;
-
- left = rxm->full_len - skb->len;
- /* Check if all the data is decrypted already */
- skb_iter = skb_shinfo(skb)->frag_list;
- while (skb_iter && left > 0) {
- is_decrypted &= skb_iter->decrypted;
- is_encrypted &= !skb_iter->decrypted;
-
- left -= skb_iter->len;
- skb_iter = skb_iter->next;
+ int is_decrypted, is_encrypted;
+
+ if (!tls_strp_msg_mixed_decrypted(sw_ctx)) {
+ is_decrypted = skb->decrypted;
+ is_encrypted = !is_decrypted;
+ } else {
+ is_decrypted = 0;
+ is_encrypted = 0;
}
trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index b32c112984dd..f2e7302a4d96 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -111,7 +111,8 @@ int wait_on_pending_writer(struct sock *sk, long *timeo)
break;
}
- if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
+ if (sk_wait_event(sk, timeo,
+ !READ_ONCE(sk->sk_write_pending), &wait))
break;
}
remove_wait_queue(sk_sleep(sk), &wait);
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
index 955ac3e0bf4d..da95abbb7ea3 100644
--- a/net/tls/tls_strp.c
+++ b/net/tls/tls_strp.c
@@ -29,34 +29,50 @@ static void tls_strp_anchor_free(struct tls_strparser *strp)
struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
- shinfo->frag_list = NULL;
+ if (!strp->copy_mode)
+ shinfo->frag_list = NULL;
consume_skb(strp->anchor);
strp->anchor = NULL;
}
-/* Create a new skb with the contents of input copied to its page frags */
-static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
+static struct sk_buff *
+tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb,
+ int offset, int len)
{
- struct strp_msg *rxm;
struct sk_buff *skb;
- int i, err, offset;
+ int i, err;
- skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
+ skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER,
&err, strp->sk->sk_allocation);
if (!skb)
return NULL;
- offset = strp->stm.offset;
for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
- WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
+ WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
skb_frag_address(frag),
skb_frag_size(frag)));
offset += skb_frag_size(frag);
}
- skb_copy_header(skb, strp->anchor);
+ skb->len = len;
+ skb->data_len = len;
+ skb_copy_header(skb, in_skb);
+ return skb;
+}
+
+/* Create a new skb with the contents of input copied to its page frags */
+static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
+{
+ struct strp_msg *rxm;
+ struct sk_buff *skb;
+
+ skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset,
+ strp->stm.full_len);
+ if (!skb)
+ return NULL;
+
rxm = strp_msg(skb);
rxm->offset = 0;
return skb;
@@ -180,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
for (i = 0; i < shinfo->nr_frags; i++)
__skb_frag_unref(&shinfo->frags[i], false);
shinfo->nr_frags = 0;
+ if (strp->copy_mode) {
+ kfree_skb_list(shinfo->frag_list);
+ shinfo->frag_list = NULL;
+ }
strp->copy_mode = 0;
+ strp->mixed_decrypted = 0;
}
-static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
- unsigned int offset, size_t in_len)
+static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb,
+ struct sk_buff *in_skb, unsigned int offset,
+ size_t in_len)
{
- struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
- struct sk_buff *skb;
- skb_frag_t *frag;
size_t len, chunk;
+ skb_frag_t *frag;
int sz;
- if (strp->msg_ready)
- return 0;
-
- skb = strp->anchor;
frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
len = in_len;
@@ -208,19 +224,26 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
skb_frag_size(frag),
chunk));
- sz = tls_rx_msg_size(strp, strp->anchor);
- if (sz < 0) {
- desc->error = sz;
- return 0;
- }
-
- /* We may have over-read, sz == 0 is guaranteed under-read */
- if (sz > 0)
- chunk = min_t(size_t, chunk, sz - skb->len);
-
skb->len += chunk;
skb->data_len += chunk;
skb_frag_size_add(frag, chunk);
+
+ sz = tls_rx_msg_size(strp, skb);
+ if (sz < 0)
+ return sz;
+
+ /* We may have over-read, sz == 0 is guaranteed under-read */
+ if (unlikely(sz && sz < skb->len)) {
+ int over = skb->len - sz;
+
+ WARN_ON_ONCE(over > chunk);
+ skb->len -= over;
+ skb->data_len -= over;
+ skb_frag_size_add(frag, -over);
+
+ chunk -= over;
+ }
+
frag++;
len -= chunk;
offset += chunk;
@@ -247,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
offset += chunk;
}
- if (strp->stm.full_len == skb->len) {
+read_done:
+ return in_len - len;
+}
+
+static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb,
+ struct sk_buff *in_skb, unsigned int offset,
+ size_t in_len)
+{
+ struct sk_buff *nskb, *first, *last;
+ struct skb_shared_info *shinfo;
+ size_t chunk;
+ int sz;
+
+ if (strp->stm.full_len)
+ chunk = strp->stm.full_len - skb->len;
+ else
+ chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
+ chunk = min(chunk, in_len);
+
+ nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk);
+ if (!nskb)
+ return -ENOMEM;
+
+ shinfo = skb_shinfo(skb);
+ if (!shinfo->frag_list) {
+ shinfo->frag_list = nskb;
+ nskb->prev = nskb;
+ } else {
+ first = shinfo->frag_list;
+ last = first->prev;
+ last->next = nskb;
+ first->prev = nskb;
+ }
+
+ skb->len += chunk;
+ skb->data_len += chunk;
+
+ if (!strp->stm.full_len) {
+ sz = tls_rx_msg_size(strp, skb);
+ if (sz < 0)
+ return sz;
+
+ /* We may have over-read, sz == 0 is guaranteed under-read */
+ if (unlikely(sz && sz < skb->len)) {
+ int over = skb->len - sz;
+
+ WARN_ON_ONCE(over > chunk);
+ skb->len -= over;
+ skb->data_len -= over;
+ __pskb_trim(nskb, nskb->len - over);
+
+ chunk -= over;
+ }
+
+ strp->stm.full_len = sz;
+ }
+
+ return chunk;
+}
+
+static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
+ unsigned int offset, size_t in_len)
+{
+ struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
+ struct sk_buff *skb;
+ int ret;
+
+ if (strp->msg_ready)
+ return 0;
+
+ skb = strp->anchor;
+ if (!skb->len)
+ skb_copy_decrypted(skb, in_skb);
+ else
+ strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb);
+
+ if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted)
+ ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len);
+ else
+ ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len);
+ if (ret < 0) {
+ desc->error = ret;
+ ret = 0;
+ }
+
+ if (strp->stm.full_len && strp->stm.full_len == skb->len) {
desc->count = 0;
strp->msg_ready = 1;
tls_rx_msg_ready(strp);
}
-read_done:
- return in_len - len;
+ return ret;
}
static int tls_strp_read_copyin(struct tls_strparser *strp)
@@ -315,15 +422,19 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
return 0;
}
-static bool tls_strp_check_no_dup(struct tls_strparser *strp)
+static bool tls_strp_check_queue_ok(struct tls_strparser *strp)
{
unsigned int len = strp->stm.offset + strp->stm.full_len;
- struct sk_buff *skb;
+ struct sk_buff *first, *skb;
u32 seq;
- skb = skb_shinfo(strp->anchor)->frag_list;
- seq = TCP_SKB_CB(skb)->seq;
+ first = skb_shinfo(strp->anchor)->frag_list;
+ skb = first;
+ seq = TCP_SKB_CB(first)->seq;
+ /* Make sure there's no duplicate data in the queue,
+ * and the decrypted status matches.
+ */
while (skb->len < len) {
seq += skb->len;
len -= skb->len;
@@ -331,6 +442,8 @@ static bool tls_strp_check_no_dup(struct tls_strparser *strp)
if (TCP_SKB_CB(skb)->seq != seq)
return false;
+ if (skb_cmp_decrypted(first, skb))
+ return false;
}
return true;
@@ -411,7 +524,7 @@ static int tls_strp_read_sock(struct tls_strparser *strp)
return tls_strp_read_copy(strp, true);
}
- if (!tls_strp_check_no_dup(strp))
+ if (!tls_strp_check_queue_ok(strp))
return tls_strp_read_copy(strp, false);
strp->msg_ready = 1;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 635b8bf6b937..6e6a7c37d685 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -2304,10 +2304,14 @@ static void tls_data_ready(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock;
+ gfp_t alloc_save;
trace_sk_data_ready(sk);
+ alloc_save = sk->sk_allocation;
+ sk->sk_allocation = GFP_ATOMIC;
tls_strp_data_ready(&ctx->strp);
+ sk->sk_allocation = alloc_save;
psock = sk_psock_get(sk);
if (psock) {