Skip to content

Commit

Permalink
Merge branch 'wireguard-fixes'
Browse files Browse the repository at this point in the history
Jason A. Donenfeld says:

====================
wireguard fixes for 5.13-rc5

Here are bug fixes to WireGuard for 5.13-rc5:

1-2,6) These are small, trivial tweaks to our test harness.

3) Linus thinks -O3 is still dangerous to enable. The code gen wasn't so
   much different with -O2 either.

4) We were accidentally calling synchronize_rcu instead of
   synchronize_net while holding the rtnl_lock, resulting in some rather
   large stalls that hit production machines.

5) Peer allocation was wasting literally hundreds of megabytes on real
   world deployments, due to oddly sized large objects not fitting
   nicely into a kmalloc slab.

7-9) We move from an insanely expensive O(n) algorithm to a fast O(1)
     algorithm, and cleanup a massive memory leak in the process, in
     which allowed ips churn would leave danging nodes hanging around
     without cleanup until the interface was removed. The O(1) algorithm
     eliminates packet stalls and high latency issues, in addition to
     bringing operations that took as much as 10 minutes down to less
     than a second.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
David S. Miller committed Jun 4, 2021
2 parents 579028d + bf7b042 commit 6fd815b
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 195 deletions.
3 changes: 1 addition & 2 deletions drivers/net/wireguard/Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
ccflags-y := -O3
ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
ccflags-y := -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG
wireguard-y := main.o
wireguard-y += noise.o
Expand Down
189 changes: 99 additions & 90 deletions drivers/net/wireguard/allowedips.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "allowedips.h"
#include "peer.h"

static struct kmem_cache *node_cache;

static void swap_endian(u8 *dst, const u8 *src, u8 bits)
{
if (bits == 32) {
Expand All @@ -28,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
node->bitlen = bits;
memcpy(node->bits, src, bits / 8U);
}
#define CHOOSE_NODE(parent, key) \
parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]

static inline u8 choose(struct allowedips_node *node, const u8 *key)
{
return (key[node->bit_at_a] >> node->bit_at_b) & 1;
}

static void push_rcu(struct allowedips_node **stack,
struct allowedips_node __rcu *p, unsigned int *len)
Expand All @@ -40,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack,
}
}

static void node_free_rcu(struct rcu_head *rcu)
{
kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu));
}

static void root_free_rcu(struct rcu_head *rcu)
{
struct allowedips_node *node, *stack[128] = {
Expand All @@ -49,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu)
while (len > 0 && (node = stack[--len])) {
push_rcu(stack, node->bit[0], &len);
push_rcu(stack, node->bit[1], &len);
kfree(node);
kmem_cache_free(node_cache, node);
}
}

Expand All @@ -66,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
}
}

static void walk_remove_by_peer(struct allowedips_node __rcu **top,
struct wg_peer *peer, struct mutex *lock)
{
#define REF(p) rcu_access_pointer(p)
#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
#define PUSH(p) ({ \
WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
stack[len++] = p; \
})

struct allowedips_node __rcu **stack[128], **nptr;
struct allowedips_node *node, *prev;
unsigned int len;

if (unlikely(!peer || !REF(*top)))
return;

for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
nptr = stack[len - 1];
node = DEREF(nptr);
if (!node) {
--len;
continue;
}
if (!prev || REF(prev->bit[0]) == node ||
REF(prev->bit[1]) == node) {
if (REF(node->bit[0]))
PUSH(&node->bit[0]);
else if (REF(node->bit[1]))
PUSH(&node->bit[1]);
} else if (REF(node->bit[0]) == prev) {
if (REF(node->bit[1]))
PUSH(&node->bit[1]);
} else {
if (rcu_dereference_protected(node->peer,
lockdep_is_held(lock)) == peer) {
RCU_INIT_POINTER(node->peer, NULL);
list_del_init(&node->peer_list);
if (!node->bit[0] || !node->bit[1]) {
rcu_assign_pointer(*nptr, DEREF(
&node->bit[!REF(node->bit[0])]));
kfree_rcu(node, rcu);
node = DEREF(nptr);
}
}
--len;
}
}

