Skip to content

Commit

Permalink
netlink: Add compare function for netlink_table
Browse files Browse the repository at this point in the history
As we know, netlink sockets are private resource of
net namespace, they can communicate with each other
only when they in the same net namespace. this works
well until we try to add namespace support for other
subsystems which use netlink.

Don't like ipv4 and route table.., it is not suited to
make these subsytems belong to net namespace, Such as
audit and crypto subsystems,they are more suitable to
user namespace.

So we must have the ability to make the netlink sockets
in same user namespace can communicate with each other.

This patch adds a new function pointer "compare" for
netlink_table, we can decide if the netlink sockets can
communicate with each other through this netlink_table
self-defined compare function.

The behavior isn't changed if we don't provide the compare
function for netlink_table.

Signed-off-by: Gao feng <gaofeng@cn.fujitsu.com>
Acked-by: Serge E. Hallyn <serge.hallyn@ubuntu.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
Gao feng authored and David S. Miller committed Jun 11, 2013
1 parent 8249152 commit da12c90
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
1 change: 1 addition & 0 deletions include/linux/netlink.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ struct netlink_kernel_cfg {
void (*input)(struct sk_buff *skb);
struct mutex *cb_mutex;
void (*bind)(int group);
bool (*compare)(struct net *net, struct sock *sk);
};

extern struct sock *__netlink_kernel_create(struct net *net, int unit,
Expand Down
33 changes: 25 additions & 8 deletions net/netlink/af_netlink.c
Original file line number Diff line number Diff line change
Expand Up @@ -858,16 +858,23 @@ netlink_unlock_table(void)
wake_up(&nl_table_wait);
}

static bool netlink_compare(struct net *net, struct sock *sk)
{
return net_eq(sock_net(sk), net);
}

static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
{
struct nl_portid_hash *hash = &nl_table[protocol].hash;
struct netlink_table *table = &nl_table[protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head;
struct sock *sk;

read_lock(&nl_table_lock);
head = nl_portid_hashfn(hash, portid);
sk_for_each(sk, head) {
if (net_eq(sock_net(sk), net) && (nlk_sk(sk)->portid == portid)) {
if (table->compare(net, sk) &&
(nlk_sk(sk)->portid == portid)) {
sock_hold(sk);
goto found;
}
Expand Down Expand Up @@ -980,7 +987,8 @@ netlink_update_listeners(struct sock *sk)

static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
{
struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash;
struct netlink_table *table = &nl_table[sk->sk_protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head;
int err = -EADDRINUSE;
struct sock *osk;
Expand All @@ -990,7 +998,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
head = nl_portid_hashfn(hash, portid);
len = 0;
sk_for_each(osk, head) {
if (net_eq(sock_net(osk), net) && (nlk_sk(osk)->portid == portid))
if (table->compare(net, osk) &&
(nlk_sk(osk)->portid == portid))
break;
len++;
}
Expand Down Expand Up @@ -1165,6 +1174,7 @@ static int netlink_release(struct socket *sock)
kfree_rcu(old, rcu);
nl_table[sk->sk_protocol].module = NULL;
nl_table[sk->sk_protocol].bind = NULL;
nl_table[sk->sk_protocol].compare = NULL;
nl_table[sk->sk_protocol].flags = 0;
nl_table[sk->sk_protocol].registered = 0;
}
Expand All @@ -1187,7 +1197,8 @@ static int netlink_autobind(struct socket *sock)
{
struct sock *sk = sock->sk;
struct net *net = sock_net(sk);
struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash;
struct netlink_table *table = &nl_table[sk->sk_protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head;
struct sock *osk;
s32 portid = task_tgid_vnr(current);
Expand All @@ -1199,7 +1210,7 @@ static int netlink_autobind(struct socket *sock)
netlink_table_grab();
head = nl_portid_hashfn(hash, portid);
sk_for_each(osk, head) {
if (!net_eq(sock_net(osk), net))
if (!table->compare(net, osk))
continue;
if (nlk_sk(osk)->portid == portid) {
/* Bind collision, search negative portid values. */
Expand Down Expand Up @@ -2315,9 +2326,12 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
rcu_assign_pointer(nl_table[unit].listeners, listeners);
nl_table[unit].cb_mutex = cb_mutex;
nl_table[unit].module = module;
nl_table[unit].compare = netlink_compare;
if (cfg) {
nl_table[unit].bind = cfg->bind;
nl_table[unit].flags = cfg->flags;
if (cfg->compare)
nl_table[unit].compare = cfg->compare;
}
nl_table[unit].registered = 1;
} else {
Expand Down Expand Up @@ -2740,18 +2754,20 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
{
struct sock *s;
struct nl_seq_iter *iter;
struct net *net;
int i, j;

++*pos;

if (v == SEQ_START_TOKEN)
return netlink_seq_socket_idx(seq, 0);

net = seq_file_net(seq);
iter = seq->private;
s = v;
do {
s = sk_next(s);
} while (s && sock_net(s) != seq_file_net(seq));
} while (s && !nl_table[s->sk_protocol].compare(net, s));
if (s)
return s;

Expand All @@ -2763,7 +2779,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)

for (; j <= hash->mask; j++) {
s = sk_head(&hash->table[j]);
while (s && sock_net(s) != seq_file_net(seq))

while (s && !nl_table[s->sk_protocol].compare(net, s))
s = sk_next(s);
if (s) {
iter->link = i;
Expand Down
1 change: 1 addition & 0 deletions net/netlink/af_netlink.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct netlink_table {
struct mutex *cb_mutex;
struct module *module;
void (*bind)(int group);
bool (*compare)(struct net *net, struct sock *sock);
int registered;
};

Expand Down

0 comments on commit da12c90

Please sign in to comment.