summaryrefslogtreecommitdiffstats
path: root/net/tls/tls_main.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r--net/tls/tls_main.c58
1 files changed, 39 insertions, 19 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 311cec8e533d..78cb4a584080 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -55,8 +55,10 @@ enum {
static struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
+static struct proto *saved_tcpv4_prot;
+static DEFINE_MUTEX(tcpv4_prot_mutex);
static LIST_HEAD(device_list);
-static DEFINE_MUTEX(device_mutex);
+static DEFINE_SPINLOCK(device_spinlock);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops;
@@ -538,11 +540,14 @@ static struct tls_context *create_ctx(struct sock *sk)
struct inet_connection_sock *icsk = inet_csk(sk);
struct tls_context *ctx;
- ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
+ ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
if (!ctx)
return NULL;
icsk->icsk_ulp_data = ctx;
+ ctx->setsockopt = sk->sk_prot->setsockopt;
+ ctx->getsockopt = sk->sk_prot->getsockopt;
+ ctx->sk_proto_close = sk->sk_prot->close;
return ctx;
}
@@ -552,7 +557,7 @@ static int tls_hw_prot(struct sock *sk)
struct tls_device *dev;
int rc = 0;
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
if (dev->feature && dev->feature(dev)) {
ctx = create_ctx(sk);
@@ -570,7 +575,7 @@ static int tls_hw_prot(struct sock *sk)
}
}
out:
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
return rc;
}
@@ -579,12 +584,17 @@ static void tls_hw_unhash(struct sock *sk)
struct tls_context *ctx = tls_get_ctx(sk);
struct tls_device *dev;
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
- if (dev->unhash)
+ if (dev->unhash) {
+ kref_get(&dev->kref);
+ spin_unlock_bh(&device_spinlock);
dev->unhash(dev, sk);
+ kref_put(&dev->kref, dev->release);
+ spin_lock_bh(&device_spinlock);
+ }
}
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
ctx->unhash(sk);
}
@@ -595,12 +605,17 @@ static int tls_hw_hash(struct sock *sk)
int err;
err = ctx->hash(sk);
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
- if (dev->hash)
+ if (dev->hash) {
+ kref_get(&dev->kref);
+ spin_unlock_bh(&device_spinlock);
err |= dev->hash(dev, sk);
+ kref_put(&dev->kref, dev->release);
+ spin_lock_bh(&device_spinlock);
+ }
}
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
if (err)
tls_hw_unhash(sk);
@@ -675,9 +690,6 @@ static int tls_init(struct sock *sk)
rc = -ENOMEM;
goto out;
}
- ctx->setsockopt = sk->sk_prot->setsockopt;
- ctx->getsockopt = sk->sk_prot->getsockopt;
- ctx->sk_proto_close = sk->sk_prot->close;
/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
if (ip_ver == TLSV6 &&
@@ -690,6 +702,16 @@ static int tls_init(struct sock *sk)
mutex_unlock(&tcpv6_prot_mutex);
}
+ if (ip_ver == TLSV4 &&
+ unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
+ mutex_lock(&tcpv4_prot_mutex);
+ if (likely(sk->sk_prot != saved_tcpv4_prot)) {
+ build_protos(tls_prots[TLSV4], sk->sk_prot);
+ smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
+ }
+ mutex_unlock(&tcpv4_prot_mutex);
+ }
+
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
update_sk_prot(sk, ctx);
@@ -699,17 +721,17 @@ out:
void tls_register_device(struct tls_device *device)
{
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_add_tail(&device->dev_list, &device_list);
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
}
EXPORT_SYMBOL(tls_register_device);
void tls_unregister_device(struct tls_device *device)
{
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_del(&device->dev_list);
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
}
EXPORT_SYMBOL(tls_unregister_device);
@@ -721,8 +743,6 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
static int __init tls_register(void)
{
- build_protos(tls_prots[TLSV4], &tcp_prot);
-
tls_sw_proto_ops = inet_stream_ops;
tls_sw_proto_ops.splice_read = tls_sw_splice_read;