diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 48f1c26459d0..f208f8455ef2 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -328,7 +328,10 @@ static void tls_sk_proto_unhash(struct sock *sk) ctx = tls_get_ctx(sk); tls_sk_proto_cleanup(sk, ctx, timeo); + write_lock_bh(&sk->sk_callback_lock); icsk->icsk_ulp_data = NULL; + sk->sk_prot = ctx->sk_proto; + write_unlock_bh(&sk->sk_callback_lock); if (ctx->sk_proto->unhash) ctx->sk_proto->unhash(sk); @@ -337,7 +340,7 @@ static void tls_sk_proto_unhash(struct sock *sk) static void tls_sk_proto_close(struct sock *sk, long timeout) { - void (*sk_proto_close)(struct sock *sk, long timeout); + struct inet_connection_sock *icsk = inet_csk(sk); struct tls_context *ctx = tls_get_ctx(sk); long timeo = sock_sndtimeo(sk, 0); bool free_ctx; @@ -347,12 +350,15 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) lock_sock(sk); free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; - sk_proto_close = ctx->sk_proto_close; if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) tls_sk_proto_cleanup(sk, ctx, timeo); + write_lock_bh(&sk->sk_callback_lock); + if (free_ctx) + icsk->icsk_ulp_data = NULL; sk->sk_prot = ctx->sk_proto; + write_unlock_bh(&sk->sk_callback_lock); release_sock(sk); if (ctx->tx_conf == TLS_SW) tls_sw_free_ctx_tx(ctx); @@ -360,7 +366,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) tls_sw_strparser_done(ctx); if (ctx->rx_conf == TLS_SW) tls_sw_free_ctx_rx(ctx); - sk_proto_close(sk, timeout); + ctx->sk_proto_close(sk, timeout); if (free_ctx) tls_ctx_free(ctx); @@ -827,7 +833,7 @@ static int tls_init(struct sock *sk) int rc = 0; if (tls_hw_prot(sk)) - goto out; + return 0; /* The TLS ulp is currently supported only for TCP sockets * in ESTABLISHED state. @@ -838,22 +844,38 @@ static int tls_init(struct sock *sk) if (sk->sk_state != TCP_ESTABLISHED) return -ENOTSUPP; + tls_build_proto(sk); + /* allocate tls context */ + write_lock_bh(&sk->sk_callback_lock); ctx = create_ctx(sk); if (!ctx) { rc = -ENOMEM; goto out; } - tls_build_proto(sk); ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; ctx->sk_proto = sk->sk_prot; update_sk_prot(sk, ctx); out: + write_unlock_bh(&sk->sk_callback_lock); return rc; } +static void tls_update(struct sock *sk, struct proto *p) +{ + struct tls_context *ctx; + + ctx = tls_get_ctx(sk); + if (likely(ctx)) { + ctx->sk_proto_close = p->close; + ctx->sk_proto = p; + } else { + sk->sk_prot = p; + } +} + void tls_register_device(struct tls_device *device) { spin_lock_bh(&device_spinlock); @@ -874,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .name = "tls", .owner = THIS_MODULE, .init = tls_init, + .update = tls_update, }; static int __init tls_register(void) |