#undef REF
#undef DEREF
#undef PUSH
}

static unsigned int fls128(u64 a, u64 b)
{
return a ? fls64(a) + 64U : fls64(b);
Expand Down Expand Up @@ -159,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
found = node;
if (node->cidr == bits)
break;
node = rcu_dereference_bh(CHOOSE_NODE(node, key));
node = rcu_dereference_bh(node->bit[choose(node, key)]);
}
return found;
}
Expand Down Expand Up @@ -191,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
u8 cidr, u8 bits, struct allowedips_node **rnode,
struct mutex *lock)
{
struct allowedips_node *node = rcu_dereference_protected(trie,
lockdep_is_held(lock));
struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
struct allowedips_node *parent = NULL;
bool exact = false;

Expand All @@ -202,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
exact = true;
break;
}
node = rcu_dereference_protected(CHOOSE_NODE(parent, key),
lockdep_is_held(lock));
node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock));
}
*rnode = parent;
return exact;
}

static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node)
{
node->parent_bit_packed = (unsigned long)parent | bit;
rcu_assign_pointer(*parent, node);
}

static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node)
{
u8 bit = choose(parent, node->bits);
connect_node(&parent->bit[bit], bit, node);
}

static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
u8 cidr, struct wg_peer *peer, struct mutex *lock)
{
Expand All @@ -218,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return -EINVAL;

if (!rcu_access_pointer(*trie)) {
node = kzalloc(sizeof(*node), GFP_KERNEL);
node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!node))
return -ENOMEM;
RCU_INIT_POINTER(node->peer, peer);
list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits);
rcu_assign_pointer(*trie, node);
connect_node(trie, 2, node);
return 0;
}
if (node_placement(*trie, key, cidr, bits, &node, lock)) {
Expand All @@ -233,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return 0;
}

newnode = kzalloc(sizeof(*newnode), GFP_KERNEL);
newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!newnode))
return -ENOMEM;
RCU_INIT_POINTER(newnode->peer, peer);
Expand All @@ -243,41 +209,40 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (!node) {
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
} else {
down = rcu_dereference_protected(CHOOSE_NODE(node, key),
lockdep_is_held(lock));
const u8 bit = choose(node, key);
down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock));
if (!down) {
rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
connect_node(&node->bit[bit], bit, newnode);
return 0;
}
}
cidr = min(cidr, common_bits(down, key, bits));
parent = node;

if (newnode->cidr == cidr) {
rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
choose_and_connect_node(newnode, down);
if (!parent)
rcu_assign_pointer(*trie, newnode);
connect_node(trie, 2, newnode);
else
rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
newnode);
} else {
node = kzalloc(sizeof(*node), GFP_KERNEL);
if (unlikely(!node)) {
list_del(&newnode->peer_list);
kfree(newnode);
return -ENOMEM;
}
INIT_LIST_HEAD(&node->peer_list);
copy_and_assign_cidr(node, newnode->bits, cidr, bits);
choose_and_connect_node(parent, newnode);
return 0;
}

rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
if (!parent)
rcu_assign_pointer(*trie, node);
else
rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
node);
node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!node)) {
list_del(&newnode->peer_list);
kmem_cache_free(node_cache, newnode);
return -ENOMEM;
}
INIT_LIST_HEAD(&node->peer_list);
copy_and_assign_cidr(node, newnode->bits, cidr, bits);

choose_and_connect_node(node, down);
choose_and_connect_node(node, newnode);
if (!parent)
connect_node(trie, 2, node);
else
choose_and_connect_node(parent, node);
return 0;
}

