Skip to content

Commit

Permalink
net: rtnetlink: use rcu to free rtnl message handlers
Browse files Browse the repository at this point in the history
rtnetlink is littered with READ_ONCE() because we can have read accesses
while another cpu can write to the structure we're reading by
(un)registering doit or dumpit handlers.

This patch changes this so that (un)registering cpu allocates a new
structure and then publishes it via rcu_assign_pointer, i.e. once
another cpu can see such pointer no modifications will occur anymore.

based on initial patch from Peter Zijlstra.

Cc: Peter Zijlstra <peterz@infradead.org>
Signed-off-by: Florian Westphal <fw@strlen.de>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
Florian Westphal authored and David S. Miller committed Dec 4, 2017
1 parent 9753c21 commit addf9b9
Showing 1 changed file with 101 additions and 53 deletions.
154 changes: 101 additions & 53 deletions net/core/rtnetlink.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct rtnl_link {
rtnl_doit_func doit;
rtnl_dumpit_func dumpit;
unsigned int flags;
struct rcu_head rcu;
};

static DEFINE_MUTEX(rtnl_mutex);
Expand Down Expand Up @@ -127,7 +128,7 @@ bool lockdep_rtnl_is_held(void)
EXPORT_SYMBOL(lockdep_rtnl_is_held);
#endif /* #ifdef CONFIG_PROVE_LOCKING */

static struct rtnl_link __rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
static struct rtnl_link __rcu **rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];

static inline int rtm_msgindex(int msgtype)
Expand All @@ -144,6 +145,20 @@ static inline int rtm_msgindex(int msgtype)
return msgindex;
}

static struct rtnl_link *rtnl_get_link(int protocol, int msgtype)
{
struct rtnl_link **tab;

if (protocol >= ARRAY_SIZE(rtnl_msg_handlers))
protocol = PF_UNSPEC;

tab = rcu_dereference_rtnl(rtnl_msg_handlers[protocol]);
if (!tab)
tab = rcu_dereference_rtnl(rtnl_msg_handlers[PF_UNSPEC]);

return tab[msgtype];
}

/**
* __rtnl_register - Register a rtnetlink message type
* @protocol: Protocol family or PF_UNSPEC
Expand All @@ -166,28 +181,52 @@ int __rtnl_register(int protocol, int msgtype,
rtnl_doit_func doit, rtnl_dumpit_func dumpit,
unsigned int flags)
{
struct rtnl_link *tab;
struct rtnl_link **tab, *link, *old;
int msgindex;
int ret = -ENOBUFS;

BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
msgindex = rtm_msgindex(msgtype);

tab = rcu_dereference_raw(rtnl_msg_handlers[protocol]);
rtnl_lock();
tab = rtnl_msg_handlers[protocol];
if (tab == NULL) {
tab = kcalloc(RTM_NR_MSGTYPES, sizeof(*tab), GFP_KERNEL);
if (tab == NULL)
return -ENOBUFS;
tab = kcalloc(RTM_NR_MSGTYPES, sizeof(void *), GFP_KERNEL);
if (!tab)
goto unlock;

/* ensures we see the 0 stores */
rcu_assign_pointer(rtnl_msg_handlers[protocol], tab);
}

old = rtnl_dereference(tab[msgindex]);
if (old) {
link = kmemdup(old, sizeof(*old), GFP_KERNEL);
if (!link)
goto unlock;
} else {
link = kzalloc(sizeof(*link), GFP_KERNEL);
if (!link)
goto unlock;
}

WARN_ON(doit && link->doit && link->doit != doit);
if (doit)
tab[msgindex].doit = doit;
link->doit = doit;
WARN_ON(dumpit && link->dumpit && link->dumpit != dumpit);
if (dumpit)
tab[msgindex].dumpit = dumpit;
tab[msgindex].flags |= flags;
link->dumpit = dumpit;

return 0;
link->flags |= flags;

/* publish protocol:msgtype */
rcu_assign_pointer(tab[msgindex], link);
ret = 0;
if (old)
kfree_rcu(old, rcu);
unlock:
rtnl_unlock();
return ret;
}
EXPORT_SYMBOL_GPL(__rtnl_register);

