diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c
index 3c204d2144fac..3d9be36965f68 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c
@@ -320,19 +320,6 @@ struct mlxsw_sp_prefix_usage {
 #define mlxsw_sp_prefix_usage_for_each(prefix, prefix_usage) \
 	for_each_set_bit(prefix, (prefix_usage)->b, MLXSW_SP_PREFIX_COUNT)
 
-static bool
-mlxsw_sp_prefix_usage_subset(struct mlxsw_sp_prefix_usage *prefix_usage1,
-			     struct mlxsw_sp_prefix_usage *prefix_usage2)
-{
-	unsigned char prefix;
-
-	mlxsw_sp_prefix_usage_for_each(prefix, prefix_usage1) {
-		if (!test_bit(prefix, prefix_usage2->b))
-			return false;
-	}
-	return true;
-}
-
 static bool
 mlxsw_sp_prefix_usage_eq(struct mlxsw_sp_prefix_usage *prefix_usage1,
 			 struct mlxsw_sp_prefix_usage *prefix_usage2)
@@ -589,16 +576,14 @@ mlxsw_sp_lpm_tree_get(struct mlxsw_sp *mlxsw_sp,
 		    lpm_tree->proto == proto &&
 		    mlxsw_sp_prefix_usage_eq(&lpm_tree->prefix_usage,
 					     prefix_usage))
-			goto inc_ref_count;
+			return lpm_tree;
 	}
-	lpm_tree = mlxsw_sp_lpm_tree_create(mlxsw_sp, prefix_usage,
-					    proto);
-	if (IS_ERR(lpm_tree))
-		return lpm_tree;
+	return mlxsw_sp_lpm_tree_create(mlxsw_sp, prefix_usage, proto);
+}
 
-inc_ref_count:
+static void mlxsw_sp_lpm_tree_hold(struct mlxsw_sp_lpm_tree *lpm_tree)
+{
 	lpm_tree->ref_count++;
-	return lpm_tree;
 }
 
 static void mlxsw_sp_lpm_tree_put(struct mlxsw_sp *mlxsw_sp,
@@ -750,46 +735,6 @@ static void mlxsw_sp_vr_destroy(struct mlxsw_sp_vr *vr)
 	vr->fib4 = NULL;
 }
 
-static int
-mlxsw_sp_vr_lpm_tree_check(struct mlxsw_sp *mlxsw_sp, struct mlxsw_sp_fib *fib,
-			   struct mlxsw_sp_prefix_usage *req_prefix_usage)
-{
-	struct mlxsw_sp_lpm_tree *lpm_tree = fib->lpm_tree;
-	struct mlxsw_sp_lpm_tree *new_tree;
-	int err;
-
-	if (mlxsw_sp_prefix_usage_eq(req_prefix_usage, &lpm_tree->prefix_usage))
-		return 0;
-
-	new_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, req_prefix_usage,
-					 fib->proto);
-	if (IS_ERR(new_tree)) {
-		/* We failed to get a tree according to the required
-		 * prefix usage. However, the current tree might be still good
-		 * for us if our requirement is subset of the prefixes used
-		 * in the tree.
-		 */
-		if (mlxsw_sp_prefix_usage_subset(req_prefix_usage,
-						 &lpm_tree->prefix_usage))
-			return 0;
-		return PTR_ERR(new_tree);
-	}
-
-	/* Prevent packet loss by overwriting existing binding */
-	fib->lpm_tree = new_tree;
-	err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, new_tree->id);
-	if (err)
-		goto err_tree_bind;
-	mlxsw_sp_lpm_tree_put(mlxsw_sp, lpm_tree);
-
-	return 0;
-
-err_tree_bind:
-	fib->lpm_tree = lpm_tree;
-	mlxsw_sp_lpm_tree_put(mlxsw_sp, new_tree);
-	return err;
-}
-
 static struct mlxsw_sp_vr *mlxsw_sp_vr_get(struct mlxsw_sp *mlxsw_sp, u32 tb_id)
 {
 	struct mlxsw_sp_vr *vr;
@@ -808,6 +753,100 @@ static void mlxsw_sp_vr_put(struct mlxsw_sp_vr *vr)
 		mlxsw_sp_vr_destroy(vr);
 }
 
