From efd402537673f9951992aea4ef0f5ff51d858f4b Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:25:55 +0000
Subject: [PATCH 1/9] sock_diag: annotate data-races around
 sock_diag_handlers[family]

__sock_diag_cmd() and sock_diag_bind() read sock_diag_handlers[family]
without a lock held.

Use READ_ONCE()/WRITE_ONCE() annotations to avoid potential issues.

Fixes: 8ef874bfc729 ("sock_diag: Move the sock_ code to net/core/")
Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/core/sock_diag.c | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c
index b1e29e18d1d60..c53b731f2d672 100644
--- a/net/core/sock_diag.c
+++ b/net/core/sock_diag.c
@@ -193,7 +193,7 @@ int sock_diag_register(const struct sock_diag_handler *hndl)
 	if (sock_diag_handlers[hndl->family])
 		err = -EBUSY;
 	else
-		sock_diag_handlers[hndl->family] = hndl;
+		WRITE_ONCE(sock_diag_handlers[hndl->family], hndl);
 	mutex_unlock(&sock_diag_table_mutex);
 
 	return err;
@@ -209,7 +209,7 @@ void sock_diag_unregister(const struct sock_diag_handler *hnld)
 
 	mutex_lock(&sock_diag_table_mutex);
 	BUG_ON(sock_diag_handlers[family] != hnld);
-	sock_diag_handlers[family] = NULL;
+	WRITE_ONCE(sock_diag_handlers[family], NULL);
 	mutex_unlock(&sock_diag_table_mutex);
 }
 EXPORT_SYMBOL_GPL(sock_diag_unregister);
@@ -227,7 +227,7 @@ static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh)
 		return -EINVAL;
 	req->sdiag_family = array_index_nospec(req->sdiag_family, AF_MAX);
 
-	if (sock_diag_handlers[req->sdiag_family] == NULL)
+	if (READ_ONCE(sock_diag_handlers[req->sdiag_family]) == NULL)
 		sock_load_diag_module(req->sdiag_family, 0);
 
 	mutex_lock(&sock_diag_table_mutex);
@@ -286,12 +286,12 @@ static int sock_diag_bind(struct net *net, int group)
 	switch (group) {
 	case SKNLGRP_INET_TCP_DESTROY:
 	case SKNLGRP_INET_UDP_DESTROY:
-		if (!sock_diag_handlers[AF_INET])
+		if (!READ_ONCE(sock_diag_handlers[AF_INET]))
 			sock_load_diag_module(AF_INET, 0);
 		break;
 	case SKNLGRP_INET6_TCP_DESTROY:
 	case SKNLGRP_INET6_UDP_DESTROY:
-		if (!sock_diag_handlers[AF_INET6])
+		if (!READ_ONCE(sock_diag_handlers[AF_INET6]))
 			sock_load_diag_module(AF_INET6, 0);
 		break;
 	}

From e50e10ae5d81ddb41547114bfdc5edc04422f082 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:25:56 +0000
Subject: [PATCH 2/9] inet_diag: annotate data-races around inet_diag_table[]

inet_diag_lock_handler() reads inet_diag_table[proto] locklessly.

Use READ_ONCE()/WRITE_ONCE() annotations to avoid potential issues.

Fixes: d523a328fb02 ("[INET]: Fix inet_diag dead-lock regression")
Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/ipv4/inet_diag.c | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index 8e6b6aa0579e1..9804e9608a5a0 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -57,7 +57,7 @@ static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
 		return ERR_PTR(-ENOENT);
 	}
 
-	if (!inet_diag_table[proto])
+	if (!READ_ONCE(inet_diag_table[proto]))
 		sock_load_diag_module(AF_INET, proto);
 
 	mutex_lock(&inet_diag_table_mutex);
@@ -1503,7 +1503,7 @@ int inet_diag_register(const struct inet_diag_handler *h)
 	mutex_lock(&inet_diag_table_mutex);
 	err = -EEXIST;
 	if (!inet_diag_table[type]) {
-		inet_diag_table[type] = h;
+		WRITE_ONCE(inet_diag_table[type], h);
 		err = 0;
 	}
 	mutex_unlock(&inet_diag_table_mutex);
@@ -1520,7 +1520,7 @@ void inet_diag_unregister(const struct inet_diag_handler *h)
 		return;
 
 	mutex_lock(&inet_diag_table_mutex);
-	inet_diag_table[type] = NULL;
+	WRITE_ONCE(inet_diag_table[type], NULL);
 	mutex_unlock(&inet_diag_table_mutex);
 }
 EXPORT_SYMBOL_GPL(inet_diag_unregister);

