Skip to content

Commit

Permalink
Merge branch 'mptcp-annotate-lockless'
Browse files Browse the repository at this point in the history
Matthieu Baerts says:

====================
mptcp: annotate lockless access

This is a series of 5 patches from Paolo to annotate lockless access.

The MPTCP locking schema is already quite complex. We need to clarify it
and make the lockless access already there consistent, or later changes
will be even harder to follow and understand.

This series goes through all the msk fields accessed in the RX and TX
path and makes the lockless annotation consistent with the in-use
locking schema.

As a bonus, this should fix data races eventually found by fuzzers --
even if we haven't seen many such reports so far.

Patch 1/5 hints we could remove "local_key" and "remote_key" from the
subflow context, and always use the ones from the msk socket, possibly
reducing the context memory usage. That change is left over as a
possible follow-up.
====================

Signed-off-by: Matthieu Baerts (NGI0) <matttbe@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
David S. Miller committed Feb 5, 2024
2 parents 1e08223 + 28e5c13 commit c3b39ea
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 49 deletions.
20 changes: 10 additions & 10 deletions net/mptcp/options.c
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,8 @@ static bool mptcp_established_options_add_addr(struct sock *sk, struct sk_buff *
opts->suboptions |= OPTION_MPTCP_ADD_ADDR;
if (!echo) {
MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_ADDADDRTX);
opts->ahmac = add_addr_generate_hmac(msk->local_key,
msk->remote_key,
opts->ahmac = add_addr_generate_hmac(READ_ONCE(msk->local_key),
READ_ONCE(msk->remote_key),
&opts->addr);
} else {
MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_ECHOADDTX);
Expand Down Expand Up @@ -792,7 +792,7 @@ static bool mptcp_established_options_fastclose(struct sock *sk,

*size = TCPOLEN_MPTCP_FASTCLOSE;
opts->suboptions |= OPTION_MPTCP_FASTCLOSE;
opts->rcvr_key = msk->remote_key;
opts->rcvr_key = READ_ONCE(msk->remote_key);

pr_debug("FASTCLOSE key=%llu", opts->rcvr_key);
MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPFASTCLOSETX);
Expand Down Expand Up @@ -1030,7 +1030,7 @@ u64 __mptcp_expand_seq(u64 old_seq, u64 cur_seq)
static void __mptcp_snd_una_update(struct mptcp_sock *msk, u64 new_snd_una)
{
msk->bytes_acked += new_snd_una - msk->snd_una;
msk->snd_una = new_snd_una;
WRITE_ONCE(msk->snd_una, new_snd_una);
}

static void ack_update_msk(struct mptcp_sock *msk,
Expand All @@ -1057,10 +1057,10 @@ static void ack_update_msk(struct mptcp_sock *msk,
new_wnd_end = new_snd_una + tcp_sk(ssk)->snd_wnd;

if (after64(new_wnd_end, msk->wnd_end))
msk->wnd_end = new_wnd_end;
WRITE_ONCE(msk->wnd_end, new_wnd_end);

/* this assumes mptcp_incoming_options() is invoked after tcp_ack() */
if (after64(msk->wnd_end, READ_ONCE(msk->snd_nxt)))
if (after64(msk->wnd_end, snd_nxt))
__mptcp_check_push(sk, ssk);

if (after64(new_snd_una, old_snd_una)) {
Expand All @@ -1071,7 +1071,7 @@ static void ack_update_msk(struct mptcp_sock *msk,

trace_ack_update_msk(mp_opt->data_ack,
old_snd_una, new_snd_una,
new_wnd_end, msk->wnd_end);
new_wnd_end, READ_ONCE(msk->wnd_end));
}

bool mptcp_update_rcv_data_fin(struct mptcp_sock *msk, u64 data_fin_seq, bool use_64bit)
Expand Down Expand Up @@ -1099,8 +1099,8 @@ static bool add_addr_hmac_valid(struct mptcp_sock *msk,
if (mp_opt->echo)
return true;

hmac = add_addr_generate_hmac(msk->remote_key,
msk->local_key,
hmac = add_addr_generate_hmac(READ_ONCE(msk->remote_key),
READ_ONCE(msk->local_key),
&mp_opt->addr);

pr_debug("msk=%p, ahmac=%llu, mp_opt->ahmac=%llu\n",
Expand Down Expand Up @@ -1147,7 +1147,7 @@ bool mptcp_incoming_options(struct sock *sk, struct sk_buff *skb)

if (unlikely(mp_opt.suboptions != OPTION_MPTCP_DSS)) {
if ((mp_opt.suboptions & OPTION_MPTCP_FASTCLOSE) &&
msk->local_key == mp_opt.rcvr_key) {
READ_ONCE(msk->local_key) == mp_opt.rcvr_key) {
WRITE_ONCE(msk->rcv_fastclose, true);
mptcp_schedule_work((struct sock *)msk);
MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPFASTCLOSERX);
Expand Down
2 changes: 1 addition & 1 deletion net/mptcp/pm.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void mptcp_pm_new_connection(struct mptcp_sock *msk, const struct sock *ssk, int
{
struct mptcp_pm_data *pm = &msk->pm;

pr_debug("msk=%p, token=%u side=%d", msk, msk->token, server_side);
pr_debug("msk=%p, token=%u side=%d", msk, READ_ONCE(msk->token), server_side);

WRITE_ONCE(pm->server_side, server_side);
mptcp_event(MPTCP_EVENT_CREATED, msk, ssk, GFP_ATOMIC);
Expand Down
10 changes: 5 additions & 5 deletions net/mptcp/pm_netlink.c
Original file line number Diff line number Diff line change
Expand Up @@ -1997,7 +1997,7 @@ static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
const struct mptcp_subflow_context *sf;
u8 sk_err;

if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)))
return -EMSGSIZE;

if (mptcp_event_add_subflow(skb, ssk))
Expand Down Expand Up @@ -2055,7 +2055,7 @@ static int mptcp_event_created(struct sk_buff *skb,
const struct mptcp_sock *msk,
const struct sock *ssk)
{
int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token));

if (err)
return err;
Expand Down Expand Up @@ -2083,7 +2083,7 @@ void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
if (!nlh)
goto nla_put_failure;

if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)))
goto nla_put_failure;

if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
Expand Down Expand Up @@ -2118,7 +2118,7 @@ void mptcp_event_addr_announced(const struct sock *ssk,
if (!nlh)
goto nla_put_failure;

if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)))
goto nla_put_failure;

