diff options
Diffstat (limited to 'include/net/inet_hashtables.h')
-rw-r--r-- | include/net/inet_hashtables.h | 49 |
1 files changed, 47 insertions, 2 deletions
diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h index c0532cc7587f..1177effabed3 100644 --- a/include/net/inet_hashtables.h +++ b/include/net/inet_hashtables.h @@ -449,6 +449,46 @@ static inline struct sock *inet_lookup(struct net *net, return sk; } +static inline +struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff, + const __be32 saddr, const __be16 sport, + const __be32 daddr, const __be16 dport, + bool *refcounted, inet_ehashfn_t *ehashfn) +{ + struct sock *sk, *reuse_sk; + bool prefetched; + + sk = skb_steal_sock(skb, refcounted, &prefetched); + if (!sk) + return NULL; + + if (!prefetched) + return sk; + + if (sk->sk_protocol == IPPROTO_TCP) { + if (sk->sk_state != TCP_LISTEN) + return sk; + } else if (sk->sk_protocol == IPPROTO_UDP) { + if (sk->sk_state != TCP_CLOSE) + return sk; + } else { + return sk; + } + + reuse_sk = inet_lookup_reuseport(net, sk, skb, doff, + saddr, sport, daddr, ntohs(dport), + ehashfn); + if (!reuse_sk) + return sk; + + /* We've chosen a new reuseport sock which is never refcounted. This + * implies that sk also isn't refcounted. + */ + WARN_ON_ONCE(*refcounted); + + return reuse_sk; +} + static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo, struct sk_buff *skb, int doff, @@ -457,13 +497,18 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo, const int sdif, bool *refcounted) { - struct sock *sk = skb_steal_sock(skb, refcounted); + struct net *net = dev_net(skb_dst(skb)->dev); const struct iphdr *iph = ip_hdr(skb); + struct sock *sk; + sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport, + refcounted, inet_ehashfn); + if (IS_ERR(sk)) + return NULL; if (sk) return sk; - return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb, + return __inet_lookup(net, hashinfo, skb, doff, iph->saddr, sport, iph->daddr, dport, inet_iif(skb), sdif, refcounted); |