From db5914695a84a7b128ec2e4e9272e6e8091753e1 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:25:57 +0000
Subject: [PATCH 3/9] inet_diag: add module pointer to "struct
 inet_diag_handler"

Following patch is going to use RCU instead of
inet_diag_table_mutex acquisition.

This patch is a preparation, no change of behavior yet.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 include/linux/inet_diag.h | 1 +
 net/dccp/diag.c           | 1 +
 net/ipv4/raw_diag.c       | 1 +
 net/ipv4/tcp_diag.c       | 1 +
 net/ipv4/udp_diag.c       | 2 ++
 net/mptcp/mptcp_diag.c    | 1 +
 net/sctp/diag.c           | 1 +
 7 files changed, 8 insertions(+)

diff --git a/include/linux/inet_diag.h b/include/linux/inet_diag.h
index 84abb30a3fbb1..a9033696b0aad 100644
--- a/include/linux/inet_diag.h
+++ b/include/linux/inet_diag.h
@@ -8,6 +8,7 @@
 struct inet_hashinfo;
 
 struct inet_diag_handler {
+	struct module	*owner;
 	void		(*dump)(struct sk_buff *skb,
 				struct netlink_callback *cb,
 				const struct inet_diag_req_v2 *r);
diff --git a/net/dccp/diag.c b/net/dccp/diag.c
index 8a82c5a2c5a8c..f5019d95c3ae5 100644
--- a/net/dccp/diag.c
+++ b/net/dccp/diag.c
@@ -58,6 +58,7 @@ static int dccp_diag_dump_one(struct netlink_callback *cb,
 }
 
 static const struct inet_diag_handler dccp_diag_handler = {
+	.owner		 = THIS_MODULE,
 	.dump		 = dccp_diag_dump,
 	.dump_one	 = dccp_diag_dump_one,
 	.idiag_get_info	 = dccp_diag_get_info,
diff --git a/net/ipv4/raw_diag.c b/net/ipv4/raw_diag.c
index fe2140c8375c8..cc793bd8de258 100644
--- a/net/ipv4/raw_diag.c
+++ b/net/ipv4/raw_diag.c
@@ -213,6 +213,7 @@ static int raw_diag_destroy(struct sk_buff *in_skb,
 #endif
 
 static const struct inet_diag_handler raw_diag_handler = {
+	.owner			= THIS_MODULE,
 	.dump			= raw_diag_dump,
 	.dump_one		= raw_diag_dump_one,
 	.idiag_get_info		= raw_diag_get_info,
diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c
index 4cbe4b44425a6..f428ecf9120f2 100644
--- a/net/ipv4/tcp_diag.c
+++ b/net/ipv4/tcp_diag.c
@@ -222,6 +222,7 @@ static int tcp_diag_destroy(struct sk_buff *in_skb,
 #endif
 
 static const struct inet_diag_handler tcp_diag_handler = {
+	.owner			= THIS_MODULE,
 	.dump			= tcp_diag_dump,
 	.dump_one		= tcp_diag_dump_one,
 	.idiag_get_info		= tcp_diag_get_info,
diff --git a/net/ipv4/udp_diag.c b/net/ipv4/udp_diag.c
index dc41a22ee80e8..38cb3a28e4ed6 100644
--- a/net/ipv4/udp_diag.c
+++ b/net/ipv4/udp_diag.c
@@ -237,6 +237,7 @@ static int udplite_diag_destroy(struct sk_buff *in_skb,
 #endif
 
 static const struct inet_diag_handler udp_diag_handler = {
+	.owner		 = THIS_MODULE,
 	.dump		 = udp_diag_dump,
 	.dump_one	 = udp_diag_dump_one,
 	.idiag_get_info  = udp_diag_get_info,
@@ -260,6 +261,7 @@ static int udplite_diag_dump_one(struct netlink_callback *cb,
 }
 
 static const struct inet_diag_handler udplite_diag_handler = {
+	.owner		 = THIS_MODULE,
 	.dump		 = udplite_diag_dump,
 	.dump_one	 = udplite_diag_dump_one,
 	.idiag_get_info  = udp_diag_get_info,
diff --git a/net/mptcp/mptcp_diag.c b/net/mptcp/mptcp_diag.c
index 5409c2ea3f572..bd8ff5950c8d3 100644
--- a/net/mptcp/mptcp_diag.c
+++ b/net/mptcp/mptcp_diag.c
@@ -225,6 +225,7 @@ static void mptcp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
 }
 
 static const struct inet_diag_handler mptcp_diag_handler = {
+	.owner		 = THIS_MODULE,
 	.dump		 = mptcp_diag_dump,
 	.dump_one	 = mptcp_diag_dump_one,
 	.idiag_get_info  = mptcp_diag_get_info,
diff --git a/net/sctp/diag.c b/net/sctp/diag.c
index eb05131ff1dd6..23359e522273f 100644
--- a/net/sctp/diag.c
+++ b/net/sctp/diag.c
@@ -507,6 +507,7 @@ static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
 }
 
 static const struct inet_diag_handler sctp_diag_handler = {
+	.owner		 = THIS_MODULE,
 	.dump		 = sctp_diag_dump,
 	.dump_one	 = sctp_diag_dump_one,
 	.idiag_get_info  = sctp_diag_get_info,

From 223f55196bbdb182a9b8de6108a0834b5e5e832e Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:25:58 +0000
Subject: [PATCH 4/9] inet_diag: allow concurrent operations

inet_diag_lock_handler() current implementation uses a mutex
to protect inet_diag_table[] array against concurrent changes.

This makes inet_diag dump serialized, thus less scalable
than legacy /proc files.

It is time to switch to full RCU protection.

As a bonus, if a target is statically linked instead of being
modular, inet_diag_lock_handler() & inet_diag_unlock_handler()
reduce to reads only.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/ipv4/inet_diag.c | 80 ++++++++++++++++++++++----------------------
 1 file changed, 40 insertions(+), 40 deletions(-)

diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index 9804e9608a5a0..abf7dc9827969 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -32,7 +32,7 @@
 #include <linux/inet_diag.h>
 #include <linux/sock_diag.h>
 
-static const struct inet_diag_handler **inet_diag_table;
+static const struct inet_diag_handler __rcu **inet_diag_table;
 
 struct inet_diag_entry {
 	const __be32 *saddr;
@@ -48,28 +48,28 @@ struct inet_diag_entry {
 #endif
 };
 
-static DEFINE_MUTEX(inet_diag_table_mutex);
-
 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
 {
-	if (proto < 0 || proto >= IPPROTO_MAX) {
-		mutex_lock(&inet_diag_table_mutex);
-		return ERR_PTR(-ENOENT);
-	}
+	const struct inet_diag_handler *handler;
+
+	if (proto < 0 || proto >= IPPROTO_MAX)
+		return NULL;
 
 	if (!READ_ONCE(inet_diag_table[proto]))
 		sock_load_diag_module(AF_INET, proto);
 
-	mutex_lock(&inet_diag_table_mutex);
-	if (!inet_diag_table[proto])
-		return ERR_PTR(-ENOENT);
+	rcu_read_lock();
+	handler = rcu_dereference(inet_diag_table[proto]);
+	if (handler && !try_module_get(handler->owner))
+		handler = NULL;
+	rcu_read_unlock();
 
-	return inet_diag_table[proto];
+	return handler;
 }
 
 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
 {
-	mutex_unlock(&inet_diag_table_mutex);
+	module_put(handler->owner);
 }
 
 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
@@ -104,9 +104,12 @@ static size_t inet_sk_attr_size(struct sock *sk,
 	const struct inet_diag_handler *handler;
 	size_t aux = 0;
 
-	handler = inet_diag_table[req->sdiag_protocol];
+	rcu_read_lock();
+	handler = rcu_dereference(inet_diag_table[req->sdiag_protocol]);
+	DEBUG_NET_WARN_ON_ONCE(!handler);
 	if (handler && handler->idiag_get_aux_size)
 		aux = handler->idiag_get_aux_size(sk, net_admin);
+	rcu_read_unlock();
 
 	return	  nla_total_size(sizeof(struct tcp_info))
 		+ nla_total_size(sizeof(struct inet_diag_msg))
@@ -244,10 +247,16 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
 	struct nlmsghdr  *nlh;
 	struct nlattr *attr;
 	void *info = NULL;
+	int protocol;
 
 	cb_data = cb->data;
-	handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)];
-	BUG_ON(!handler);
+	protocol = inet_diag_get_protocol(req, cb_data);
+
+	/* inet_diag_lock_handler() made sure inet_diag_table[] is stable. */
+	handler = rcu_dereference_protected(inet_diag_table[protocol], 1);
+	DEBUG_NET_WARN_ON_ONCE(!handler);
+	if (!handler)
+		return -ENXIO;
 
 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
 			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
@@ -605,9 +614,10 @@ static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
 	protocol = inet_diag_get_protocol(req, &dump_data);
 
 	handler = inet_diag_lock_handler(protocol);
-	if (IS_ERR(handler)) {
-		err = PTR_ERR(handler);
-	} else if (cmd == SOCK_DIAG_BY_FAMILY) {
+	if (!handler)
+		return -ENOENT;
+
+	if (cmd == SOCK_DIAG_BY_FAMILY) {
 		struct netlink_callback cb = {
 			.nlh = nlh,
 			.skb = in_skb,
@@ -1259,12 +1269,12 @@ static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
 again:
 	prev_min_dump_alloc = cb->min_dump_alloc;
 	handler = inet_diag_lock_handler(protocol);
-	if (!IS_ERR(handler))
+	if (handler) {
 		handler->dump(skb, cb, r);
-	else
-		err = PTR_ERR(handler);
-	inet_diag_unlock_handler(handler);
-
+		inet_diag_unlock_handler(handler);
+	} else {
+		err = -ENOENT;
+	}
 	/* The skb is not large enough to fit one sk info and
 	 * inet_sk_diag_fill() has requested for a larger skb.
 	 */
@@ -1457,10 +1467,9 @@ int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
 	}
 
 	handler = inet_diag_lock_handler(sk->sk_protocol);
-	if (IS_ERR(handler)) {
-		inet_diag_unlock_handler(handler);
+	if (!handler) {
 		nlmsg_cancel(skb, nlh);
-		return PTR_ERR(handler);
+		return -ENOENT;
 	}
 
 	attr = handler->idiag_info_size
@@ -1495,20 +1504,12 @@ static const struct sock_diag_handler inet6_diag_handler = {
 int inet_diag_register(const struct inet_diag_handler *h)
 {
 	const __u16 type = h->idiag_type;
-	int err = -EINVAL;
 
 	if (type >= IPPROTO_MAX)
-		goto out;
+		return -EINVAL;
 
-	mutex_lock(&inet_diag_table_mutex);
-	err = -EEXIST;
-	if (!inet_diag_table[type]) {
-		WRITE_ONCE(inet_diag_table[type], h);
-		err = 0;
-	}
-	mutex_unlock(&inet_diag_table_mutex);
-out:
-	return err;
+	return !cmpxchg((const struct inet_diag_handler **)&inet_diag_table[type],
+			NULL, h) ? 0 : -EEXIST;
 }
 EXPORT_SYMBOL_GPL(inet_diag_register);
 
@@ -1519,9 +1520,8 @@ void inet_diag_unregister(const struct inet_diag_handler *h)
 	if (type >= IPPROTO_MAX)
 		return;
 
-	mutex_lock(&inet_diag_table_mutex);
-	WRITE_ONCE(inet_diag_table[type], NULL);
-	mutex_unlock(&inet_diag_table_mutex);
+	xchg((const struct inet_diag_handler **)&inet_diag_table[type],
+	     NULL);
 }
 EXPORT_SYMBOL_GPL(inet_diag_unregister);
 

From 114b4bb1cc19239b272d52ebbe156053483fe2f8 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:25:59 +0000
Subject: [PATCH 5/9] sock_diag: add module pointer to "struct
 sock_diag_handler"

Following patch is going to use RCU instead of
sock_diag_table_mutex acquisition.

This patch is a preparation, no change of behavior yet.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 include/linux/sock_diag.h | 1 +
 net/ipv4/inet_diag.c      | 2 ++
 net/netlink/diag.c        | 1 +
 net/packet/diag.c         | 1 +
 net/smc/smc_diag.c        | 1 +
 net/tipc/diag.c           | 1 +
 net/unix/diag.c           | 1 +
 net/vmw_vsock/diag.c      | 1 +
 net/xdp/xsk_diag.c        | 1 +
 9 files changed, 10 insertions(+)

diff --git a/include/linux/sock_diag.h b/include/linux/sock_diag.h
index 0b9ecd8cf9793..7c07754d711b9 100644
--- a/include/linux/sock_diag.h
+++ b/include/linux/sock_diag.h
@@ -13,6 +13,7 @@ struct nlmsghdr;
 struct sock;
 
 struct sock_diag_handler {
+	struct module *owner;
 	__u8 family;
 	int (*dump)(struct sk_buff *skb, struct nlmsghdr *nlh);
 	int (*get_info)(struct sk_buff *skb, struct sock *sk);
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index abf7dc9827969..52ce20691e4ef 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -1488,6 +1488,7 @@ int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
 }
 
 static const struct sock_diag_handler inet_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_INET,
 	.dump = inet_diag_handler_cmd,
 	.get_info = inet_diag_handler_get_info,
@@ -1495,6 +1496,7 @@ static const struct sock_diag_handler inet_diag_handler = {
 };
 
 static const struct sock_diag_handler inet6_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_INET6,
 	.dump = inet_diag_handler_cmd,
 	.get_info = inet_diag_handler_get_info,
diff --git a/net/netlink/diag.c b/net/netlink/diag.c
index 1eeff9422856e..e12c90d5f6ad2 100644
--- a/net/netlink/diag.c
+++ b/net/netlink/diag.c
@@ -241,6 +241,7 @@ static int netlink_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 }
 
 static const struct sock_diag_handler netlink_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_NETLINK,
 	.dump = netlink_diag_handler_dump,
 };
diff --git a/net/packet/diag.c b/net/packet/diag.c
index 9a7980e3309d6..b3bd2f6c2bf7b 100644
--- a/net/packet/diag.c
+++ b/net/packet/diag.c
@@ -245,6 +245,7 @@ static int packet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 }
 
 static const struct sock_diag_handler packet_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_PACKET,
 	.dump = packet_diag_handler_dump,
 };
diff --git a/net/smc/smc_diag.c b/net/smc/smc_diag.c
index 52f7c4f1e7670..32bad267fa3e2 100644
--- a/net/smc/smc_diag.c
+++ b/net/smc/smc_diag.c
@@ -255,6 +255,7 @@ static int smc_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 }
 
 static const struct sock_diag_handler smc_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_SMC,
 	.dump = smc_diag_handler_dump,
 };
diff --git a/net/tipc/diag.c b/net/tipc/diag.c
index 18733451c9e0c..54dde8c4e4d46 100644
--- a/net/tipc/diag.c
+++ b/net/tipc/diag.c
@@ -95,6 +95,7 @@ static int tipc_sock_diag_handler_dump(struct sk_buff *skb,
 }
 
 static const struct sock_diag_handler tipc_sock_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_TIPC,
 	.dump = tipc_sock_diag_handler_dump,
 };
diff --git a/net/unix/diag.c b/net/unix/diag.c
index bec09a3a1d44c..c3648b7065096 100644
--- a/net/unix/diag.c
+++ b/net/unix/diag.c
@@ -322,6 +322,7 @@ static int unix_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 }
 
 static const struct sock_diag_handler unix_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_UNIX,
 	.dump = unix_diag_handler_dump,
 };
diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c
index 2e29994f92ffa..ab87ef66c1e88 100644
--- a/net/vmw_vsock/diag.c
+++ b/net/vmw_vsock/diag.c
@@ -157,6 +157,7 @@ static int vsock_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 }
 
 static const struct sock_diag_handler vsock_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_VSOCK,
 	.dump = vsock_diag_handler_dump,
 };
diff --git a/net/xdp/xsk_diag.c b/net/xdp/xsk_diag.c
index 9f8955367275e..09dcea0cbbed9 100644
--- a/net/xdp/xsk_diag.c
+++ b/net/xdp/xsk_diag.c
@@ -194,6 +194,7 @@ static int xsk_diag_handler_dump(struct sk_buff *nlskb, struct nlmsghdr *hdr)
 }
 
 static const struct sock_diag_handler xsk_diag_handler = {
+	.owner = THIS_MODULE,
 	.family = AF_XDP,
 	.dump = xsk_diag_handler_dump,
 };

