summaryrefslogtreecommitdiffstats
path: root/net/ipv4
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4')
-rw-r--r--net/ipv4/tcp.c5
-rw-r--r--net/ipv4/tcp_ipv4.c71
-rw-r--r--net/ipv4/tcp_minisocks.c16
-rw-r--r--net/ipv4/tcp_output.c4
4 files changed, 73 insertions, 23 deletions
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 24602a5184b0..001947136b0a 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -4464,11 +4464,8 @@ bool tcp_alloc_md5sig_pool(void)
if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) {
mutex_lock(&tcp_md5sig_mutex);
- if (!tcp_md5sig_pool_populated) {
+ if (!tcp_md5sig_pool_populated)
__tcp_alloc_md5sig_pool();
- if (tcp_md5sig_pool_populated)
- static_branch_inc(&tcp_md5_needed);
- }
mutex_unlock(&tcp_md5sig_mutex);
}
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index c72e53835397..5d83a332f1dd 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1053,7 +1053,7 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req)
* We need to maintain these in the sk structure.
*/
-DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
+DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ);
EXPORT_SYMBOL(tcp_md5_needed);
static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new)
@@ -1166,9 +1166,6 @@ static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
struct tcp_sock *tp = tcp_sk(sk);
struct tcp_md5sig_info *md5sig;
- if (rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)))
- return 0;
-
md5sig = kmalloc(sizeof(*md5sig), gfp);
if (!md5sig)
return -ENOMEM;
@@ -1180,9 +1177,9 @@ static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
}
/* This can be called on a newly created socket, from other files */
-int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
- int family, u8 prefixlen, int l3index, u8 flags,
- const u8 *newkey, u8 newkeylen, gfp_t gfp)
+static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
+ int family, u8 prefixlen, int l3index, u8 flags,
+ const u8 *newkey, u8 newkeylen, gfp_t gfp)
{
/* Add Key to the list */
struct tcp_md5sig_key *key;
@@ -1209,9 +1206,6 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
return 0;
}
- if (tcp_md5sig_info_add(sk, gfp))
- return -ENOMEM;
-
md5sig = rcu_dereference_protected(tp->md5sig_info,
lockdep_sock_is_held(sk));
@@ -1235,8 +1229,59 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
hlist_add_head_rcu(&key->node, &md5sig->head);
return 0;
}
+
+int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
+ int family, u8 prefixlen, int l3index, u8 flags,
+ const u8 *newkey, u8 newkeylen)
+{
+ struct tcp_sock *tp = tcp_sk(sk);
+
+ if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
+ if (tcp_md5sig_info_add(sk, GFP_KERNEL))
+ return -ENOMEM;
+
+ if (!static_branch_inc(&tcp_md5_needed.key)) {
+ struct tcp_md5sig_info *md5sig;
+
+ md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk));
+ rcu_assign_pointer(tp->md5sig_info, NULL);
+ kfree_rcu(md5sig);
+ return -EUSERS;
+ }
+ }
+
+ return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
+ newkey, newkeylen, GFP_KERNEL);
+}
EXPORT_SYMBOL(tcp_md5_do_add);
+int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
+ int family, u8 prefixlen, int l3index,
+ struct tcp_md5sig_key *key)
+{
+ struct tcp_sock *tp = tcp_sk(sk);
+
+ if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
+ if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
+ return -ENOMEM;
+
+ if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) {
+ struct tcp_md5sig_info *md5sig;
+
+ md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk));
+ net_warn_ratelimited("Too many TCP-MD5 keys in the system\n");
+ rcu_assign_pointer(tp->md5sig_info, NULL);
+ kfree_rcu(md5sig);
+ return -EUSERS;
+ }
+ }
+
+ return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index,
+ key->flags, key->key, key->keylen,
+ sk_gfp_mask(sk, GFP_ATOMIC));
+}
+EXPORT_SYMBOL(tcp_md5_key_copy);
+
int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
u8 prefixlen, int l3index, u8 flags)
{
@@ -1323,7 +1368,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
return -EINVAL;
return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
- cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
+ cmd.tcpm_key, cmd.tcpm_keylen);
}
static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
@@ -1580,8 +1625,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
* memory, then we end up not copying the key
* across. Shucks.
*/
- tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
- key->key, key->keylen, GFP_ATOMIC);
+ tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key);
sk_gso_disable(newsk);
}
#endif
@@ -2273,6 +2317,7 @@ void tcp_v4_destroy_sock(struct sock *sk)
tcp_clear_md5_list(sk);
kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
tp->md5sig_info = NULL;
+ static_branch_slow_dec_deferred(&tcp_md5_needed);
}
#endif
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index c375f603a16c..6908812d50d3 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -291,13 +291,19 @@ void tcp_time_wait(struct sock *sk, int state, int timeo)
*/
do {
tcptw->tw_md5_key = NULL;
- if (static_branch_unlikely(&tcp_md5_needed)) {
+ if (static_branch_unlikely(&tcp_md5_needed.key)) {
struct tcp_md5sig_key *key;
key = tp->af_specific->md5_lookup(sk, sk);
if (key) {
tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
- BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
+ if (!tcptw->tw_md5_key)
+ break;
+ BUG_ON(!tcp_alloc_md5sig_pool());
+ if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) {
+ kfree(tcptw->tw_md5_key);
+ tcptw->tw_md5_key = NULL;
+ }
}
}
} while (0);
@@ -337,11 +343,13 @@ EXPORT_SYMBOL(tcp_time_wait);
void tcp_twsk_destructor(struct sock *sk)
{
#ifdef CONFIG_TCP_MD5SIG
- if (static_branch_unlikely(&tcp_md5_needed)) {
+ if (static_branch_unlikely(&tcp_md5_needed.key)) {
struct tcp_timewait_sock *twsk = tcp_twsk(sk);
- if (twsk->tw_md5_key)
+ if (twsk->tw_md5_key) {
kfree_rcu(twsk->tw_md5_key, rcu);
+ static_branch_slow_dec_deferred(&tcp_md5_needed);
+ }
}
#endif
}
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 894410dc9293..71d01cf3c13e 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -766,7 +766,7 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb,
*md5 = NULL;
#ifdef CONFIG_TCP_MD5SIG
- if (static_branch_unlikely(&tcp_md5_needed) &&
+ if (static_branch_unlikely(&tcp_md5_needed.key) &&
rcu_access_pointer(tp->md5sig_info)) {
*md5 = tp->af_specific->md5_lookup(sk, sk);
if (*md5) {
@@ -922,7 +922,7 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb
*md5 = NULL;
#ifdef CONFIG_TCP_MD5SIG
- if (static_branch_unlikely(&tcp_md5_needed) &&
+ if (static_branch_unlikely(&tcp_md5_needed.key) &&
rcu_access_pointer(tp->md5sig_info)) {
*md5 = tp->af_specific->md5_lookup(sk, sk);
if (*md5) {