Expand Down Expand Up @@ -220,24 +259,25 @@ EXPORT_SYMBOL_GPL(rtnl_register);
*/
int rtnl_unregister(int protocol, int msgtype)
{
struct rtnl_link *handlers;
struct rtnl_link **tab, *link;
int msgindex;

BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
msgindex = rtm_msgindex(msgtype);

rtnl_lock();
handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
if (!handlers) {
tab = rtnl_dereference(rtnl_msg_handlers[protocol]);
if (!tab) {
rtnl_unlock();
return -ENOENT;
}

handlers[msgindex].doit = NULL;
handlers[msgindex].dumpit = NULL;
handlers[msgindex].flags = 0;
link = tab[msgindex];
rcu_assign_pointer(tab[msgindex], NULL);
rtnl_unlock();

kfree_rcu(link, rcu);

return 0;
}
EXPORT_SYMBOL_GPL(rtnl_unregister);
Expand All @@ -251,20 +291,29 @@ EXPORT_SYMBOL_GPL(rtnl_unregister);
*/
void rtnl_unregister_all(int protocol)
{
struct rtnl_link *handlers;
struct rtnl_link **tab, *link;
int msgindex;

BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);

rtnl_lock();
handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
tab = rtnl_msg_handlers[protocol];
RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL);
for (msgindex = 0; msgindex < RTM_NR_MSGTYPES; msgindex++) {
link = tab[msgindex];
if (!link)
continue;

rcu_assign_pointer(tab[msgindex], NULL);
kfree_rcu(link, rcu);
}
rtnl_unlock();

synchronize_net();

while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 1)
schedule();
kfree(handlers);
kfree(tab);
}
EXPORT_SYMBOL_GPL(rtnl_unregister_all);

Expand Down Expand Up @@ -2973,18 +3022,26 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb)
s_idx = 1;

for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) {
struct rtnl_link **tab;
int type = cb->nlh->nlmsg_type-RTM_BASE;
struct rtnl_link *handlers;
struct rtnl_link *link;
rtnl_dumpit_func dumpit;

if (idx < s_idx || idx == PF_PACKET)
continue;

handlers = rtnl_dereference(rtnl_msg_handlers[idx]);
if (!handlers)
if (type < 0 || type >= RTM_NR_MSGTYPES)
continue;

dumpit = READ_ONCE(handlers[type].dumpit);
tab = rcu_dereference_rtnl(rtnl_msg_handlers[idx]);
if (!tab)
continue;

link = tab[type];
if (!link)
continue;

dumpit = link->dumpit;
if (!dumpit)
continue;

Expand Down Expand Up @@ -4314,7 +4371,7 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
struct netlink_ext_ack *extack)
{
struct net *net = sock_net(skb->sk);
struct rtnl_link *handlers;
struct rtnl_link *link;
int err = -EOPNOTSUPP;
rtnl_doit_func doit;
unsigned int flags;
Expand All @@ -4338,32 +4395,20 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
if (kind != 2 && !netlink_net_capable(skb, CAP_NET_ADMIN))
return -EPERM;

if (family >= ARRAY_SIZE(rtnl_msg_handlers))
family = PF_UNSPEC;

rcu_read_lock();
handlers = rcu_dereference(rtnl_msg_handlers[family]);
if (!handlers) {
family = PF_UNSPEC;
handlers = rcu_dereference(rtnl_msg_handlers[family]);
}

if (kind == 2 && nlh->nlmsg_flags&NLM_F_DUMP) {
struct sock *rtnl;
rtnl_dumpit_func dumpit;
u16 min_dump_alloc = 0;

dumpit = READ_ONCE(handlers[type].dumpit);
if (!dumpit) {
link = rtnl_get_link(family, type);
if (!link || !link->dumpit) {
family = PF_UNSPEC;
handlers = rcu_dereference(rtnl_msg_handlers[PF_UNSPEC]);
if (!handlers)
goto err_unlock;

dumpit = READ_ONCE(handlers[type].dumpit);
if (!dumpit)
link = rtnl_get_link(family, type);
if (!link || !link->dumpit)
goto err_unlock;
}
dumpit = link->dumpit;

refcount_inc(&rtnl_msg_handlers_ref[family]);

Expand All @@ -4384,33 +4429,36 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
return err;
}

doit = READ_ONCE(handlers[type].doit);
if (!doit) {
link = rtnl_get_link(family, type);
if (!link || !link->doit) {
family = PF_UNSPEC;
handlers = rcu_dereference(rtnl_msg_handlers[family]);
link = rtnl_get_link(PF_UNSPEC, type);
if (!link || !link->doit)
goto out_unlock;
}

flags = READ_ONCE(handlers[type].flags);
flags = link->flags;
if (flags & RTNL_FLAG_DOIT_UNLOCKED) {
refcount_inc(&rtnl_msg_handlers_ref[family]);
doit = READ_ONCE(handlers[type].doit);
doit = link->doit;
rcu_read_unlock();
if (doit)
err = doit(skb, nlh, extack);
refcount_dec(&rtnl_msg_handlers_ref[family]);
return err;
}

rcu_read_unlock();

rtnl_lock();
handlers = rtnl_dereference(rtnl_msg_handlers[family]);
if (handlers) {
doit = READ_ONCE(handlers[type].doit);
if (doit)
err = doit(skb, nlh, extack);
}
link = rtnl_get_link(family, type);
if (link && link->doit)
err = link->doit(skb, nlh, extack);
rtnl_unlock();

return err;

out_unlock:
rcu_read_unlock();
return err;

err_unlock:
Expand Down

0 comments on commit addf9b9

Please sign in to comment.