diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 29756d88799b6..b2240361f1a14 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -396,13 +396,10 @@ static inline unsigned long busy_clock(void)
 	return local_clock() >> 10;
 }
 
-static bool vhost_can_busy_poll(struct vhost_dev *dev,
-				unsigned long endtime)
+static bool vhost_can_busy_poll(unsigned long endtime)
 {
-	return likely(!need_resched()) &&
-	       likely(!time_after(busy_clock(), endtime)) &&
-	       likely(!signal_pending(current)) &&
-	       !vhost_has_work(dev);
+	return likely(!need_resched() && !time_after(busy_clock(), endtime) &&
+		      !signal_pending(current));
 }
 
 static void vhost_net_disable_vq(struct vhost_net *n,
@@ -434,7 +431,8 @@ static int vhost_net_enable_vq(struct vhost_net *n,
 static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
 				    struct vhost_virtqueue *vq,
 				    struct iovec iov[], unsigned int iov_size,
-				    unsigned int *out_num, unsigned int *in_num)
+				    unsigned int *out_num, unsigned int *in_num,
+				    bool *busyloop_intr)
 {
 	unsigned long uninitialized_var(endtime);
 	int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
@@ -443,9 +441,15 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
 	if (r == vq->num && vq->busyloop_timeout) {
 		preempt_disable();
 		endtime = busy_clock() + vq->busyloop_timeout;
-		while (vhost_can_busy_poll(vq->dev, endtime) &&
-		       vhost_vq_avail_empty(vq->dev, vq))
+		while (vhost_can_busy_poll(endtime)) {
+			if (vhost_has_work(vq->dev)) {
+				*busyloop_intr = true;
+				break;
+			}
+			if (!vhost_vq_avail_empty(vq->dev, vq))
+				break;
 			cpu_relax();
+		}
 		preempt_enable();
 		r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 				      out_num, in_num, NULL, NULL);
@@ -501,20 +505,24 @@ static void handle_tx(struct vhost_net *net)
 	zcopy = nvq->ubufs;
 
 	for (;;) {
+		bool busyloop_intr;
+
 		/* Release DMAs done buffers first */
 		if (zcopy)
 			vhost_zerocopy_signal_used(net, vq);
 
-
+		busyloop_intr = false;
 		head = vhost_net_tx_get_vq_desc(net, vq, vq->iov,
 						ARRAY_SIZE(vq->iov),
-						&out, &in);
+						&out, &in, &busyloop_intr);
 		/* On error, stop handling until the next kick. */
 		if (unlikely(head < 0))
 			break;
 		/* Nothing new?  Wait for eventfd to tell us they refilled. */
 		if (head == vq->num) {
-			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+			if (unlikely(busyloop_intr)) {
+				vhost_poll_queue(&vq->poll);
+			} else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 				vhost_disable_notify(&net->dev, vq);
 				continue;
 			}
@@ -645,41 +653,50 @@ static void vhost_rx_signal_used(struct vhost_net_virtqueue *nvq)
 	nvq->done_idx = 0;
 }
 
-static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
+static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk,
+				      bool *busyloop_intr)
 {
-	struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX];
-	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
-	struct vhost_virtqueue *vq = &nvq->vq;
+	struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
+	struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX];
+	struct vhost_virtqueue *rvq = &rnvq->vq;
+	struct vhost_virtqueue *tvq = &tnvq->vq;
 	unsigned long uninitialized_var(endtime);
-	int len = peek_head_len(rvq, sk);
+	int len = peek_head_len(rnvq, sk);
 
-	if (!len && vq->busyloop_timeout) {
+	if (!len && tvq->busyloop_timeout) {
 		/* Flush batched heads first */
-		vhost_rx_signal_used(rvq);
+		vhost_rx_signal_used(rnvq);
 		/* Both tx vq and rx socket were polled here */
-		mutex_lock_nested(&vq->mutex, 1);
-		vhost_disable_notify(&net->dev, vq);
+		mutex_lock_nested(&tvq->mutex, 1);
+		vhost_disable_notify(&net->dev, tvq);
 
 		preempt_disable();
-		endtime = busy_clock() + vq->busyloop_timeout;
+		endtime = busy_clock() + tvq->busyloop_timeout;
 
-		while (vhost_can_busy_poll(&net->dev, endtime) &&
-		       !sk_has_rx_data(sk) &&
-		       vhost_vq_avail_empty(&net->dev, vq))
+		while (vhost_can_busy_poll(endtime)) {
+			if (vhost_has_work(&net->dev)) {
+				*busyloop_intr = true;
+				break;
+			}
+			if ((sk_has_rx_data(sk) &&
+			     !vhost_vq_avail_empty(&net->dev, rvq)) ||
+			    !vhost_vq_avail_empty(&net->dev, tvq))
+				break;
 			cpu_relax();
+		}
 
 		preempt_enable();
 
-		if (!vhost_vq_avail_empty(&net->dev, vq))
-			vhost_poll_queue(&vq->poll);
-		else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
-			vhost_disable_notify(&net->dev, vq);
-			vhost_poll_queue(&vq->poll);
+		if (!vhost_vq_avail_empty(&net->dev, tvq)) {
+			vhost_poll_queue(&tvq->poll);
+		} else if (unlikely(vhost_enable_notify(&net->dev, tvq))) {
+			vhost_disable_notify(&net->dev, tvq);
+			vhost_poll_queue(&tvq->poll);
 		}
 
-		mutex_unlock(&vq->mutex);
+		mutex_unlock(&tvq->mutex);
 
-		len = peek_head_len(rvq, sk);
+		len = peek_head_len(rnvq, sk);
 	}
 
 	return len;
@@ -786,6 +803,7 @@ static void handle_rx(struct vhost_net *net)
 	s16 headcount;
 	size_t vhost_hlen, sock_hlen;
 	size_t vhost_len, sock_len;
+	bool busyloop_intr = false;
 	struct socket *sock;
 	struct iov_iter fixup;
 	__virtio16 num_buffers;
@@ -809,7 +827,8 @@ static void handle_rx(struct vhost_net *net)
 		vq->log : NULL;
 	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
 
-	while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
+	while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
+						      &busyloop_intr))) {
 		sock_len += sock_hlen;
 		vhost_len = sock_len + vhost_hlen;
 		headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
@@ -820,7 +839,9 @@ static void handle_rx(struct vhost_net *net)
 			goto out;
 		/* OK, now we need to know about added descriptors. */
 		if (!headcount) {
-			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+			if (unlikely(busyloop_intr)) {
+				vhost_poll_queue(&vq->poll);
+			} else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 				/* They have slipped one in as we were
 				 * doing that: check again. */
 				vhost_disable_notify(&net->dev, vq);
@@ -830,6 +851,7 @@ static void handle_rx(struct vhost_net *net)
 			 * they refilled. */
 			goto out;
 		}
+		busyloop_intr = false;
 		if (nvq->rx_ring)
 			msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
 		/* On overrun, truncate and discard */
@@ -896,7 +918,10 @@ static void handle_rx(struct vhost_net *net)
 			goto out;
 		}
 	}
-	vhost_net_enable_vq(net, vq);
+	if (unlikely(busyloop_intr))
+		vhost_poll_queue(&vq->poll);
+	else
+		vhost_net_enable_vq(net, vq);
 out:
 	vhost_rx_signal_used(nvq);
 	mutex_unlock(&vq->mutex);