Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpf: Fix use-after-free of sockmap #8686

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions net/core/skmsg.c
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,14 @@ static void sk_psock_backlog(struct work_struct *work)
bool ingress;
int ret;

/* Increment the psock refcnt to synchronize with close(fd) path in
* sock_map_close(), ensuring we wait for backlog thread completion
* before sk_socket freed. If refcnt increment fails, it indicates
* sock_map_close() completed with sk_socket potentially already freed.
*/
if (!sk_psock_get(psock->sk))
return;

mutex_lock(&psock->work_mutex);
if (unlikely(state->len)) {
len = state->len;
Expand Down Expand Up @@ -702,6 +710,7 @@ static void sk_psock_backlog(struct work_struct *work)
}
end:
mutex_unlock(&psock->work_mutex);
sk_psock_put(psock->sk, psock);
}

struct sk_psock *sk_psock_init(struct sock *sk, int node)
Expand Down Expand Up @@ -1222,17 +1231,24 @@ static int sk_psock_verdict_recv(struct sock *sk, struct sk_buff *skb)

static void sk_psock_verdict_data_ready(struct sock *sk)
{
struct socket *sock = sk->sk_socket;
struct socket *sock;
const struct proto_ops *ops;
int copied;

trace_sk_data_ready(sk);

if (unlikely(!sock))
rcu_read_lock();
sock = sk->sk_socket;
if (unlikely(!sock)) {
rcu_read_unlock();
return;
}
ops = READ_ONCE(sock->ops);
if (!ops || !ops->read_skb)
if (!ops || !ops->read_skb) {
rcu_read_unlock();
return;
}
rcu_read_unlock();
copied = ops->read_skb(sk, sk_psock_verdict_recv);
if (copied >= 0) {
struct sk_psock *psock;
Expand Down
13 changes: 12 additions & 1 deletion tools/testing/selftests/bpf/prog_tests/socket_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,22 @@ static inline int recv_timeout(int fd, void *buf, size_t len, int flags,

static inline int create_pair(int family, int sotype, int *p0, int *p1)
{
__close_fd int s, c = -1, p = -1;
__close_fd int s = -1, c = -1, p = -1;
struct sockaddr_storage addr;
socklen_t len = sizeof(addr);
int err;

if (family == AF_UNIX) {
int fds[2];

err = socketpair(family, sotype, 0, fds);
if (!err) {
*p0 = fds[0];
*p1 = fds[1];
}
return err;
}

s = socket_loopback(family, sotype);
if (s < 0)
return s;
Expand Down
60 changes: 60 additions & 0 deletions tools/testing/selftests/bpf/prog_tests/sockmap_basic.c
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,59 @@ static void test_sockmap_vsock_unconnected(void)
xclose(map);
}

void *close_thread(void *arg)
{
int *fd = (int *)arg;

sleep(1);
close(*fd);
*fd = -1;
return NULL;
}

void test_sockmap_with_close_on_write(int family, int sotype)
{
struct test_sockmap_pass_prog *skel;
int err, map, verdict;
pthread_t tid;
int zero = 0;
int c = -1, p = -1;

skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;

verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
map = bpf_map__fd(skel->maps.sock_map_rx);

err = bpf_prog_attach(verdict, map, BPF_SK_SKB_STREAM_VERDICT, 0);
if (!ASSERT_OK(err, "bpf_prog_attach"))
goto out;

err = create_pair(family, sotype, &c, &p);
if (!ASSERT_OK(err, "create_pair"))
goto out;

err = bpf_map_update_elem(map, &zero, &p, BPF_ANY);
if (!ASSERT_OK(err, "bpf_map_update_elem"))
goto out;

err = pthread_create(&tid, 0, close_thread, &p);
if (!ASSERT_OK(err, "pthread_create"))
goto out;

while (p > 0)
send(c, "a", 1, MSG_NOSIGNAL);

pthread_join(tid, NULL);
out:
if (c > 0)
close(c);
if (p > 0)
close(p);
test_sockmap_pass_prog__destroy(skel);
}

void test_sockmap_basic(void)
{
if (test__start_subtest("sockmap create_update_free"))
Expand Down Expand Up @@ -1108,4 +1161,11 @@ void test_sockmap_basic(void)
test_sockmap_skb_verdict_vsock_poll();
if (test__start_subtest("sockmap vsock unconnected"))
test_sockmap_vsock_unconnected();
if (test__start_subtest("sockmap with write on close")) {
test_sockmap_with_close_on_write(AF_UNIX, SOCK_STREAM);
test_sockmap_with_close_on_write(AF_UNIX, SOCK_DGRAM);
test_sockmap_with_close_on_write(AF_INET, SOCK_STREAM);
test_sockmap_with_close_on_write(AF_INET, SOCK_DGRAM);
test_sockmap_with_close_on_write(AF_VSOCK, SOCK_STREAM);
}
}
Loading