diff options
-rw-r--r-- | drivers/crypto/chelsio/chtls/chtls_cm.c | 4 | ||||
-rw-r--r-- | net/tls/tls_main.c | 19 |
2 files changed, 21 insertions, 2 deletions
diff --git a/drivers/crypto/chelsio/chtls/chtls_cm.c b/drivers/crypto/chelsio/chtls/chtls_cm.c index 59b75299fcbc..2e11b0d6a9bb 100644 --- a/drivers/crypto/chelsio/chtls/chtls_cm.c +++ b/drivers/crypto/chelsio/chtls/chtls_cm.c @@ -24,6 +24,7 @@ #include <net/inet_common.h> #include <net/tcp.h> #include <net/dst.h> +#include <net/tls.h> #include "chtls.h" #include "chtls_cm.h" @@ -1015,6 +1016,7 @@ static struct sock *chtls_recv_sock(struct sock *lsk, { struct inet_sock *newinet; const struct iphdr *iph; + struct tls_context *ctx; struct net_device *ndev; struct chtls_sock *csk; struct dst_entry *dst; @@ -1063,6 +1065,8 @@ static struct sock *chtls_recv_sock(struct sock *lsk, oreq->ts_recent = PASS_OPEN_TID_G(ntohl(req->tos_stid)); sk_setup_caps(newsk, dst); + ctx = tls_get_ctx(lsk); + newsk->sk_destruct = ctx->sk_destruct; csk->sk = newsk; csk->passive_reap_next = oreq; csk->tx_chan = cxgb4_port_chan(ndev); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index fc97a105ebc2..d36d095cbcf0 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -266,8 +266,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) lock_sock(sk); sk_proto_close = ctx->sk_proto_close; - if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) || - (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) { + if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) + goto skip_tx_cleanup; + + if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { free_ctx = true; goto skip_tx_cleanup; } @@ -579,6 +581,17 @@ static void tls_build_proto(struct sock *sk) } } +static void tls_hw_sk_destruct(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + + ctx->sk_destruct(sk); + /* Free ctx */ + kfree(ctx); + icsk->icsk_ulp_data = NULL; +} + static int tls_hw_prot(struct sock *sk) { struct tls_context *ctx; @@ -597,6 +610,8 @@ static int tls_hw_prot(struct sock *sk) ctx->hash = sk->sk_prot->hash; ctx->unhash = sk->sk_prot->unhash; ctx->sk_proto_close = sk->sk_prot->close; + ctx->sk_destruct = sk->sk_destruct; + sk->sk_destruct = tls_hw_sk_destruct; ctx->rx_conf = TLS_HW_RECORD; ctx->tx_conf = TLS_HW_RECORD; update_sk_prot(sk, ctx); |