if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
Expand Down Expand Up @@ -2234,7 +2234,7 @@ void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
goto nla_put_failure;
break;
case MPTCP_EVENT_CLOSED:
if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)) < 0)
goto nla_put_failure;
break;
case MPTCP_EVENT_ANNOUNCED:
Expand Down
52 changes: 27 additions & 25 deletions net/mptcp/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ static void mptcp_close_wake_up(struct sock *sk)
sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
}

/* called under the msk socket lock */
static bool mptcp_pending_data_fin_ack(struct sock *sk)
{
struct mptcp_sock *msk = mptcp_sk(sk);
Expand Down Expand Up @@ -441,16 +442,17 @@ static void mptcp_check_data_fin_ack(struct sock *sk)
}
}

/* can be called with no lock acquired */
static bool mptcp_pending_data_fin(struct sock *sk, u64 *seq)
{
struct mptcp_sock *msk = mptcp_sk(sk);

if (READ_ONCE(msk->rcv_data_fin) &&
((1 << sk->sk_state) &
((1 << inet_sk_state_load(sk)) &
(TCPF_ESTABLISHED | TCPF_FIN_WAIT1 | TCPF_FIN_WAIT2))) {
u64 rcv_data_fin_seq = READ_ONCE(msk->rcv_data_fin_seq);

if (msk->ack_seq == rcv_data_fin_seq) {
if (READ_ONCE(msk->ack_seq) == rcv_data_fin_seq) {
if (seq)
*seq = rcv_data_fin_seq;

Expand Down Expand Up @@ -748,7 +750,7 @@ static bool __mptcp_ofo_queue(struct mptcp_sock *msk)
__skb_queue_tail(&sk->sk_receive_queue, skb);
}
msk->bytes_received += end_seq - msk->ack_seq;
msk->ack_seq = end_seq;
WRITE_ONCE(msk->ack_seq, end_seq);
moved = true;
}
return moved;
Expand Down Expand Up @@ -985,6 +987,7 @@ static void dfrag_clear(struct sock *sk, struct mptcp_data_frag *dfrag)
put_page(dfrag->page);
}

/* called under both the msk socket lock and the data lock */
static void __mptcp_clean_una(struct sock *sk)
{
struct mptcp_sock *msk = mptcp_sk(sk);
Expand Down Expand Up @@ -1033,13 +1036,15 @@ static void __mptcp_clean_una(struct sock *sk)
msk->recovery = false;

out:
if (snd_una == READ_ONCE(msk->snd_nxt) &&
snd_una == READ_ONCE(msk->write_seq)) {
if (snd_una == msk->snd_nxt && snd_una == msk->write_seq) {
if (mptcp_rtx_timer_pending(sk) && !mptcp_data_fin_enabled(msk))
mptcp_stop_rtx_timer(sk);
} else {
mptcp_reset_rtx_timer(sk);
}

if (mptcp_pending_data_fin_ack(sk))
mptcp_schedule_work(sk);
}

static void __mptcp_clean_una_wakeup(struct sock *sk)
Expand Down Expand Up @@ -1499,7 +1504,7 @@ static void mptcp_update_post_push(struct mptcp_sock *msk,
*/
if (likely(after64(snd_nxt_new, msk->snd_nxt))) {
msk->bytes_sent += snd_nxt_new - msk->snd_nxt;
msk->snd_nxt = snd_nxt_new;
WRITE_ONCE(msk->snd_nxt, snd_nxt_new);
}
}

Expand Down Expand Up @@ -2108,7 +2113,7 @@ static unsigned int mptcp_inq_hint(const struct sock *sk)

skb = skb_peek(&msk->receive_queue);
if (skb) {
u64 hint_val = msk->ack_seq - MPTCP_SKB_CB(skb)->map_seq;
u64 hint_val = READ_ONCE(msk->ack_seq) - MPTCP_SKB_CB(skb)->map_seq;

if (hint_val >= INT_MAX)
return INT_MAX;
Expand Down Expand Up @@ -2752,7 +2757,7 @@ static void __mptcp_init_sock(struct sock *sk)
__skb_queue_head_init(&msk->receive_queue);
msk->out_of_order_queue = RB_ROOT;
msk->first_pending = NULL;
msk->rmem_fwd_alloc = 0;
WRITE_ONCE(msk->rmem_fwd_alloc, 0);
WRITE_ONCE(msk->rmem_released, 0);
msk->timer_ival = TCP_RTO_MIN;
msk->scaling_ratio = TCP_DEFAULT_SCALING_RATIO;
Expand Down Expand Up @@ -2968,7 +2973,7 @@ static void __mptcp_destroy_sock(struct sock *sk)

sk->sk_prot->destroy(sk);

WARN_ON_ONCE(msk->rmem_fwd_alloc);
WARN_ON_ONCE(READ_ONCE(msk->rmem_fwd_alloc));
WARN_ON_ONCE(msk->rmem_released);
sk_stream_kill_queues(sk);
xfrm_sk_free_policy(sk);
Expand Down Expand Up @@ -3144,16 +3149,16 @@ static int mptcp_disconnect(struct sock *sk, int flags)
msk->cb_flags = 0;
msk->push_pending = 0;
msk->recovery = false;
msk->can_ack = false;
msk->fully_established = false;
msk->rcv_data_fin = false;
msk->snd_data_fin_enable = false;
msk->rcv_fastclose = false;
msk->use_64bit_ack = false;
msk->bytes_consumed = 0;
WRITE_ONCE(msk->can_ack, false);
WRITE_ONCE(msk->fully_established, false);
WRITE_ONCE(msk->rcv_data_fin, false);
WRITE_ONCE(msk->snd_data_fin_enable, false);
WRITE_ONCE(msk->rcv_fastclose, false);
WRITE_ONCE(msk->use_64bit_ack, false);
WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk)));
mptcp_pm_data_reset(msk);
mptcp_ca_reset(sk);
msk->bytes_consumed = 0;
msk->bytes_acked = 0;
msk->bytes_received = 0;
msk->bytes_sent = 0;
Expand Down Expand Up @@ -3193,17 +3198,17 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk,
__mptcp_init_sock(nsk);

msk = mptcp_sk(nsk);
msk->local_key = subflow_req->local_key;
msk->token = subflow_req->token;
WRITE_ONCE(msk->local_key, subflow_req->local_key);
WRITE_ONCE(msk->token, subflow_req->token);
msk->in_accept_queue = 1;
WRITE_ONCE(msk->fully_established, false);
if (mp_opt->suboptions & OPTION_MPTCP_CSUMREQD)
WRITE_ONCE(msk->csum_enabled, true);

msk->write_seq = subflow_req->idsn + 1;
msk->snd_nxt = msk->write_seq;
msk->snd_una = msk->write_seq;
msk->wnd_end = msk->snd_nxt + req->rsk_rcv_wnd;
WRITE_ONCE(msk->write_seq, subflow_req->idsn + 1);
WRITE_ONCE(msk->snd_nxt, msk->write_seq);
WRITE_ONCE(msk->snd_una, msk->write_seq);
WRITE_ONCE(msk->wnd_end, msk->snd_nxt + req->rsk_rcv_wnd);
msk->setsockopt_seq = mptcp_sk(sk)->setsockopt_seq;
mptcp_init_sched(msk, mptcp_sk(sk)->sched);

Expand Down Expand Up @@ -3303,9 +3308,6 @@ void __mptcp_data_acked(struct sock *sk)
__mptcp_clean_una(sk);
else
__set_bit(MPTCP_CLEAN_UNA, &mptcp_sk(sk)->cb_flags);

if (mptcp_pending_data_fin_ack(sk))
mptcp_schedule_work(sk);
}