From 1d55a6974756cf3979efd2cc68bcece611a44053 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:26:00 +0000
Subject: [PATCH 6/9] sock_diag: allow concurrent operations

sock_diag_broadcast_destroy_work() and __sock_diag_cmd()
are currently using sock_diag_table_mutex to protect
against concurrent sock_diag_handlers[] changes.

This makes inet_diag dump serialized, thus less scalable
than legacy /proc files.

It is time to switch to full RCU protection.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/core/sock_diag.c | 73 +++++++++++++++++++++++++-------------------
 1 file changed, 42 insertions(+), 31 deletions(-)

diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c
index c53b731f2d672..72009e1f4380d 100644
--- a/net/core/sock_diag.c
+++ b/net/core/sock_diag.c
@@ -16,7 +16,7 @@
 #include <linux/inet_diag.h>
 #include <linux/sock_diag.h>
 
-static const struct sock_diag_handler *sock_diag_handlers[AF_MAX];
+static const struct sock_diag_handler __rcu *sock_diag_handlers[AF_MAX];
 static int (*inet_rcv_compat)(struct sk_buff *skb, struct nlmsghdr *nlh);
 static DEFINE_MUTEX(sock_diag_table_mutex);
 static struct workqueue_struct *broadcast_wq;
@@ -122,6 +122,24 @@ static size_t sock_diag_nlmsg_size(void)
 	       + nla_total_size_64bit(sizeof(struct tcp_info))); /* INET_DIAG_INFO */
 }
 
