summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/linux/skmsg.h20
-rw-r--r--net/ipv4/tcp_bpf.c16
2 files changed, 17 insertions, 19 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 4d3d75d63066..2be51b7a5800 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -347,11 +347,23 @@ static inline void sk_psock_update_proto(struct sock *sk,
struct sk_psock *psock,
struct proto *ops)
{
- psock->saved_unhash = sk->sk_prot->unhash;
- psock->saved_close = sk->sk_prot->close;
- psock->saved_write_space = sk->sk_write_space;
+ /* Initialize saved callbacks and original proto only once, since this
+ * function may be called multiple times for a psock, e.g. when
+ * psock->progs.msg_parser is updated.
+ *
+ * Since we've not installed the new proto, psock is not yet in use and
+ * we can initialize it without synchronization.
+ */
+ if (!psock->sk_proto) {
+ struct proto *orig = READ_ONCE(sk->sk_prot);
+
+ psock->saved_unhash = orig->unhash;
+ psock->saved_close = orig->close;
+ psock->saved_write_space = sk->sk_write_space;
+
+ psock->sk_proto = orig;
+ }
- psock->sk_proto = sk->sk_prot;
/* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, ops);
}
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 7d6e1b75d4d4..3327afa05c3d 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -637,20 +637,6 @@ static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
}
-static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
-{
- int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
- int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
-
- /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
- * or added requiring sk_prot hook updates. We keep original saved
- * hooks in this case.
- *
- * Pairs with lockless read in sk_clone_lock().
- */
- WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
-}
-
static int tcp_bpf_assert_proto_ops(struct proto *ops)
{
/* In order to avoid retpoline, we make assumptions when we call
@@ -670,7 +656,7 @@ void tcp_bpf_reinit(struct sock *sk)
rcu_read_lock();
psock = sk_psock(sk);
- tcp_bpf_reinit_sk_prot(sk, psock);
+ tcp_bpf_update_sk_prot(sk, psock);
rcu_read_unlock();
}