+static bool
+mlxsw_sp_vr_lpm_tree_should_replace(struct mlxsw_sp_vr *vr,
+				    enum mlxsw_sp_l3proto proto, u8 tree_id)
+{
+	struct mlxsw_sp_fib *fib = mlxsw_sp_vr_fib(vr, proto);
+
+	if (!mlxsw_sp_vr_is_used(vr))
+		return false;
+	if (fib->lpm_tree && fib->lpm_tree->id == tree_id)
+		return true;
+	return false;
+}
+
+static int mlxsw_sp_vr_lpm_tree_replace(struct mlxsw_sp *mlxsw_sp,
+					struct mlxsw_sp_fib *fib,
+					struct mlxsw_sp_lpm_tree *new_tree)
+{
+	struct mlxsw_sp_lpm_tree *old_tree = fib->lpm_tree;
+	int err;
+
+	err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, new_tree->id);
+	if (err)
+		return err;
+	fib->lpm_tree = new_tree;
+	mlxsw_sp_lpm_tree_hold(new_tree);
+	mlxsw_sp_lpm_tree_put(mlxsw_sp, old_tree);
+	return 0;
+}
+
+static int mlxsw_sp_vrs_lpm_tree_replace(struct mlxsw_sp *mlxsw_sp,
+					 struct mlxsw_sp_fib *fib,
+					 struct mlxsw_sp_lpm_tree *new_tree)
+{
+	struct mlxsw_sp_lpm_tree *old_tree = fib->lpm_tree;
+	enum mlxsw_sp_l3proto proto = fib->proto;
+	u8 old_id, new_id = new_tree->id;
+	struct mlxsw_sp_vr *vr;
+	int i, err;
+
+	if (!old_tree)
+		goto no_replace;
+	old_id = old_tree->id;
+
+	for (i = 0; i < MLXSW_CORE_RES_GET(mlxsw_sp->core, MAX_VRS); i++) {
+		vr = &mlxsw_sp->router->vrs[i];
+		if (!mlxsw_sp_vr_lpm_tree_should_replace(vr, proto, old_id))
+			continue;
+		err = mlxsw_sp_vr_lpm_tree_replace(mlxsw_sp,
+						   mlxsw_sp_vr_fib(vr, proto),
+						   new_tree);
+		if (err)
+			goto err_tree_replace;
+	}
+
+	return 0;
+
+err_tree_replace:
+	for (i--; i >= 0; i--) {
+		if (!mlxsw_sp_vr_lpm_tree_should_replace(vr, proto, new_id))
+			continue;
+		mlxsw_sp_vr_lpm_tree_replace(mlxsw_sp,
+					     mlxsw_sp_vr_fib(vr, proto),
+					     old_tree);
+	}
+	return err;
+
+no_replace:
+	err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, new_tree->id);
+	if (err)
+		return err;
+	fib->lpm_tree = new_tree;
+	mlxsw_sp_lpm_tree_hold(new_tree);
+	return 0;
+}
+
+static void
+mlxsw_sp_vrs_prefixes(struct mlxsw_sp *mlxsw_sp,
+		      enum mlxsw_sp_l3proto proto,
+		      struct mlxsw_sp_prefix_usage *req_prefix_usage)
+{
+	int i;
+
+	for (i = 0; i < MLXSW_CORE_RES_GET(mlxsw_sp->core, MAX_VRS); i++) {
+		struct mlxsw_sp_vr *vr = &mlxsw_sp->router->vrs[i];
+		struct mlxsw_sp_fib *fib = mlxsw_sp_vr_fib(vr, proto);
+		unsigned char prefix;
+
+		if (!mlxsw_sp_vr_is_used(vr))
+			continue;
+		mlxsw_sp_prefix_usage_for_each(prefix, &fib->prefix_usage)
+			mlxsw_sp_prefix_usage_set(req_prefix_usage, prefix);
+	}
+}
+
 static int mlxsw_sp_vrs_init(struct mlxsw_sp *mlxsw_sp)
 {
 	struct mlxsw_sp_vr *vr;
@@ -2586,6 +2625,67 @@ mlxsw_sp_fib_node_entry_is_first(const struct mlxsw_sp_fib_node *fib_node,
 				struct mlxsw_sp_fib_entry, list) == fib_entry;
 }
 