+static const struct sock_diag_handler *sock_diag_lock_handler(int family)
+{
+	const struct sock_diag_handler *handler;
+
+	rcu_read_lock();
+	handler = rcu_dereference(sock_diag_handlers[family]);
+	if (handler && !try_module_get(handler->owner))
+		handler = NULL;
+	rcu_read_unlock();
+
+	return handler;
+}
+
+static void sock_diag_unlock_handler(const struct sock_diag_handler *handler)
+{
+	module_put(handler->owner);
+}
+
 static void sock_diag_broadcast_destroy_work(struct work_struct *work)
 {
 	struct broadcast_sk *bsk =
@@ -138,12 +156,12 @@ static void sock_diag_broadcast_destroy_work(struct work_struct *work)
 	if (!skb)
 		goto out;
 
-	mutex_lock(&sock_diag_table_mutex);
-	hndl = sock_diag_handlers[sk->sk_family];
-	if (hndl && hndl->get_info)
-		err = hndl->get_info(skb, sk);
-	mutex_unlock(&sock_diag_table_mutex);
-
+	hndl = sock_diag_lock_handler(sk->sk_family);
+	if (hndl) {
+		if (hndl->get_info)
+			err = hndl->get_info(skb, sk);
+		sock_diag_unlock_handler(hndl);
+	}
 	if (!err)
 		nlmsg_multicast(sock_net(sk)->diag_nlsk, skb, 0, group,
 				GFP_KERNEL);
@@ -184,33 +202,26 @@ EXPORT_SYMBOL_GPL(sock_diag_unregister_inet_compat);
 
 int sock_diag_register(const struct sock_diag_handler *hndl)
 {
-	int err = 0;
+	int family = hndl->family;
 
-	if (hndl->family >= AF_MAX)
+	if (family >= AF_MAX)
 		return -EINVAL;
 
-	mutex_lock(&sock_diag_table_mutex);
-	if (sock_diag_handlers[hndl->family])
-		err = -EBUSY;
-	else
-		WRITE_ONCE(sock_diag_handlers[hndl->family], hndl);
-	mutex_unlock(&sock_diag_table_mutex);
-
-	return err;
+	return !cmpxchg((const struct sock_diag_handler **)
+				&sock_diag_handlers[family],
+			NULL, hndl) ? 0 : -EBUSY;
 }
 EXPORT_SYMBOL_GPL(sock_diag_register);
 
-void sock_diag_unregister(const struct sock_diag_handler *hnld)
+void sock_diag_unregister(const struct sock_diag_handler *hndl)
 {
-	int family = hnld->family;
+	int family = hndl->family;
 
 	if (family >= AF_MAX)
 		return;
 
-	mutex_lock(&sock_diag_table_mutex);
-	BUG_ON(sock_diag_handlers[family] != hnld);
-	WRITE_ONCE(sock_diag_handlers[family], NULL);
-	mutex_unlock(&sock_diag_table_mutex);
+	xchg((const struct sock_diag_handler **)&sock_diag_handlers[family],
+	     NULL);
 }
 EXPORT_SYMBOL_GPL(sock_diag_unregister);
 
@@ -227,20 +238,20 @@ static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh)
 		return -EINVAL;
 	req->sdiag_family = array_index_nospec(req->sdiag_family, AF_MAX);
 