Expand Down Expand Up @@ -335,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock)
{
struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
bool free_parent;

if (list_empty(&peer->allowedips_list))
return;
++table->seq;
walk_remove_by_peer(&table->root4, peer, lock);
walk_remove_by_peer(&table->root6, peer, lock);
list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
list_del_init(&node->peer_list);
RCU_INIT_POINTER(node->peer, NULL);
if (node->bit[0] && node->bit[1])
continue;
child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
lockdep_is_held(lock));
if (child)
child->parent_bit_packed = node->parent_bit_packed;
parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
*parent_bit = child;
parent = (void *)parent_bit -
offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
free_parent = !rcu_access_pointer(node->bit[0]) &&
!rcu_access_pointer(node->bit[1]) &&
(node->parent_bit_packed & 3) <= 1 &&
!rcu_access_pointer(parent->peer);
if (free_parent)
child = rcu_dereference_protected(
parent->bit[!(node->parent_bit_packed & 1)],
lockdep_is_held(lock));
call_rcu(&node->rcu, node_free_rcu);
if (!free_parent)
continue;
if (child)
child->parent_bit_packed = parent->parent_bit_packed;
*(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
call_rcu(&parent->rcu, node_free_rcu);
}
}

int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
Expand Down Expand Up @@ -374,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
return NULL;
}

int __init wg_allowedips_slab_init(void)
{
node_cache = KMEM_CACHE(allowedips_node, 0);
return node_cache ? 0 : -ENOMEM;
}

void wg_allowedips_slab_uninit(void)
{
rcu_barrier();
kmem_cache_destroy(node_cache);
}

#include "selftest/allowedips.c"
14 changes: 7 additions & 7 deletions drivers/net/wireguard/allowedips.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@ struct wg_peer;
struct allowedips_node {
struct wg_peer __rcu *peer;
struct allowedips_node __rcu *bit[2];
/* While it may seem scandalous that we waste space for v4,
* we're alloc'ing to the nearest power of 2 anyway, so this
* doesn't actually make a difference.
*/
u8 bits[16] __aligned(__alignof(u64));
u8 cidr, bit_at_a, bit_at_b, bitlen;
u8 bits[16] __aligned(__alignof(u64));

/* Keep rarely used list at bottom to be beyond cache line. */
/* Keep rarely used members at bottom to be beyond cache line. */
unsigned long parent_bit_packed;
union {
struct list_head peer_list;
struct rcu_head rcu;
Expand All @@ -33,7 +30,7 @@ struct allowedips {
struct allowedips_node __rcu *root4;
struct allowedips_node __rcu *root6;
u64 seq;
};
} __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */

void wg_allowedips_init(struct allowedips *table);
void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
Expand All @@ -56,4 +53,7 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
bool wg_allowedips_selftest(void);
#endif

int wg_allowedips_slab_init(void);
void wg_allowedips_slab_uninit(void);

#endif /* _WG_ALLOWEDIPS_H */
17 changes: 16 additions & 1 deletion drivers/net/wireguard/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,22 @@ static int __init mod_init(void)
{
int ret;

ret = wg_allowedips_slab_init();
if (ret < 0)
goto err_allowedips;

#ifdef DEBUG
ret = -ENOTRECOVERABLE;
if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() ||
!wg_ratelimiter_selftest())
return -ENOTRECOVERABLE;
goto err_peer;
#endif
wg_noise_init();

ret = wg_peer_init();
if (ret < 0)
goto err_peer;

ret = wg_device_init();
if (ret < 0)
goto err_device;
Expand All @@ -44,13 +53,19 @@ static int __init mod_init(void)
err_netlink:
wg_device_uninit();
err_device:
wg_peer_uninit();
err_peer:
wg_allowedips_slab_uninit();
err_allowedips:
return ret;
}

static void __exit mod_exit(void)
{
wg_genetlink_uninit();
wg_device_uninit();
wg_peer_uninit();
wg_allowedips_slab_uninit();
}

module_init(mod_init);
Expand Down
Loading

0 comments on commit 6fd815b

Please sign in to comment.