+static int mlxsw_sp_fib_lpm_tree_link(struct mlxsw_sp *mlxsw_sp,
+				      struct mlxsw_sp_fib *fib,
+				      struct mlxsw_sp_fib_node *fib_node)
+{
+	struct mlxsw_sp_prefix_usage req_prefix_usage = {{ 0 } };
+	struct mlxsw_sp_lpm_tree *lpm_tree;
+	int err;
+
+	/* Since the tree is shared between all virtual routers we must
+	 * make sure it contains all the required prefix lengths. This
+	 * can be computed by either adding the new prefix length to the
+	 * existing prefix usage of a bound tree, or by aggregating the
+	 * prefix lengths across all virtual routers and adding the new
+	 * one as well.
+	 */
+	if (fib->lpm_tree)
+		mlxsw_sp_prefix_usage_cpy(&req_prefix_usage,
+					  &fib->lpm_tree->prefix_usage);
+	else
+		mlxsw_sp_vrs_prefixes(mlxsw_sp, fib->proto, &req_prefix_usage);
+	mlxsw_sp_prefix_usage_set(&req_prefix_usage, fib_node->key.prefix_len);
+
+	lpm_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, &req_prefix_usage,
+					 fib->proto);
+	if (IS_ERR(lpm_tree))
+		return PTR_ERR(lpm_tree);
+
+	if (fib->lpm_tree && fib->lpm_tree->id == lpm_tree->id)
+		return 0;
+
+	err = mlxsw_sp_vrs_lpm_tree_replace(mlxsw_sp, fib, lpm_tree);
+	if (err)
+		return err;
+
+	return 0;
+}
+
+static void mlxsw_sp_fib_lpm_tree_unlink(struct mlxsw_sp *mlxsw_sp,
+					 struct mlxsw_sp_fib *fib)
+{
+	struct mlxsw_sp_prefix_usage req_prefix_usage = {{ 0 } };
+	struct mlxsw_sp_lpm_tree *lpm_tree;
+
+	/* Aggregate prefix lengths across all virtual routers to make
+	 * sure we only have used prefix lengths in the LPM tree.
+	 */
+	mlxsw_sp_vrs_prefixes(mlxsw_sp, fib->proto, &req_prefix_usage);
+	lpm_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, &req_prefix_usage,
+					 fib->proto);
+	if (IS_ERR(lpm_tree))
+		goto err_tree_get;
+	mlxsw_sp_vrs_lpm_tree_replace(mlxsw_sp, fib, lpm_tree);
+
+err_tree_get:
+	if (!mlxsw_sp_prefix_usage_none(&fib->prefix_usage))
+		return;
+	mlxsw_sp_vr_lpm_tree_unbind(mlxsw_sp, fib);
+	mlxsw_sp_lpm_tree_put(mlxsw_sp, fib->lpm_tree);
+	fib->lpm_tree = NULL;
+}
+
 static void mlxsw_sp_fib_node_prefix_inc(struct mlxsw_sp_fib_node *fib_node)
 {
 	unsigned char prefix_len = fib_node->key.prefix_len;
@@ -2608,8 +2708,6 @@ static int mlxsw_sp_fib_node_init(struct mlxsw_sp *mlxsw_sp,
 				  struct mlxsw_sp_fib_node *fib_node,
 				  struct mlxsw_sp_fib *fib)
 {
-	struct mlxsw_sp_prefix_usage req_prefix_usage;
-	struct mlxsw_sp_lpm_tree *lpm_tree;
 	int err;
 
 	err = mlxsw_sp_fib_node_insert(fib, fib_node);
@@ -2617,33 +2715,15 @@ static int mlxsw_sp_fib_node_init(struct mlxsw_sp *mlxsw_sp,
 		return err;
 	fib_node->fib = fib;
 
-	mlxsw_sp_prefix_usage_cpy(&req_prefix_usage, &fib->prefix_usage);
-	mlxsw_sp_prefix_usage_set(&req_prefix_usage, fib_node->key.prefix_len);
-
-	if (!mlxsw_sp_prefix_usage_none(&fib->prefix_usage)) {
-		err = mlxsw_sp_vr_lpm_tree_check(mlxsw_sp, fib,
-						 &req_prefix_usage);
-		if (err)
-			goto err_tree_check;
-	} else {
-		lpm_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, &req_prefix_usage,
-						 fib->proto);
-		if (IS_ERR(lpm_tree))
-			return PTR_ERR(lpm_tree);
-		fib->lpm_tree = lpm_tree;
-		err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, lpm_tree->id);
-		if (err)
-			goto err_tree_bind;
-	}
+	err = mlxsw_sp_fib_lpm_tree_link(mlxsw_sp, fib, fib_node);
+	if (err)
+		goto err_fib_lpm_tree_link;
 
 	mlxsw_sp_fib_node_prefix_inc(fib_node);
 
 	return 0;
 
-err_tree_bind:
-	fib->lpm_tree = NULL;
-	mlxsw_sp_lpm_tree_put(mlxsw_sp, lpm_tree);
-err_tree_check:
+err_fib_lpm_tree_link:
 	fib_node->fib = NULL;
 	mlxsw_sp_fib_node_remove(fib, fib_node);
 	return err;
@@ -2652,19 +2732,10 @@ static int mlxsw_sp_fib_node_init(struct mlxsw_sp *mlxsw_sp,
 static void mlxsw_sp_fib_node_fini(struct mlxsw_sp *mlxsw_sp,
 				   struct mlxsw_sp_fib_node *fib_node)
 {
-	struct mlxsw_sp_lpm_tree *lpm_tree = fib_node->fib->lpm_tree;
 	struct mlxsw_sp_fib *fib = fib_node->fib;
 
 	mlxsw_sp_fib_node_prefix_dec(fib_node);
-
-	if (mlxsw_sp_prefix_usage_none(&fib->prefix_usage)) {
-		mlxsw_sp_vr_lpm_tree_unbind(mlxsw_sp, fib);
-		fib->lpm_tree = NULL;
-		mlxsw_sp_lpm_tree_put(mlxsw_sp, lpm_tree);
-	} else {
-		mlxsw_sp_vr_lpm_tree_check(mlxsw_sp, fib, &fib->prefix_usage);
-	}
-
+	mlxsw_sp_fib_lpm_tree_unlink(mlxsw_sp, fib);
 	fib_node->fib = NULL;
 	mlxsw_sp_fib_node_remove(fib, fib_node);
 }