-	if (READ_ONCE(sock_diag_handlers[req->sdiag_family]) == NULL)
+	if (!rcu_access_pointer(sock_diag_handlers[req->sdiag_family]))
 		sock_load_diag_module(req->sdiag_family, 0);
 
-	mutex_lock(&sock_diag_table_mutex);
-	hndl = sock_diag_handlers[req->sdiag_family];
+	hndl = sock_diag_lock_handler(req->sdiag_family);
 	if (hndl == NULL)
-		err = -ENOENT;
-	else if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY)
+		return -ENOENT;
+
+	if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY)
 		err = hndl->dump(skb, nlh);
 	else if (nlh->nlmsg_type == SOCK_DESTROY && hndl->destroy)
 		err = hndl->destroy(skb, nlh);
 	else
 		err = -EOPNOTSUPP;
-	mutex_unlock(&sock_diag_table_mutex);
+	sock_diag_unlock_handler(hndl);
 
 	return err;
 }
@@ -286,12 +297,12 @@ static int sock_diag_bind(struct net *net, int group)
 	switch (group) {
 	case SKNLGRP_INET_TCP_DESTROY:
 	case SKNLGRP_INET_UDP_DESTROY:
-		if (!READ_ONCE(sock_diag_handlers[AF_INET]))
+		if (!rcu_access_pointer(sock_diag_handlers[AF_INET]))
 			sock_load_diag_module(AF_INET, 0);
 		break;
 	case SKNLGRP_INET6_TCP_DESTROY:
 	case SKNLGRP_INET6_UDP_DESTROY:
-		if (!READ_ONCE(sock_diag_handlers[AF_INET6]))
+		if (!rcu_access_pointer(sock_diag_handlers[AF_INET6]))
 			sock_load_diag_module(AF_INET6, 0);
 		break;
 	}

