diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c index bd0732d5d219e..6957608788981 100644 --- a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c +++ b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c @@ -513,7 +513,13 @@ static int mlx5e_create_rq(struct mlx5e_channel *c, rq->channel = c; rq->ix = c->ix; rq->priv = c->priv; - rq->xdp_prog = priv->xdp_prog; + + rq->xdp_prog = priv->xdp_prog ? bpf_prog_inc(priv->xdp_prog) : NULL; + if (IS_ERR(rq->xdp_prog)) { + err = PTR_ERR(rq->xdp_prog); + rq->xdp_prog = NULL; + goto err_rq_wq_destroy; + } rq->buff.map_dir = DMA_FROM_DEVICE; if (rq->xdp_prog) @@ -590,12 +596,11 @@ static int mlx5e_create_rq(struct mlx5e_channel *c, rq->page_cache.head = 0; rq->page_cache.tail = 0; - if (rq->xdp_prog) - bpf_prog_add(rq->xdp_prog, 1); - return 0; err_rq_wq_destroy: + if (rq->xdp_prog) + bpf_prog_put(rq->xdp_prog); mlx5_wq_destroy(&rq->wq_ctrl); return err; @@ -3139,11 +3144,21 @@ static int mlx5e_xdp_set(struct net_device *netdev, struct bpf_prog *prog) if (was_opened && reset) mlx5e_close_locked(netdev); + if (was_opened && !reset) { + /* num_channels is invariant here, so we can take the + * batched reference right upfront. + */ + prog = bpf_prog_add(prog, priv->params.num_channels); + if (IS_ERR(prog)) { + err = PTR_ERR(prog); + goto unlock; + } + } - /* exchange programs */ + /* exchange programs, extra prog reference we got from caller + * as long as we don't fail from this point onwards. + */ old_prog = xchg(&priv->xdp_prog, prog); - if (prog) - bpf_prog_add(prog, 1); if (old_prog) bpf_prog_put(old_prog); @@ -3159,7 +3174,6 @@ static int mlx5e_xdp_set(struct net_device *netdev, struct bpf_prog *prog) /* exchanging programs w/o reset, we update ref counts on behalf * of the channels RQs here. */ - bpf_prog_add(prog, priv->params.num_channels); for (i = 0; i < priv->params.num_channels; i++) { struct mlx5e_channel *c = priv->channel[i]; @@ -3691,6 +3705,9 @@ static void mlx5e_nic_cleanup(struct mlx5e_priv *priv) if (MLX5_CAP_GEN(mdev, vport_group_manager)) mlx5_eswitch_unregister_vport_rep(esw, 0); + + if (priv->xdp_prog) + bpf_prog_put(priv->xdp_prog); } static int mlx5e_init_nic_rx(struct mlx5e_priv *priv) diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 01c1487277b28..69d0a7f12a3bd 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -233,14 +233,14 @@ void bpf_register_map_type(struct bpf_map_type_list *tl); struct bpf_prog *bpf_prog_get(u32 ufd); struct bpf_prog *bpf_prog_get_type(u32 ufd, enum bpf_prog_type type); -struct bpf_prog *bpf_prog_add(struct bpf_prog *prog, int i); +struct bpf_prog * __must_check bpf_prog_add(struct bpf_prog *prog, int i); void bpf_prog_sub(struct bpf_prog *prog, int i); -struct bpf_prog *bpf_prog_inc(struct bpf_prog *prog); +struct bpf_prog * __must_check bpf_prog_inc(struct bpf_prog *prog); void bpf_prog_put(struct bpf_prog *prog); struct bpf_map *bpf_map_get_with_uref(u32 ufd); struct bpf_map *__bpf_map_get(struct fd f); -struct bpf_map *bpf_map_inc(struct bpf_map *map, bool uref); +struct bpf_map * __must_check bpf_map_inc(struct bpf_map *map, bool uref); void bpf_map_put_with_uref(struct bpf_map *map); void bpf_map_put(struct bpf_map *map); int bpf_map_precharge_memlock(u32 pages); @@ -299,7 +299,8 @@ static inline struct bpf_prog *bpf_prog_get_type(u32 ufd, { return ERR_PTR(-EOPNOTSUPP); } -static inline struct bpf_prog *bpf_prog_add(struct bpf_prog *prog, int i) +static inline struct bpf_prog * __must_check bpf_prog_add(struct bpf_prog *prog, + int i) { return ERR_PTR(-EOPNOTSUPP); } @@ -311,7 +312,8 @@ static inline void bpf_prog_sub(struct bpf_prog *prog, int i) static inline void bpf_prog_put(struct bpf_prog *prog) { } -static inline struct bpf_prog *bpf_prog_inc(struct bpf_prog *prog) + +static inline struct bpf_prog * __must_check bpf_prog_inc(struct bpf_prog *prog) { return ERR_PTR(-EOPNOTSUPP); } diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c index ce1b7de7d72c6..eb15498b8d55c 100644 --- a/kernel/bpf/syscall.c +++ b/kernel/bpf/syscall.c @@ -696,6 +696,7 @@ struct bpf_prog *bpf_prog_inc(struct bpf_prog *prog) { return bpf_prog_add(prog, 1); } +EXPORT_SYMBOL_GPL(bpf_prog_inc); static struct bpf_prog *__bpf_prog_get(u32 ufd, enum bpf_prog_type *type) {