void __mptcp_check_push(struct sock *sk, struct sock *ssk)
Expand Down
8 changes: 5 additions & 3 deletions net/mptcp/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ struct mptcp_data_frag {
struct mptcp_sock {
/* inet_connection_sock must be the first member */
struct inet_connection_sock sk;
u64 local_key;
u64 remote_key;
u64 local_key; /* protected by the first subflow socket lock
* lockless access read
*/
u64 remote_key; /* same as above */
u64 write_seq;
u64 bytes_sent;
u64 snd_nxt;
Expand Down Expand Up @@ -400,7 +402,7 @@ static inline struct mptcp_data_frag *mptcp_rtx_head(struct sock *sk)
{
struct mptcp_sock *msk = mptcp_sk(sk);

if (msk->snd_una == READ_ONCE(msk->snd_nxt))
if (msk->snd_una == msk->snd_nxt)
return NULL;

return list_first_entry_or_null(&msk->rtx_queue, struct mptcp_data_frag, list);
Expand Down
2 changes: 1 addition & 1 deletion net/mptcp/sockopt.c
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ void mptcp_diag_fill_info(struct mptcp_sock *msk, struct mptcp_info *info)
mptcp_data_unlock(sk);

slow = lock_sock_fast(sk);
info->mptcpi_csum_enabled = msk->csum_enabled;
info->mptcpi_csum_enabled = READ_ONCE(msk->csum_enabled);
info->mptcpi_token = msk->token;
info->mptcpi_write_seq = msk->write_seq;
info->mptcpi_retransmits = inet_csk(sk)->icsk_retransmits;
Expand Down
10 changes: 6 additions & 4 deletions net/mptcp/subflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ static void subflow_req_create_thmac(struct mptcp_subflow_request_sock *subflow_

get_random_bytes(&subflow_req->local_nonce, sizeof(u32));

subflow_generate_hmac(msk->local_key, msk->remote_key,
subflow_generate_hmac(READ_ONCE(msk->local_key),
READ_ONCE(msk->remote_key),
subflow_req->local_nonce,
subflow_req->remote_nonce, hmac);

Expand Down Expand Up @@ -694,7 +695,8 @@ static bool subflow_hmac_valid(const struct request_sock *req,
if (!msk)
return false;

subflow_generate_hmac(msk->remote_key, msk->local_key,
subflow_generate_hmac(READ_ONCE(msk->remote_key),
READ_ONCE(msk->local_key),
subflow_req->remote_nonce,
subflow_req->local_nonce, hmac);

Expand Down Expand Up @@ -1530,8 +1532,8 @@ int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc,
mptcp_pm_get_flags_and_ifindex_by_id(msk, local_id,
&flags, &ifindex);
subflow->remote_key_valid = 1;
subflow->remote_key = msk->remote_key;
subflow->local_key = msk->local_key;
subflow->remote_key = READ_ONCE(msk->remote_key);
subflow->local_key = READ_ONCE(msk->local_key);
subflow->token = msk->token;
mptcp_info2sockaddr(loc, &addr, ssk->sk_family);

Expand Down

0 comments on commit c3b39ea

Please sign in to comment.