diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 86f9cd5d1c4e1..868a82ad55978 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3147,19 +3147,7 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 	} else if (fn->ret_type == RET_PTR_TO_SOCKET_OR_NULL) {
 		mark_reg_known_zero(env, regs, BPF_REG_0);
 		regs[BPF_REG_0].type = PTR_TO_SOCKET_OR_NULL;
-		if (is_acquire_function(func_id)) {
-			int id = acquire_reference_state(env, insn_idx);
-
-			if (id < 0)
-				return id;
-			/* For mark_ptr_or_null_reg() */
-			regs[BPF_REG_0].id = id;
-			/* For release_reference() */
-			regs[BPF_REG_0].ref_obj_id = id;
-		} else {
-			/* For mark_ptr_or_null_reg() */
-			regs[BPF_REG_0].id = ++env->id_gen;
-		}
+		regs[BPF_REG_0].id = ++env->id_gen;
 	} else if (fn->ret_type == RET_PTR_TO_TCP_SOCK_OR_NULL) {
 		mark_reg_known_zero(env, regs, BPF_REG_0);
 		regs[BPF_REG_0].type = PTR_TO_TCP_SOCK_OR_NULL;
@@ -3170,9 +3158,19 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 		return -EINVAL;
 	}
 
-	if (is_ptr_cast_function(func_id))
+	if (is_ptr_cast_function(func_id)) {
 		/* For release_reference() */
 		regs[BPF_REG_0].ref_obj_id = meta.ref_obj_id;
+	} else if (is_acquire_function(func_id)) {
+		int id = acquire_reference_state(env, insn_idx);
+
+		if (id < 0)
+			return id;
+		/* For mark_ptr_or_null_reg() */
+		regs[BPF_REG_0].id = id;
+		/* For release_reference() */
+		regs[BPF_REG_0].ref_obj_id = id;
+	}
 
 	do_refine_retval_range(regs, fn->ret_type, func_id, &meta);