From 97b9af7a70936e331170c79040cc9bf20071b566 Mon Sep 17 00:00:00 2001 From: Wen Gu Date: Fri, 22 Apr 2022 15:56:18 +0800 Subject: [PATCH 1/2] net/smc: Only save the original clcsock callback functions Both listen and fallback process will save the current clcsock callback functions and establish new ones. But if both of them happen, the saved callback functions will be overwritten. So this patch introduces some helpers to ensure that only save the original callback functions of clcsock. Fixes: 341adeec9ada ("net/smc: Forward wakeup to smc socket waitqueue after fallback") Signed-off-by: Wen Gu Acked-by: Karsten Graul Signed-off-by: Jakub Kicinski --- net/smc/af_smc.c | 55 +++++++++++++++++++++++++++++---------------- net/smc/smc.h | 29 ++++++++++++++++++++++++ net/smc/smc_close.c | 3 ++- 3 files changed, 67 insertions(+), 20 deletions(-) diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index bbb1a4ce50505..d8433f17c5c94 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -373,6 +373,7 @@ static struct sock *smc_sock_alloc(struct net *net, struct socket *sock, sk->sk_prot->hash(sk); sk_refcnt_debug_inc(sk); mutex_init(&smc->clcsock_release_lock); + smc_init_saved_callbacks(smc); return sk; } @@ -782,9 +783,24 @@ static void smc_fback_error_report(struct sock *clcsk) smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_error_report); } +static void smc_fback_replace_callbacks(struct smc_sock *smc) +{ + struct sock *clcsk = smc->clcsock->sk; + + clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + + smc_clcsock_replace_cb(&clcsk->sk_state_change, smc_fback_state_change, + &smc->clcsk_state_change); + smc_clcsock_replace_cb(&clcsk->sk_data_ready, smc_fback_data_ready, + &smc->clcsk_data_ready); + smc_clcsock_replace_cb(&clcsk->sk_write_space, smc_fback_write_space, + &smc->clcsk_write_space); + smc_clcsock_replace_cb(&clcsk->sk_error_report, smc_fback_error_report, + &smc->clcsk_error_report); +} + static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) { - struct sock *clcsk; int rc = 0; mutex_lock(&smc->clcsock_release_lock); @@ -792,10 +808,7 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) rc = -EBADF; goto out; } - clcsk = smc->clcsock->sk; - if (smc->use_fallback) - goto out; smc->use_fallback = true; smc->fallback_rsn = reason_code; smc_stat_fallback(smc); @@ -810,18 +823,7 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) * in smc sk->sk_wq and they should be woken up * as clcsock's wait queue is woken up. */ - smc->clcsk_state_change = clcsk->sk_state_change; - smc->clcsk_data_ready = clcsk->sk_data_ready; - smc->clcsk_write_space = clcsk->sk_write_space; - smc->clcsk_error_report = clcsk->sk_error_report; - - clcsk->sk_state_change = smc_fback_state_change; - clcsk->sk_data_ready = smc_fback_data_ready; - clcsk->sk_write_space = smc_fback_write_space; - clcsk->sk_error_report = smc_fback_error_report; - - smc->clcsock->sk->sk_user_data = - (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + smc_fback_replace_callbacks(smc); } out: mutex_unlock(&smc->clcsock_release_lock); @@ -1596,6 +1598,19 @@ static int smc_clcsock_accept(struct smc_sock *lsmc, struct smc_sock **new_smc) * function; switch it back to the original sk_data_ready function */ new_clcsock->sk->sk_data_ready = lsmc->clcsk_data_ready; + + /* if new clcsock has also inherited the fallback-specific callback + * functions, switch them back to the original ones. + */ + if (lsmc->use_fallback) { + if (lsmc->clcsk_state_change) + new_clcsock->sk->sk_state_change = lsmc->clcsk_state_change; + if (lsmc->clcsk_write_space) + new_clcsock->sk->sk_write_space = lsmc->clcsk_write_space; + if (lsmc->clcsk_error_report) + new_clcsock->sk->sk_error_report = lsmc->clcsk_error_report; + } + (*new_smc)->clcsock = new_clcsock; out: return rc; @@ -2397,10 +2412,10 @@ static int smc_listen(struct socket *sock, int backlog) /* save original sk_data_ready function and establish * smc-specific sk_data_ready function */ - smc->clcsk_data_ready = smc->clcsock->sk->sk_data_ready; - smc->clcsock->sk->sk_data_ready = smc_clcsock_data_ready; smc->clcsock->sk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + smc_clcsock_replace_cb(&smc->clcsock->sk->sk_data_ready, + smc_clcsock_data_ready, &smc->clcsk_data_ready); /* save original ops */ smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops; @@ -2415,7 +2430,9 @@ static int smc_listen(struct socket *sock, int backlog) rc = kernel_listen(smc->clcsock, backlog); if (rc) { - smc->clcsock->sk->sk_data_ready = smc->clcsk_data_ready; + smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, + &smc->clcsk_data_ready); + smc->clcsock->sk->sk_user_data = NULL; goto out; } sk->sk_max_ack_backlog = backlog; diff --git a/net/smc/smc.h b/net/smc/smc.h index ea0620529ebea..5ed765ea0c731 100644 --- a/net/smc/smc.h +++ b/net/smc/smc.h @@ -288,12 +288,41 @@ static inline struct smc_sock *smc_sk(const struct sock *sk) return (struct smc_sock *)sk; } +static inline void smc_init_saved_callbacks(struct smc_sock *smc) +{ + smc->clcsk_state_change = NULL; + smc->clcsk_data_ready = NULL; + smc->clcsk_write_space = NULL; + smc->clcsk_error_report = NULL; +} + static inline struct smc_sock *smc_clcsock_user_data(const struct sock *clcsk) { return (struct smc_sock *) ((uintptr_t)clcsk->sk_user_data & ~SK_USER_DATA_NOCOPY); } +/* save target_cb in saved_cb, and replace target_cb with new_cb */ +static inline void smc_clcsock_replace_cb(void (**target_cb)(struct sock *), + void (*new_cb)(struct sock *), + void (**saved_cb)(struct sock *)) +{ + /* only save once */ + if (!*saved_cb) + *saved_cb = *target_cb; + *target_cb = new_cb; +} + +/* restore target_cb to saved_cb, and reset saved_cb to NULL */ +static inline void smc_clcsock_restore_cb(void (**target_cb)(struct sock *), + void (**saved_cb)(struct sock *)) +{ + if (!*saved_cb) + return; + *target_cb = *saved_cb; + *saved_cb = NULL; +} + extern struct workqueue_struct *smc_hs_wq; /* wq for handshake work */ extern struct workqueue_struct *smc_close_wq; /* wq for close work */ diff --git a/net/smc/smc_close.c b/net/smc/smc_close.c index 676cb2333d3c4..7bd1ef55b9dfd 100644 --- a/net/smc/smc_close.c +++ b/net/smc/smc_close.c @@ -214,7 +214,8 @@ int smc_close_active(struct smc_sock *smc) sk->sk_state = SMC_CLOSED; sk->sk_state_change(sk); /* wake up accept */ if (smc->clcsock && smc->clcsock->sk) { - smc->clcsock->sk->sk_data_ready = smc->clcsk_data_ready; + smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, + &smc->clcsk_data_ready); smc->clcsock->sk->sk_user_data = NULL; rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR); } From 0558226cebee256aa3f8ec0cc5a800a10bf120a6 Mon Sep 17 00:00:00 2001 From: Wen Gu Date: Fri, 22 Apr 2022 15:56:19 +0800 Subject: [PATCH 2/2] net/smc: Fix slab-out-of-bounds issue in fallback syzbot reported a slab-out-of-bounds/use-after-free issue, which was caused by accessing an already freed smc sock in fallback-specific callback functions of clcsock. This patch fixes the issue by restoring fallback-specific callback functions to original ones and resetting clcsock sk_user_data to NULL before freeing smc sock. Meanwhile, this patch introduces sk_callback_lock to make the access and assignment to sk_user_data mutually exclusive. Reported-by: syzbot+b425899ed22c6943e00b@syzkaller.appspotmail.com Fixes: 341adeec9ada ("net/smc: Forward wakeup to smc socket waitqueue after fallback") Link: https://lore.kernel.org/r/00000000000013ca8105d7ae3ada@google.com/ Signed-off-by: Wen Gu Acked-by: Karsten Graul Signed-off-by: Jakub Kicinski --- net/smc/af_smc.c | 80 ++++++++++++++++++++++++++++++++------------- net/smc/smc_close.c | 2 ++ 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index d8433f17c5c94..fce16b9d6e1a4 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -243,11 +243,27 @@ struct proto smc_proto6 = { }; EXPORT_SYMBOL_GPL(smc_proto6); +static void smc_fback_restore_callbacks(struct smc_sock *smc) +{ + struct sock *clcsk = smc->clcsock->sk; + + write_lock_bh(&clcsk->sk_callback_lock); + clcsk->sk_user_data = NULL; + + smc_clcsock_restore_cb(&clcsk->sk_state_change, &smc->clcsk_state_change); + smc_clcsock_restore_cb(&clcsk->sk_data_ready, &smc->clcsk_data_ready); + smc_clcsock_restore_cb(&clcsk->sk_write_space, &smc->clcsk_write_space); + smc_clcsock_restore_cb(&clcsk->sk_error_report, &smc->clcsk_error_report); + + write_unlock_bh(&clcsk->sk_callback_lock); +} + static void smc_restore_fallback_changes(struct smc_sock *smc) { if (smc->clcsock->file) { /* non-accepted sockets have no file yet */ smc->clcsock->file->private_data = smc->sk.sk_socket; smc->clcsock->file = NULL; + smc_fback_restore_callbacks(smc); } } @@ -745,48 +761,57 @@ static void smc_fback_forward_wakeup(struct smc_sock *smc, struct sock *clcsk, static void smc_fback_state_change(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; - if (!smc) - return; - smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_state_change); + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_state_change); + read_unlock_bh(&clcsk->sk_callback_lock); } static void smc_fback_data_ready(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; - if (!smc) - return; - smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_data_ready); + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_data_ready); + read_unlock_bh(&clcsk->sk_callback_lock); } static void smc_fback_write_space(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; - if (!smc) - return; - smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_write_space); + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_write_space); + read_unlock_bh(&clcsk->sk_callback_lock); } static void smc_fback_error_report(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; - if (!smc) - return; - smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_error_report); + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_error_report); + read_unlock_bh(&clcsk->sk_callback_lock); } static void smc_fback_replace_callbacks(struct smc_sock *smc) { struct sock *clcsk = smc->clcsock->sk; + write_lock_bh(&clcsk->sk_callback_lock); clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); smc_clcsock_replace_cb(&clcsk->sk_state_change, smc_fback_state_change, @@ -797,6 +822,8 @@ static void smc_fback_replace_callbacks(struct smc_sock *smc) &smc->clcsk_write_space); smc_clcsock_replace_cb(&clcsk->sk_error_report, smc_fback_error_report, &smc->clcsk_error_report); + + write_unlock_bh(&clcsk->sk_callback_lock); } static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) @@ -2370,17 +2397,20 @@ static void smc_tcp_listen_work(struct work_struct *work) static void smc_clcsock_data_ready(struct sock *listen_clcsock) { - struct smc_sock *lsmc = - smc_clcsock_user_data(listen_clcsock); + struct smc_sock *lsmc; + read_lock_bh(&listen_clcsock->sk_callback_lock); + lsmc = smc_clcsock_user_data(listen_clcsock); if (!lsmc) - return; + goto out; lsmc->clcsk_data_ready(listen_clcsock); if (lsmc->sk.sk_state == SMC_LISTEN) { sock_hold(&lsmc->sk); /* sock_put in smc_tcp_listen_work() */ if (!queue_work(smc_tcp_ls_wq, &lsmc->tcp_listen_work)) sock_put(&lsmc->sk); } +out: + read_unlock_bh(&listen_clcsock->sk_callback_lock); } static int smc_listen(struct socket *sock, int backlog) @@ -2412,10 +2442,12 @@ static int smc_listen(struct socket *sock, int backlog) /* save original sk_data_ready function and establish * smc-specific sk_data_ready function */ + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); smc->clcsock->sk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); smc_clcsock_replace_cb(&smc->clcsock->sk->sk_data_ready, smc_clcsock_data_ready, &smc->clcsk_data_ready); + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); /* save original ops */ smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops; @@ -2430,9 +2462,11 @@ static int smc_listen(struct socket *sock, int backlog) rc = kernel_listen(smc->clcsock, backlog); if (rc) { + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, &smc->clcsk_data_ready); smc->clcsock->sk->sk_user_data = NULL; + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); goto out; } sk->sk_max_ack_backlog = backlog; diff --git a/net/smc/smc_close.c b/net/smc/smc_close.c index 7bd1ef55b9dfd..31db7438857c9 100644 --- a/net/smc/smc_close.c +++ b/net/smc/smc_close.c @@ -214,9 +214,11 @@ int smc_close_active(struct smc_sock *smc) sk->sk_state = SMC_CLOSED; sk->sk_state_change(sk); /* wake up accept */ if (smc->clcsock && smc->clcsock->sk) { + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, &smc->clcsk_data_ready); smc->clcsock->sk->sk_user_data = NULL; + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR); } smc_close_cleanup_listen(sk);