diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 0ddc4c718833..5e913b62929e 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -655,6 +655,14 @@ static void sk_psock_backlog(struct work_struct *work) bool ingress; int ret; + /* Increment the psock refcnt to synchronize with close(fd) path in + * sock_map_close(), ensuring we wait for backlog thread completion + * before sk_socket freed. If refcnt increment fails, it indicates + * sock_map_close() completed with sk_socket potentially already freed. + */ + if (!sk_psock_get(psock->sk)) + return; + mutex_lock(&psock->work_mutex); if (unlikely(state->len)) { len = state->len; @@ -702,6 +710,7 @@ static void sk_psock_backlog(struct work_struct *work) } end: mutex_unlock(&psock->work_mutex); + sk_psock_put(psock->sk, psock); } struct sk_psock *sk_psock_init(struct sock *sk, int node) @@ -1222,17 +1231,24 @@ static int sk_psock_verdict_recv(struct sock *sk, struct sk_buff *skb) static void sk_psock_verdict_data_ready(struct sock *sk) { - struct socket *sock = sk->sk_socket; + struct socket *sock; const struct proto_ops *ops; int copied; trace_sk_data_ready(sk); - if (unlikely(!sock)) + rcu_read_lock(); + sock = sk->sk_socket; + if (unlikely(!sock)) { + rcu_read_unlock(); return; + } ops = READ_ONCE(sock->ops); - if (!ops || !ops->read_skb) + if (!ops || !ops->read_skb) { + rcu_read_unlock(); return; + } + rcu_read_unlock(); copied = ops->read_skb(sk, sk_psock_verdict_recv); if (copied >= 0) { struct sk_psock *psock; diff --git a/tools/testing/selftests/bpf/prog_tests/socket_helpers.h b/tools/testing/selftests/bpf/prog_tests/socket_helpers.h index 1bdfb79ef009..a805143dd84f 100644 --- a/tools/testing/selftests/bpf/prog_tests/socket_helpers.h +++ b/tools/testing/selftests/bpf/prog_tests/socket_helpers.h @@ -313,11 +313,22 @@ static inline int recv_timeout(int fd, void *buf, size_t len, int flags, static inline int create_pair(int family, int sotype, int *p0, int *p1) { - __close_fd int s, c = -1, p = -1; + __close_fd int s = -1, c = -1, p = -1; struct sockaddr_storage addr; socklen_t len = sizeof(addr); int err; + if (family == AF_UNIX) { + int fds[2]; + + err = socketpair(family, sotype, 0, fds); + if (!err) { + *p0 = fds[0]; + *p1 = fds[1]; + } + return err; + } + s = socket_loopback(family, sotype); if (s < 0) return s; diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c index 1e3e4392dcca..c72357f41035 100644 --- a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c @@ -1042,6 +1042,59 @@ static void test_sockmap_vsock_unconnected(void) xclose(map); } +void *close_thread(void *arg) +{ + int *fd = (int *)arg; + + sleep(1); + close(*fd); + *fd = -1; + return NULL; +} + +void test_sockmap_with_close_on_write(int family, int sotype) +{ + struct test_sockmap_pass_prog *skel; + int err, map, verdict; + pthread_t tid; + int zero = 0; + int c = -1, p = -1; + + skel = test_sockmap_pass_prog__open_and_load(); + if (!ASSERT_OK_PTR(skel, "open_and_load")) + return; + + verdict = bpf_program__fd(skel->progs.prog_skb_verdict); + map = bpf_map__fd(skel->maps.sock_map_rx); + + err = bpf_prog_attach(verdict, map, BPF_SK_SKB_STREAM_VERDICT, 0); + if (!ASSERT_OK(err, "bpf_prog_attach")) + goto out; + + err = create_pair(family, sotype, &c, &p); + if (!ASSERT_OK(err, "create_pair")) + goto out; + + err = bpf_map_update_elem(map, &zero, &p, BPF_ANY); + if (!ASSERT_OK(err, "bpf_map_update_elem")) + goto out; + + err = pthread_create(&tid, 0, close_thread, &p); + if (!ASSERT_OK(err, "pthread_create")) + goto out; + + while (p > 0) + send(c, "a", 1, MSG_NOSIGNAL); + + pthread_join(tid, NULL); +out: + if (c > 0) + close(c); + if (p > 0) + close(p); + test_sockmap_pass_prog__destroy(skel); +} + void test_sockmap_basic(void) { if (test__start_subtest("sockmap create_update_free")) @@ -1108,4 +1161,11 @@ void test_sockmap_basic(void) test_sockmap_skb_verdict_vsock_poll(); if (test__start_subtest("sockmap vsock unconnected")) test_sockmap_vsock_unconnected(); + if (test__start_subtest("sockmap with write on close")) { + test_sockmap_with_close_on_write(AF_UNIX, SOCK_STREAM); + test_sockmap_with_close_on_write(AF_UNIX, SOCK_DGRAM); + test_sockmap_with_close_on_write(AF_INET, SOCK_STREAM); + test_sockmap_with_close_on_write(AF_INET, SOCK_DGRAM); + test_sockmap_with_close_on_write(AF_VSOCK, SOCK_STREAM); + } }