summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/tls/tls_main.c55
1 files changed, 55 insertions, 0 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index d152a00a7a27..48f1c26459d0 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -261,6 +261,33 @@ void tls_ctx_free(struct tls_context *ctx)
kfree(ctx);
}
+static void tls_ctx_free_deferred(struct work_struct *gc)
+{
+ struct tls_context *ctx = container_of(gc, struct tls_context, gc);
+
+ /* Ensure any remaining work items are completed. The sk will
+ * already have lost its tls_ctx reference by the time we get
+ * here so no xmit operation will actually be performed.
+ */
+ if (ctx->tx_conf == TLS_SW) {
+ tls_sw_cancel_work_tx(ctx);
+ tls_sw_free_ctx_tx(ctx);
+ }
+
+ if (ctx->rx_conf == TLS_SW) {
+ tls_sw_strparser_done(ctx);
+ tls_sw_free_ctx_rx(ctx);
+ }
+
+ tls_ctx_free(ctx);
+}
+
+static void tls_ctx_free_wq(struct tls_context *ctx)
+{
+ INIT_WORK(&ctx->gc, tls_ctx_free_deferred);
+ schedule_work(&ctx->gc);
+}
+
static void tls_sk_proto_cleanup(struct sock *sk,
struct tls_context *ctx, long timeo)
{
@@ -288,6 +315,26 @@ static void tls_sk_proto_cleanup(struct sock *sk,
#endif
}
+static void tls_sk_proto_unhash(struct sock *sk)
+{
+ struct inet_connection_sock *icsk = inet_csk(sk);
+ long timeo = sock_sndtimeo(sk, 0);
+ struct tls_context *ctx;
+
+ if (unlikely(!icsk->icsk_ulp_data)) {
+ if (sk->sk_prot->unhash)
+ sk->sk_prot->unhash(sk);
+ }
+
+ ctx = tls_get_ctx(sk);
+ tls_sk_proto_cleanup(sk, ctx, timeo);
+ icsk->icsk_ulp_data = NULL;
+
+ if (ctx->sk_proto->unhash)
+ ctx->sk_proto->unhash(sk);
+ tls_ctx_free_wq(ctx);
+}
+
static void tls_sk_proto_close(struct sock *sk, long timeout)
{
void (*sk_proto_close)(struct sock *sk, long timeout);
@@ -305,6 +352,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
tls_sk_proto_cleanup(sk, ctx, timeo);
+ sk->sk_prot = ctx->sk_proto;
release_sock(sk);
if (ctx->tx_conf == TLS_SW)
tls_sw_free_ctx_tx(ctx);
@@ -608,6 +656,7 @@ static struct tls_context *create_ctx(struct sock *sk)
ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close;
+ ctx->unhash = sk->sk_prot->unhash;
return ctx;
}
@@ -731,6 +780,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
+ prot[TLS_BASE][TLS_BASE].unhash = tls_sk_proto_unhash;
prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
@@ -748,16 +798,20 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
#ifdef CONFIG_TLS_DEVICE
prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_HW][TLS_BASE].unhash = base->unhash;
prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
+ prot[TLS_HW][TLS_SW].unhash = base->unhash;
prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+ prot[TLS_BASE][TLS_HW].unhash = base->unhash;
prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+ prot[TLS_SW][TLS_HW].unhash = base->unhash;
prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
#endif
@@ -794,6 +848,7 @@ static int tls_init(struct sock *sk)
tls_build_proto(sk);
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
+ ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx);
out:
return rc;