From 86e8921df05c6e9423ab74ab8d41022775d8b83a Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:26:01 +0000
Subject: [PATCH 7/9] sock_diag: allow concurrent operation in
 sock_diag_rcv_msg()

TCPDIAG_GETSOCK and DCCPDIAG_GETSOCK diag are serialized
on sock_diag_table_mutex.

This is to make sure inet_diag module is not unloaded
while diag was ongoing.

It is time to get rid of this mutex and use RCU protection,
allowing full parallelism.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 include/linux/sock_diag.h |  9 ++++++--
 net/core/sock_diag.c      | 43 +++++++++++++++++++++++----------------
 net/ipv4/inet_diag.c      |  9 ++++++--
 3 files changed, 40 insertions(+), 21 deletions(-)

diff --git a/include/linux/sock_diag.h b/include/linux/sock_diag.h
index 7c07754d711b9..110978dc9af1b 100644
--- a/include/linux/sock_diag.h
+++ b/include/linux/sock_diag.h
@@ -23,8 +23,13 @@ struct sock_diag_handler {
 int sock_diag_register(const struct sock_diag_handler *h);
 void sock_diag_unregister(const struct sock_diag_handler *h);
 
-void sock_diag_register_inet_compat(int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh));
-void sock_diag_unregister_inet_compat(int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh));
+struct sock_diag_inet_compat {
+	struct module *owner;
+	int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh);
+};
+
+void sock_diag_register_inet_compat(const struct sock_diag_inet_compat *ptr);
+void sock_diag_unregister_inet_compat(const struct sock_diag_inet_compat *ptr);
 
 u64 __sock_gen_cookie(struct sock *sk);
 
diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c
index 72009e1f4380d..5c3666431df49 100644
--- a/net/core/sock_diag.c
+++ b/net/core/sock_diag.c
@@ -17,8 +17,9 @@
 #include <linux/sock_diag.h>
 
 static const struct sock_diag_handler __rcu *sock_diag_handlers[AF_MAX];
-static int (*inet_rcv_compat)(struct sk_buff *skb, struct nlmsghdr *nlh);
-static DEFINE_MUTEX(sock_diag_table_mutex);
+
+static struct sock_diag_inet_compat __rcu *inet_rcv_compat;
+
 static struct workqueue_struct *broadcast_wq;
 
 DEFINE_COOKIE(sock_cookie);
@@ -184,19 +185,20 @@ void sock_diag_broadcast_destroy(struct sock *sk)
 	queue_work(broadcast_wq, &bsk->work);
 }
 
-void sock_diag_register_inet_compat(int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh))
+void sock_diag_register_inet_compat(const struct sock_diag_inet_compat *ptr)
 {
-	mutex_lock(&sock_diag_table_mutex);
-	inet_rcv_compat = fn;
-	mutex_unlock(&sock_diag_table_mutex);
+	xchg((__force const struct sock_diag_inet_compat **)&inet_rcv_compat,
+	     ptr);
 }
 EXPORT_SYMBOL_GPL(sock_diag_register_inet_compat);
 
-void sock_diag_unregister_inet_compat(int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh))
+void sock_diag_unregister_inet_compat(const struct sock_diag_inet_compat *ptr)
 {
-	mutex_lock(&sock_diag_table_mutex);
-	inet_rcv_compat = NULL;
-	mutex_unlock(&sock_diag_table_mutex);
+	const struct sock_diag_inet_compat *old;
+
+	old = xchg((__force const struct sock_diag_inet_compat **)&inet_rcv_compat,
+		   NULL);
+	WARN_ON_ONCE(old != ptr);
 }
 EXPORT_SYMBOL_GPL(sock_diag_unregister_inet_compat);
 
