diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 7a0f7998376a5..98ac73938bd81 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -107,8 +107,8 @@ static void remote_address(const struct sock_common *skc, #endif } -static bool lookup_subflow_by_saddr(const struct list_head *list, - const struct mptcp_addr_info *saddr) +bool mptcp_lookup_subflow_by_saddr(const struct list_head *list, + const struct mptcp_addr_info *saddr) { struct mptcp_subflow_context *subflow; struct mptcp_addr_info cur; @@ -1447,8 +1447,8 @@ int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info) return ret; } -static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, - const struct mptcp_addr_info *addr) +bool mptcp_remove_anno_list_by_saddr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) { struct mptcp_pm_add_entry *entry; @@ -1476,7 +1476,7 @@ static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, list.ids[list.nr++] = mptcp_endp_get_local_id(msk, addr); - ret = remove_anno_list_by_saddr(msk, addr); + ret = mptcp_remove_anno_list_by_saddr(msk, addr); if (ret || force) { spin_lock_bh(&msk->pm.lock); if (ret) { @@ -1520,7 +1520,7 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, } lock_sock(sk); - remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); + remove_subflow = mptcp_lookup_subflow_by_saddr(&msk->conn_list, addr); mptcp_pm_remove_anno_addr(msk, addr, remove_subflow && !(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)); @@ -1633,36 +1633,6 @@ int mptcp_pm_nl_del_addr_doit(struct sk_buff *skb, struct genl_info *info) return ret; } -/* Called from the userspace PM only */ -void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list) -{ - struct mptcp_rm_list alist = { .nr = 0 }; - struct mptcp_pm_addr_entry *entry; - int anno_nr = 0; - - list_for_each_entry(entry, rm_list, list) { - if (alist.nr >= MPTCP_RM_IDS_MAX) - break; - - /* only delete if either announced or matching a subflow */ - if (remove_anno_list_by_saddr(msk, &entry->addr)) - anno_nr++; - else if (!lookup_subflow_by_saddr(&msk->conn_list, - &entry->addr)) - continue; - - alist.ids[alist.nr++] = entry->addr.id; - } - - if (alist.nr) { - spin_lock_bh(&msk->pm.lock); - msk->pm.add_addr_signaled -= anno_nr; - mptcp_pm_remove_addr(msk, &alist); - spin_unlock_bh(&msk->pm.lock); - } -} - -/* Called from the in-kernel PM only */ static void mptcp_pm_flush_addrs_and_subflows(struct mptcp_sock *msk, struct list_head *rm_list) { @@ -1671,11 +1641,11 @@ static void mptcp_pm_flush_addrs_and_subflows(struct mptcp_sock *msk, list_for_each_entry(entry, rm_list, list) { if (slist.nr < MPTCP_RM_IDS_MAX && - lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) + mptcp_lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) slist.ids[slist.nr++] = mptcp_endp_get_local_id(msk, &entry->addr); if (alist.nr < MPTCP_RM_IDS_MAX && - remove_anno_list_by_saddr(msk, &entry->addr)) + mptcp_remove_anno_list_by_saddr(msk, &entry->addr)) alist.ids[alist.nr++] = mptcp_endp_get_local_id(msk, &entry->addr); } diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c index e35178f5205fa..740a10d669f85 100644 --- a/net/mptcp/pm_userspace.c +++ b/net/mptcp/pm_userspace.c @@ -8,6 +8,10 @@ #include "mib.h" #include "mptcp_pm_gen.h" +#define mptcp_for_each_userspace_pm_addr(__msk, __entry) \ + list_for_each_entry(__entry, \ + &((__msk)->pm.userspace_pm_local_addr_list), list) + void mptcp_free_local_addr_list(struct mptcp_sock *msk) { struct mptcp_pm_addr_entry *entry, *tmp; @@ -26,6 +30,19 @@ void mptcp_free_local_addr_list(struct mptcp_sock *msk) } } +static struct mptcp_pm_addr_entry * +mptcp_userspace_pm_lookup_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + struct mptcp_pm_addr_entry *entry; + + mptcp_for_each_userspace_pm_addr(msk, entry) { + if (mptcp_addresses_equal(&entry->addr, addr, false)) + return entry; + } + return NULL; +} + static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, struct mptcp_pm_addr_entry *entry, bool needs_id) @@ -41,7 +58,7 @@ static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); spin_lock_bh(&msk->pm.lock); - list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) { + mptcp_for_each_userspace_pm_addr(msk, e) { addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true); if (addr_match && entry->addr.id == 0 && needs_id) entry->addr.id = e->addr.id; @@ -90,22 +107,20 @@ static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, static int mptcp_userspace_pm_delete_local_addr(struct mptcp_sock *msk, struct mptcp_pm_addr_entry *addr) { - struct mptcp_pm_addr_entry *entry, *tmp; struct sock *sk = (struct sock *)msk; + struct mptcp_pm_addr_entry *entry; - list_for_each_entry_safe(entry, tmp, &msk->pm.userspace_pm_local_addr_list, list) { - if (mptcp_addresses_equal(&entry->addr, &addr->addr, false)) { - /* TODO: a refcount is needed because the entry can - * be used multiple times (e.g. fullmesh mode). - */ - list_del_rcu(&entry->list); - sock_kfree_s(sk, entry, sizeof(*entry)); - msk->pm.local_addr_used--; - return 0; - } - } - - return -EINVAL; + entry = mptcp_userspace_pm_lookup_addr(msk, &addr->addr); + if (!entry) + return -EINVAL; + + /* TODO: a refcount is needed because the entry can + * be used multiple times (e.g. fullmesh mode). + */ + list_del_rcu(&entry->list); + sock_kfree_s(sk, entry, sizeof(*entry)); + msk->pm.local_addr_used--; + return 0; } static struct mptcp_pm_addr_entry * @@ -113,7 +128,7 @@ mptcp_userspace_pm_lookup_addr_by_id(struct mptcp_sock *msk, unsigned int id) { struct mptcp_pm_addr_entry *entry; - list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { + mptcp_for_each_userspace_pm_addr(msk, entry) { if (entry->addr.id == id) return entry; } @@ -123,17 +138,12 @@ mptcp_userspace_pm_lookup_addr_by_id(struct mptcp_sock *msk, unsigned int id) int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc) { - struct mptcp_pm_addr_entry *entry = NULL, *e, new_entry; + struct mptcp_pm_addr_entry *entry = NULL, new_entry; __be16 msk_sport = ((struct inet_sock *) inet_sk((struct sock *)msk))->inet_sport; spin_lock_bh(&msk->pm.lock); - list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) { - if (mptcp_addresses_equal(&e->addr, skc, false)) { - entry = e; - break; - } - } + entry = mptcp_userspace_pm_lookup_addr(msk, skc); spin_unlock_bh(&msk->pm.lock); if (entry) return entry->addr.id; @@ -153,50 +163,60 @@ bool mptcp_userspace_pm_is_backup(struct mptcp_sock *msk, struct mptcp_addr_info *skc) { struct mptcp_pm_addr_entry *entry; - bool backup = false; + bool backup; spin_lock_bh(&msk->pm.lock); - list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { - if (mptcp_addresses_equal(&entry->addr, skc, false)) { - backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); - break; - } - } + entry = mptcp_userspace_pm_lookup_addr(msk, skc); + backup = entry && !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); spin_unlock_bh(&msk->pm.lock); return backup; } -int mptcp_pm_nl_announce_doit(struct sk_buff *skb, struct genl_info *info) +static struct mptcp_sock *mptcp_userspace_pm_get_sock(const struct genl_info *info) { struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; + struct mptcp_sock *msk; + + if (!token) { + GENL_SET_ERR_MSG(info, "missing required token"); + return NULL; + } + + msk = mptcp_token_get_sock(genl_info_net(info), nla_get_u32(token)); + if (!msk) { + NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + return NULL; + } + + if (!mptcp_pm_is_userspace(msk)) { + GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); + sock_put((struct sock *)msk); + return NULL; + } + + return msk; +} + +int mptcp_pm_nl_announce_doit(struct sk_buff *skb, struct genl_info *info) +{ struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR]; struct mptcp_pm_addr_entry addr_val; struct mptcp_sock *msk; int err = -EINVAL; struct sock *sk; - u32 token_val; - if (!addr || !token) { - GENL_SET_ERR_MSG(info, "missing required inputs"); + if (!addr) { + GENL_SET_ERR_MSG(info, "missing required address"); return err; } - token_val = nla_get_u32(token); - - msk = mptcp_token_get_sock(sock_net(skb->sk), token_val); - if (!msk) { - NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) return err; - } sk = (struct sock *)msk; - if (!mptcp_pm_is_userspace(msk)) { - GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); - goto announce_err; - } - err = mptcp_pm_parse_entry(addr, info, true, &addr_val); if (err < 0) { GENL_SET_ERR_MSG(info, "error parsing local address"); @@ -267,40 +287,48 @@ static int mptcp_userspace_pm_remove_id_zero_address(struct mptcp_sock *msk, return err; } +void mptcp_pm_remove_addr_entry(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *entry) +{ + struct mptcp_rm_list alist = { .nr = 0 }; + int anno_nr = 0; + + /* only delete if either announced or matching a subflow */ + if (mptcp_remove_anno_list_by_saddr(msk, &entry->addr)) + anno_nr++; + else if (!mptcp_lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) + return; + + alist.ids[alist.nr++] = entry->addr.id; + + spin_lock_bh(&msk->pm.lock); + msk->pm.add_addr_signaled -= anno_nr; + mptcp_pm_remove_addr(msk, &alist); + spin_unlock_bh(&msk->pm.lock); +} + int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info) { - struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID]; struct mptcp_pm_addr_entry *match; - struct mptcp_pm_addr_entry *entry; struct mptcp_sock *msk; - LIST_HEAD(free_list); int err = -EINVAL; struct sock *sk; - u32 token_val; u8 id_val; - if (!id || !token) { - GENL_SET_ERR_MSG(info, "missing required inputs"); + if (!id) { + GENL_SET_ERR_MSG(info, "missing required ID"); return err; } id_val = nla_get_u8(id); - token_val = nla_get_u32(token); - msk = mptcp_token_get_sock(sock_net(skb->sk), token_val); - if (!msk) { - NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) return err; - } sk = (struct sock *)msk; - if (!mptcp_pm_is_userspace(msk)) { - GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); - goto out; - } - if (id_val == 0) { err = mptcp_userspace_pm_remove_id_zero_address(msk, info); goto out; @@ -317,16 +345,14 @@ int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info) goto out; } - list_move(&match->list, &free_list); + list_del_rcu(&match->list); spin_unlock_bh(&msk->pm.lock); - mptcp_pm_remove_addrs(msk, &free_list); + mptcp_pm_remove_addr_entry(msk, match); release_sock(sk); - list_for_each_entry_safe(match, entry, &free_list, list) { - sock_kfree_s(sk, match, sizeof(*match)); - } + sock_kfree_s(sk, match, sizeof(*match)); err = 0; out: @@ -337,7 +363,6 @@ int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info) int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info) { struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; - struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; struct mptcp_pm_addr_entry entry = { 0 }; struct mptcp_addr_info addr_r; @@ -345,28 +370,18 @@ int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info) struct mptcp_sock *msk; int err = -EINVAL; struct sock *sk; - u32 token_val; - if (!laddr || !raddr || !token) { - GENL_SET_ERR_MSG(info, "missing required inputs"); + if (!laddr || !raddr) { + GENL_SET_ERR_MSG(info, "missing required address(es)"); return err; } - token_val = nla_get_u32(token); - - msk = mptcp_token_get_sock(genl_info_net(info), token_val); - if (!msk) { - NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) return err; - } sk = (struct sock *)msk; - if (!mptcp_pm_is_userspace(msk)) { - GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); - goto create_err; - } - err = mptcp_pm_parse_entry(laddr, info, true, &entry); if (err < 0) { NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr"); @@ -469,36 +484,25 @@ static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk, int mptcp_pm_nl_subflow_destroy_doit(struct sk_buff *skb, struct genl_info *info) { struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; - struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; - struct mptcp_addr_info addr_l; + struct mptcp_pm_addr_entry addr_l; struct mptcp_addr_info addr_r; struct mptcp_sock *msk; struct sock *sk, *ssk; int err = -EINVAL; - u32 token_val; - if (!laddr || !raddr || !token) { - GENL_SET_ERR_MSG(info, "missing required inputs"); + if (!laddr || !raddr) { + GENL_SET_ERR_MSG(info, "missing required address(es)"); return err; } - token_val = nla_get_u32(token); - - msk = mptcp_token_get_sock(genl_info_net(info), token_val); - if (!msk) { - NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) return err; - } sk = (struct sock *)msk; - if (!mptcp_pm_is_userspace(msk)) { - GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); - goto destroy_err; - } - - err = mptcp_pm_parse_addr(laddr, info, &addr_l); + err = mptcp_pm_parse_entry(laddr, info, true, &addr_l); if (err < 0) { NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr"); goto destroy_err; @@ -511,43 +515,41 @@ int mptcp_pm_nl_subflow_destroy_doit(struct sk_buff *skb, struct genl_info *info } #if IS_ENABLED(CONFIG_MPTCP_IPV6) - if (addr_l.family == AF_INET && ipv6_addr_v4mapped(&addr_r.addr6)) { - ipv6_addr_set_v4mapped(addr_l.addr.s_addr, &addr_l.addr6); - addr_l.family = AF_INET6; + if (addr_l.addr.family == AF_INET && ipv6_addr_v4mapped(&addr_r.addr6)) { + ipv6_addr_set_v4mapped(addr_l.addr.addr.s_addr, &addr_l.addr.addr6); + addr_l.addr.family = AF_INET6; } - if (addr_r.family == AF_INET && ipv6_addr_v4mapped(&addr_l.addr6)) { - ipv6_addr_set_v4mapped(addr_r.addr.s_addr, &addr_r.addr6); + if (addr_r.family == AF_INET && ipv6_addr_v4mapped(&addr_l.addr.addr6)) { + ipv6_addr_set_v4mapped(addr_r.addr.s_addr, &addr_l.addr.addr6); addr_r.family = AF_INET6; } #endif - if (addr_l.family != addr_r.family) { + if (addr_l.addr.family != addr_r.family) { GENL_SET_ERR_MSG(info, "address families do not match"); err = -EINVAL; goto destroy_err; } - if (!addr_l.port || !addr_r.port) { + if (!addr_l.addr.port || !addr_r.port) { GENL_SET_ERR_MSG(info, "missing local or remote port"); err = -EINVAL; goto destroy_err; } lock_sock(sk); - ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r); - if (ssk) { - struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); - struct mptcp_pm_addr_entry entry = { .addr = addr_l }; - - spin_lock_bh(&msk->pm.lock); - mptcp_userspace_pm_delete_local_addr(msk, &entry); - spin_unlock_bh(&msk->pm.lock); - mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); - mptcp_close_ssk(sk, ssk, subflow); - MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); - err = 0; - } else { + ssk = mptcp_nl_find_ssk(msk, &addr_l.addr, &addr_r); + if (!ssk) { err = -ESRCH; + goto release_sock; } + + spin_lock_bh(&msk->pm.lock); + mptcp_userspace_pm_delete_local_addr(msk, &addr_l); + spin_unlock_bh(&msk->pm.lock); + mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); + mptcp_close_ssk(sk, ssk, mptcp_subflow_ctx(ssk)); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); +release_sock: release_sock(sk); destroy_err: @@ -560,31 +562,19 @@ int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info) struct mptcp_pm_addr_entry loc = { .addr = { .family = AF_UNSPEC }, }; struct mptcp_pm_addr_entry rem = { .addr = { .family = AF_UNSPEC }, }; struct nlattr *attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; - struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; - struct net *net = sock_net(skb->sk); struct mptcp_pm_addr_entry *entry; struct mptcp_sock *msk; int ret = -EINVAL; struct sock *sk; - u32 token_val; u8 bkup = 0; - token_val = nla_get_u32(token); - - msk = mptcp_token_get_sock(net, token_val); - if (!msk) { - NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) return ret; - } sk = (struct sock *)msk; - if (!mptcp_pm_is_userspace(msk)) { - GENL_SET_ERR_MSG(info, "userspace PM not selected"); - goto set_flags_err; - } - ret = mptcp_pm_parse_entry(attr, info, false, &loc); if (ret < 0) goto set_flags_err; @@ -606,13 +596,12 @@ int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info) bkup = 1; spin_lock_bh(&msk->pm.lock); - list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { - if (mptcp_addresses_equal(&entry->addr, &loc.addr, false)) { - if (bkup) - entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP; - else - entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; - } + entry = mptcp_userspace_pm_lookup_addr(msk, &loc.addr); + if (entry) { + if (bkup) + entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP; + else + entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; } spin_unlock_bh(&msk->pm.lock); @@ -632,33 +621,23 @@ int mptcp_userspace_pm_dump_addr(struct sk_buff *msg, DECLARE_BITMAP(map, MPTCP_PM_MAX_ADDR_ID + 1); } *bitmap; const struct genl_info *info = genl_info_dump(cb); - struct net *net = sock_net(msg->sk); struct mptcp_pm_addr_entry *entry; struct mptcp_sock *msk; - struct nlattr *token; int ret = -EINVAL; struct sock *sk; void *hdr; bitmap = (struct id_bitmap *)cb->ctx; - token = info->attrs[MPTCP_PM_ATTR_TOKEN]; - msk = mptcp_token_get_sock(net, nla_get_u32(token)); - if (!msk) { - NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) return ret; - } sk = (struct sock *)msk; - if (!mptcp_pm_is_userspace(msk)) { - GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); - goto out; - } - lock_sock(sk); spin_lock_bh(&msk->pm.lock); - list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { + mptcp_for_each_userspace_pm_addr(msk, entry) { if (test_bit(entry->addr.id, bitmap->map)) continue; @@ -680,7 +659,6 @@ int mptcp_userspace_pm_dump_addr(struct sk_buff *msg, release_sock(sk); ret = msg->len; -out: sock_put(sk); return ret; } @@ -689,28 +667,19 @@ int mptcp_userspace_pm_get_addr(struct sk_buff *skb, struct genl_info *info) { struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; - struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; struct mptcp_pm_addr_entry addr, *entry; - struct net *net = sock_net(skb->sk); struct mptcp_sock *msk; struct sk_buff *msg; int ret = -EINVAL; struct sock *sk; void *reply; - msk = mptcp_token_get_sock(net, nla_get_u32(token)); - if (!msk) { - NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) return ret; - } sk = (struct sock *)msk; - if (!mptcp_pm_is_userspace(msk)) { - GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); - goto out; - } - ret = mptcp_pm_parse_entry(attr, info, false, &addr); if (ret < 0) goto out; diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index a93e661ef5c43..cd5132fe7d220 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -1027,6 +1027,10 @@ mptcp_pm_del_add_timer(struct mptcp_sock *msk, struct mptcp_pm_add_entry * mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk, const struct mptcp_addr_info *addr); +bool mptcp_lookup_subflow_by_saddr(const struct list_head *list, + const struct mptcp_addr_info *saddr); +bool mptcp_remove_anno_list_by_saddr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr); int mptcp_pm_set_flags(struct sk_buff *skb, struct genl_info *info); int mptcp_pm_nl_set_flags(struct sk_buff *skb, struct genl_info *info); int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info); @@ -1034,7 +1038,8 @@ int mptcp_pm_announce_addr(struct mptcp_sock *msk, const struct mptcp_addr_info *addr, bool echo); int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list); -void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list); +void mptcp_pm_remove_addr_entry(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *entry); void mptcp_free_local_addr_list(struct mptcp_sock *msk);