diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index c953b8c0d2f43..888a4b217829f 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -100,6 +100,11 @@ struct sk_psock { void (*saved_close)(struct sock *sk, long timeout); void (*saved_write_space)(struct sock *sk); void (*saved_data_ready)(struct sock *sk); + /* psock_update_sk_prot may be called with restore=false many times + * so the handler must be safe for this case. It will be called + * exactly once with restore=true when the psock is being destroyed + * and psock refcnt is zero, but before an RCU grace period. + */ int (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock, bool restore); struct proto *sk_proto; diff --git a/net/unix/unix_bpf.c b/net/unix/unix_bpf.c index 7ea7c3a0d0d06..bd84785bf8d6c 100644 --- a/net/unix/unix_bpf.c +++ b/net/unix/unix_bpf.c @@ -161,15 +161,30 @@ int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool r { struct sock *sk_pair; + /* Restore does not decrement the sk_pair reference yet because we must + * keep the a reference to the socket until after an RCU grace period + * and any pending sends have completed. + */ if (restore) { sk->sk_write_space = psock->saved_write_space; sock_replace_proto(sk, psock->sk_proto); return 0; } - sk_pair = unix_peer(sk); - sock_hold(sk_pair); - psock->sk_pair = sk_pair; + /* psock_update_sk_prot can be called multiple times if psock is + * added to multiple maps and/or slots in the same map. There is + * also an edge case where replacing a psock with itself can trigger + * an extra psock_update_sk_prot during the insert process. So it + * must be safe to do multiple calls. Here we need to ensure we don't + * increment the refcnt through sock_hold many times. There will only + * be a single matching destroy operation. + */ + if (!psock->sk_pair) { + sk_pair = unix_peer(sk); + sock_hold(sk_pair); + psock->sk_pair = sk_pair; + } + unix_stream_bpf_check_needs_rebuild(psock->sk_proto); sock_replace_proto(sk, &unix_stream_bpf_prot); return 0; diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c index 7c2241fae19a6..77e26ecffa9d7 100644 --- a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c @@ -555,6 +555,213 @@ static void test_sockmap_unconnected_unix(void) close(dgram); } +static void test_sockmap_many_socket(void) +{ + struct test_sockmap_pass_prog *skel; + int stream[2], dgram, udp, tcp; + int i, err, map, entry = 0; + + skel = test_sockmap_pass_prog__open_and_load(); + if (!ASSERT_OK_PTR(skel, "open_and_load")) + return; + + map = bpf_map__fd(skel->maps.sock_map_rx); + + dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0); + if (dgram < 0) { + test_sockmap_pass_prog__destroy(skel); + return; + } + + tcp = connected_socket_v4(); + if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) { + close(dgram); + test_sockmap_pass_prog__destroy(skel); + return; + } + + udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0); + if (udp < 0) { + close(dgram); + close(tcp); + test_sockmap_pass_prog__destroy(skel); + return; + } + + err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream); + ASSERT_OK(err, "socketpair(af_unix, sock_stream)"); + if (err) + goto out; + + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map, &entry, &stream[0], BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(stream)"); + } + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map, &entry, &dgram, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(dgram)"); + } + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map, &entry, &udp, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(udp)"); + } + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map, &entry, &tcp, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(tcp)"); + } + for (entry--; entry >= 0; entry--) { + err = bpf_map_delete_elem(map, &entry); + ASSERT_OK(err, "bpf_map_delete_elem(entry)"); + } + + close(stream[0]); + close(stream[1]); +out: + close(dgram); + close(tcp); + close(udp); + test_sockmap_pass_prog__destroy(skel); +} + +static void test_sockmap_many_maps(void) +{ + struct test_sockmap_pass_prog *skel; + int stream[2], dgram, udp, tcp; + int i, err, map[2], entry = 0; + + skel = test_sockmap_pass_prog__open_and_load(); + if (!ASSERT_OK_PTR(skel, "open_and_load")) + return; + + map[0] = bpf_map__fd(skel->maps.sock_map_rx); + map[1] = bpf_map__fd(skel->maps.sock_map_tx); + + dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0); + if (dgram < 0) { + test_sockmap_pass_prog__destroy(skel); + return; + } + + tcp = connected_socket_v4(); + if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) { + close(dgram); + test_sockmap_pass_prog__destroy(skel); + return; + } + + udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0); + if (udp < 0) { + close(dgram); + close(tcp); + test_sockmap_pass_prog__destroy(skel); + return; + } + + err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream); + ASSERT_OK(err, "socketpair(af_unix, sock_stream)"); + if (err) + goto out; + + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map[i], &entry, &stream[0], BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(stream)"); + } + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map[i], &entry, &dgram, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(dgram)"); + } + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map[i], &entry, &udp, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(udp)"); + } + for (i = 0; i < 2; i++, entry++) { + err = bpf_map_update_elem(map[i], &entry, &tcp, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(tcp)"); + } + for (entry--; entry >= 0; entry--) { + err = bpf_map_delete_elem(map[1], &entry); + entry--; + ASSERT_OK(err, "bpf_map_delete_elem(entry)"); + err = bpf_map_delete_elem(map[0], &entry); + ASSERT_OK(err, "bpf_map_delete_elem(entry)"); + } + + close(stream[0]); + close(stream[1]); +out: + close(dgram); + close(tcp); + close(udp); + test_sockmap_pass_prog__destroy(skel); +} + +static void test_sockmap_same_sock(void) +{ + struct test_sockmap_pass_prog *skel; + int stream[2], dgram, udp, tcp; + int i, err, map, zero = 0; + + skel = test_sockmap_pass_prog__open_and_load(); + if (!ASSERT_OK_PTR(skel, "open_and_load")) + return; + + map = bpf_map__fd(skel->maps.sock_map_rx); + + dgram = xsocket(AF_UNIX, SOCK_DGRAM, 0); + if (dgram < 0) { + test_sockmap_pass_prog__destroy(skel); + return; + } + + tcp = connected_socket_v4(); + if (!ASSERT_GE(tcp, 0, "connected_socket_v4")) { + close(dgram); + test_sockmap_pass_prog__destroy(skel); + return; + } + + udp = xsocket(AF_INET, SOCK_DGRAM | SOCK_NONBLOCK, 0); + if (udp < 0) { + close(dgram); + close(tcp); + test_sockmap_pass_prog__destroy(skel); + return; + } + + err = socketpair(AF_UNIX, SOCK_STREAM, 0, stream); + ASSERT_OK(err, "socketpair(af_unix, sock_stream)"); + if (err) + goto out; + + for (i = 0; i < 2; i++) { + err = bpf_map_update_elem(map, &zero, &stream[0], BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(stream)"); + } + for (i = 0; i < 2; i++) { + err = bpf_map_update_elem(map, &zero, &dgram, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(dgram)"); + } + for (i = 0; i < 2; i++) { + err = bpf_map_update_elem(map, &zero, &udp, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(udp)"); + } + for (i = 0; i < 2; i++) { + err = bpf_map_update_elem(map, &zero, &tcp, BPF_ANY); + ASSERT_OK(err, "bpf_map_update_elem(tcp)"); + } + + err = bpf_map_delete_elem(map, &zero); + ASSERT_OK(err, "bpf_map_delete_elem(entry)"); + + close(stream[0]); + close(stream[1]); +out: + close(dgram); + close(tcp); + close(udp); + test_sockmap_pass_prog__destroy(skel); +} + void test_sockmap_basic(void) { if (test__start_subtest("sockmap create_update_free")) @@ -597,7 +804,12 @@ void test_sockmap_basic(void) test_sockmap_skb_verdict_fionread(false); if (test__start_subtest("sockmap skb_verdict msg_f_peek")) test_sockmap_skb_verdict_peek(); - if (test__start_subtest("sockmap unconnected af_unix")) test_sockmap_unconnected_unix(); + if (test__start_subtest("sockmap one socket to many map entries")) + test_sockmap_many_socket(); + if (test__start_subtest("sockmap one socket to many maps")) + test_sockmap_many_maps(); + if (test__start_subtest("sockmap same socket replace")) + test_sockmap_same_sock(); }