@@ -259,20 +261,27 @@ static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh)
 static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
 			     struct netlink_ext_ack *extack)
 {
+	const struct sock_diag_inet_compat *ptr;
 	int ret;
 
 	switch (nlh->nlmsg_type) {
 	case TCPDIAG_GETSOCK:
 	case DCCPDIAG_GETSOCK:
-		if (inet_rcv_compat == NULL)
+
+		if (!rcu_access_pointer(inet_rcv_compat))
 			sock_load_diag_module(AF_INET, 0);
 
-		mutex_lock(&sock_diag_table_mutex);
-		if (inet_rcv_compat != NULL)
-			ret = inet_rcv_compat(skb, nlh);
-		else
-			ret = -EOPNOTSUPP;
-		mutex_unlock(&sock_diag_table_mutex);
+		rcu_read_lock();
+		ptr = rcu_dereference(inet_rcv_compat);
+		if (ptr && !try_module_get(ptr->owner))
+			ptr = NULL;
+		rcu_read_unlock();
+
+		ret = -EOPNOTSUPP;
+		if (ptr) {
+			ret = ptr->fn(skb, nlh);
+			module_put(ptr->owner);
+		}
 
 		return ret;
 	case SOCK_DIAG_BY_FAMILY:
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index 52ce20691e4ef..2c2d8b9dd8e9b 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -1527,6 +1527,11 @@ void inet_diag_unregister(const struct inet_diag_handler *h)
 }
 EXPORT_SYMBOL_GPL(inet_diag_unregister);
 
+static const struct sock_diag_inet_compat inet_diag_compat = {
+	.owner	= THIS_MODULE,
+	.fn	= inet_diag_rcv_msg_compat,
+};
+
 static int __init inet_diag_init(void)
 {
 	const int inet_diag_table_size = (IPPROTO_MAX *
@@ -1545,7 +1550,7 @@ static int __init inet_diag_init(void)
 	if (err)
 		goto out_free_inet;
 
-	sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
+	sock_diag_register_inet_compat(&inet_diag_compat);
 out:
 	return err;
 
@@ -1560,7 +1565,7 @@ static void __exit inet_diag_exit(void)
 {
 	sock_diag_unregister(&inet6_diag_handler);
 	sock_diag_unregister(&inet_diag_handler);
-	sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
+	sock_diag_unregister_inet_compat(&inet_diag_compat);
 	kfree(inet_diag_table);
 }
 

From f44e64990beb41167bd7c313d90bcf7e290c3582 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:26:02 +0000
Subject: [PATCH 8/9] sock_diag: remove sock_diag_mutex

sock_diag_rcv() is still serializing its operations using
a mutex, for no good reason.

This came with commit 0a9c73014415 ("[INET_DIAG]: Fix oops
in netlink_rcv_skb"), but the root cause has been fixed
with commit cd40b7d3983c ("[NET]: make netlink user -> kernel
interface synchronious")

Remove this mutex to let multiple threads run concurrently.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/core/sock_diag.c | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c
index 5c3666431df49..6541228380252 100644
--- a/net/core/sock_diag.c
+++ b/net/core/sock_diag.c
@@ -292,13 +292,9 @@ static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
 	}
 }
 
-static DEFINE_MUTEX(sock_diag_mutex);
-
 static void sock_diag_rcv(struct sk_buff *skb)
 {
-	mutex_lock(&sock_diag_mutex);
 	netlink_rcv_skb(skb, &sock_diag_rcv_msg);
-	mutex_unlock(&sock_diag_mutex);
 }
 
 static int sock_diag_bind(struct net *net, int group)

From 622a08e8de9f9ad576fd56e3a2e97a84370a8100 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Mon, 22 Jan 2024 11:26:03 +0000
Subject: [PATCH 9/9] inet_diag: skip over empty buckets

After the removal of inet_diag_table_mutex, sock_diag_table_mutex
and sock_diag_mutex, I was able so see spinlock contention from
inet_diag_dump_icsk() when running 100 parallel invocations.

It is time to skip over empty buckets.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/ipv4/inet_diag.c | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index 2c2d8b9dd8e9b..7adace541fe29 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -1045,6 +1045,10 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
 			num = 0;
 			ilb = &hashinfo->lhash2[i];
 
+			if (hlist_nulls_empty(&ilb->nulls_head)) {
+				s_num = 0;
+				continue;
+			}
 			spin_lock(&ilb->lock);
 			sk_nulls_for_each(sk, node, &ilb->nulls_head) {
 				struct inet_sock *inet = inet_sk(sk);
@@ -1109,6 +1113,10 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
 			accum = 0;
 			ibb = &hashinfo->bhash2[i];
 
+			if (hlist_empty(&ibb->chain)) {
+				s_num = 0;
+				continue;
+			}
 			spin_lock_bh(&ibb->lock);
 			inet_bind_bucket_for_each(tb2, &ibb->chain) {
 				if (!net_eq(ib2_net(tb2), net))