diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c index c58dd78509a27..383ea8b91cc78 100644 --- a/net/ipv4/ipmr.c +++ b/net/ipv4/ipmr.c @@ -120,6 +120,11 @@ static void ipmr_expire_process(struct timer_list *t); lockdep_rtnl_is_held() || \ list_empty(&net->ipv4.mr_tables)) +static bool ipmr_can_free_table(struct net *net) +{ + return !check_net(net) || !net->ipv4.mr_rules_ops; +} + static struct mr_table *ipmr_mr_table_iter(struct net *net, struct mr_table *mrt) { @@ -137,7 +142,7 @@ static struct mr_table *ipmr_mr_table_iter(struct net *net, return ret; } -static struct mr_table *ipmr_get_table(struct net *net, u32 id) +static struct mr_table *__ipmr_get_table(struct net *net, u32 id) { struct mr_table *mrt; @@ -148,6 +153,16 @@ static struct mr_table *ipmr_get_table(struct net *net, u32 id) return NULL; } +static struct mr_table *ipmr_get_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + rcu_read_lock(); + mrt = __ipmr_get_table(net, id); + rcu_read_unlock(); + return mrt; +} + static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4, struct mr_table **mrt) { @@ -189,7 +204,7 @@ static int ipmr_rule_action(struct fib_rule *rule, struct flowi *flp, arg->table = fib_rule_get_table(rule, arg); - mrt = ipmr_get_table(rule->fr_net, arg->table); + mrt = __ipmr_get_table(rule->fr_net, arg->table); if (!mrt) return -EAGAIN; res->mrt = mrt; @@ -302,6 +317,11 @@ EXPORT_SYMBOL(ipmr_rule_default); #define ipmr_for_each_table(mrt, net) \ for (mrt = net->ipv4.mrt; mrt; mrt = NULL) +static bool ipmr_can_free_table(struct net *net) +{ + return !check_net(net); +} + static struct mr_table *ipmr_mr_table_iter(struct net *net, struct mr_table *mrt) { @@ -315,6 +335,8 @@ static struct mr_table *ipmr_get_table(struct net *net, u32 id) return net->ipv4.mrt; } +#define __ipmr_get_table ipmr_get_table + static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4, struct mr_table **mrt) { @@ -403,7 +425,7 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id) if (id != RT_TABLE_DEFAULT && id >= 1000000000) return ERR_PTR(-EINVAL); - mrt = ipmr_get_table(net, id); + mrt = __ipmr_get_table(net, id); if (mrt) return mrt; @@ -413,6 +435,10 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id) static void ipmr_free_table(struct mr_table *mrt) { + struct net *net = read_pnet(&mrt->net); + + DEBUG_NET_WARN_ON_ONCE(!ipmr_can_free_table(net)); + timer_shutdown_sync(&mrt->ipmr_expire_timer); mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC | MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC); @@ -1374,7 +1400,7 @@ int ip_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval, goto out_unlock; } - mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); + mrt = __ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); if (!mrt) { ret = -ENOENT; goto out_unlock; @@ -2262,11 +2288,13 @@ int ipmr_get_route(struct net *net, struct sk_buff *skb, struct mr_table *mrt; int err; - mrt = ipmr_get_table(net, RT_TABLE_DEFAULT); - if (!mrt) + rcu_read_lock(); + mrt = __ipmr_get_table(net, RT_TABLE_DEFAULT); + if (!mrt) { + rcu_read_unlock(); return -ENOENT; + } - rcu_read_lock(); cache = ipmr_cache_find(mrt, saddr, daddr); if (!cache && skb->dev) { int vif = ipmr_find_vif(mrt, skb->dev); @@ -2550,7 +2578,7 @@ static int ipmr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, grp = nla_get_in_addr_default(tb[RTA_DST], 0); tableid = nla_get_u32_default(tb[RTA_TABLE], 0); - mrt = ipmr_get_table(net, tableid ? tableid : RT_TABLE_DEFAULT); + mrt = __ipmr_get_table(net, tableid ? tableid : RT_TABLE_DEFAULT); if (!mrt) { err = -ENOENT; goto errout_free; @@ -2604,7 +2632,7 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) if (filter.table_id) { struct mr_table *mrt; - mrt = ipmr_get_table(sock_net(skb->sk), filter.table_id); + mrt = __ipmr_get_table(sock_net(skb->sk), filter.table_id); if (!mrt) { if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IPMR) return skb->len; @@ -2712,7 +2740,7 @@ static int rtm_to_ipmr_mfcc(struct net *net, struct nlmsghdr *nlh, break; } } - mrt = ipmr_get_table(net, tblid); + mrt = __ipmr_get_table(net, tblid); if (!mrt) { ret = -ENOENT; goto out; @@ -2920,13 +2948,15 @@ static void *ipmr_vif_seq_start(struct seq_file *seq, loff_t *pos) struct net *net = seq_file_net(seq); struct mr_table *mrt; - mrt = ipmr_get_table(net, RT_TABLE_DEFAULT); - if (!mrt) + rcu_read_lock(); + mrt = __ipmr_get_table(net, RT_TABLE_DEFAULT); + if (!mrt) { + rcu_read_unlock(); return ERR_PTR(-ENOENT); + } iter->mrt = mrt; - rcu_read_lock(); return mr_vif_seq_start(seq, pos); } diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c index d66f58932a793..4147890fe98ff 100644 --- a/net/ipv6/ip6mr.c +++ b/net/ipv6/ip6mr.c @@ -108,6 +108,11 @@ static void ipmr_expire_process(struct timer_list *t); lockdep_rtnl_is_held() || \ list_empty(&net->ipv6.mr6_tables)) +static bool ip6mr_can_free_table(struct net *net) +{ + return !check_net(net) || !net->ipv6.mr6_rules_ops; +} + static struct mr_table *ip6mr_mr_table_iter(struct net *net, struct mr_table *mrt) { @@ -125,7 +130,7 @@ static struct mr_table *ip6mr_mr_table_iter(struct net *net, return ret; } -static struct mr_table *ip6mr_get_table(struct net *net, u32 id) +static struct mr_table *__ip6mr_get_table(struct net *net, u32 id) { struct mr_table *mrt; @@ -136,6 +141,16 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id) return NULL; } +static struct mr_table *ip6mr_get_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + rcu_read_lock(); + mrt = __ip6mr_get_table(net, id); + rcu_read_unlock(); + return mrt; +} + static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6, struct mr_table **mrt) { @@ -177,7 +192,7 @@ static int ip6mr_rule_action(struct fib_rule *rule, struct flowi *flp, arg->table = fib_rule_get_table(rule, arg); - mrt = ip6mr_get_table(rule->fr_net, arg->table); + mrt = __ip6mr_get_table(rule->fr_net, arg->table); if (!mrt) return -EAGAIN; res->mrt = mrt; @@ -291,6 +306,11 @@ EXPORT_SYMBOL(ip6mr_rule_default); #define ip6mr_for_each_table(mrt, net) \ for (mrt = net->ipv6.mrt6; mrt; mrt = NULL) +static bool ip6mr_can_free_table(struct net *net) +{ + return !check_net(net); +} + static struct mr_table *ip6mr_mr_table_iter(struct net *net, struct mr_table *mrt) { @@ -304,6 +324,8 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id) return net->ipv6.mrt6; } +#define __ip6mr_get_table ip6mr_get_table + static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6, struct mr_table **mrt) { @@ -382,7 +404,7 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id) { struct mr_table *mrt; - mrt = ip6mr_get_table(net, id); + mrt = __ip6mr_get_table(net, id); if (mrt) return mrt; @@ -392,6 +414,10 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id) static void ip6mr_free_table(struct mr_table *mrt) { + struct net *net = read_pnet(&mrt->net); + + DEBUG_NET_WARN_ON_ONCE(!ip6mr_can_free_table(net)); + timer_shutdown_sync(&mrt->ipmr_expire_timer); mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC | MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC); @@ -411,13 +437,15 @@ static void *ip6mr_vif_seq_start(struct seq_file *seq, loff_t *pos) struct net *net = seq_file_net(seq); struct mr_table *mrt; - mrt = ip6mr_get_table(net, RT6_TABLE_DFLT); - if (!mrt) + rcu_read_lock(); + mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT); + if (!mrt) { + rcu_read_unlock(); return ERR_PTR(-ENOENT); + } iter->mrt = mrt; - rcu_read_lock(); return mr_vif_seq_start(seq, pos); } @@ -2278,11 +2306,13 @@ int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm, struct mfc6_cache *cache; struct rt6_info *rt = dst_rt6_info(skb_dst(skb)); - mrt = ip6mr_get_table(net, RT6_TABLE_DFLT); - if (!mrt) + rcu_read_lock(); + mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT); + if (!mrt) { + rcu_read_unlock(); return -ENOENT; + } - rcu_read_lock(); cache = ip6mr_cache_find(mrt, &rt->rt6i_src.addr, &rt->rt6i_dst.addr); if (!cache && skb->dev) { int vif = ip6mr_find_vif(mrt, skb->dev); @@ -2562,7 +2592,7 @@ static int ip6mr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, grp = nla_get_in6_addr(tb[RTA_DST]); tableid = nla_get_u32_default(tb[RTA_TABLE], 0); - mrt = ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT); + mrt = __ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT); if (!mrt) { NL_SET_ERR_MSG_MOD(extack, "MR table does not exist"); return -ENOENT; @@ -2609,7 +2639,7 @@ static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) if (filter.table_id) { struct mr_table *mrt; - mrt = ip6mr_get_table(sock_net(skb->sk), filter.table_id); + mrt = __ip6mr_get_table(sock_net(skb->sk), filter.table_id); if (!mrt) { if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IP6MR) return skb->len;