diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index bfba91ecbd0a5..c71d573f1c949 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -233,16 +233,9 @@ void vhost_poll_stop(struct vhost_poll *poll)
 }
 EXPORT_SYMBOL_GPL(vhost_poll_stop);
 
-static bool vhost_worker_queue(struct vhost_worker *worker,
+static void vhost_worker_queue(struct vhost_worker *worker,
 			       struct vhost_work *work)
 {
-	if (!worker)
-		return false;
-	/*
-	 * vsock can queue while we do a VHOST_SET_OWNER, so we have a smp_wmb
-	 * when setting up the worker. We don't have a smp_rmb here because
-	 * test_and_set_bit gives us a mb already.
-	 */
 	if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
 		/* We can only add the work to the list after we're
 		 * sure it was not in the list.
@@ -251,47 +244,85 @@ static bool vhost_worker_queue(struct vhost_worker *worker,
 		llist_add(&work->node, &worker->work_list);
 		vhost_task_wake(worker->vtsk);
 	}
-
-	return true;
 }
 
 bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
 {
-	return vhost_worker_queue(vq->worker, work);
+	struct vhost_worker *worker;
+	bool queued = false;
+
+	rcu_read_lock();
+	worker = rcu_dereference(vq->worker);
+	if (worker) {
+		queued = true;
+		vhost_worker_queue(worker, work);
+	}
+	rcu_read_unlock();
+
+	return queued;
 }
 EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
 
-static void vhost_worker_flush(struct vhost_worker *worker)
+void vhost_vq_flush(struct vhost_virtqueue *vq)
 {
 	struct vhost_flush_struct flush;
 
 	init_completion(&flush.wait_event);
 	vhost_work_init(&flush.work, vhost_flush_work);
 
-	if (vhost_worker_queue(worker, &flush.work))
+	if (vhost_vq_work_queue(vq, &flush.work))
 		wait_for_completion(&flush.wait_event);
 }
+EXPORT_SYMBOL_GPL(vhost_vq_flush);
 
-void vhost_vq_flush(struct vhost_virtqueue *vq)
+/**
+ * vhost_worker_flush - flush a worker
+ * @worker: worker to flush
+ *
+ * This does not use RCU to protect the worker, so the device or worker
+ * mutex must be held.
+ */
+static void vhost_worker_flush(struct vhost_worker *worker)
 {
-	vhost_worker_flush(vq->worker);
+	struct vhost_flush_struct flush;
+
+	init_completion(&flush.wait_event);
+	vhost_work_init(&flush.work, vhost_flush_work);
+
+	vhost_worker_queue(worker, &flush.work);
+	wait_for_completion(&flush.wait_event);
 }
-EXPORT_SYMBOL_GPL(vhost_vq_flush);
 
 void vhost_dev_flush(struct vhost_dev *dev)
 {
 	struct vhost_worker *worker;
 	unsigned long i;
 
-	xa_for_each(&dev->worker_xa, i, worker)
+	xa_for_each(&dev->worker_xa, i, worker) {
+		mutex_lock(&worker->mutex);
+		if (!worker->attachment_cnt) {
+			mutex_unlock(&worker->mutex);
+			continue;
+		}
 		vhost_worker_flush(worker);
+		mutex_unlock(&worker->mutex);
+	}
 }
 EXPORT_SYMBOL_GPL(vhost_dev_flush);
 
 /* A lockless hint for busy polling code to exit the loop */
 bool vhost_vq_has_work(struct vhost_virtqueue *vq)
 {
-	return !llist_empty(&vq->worker->work_list);
+	struct vhost_worker *worker;
+	bool has_work = false;
+
+	rcu_read_lock();
+	worker = rcu_dereference(vq->worker);
+	if (worker && !llist_empty(&worker->work_list))
+		has_work = true;
+	rcu_read_unlock();
+
+	return has_work;
 }
 EXPORT_SYMBOL_GPL(vhost_vq_has_work);
 
@@ -356,7 +387,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
 	vq->busyloop_timeout = 0;
 	vq->umem = NULL;
 	vq->iotlb = NULL;
-	vq->worker = NULL;
+	rcu_assign_pointer(vq->worker, NULL);
 	vhost_vring_call_reset(&vq->call_ctx);
 	__vhost_vq_meta_reset(vq);
 }
@@ -578,7 +609,7 @@ static void vhost_workers_free(struct vhost_dev *dev)
 		return;
 
 	for (i = 0; i < dev->nvqs; i++)
-		dev->vqs[i]->worker = NULL;
+		rcu_assign_pointer(dev->vqs[i]->worker, NULL);
 	/*
 	 * Free the default worker we created and cleanup workers userspace
 	 * created but couldn't clean up (it forgot or crashed).
@@ -606,6 +637,7 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
 	if (!vtsk)
 		goto free_worker;
 
+	mutex_init(&worker->mutex);
 	init_llist_head(&worker->work_list);
 	worker->kcov_handle = kcov_common_handle();
 	worker->vtsk = vtsk;
@@ -630,13 +662,54 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
 static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
 				     struct vhost_worker *worker)
 {
-	if (vq->worker)
-		vq->worker->attachment_cnt--;
+	struct vhost_worker *old_worker;
+
+	old_worker = rcu_dereference_check(vq->worker,
+					   lockdep_is_held(&vq->dev->mutex));
+
+	mutex_lock(&worker->mutex);
 	worker->attachment_cnt++;
-	vq->worker = worker;
+	mutex_unlock(&worker->mutex);
+	rcu_assign_pointer(vq->worker, worker);
+
+	if (!old_worker)
+		return;
+	/*
+	 * Take the worker mutex to make sure we see the work queued from
+	 * device wide flushes which doesn't use RCU for execution.
+	 */
+	mutex_lock(&old_worker->mutex);
+	old_worker->attachment_cnt--;
+	/*
+	 * We don't want to call synchronize_rcu for every vq during setup
+	 * because it will slow down VM startup. If we haven't done
+	 * VHOST_SET_VRING_KICK and not done the driver specific
+	 * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will
+	 * not be any works queued for scsi and net.
+	 */
+	mutex_lock(&vq->mutex);
+	if (!vhost_vq_get_backend(vq) && !vq->kick) {
+		mutex_unlock(&vq->mutex);
+		mutex_unlock(&old_worker->mutex);
+		/*
+		 * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
+		 * Warn if it adds support for multiple workers but forgets to
+		 * handle the early queueing case.
+		 */
+		WARN_ON(!old_worker->attachment_cnt &&
+			!llist_empty(&old_worker->work_list));
+		return;
+	}
+	mutex_unlock(&vq->mutex);
+
+	/* Make sure new vq queue/flush/poll calls see the new worker */
+	synchronize_rcu();
+	/* Make sure whatever was queued gets run */
+	vhost_worker_flush(old_worker);
+	mutex_unlock(&old_worker->mutex);
 }
 
- /* Caller must have device and virtqueue mutex */
+ /* Caller must have device mutex */
 static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
 				  struct vhost_vring_worker *info)
 {
@@ -647,15 +720,6 @@ static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
 	if (!dev->use_worker)
 		return -EINVAL;
 
-	/*
-	 * We only allow userspace to set a virtqueue's worker if it's not
-	 * active and polling is not enabled. We also assume drivers
-	 * supporting this will not be internally queueing works directly or
-	 * via calls like vhost_dev_flush at this time.
-	 */
-	if (vhost_vq_get_backend(vq) || vq->kick)
-		return -EBUSY;
-
 	worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
 	if (!worker || worker->id != info->worker_id)
 		return -ENODEV;
@@ -689,8 +753,12 @@ static int vhost_free_worker(struct vhost_dev *dev,
 	if (!worker || worker->id != info->worker_id)
 		return -ENODEV;
 
-	if (worker->attachment_cnt)
+	mutex_lock(&worker->mutex);
+	if (worker->attachment_cnt) {
+		mutex_unlock(&worker->mutex);
 		return -EBUSY;
+	}
+	mutex_unlock(&worker->mutex);
 
 	vhost_worker_destroy(dev, worker);
 	return 0;
@@ -723,6 +791,7 @@ long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
 {
 	struct vhost_vring_worker ring_worker;
 	struct vhost_worker_state state;
+	struct vhost_worker *worker;
 	struct vhost_virtqueue *vq;
 	long ret;
 	u32 idx;
@@ -760,7 +829,6 @@ long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
 	if (ret)
 		return ret;
 
-	mutex_lock(&vq->mutex);
 	switch (ioctl) {
 	case VHOST_ATTACH_VRING_WORKER:
 		if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
@@ -771,8 +839,15 @@ long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
 		ret = vhost_vq_attach_worker(vq, &ring_worker);
 		break;
 	case VHOST_GET_VRING_WORKER:
+		worker = rcu_dereference_check(vq->worker,
+					       lockdep_is_held(&dev->mutex));
+		if (!worker) {
+			ret = -EINVAL;
+			break;
+		}
+
 		ring_worker.index = idx;
-		ring_worker.worker_id = vq->worker->id;
+		ring_worker.worker_id = worker->id;
 
 		if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
 			ret = -EFAULT;
@@ -782,7 +857,6 @@ long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
 		break;
 	}
 
-	mutex_unlock(&vq->mutex);
 	return ret;
 }
 EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
@@ -817,11 +891,6 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
 			err = -ENOMEM;
 			goto err_worker;
 		}
-		/*
-		 * vsock can already try to queue so make sure the worker
-		 * is setup before vhost_vq_work_queue sees vq->worker is set.
-		 */
-		smp_wmb();
 
 		for (i = 0; i < dev->nvqs; i++)
 			__vhost_vq_attach_worker(dev->vqs[i], worker);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 4920ca63b8de2..f1e7d4d132190 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -28,6 +28,8 @@ struct vhost_work {
 
 struct vhost_worker {
 	struct vhost_task	*vtsk;
+	/* Used to serialize device wide flushing with worker swapping. */
+	struct mutex		mutex;
 	struct llist_head	work_list;
 	u64			kcov_handle;
 	u32			id;
@@ -76,7 +78,7 @@ struct vhost_vring_call {
 /* The virtqueue structure describes a queue attached to a device. */
 struct vhost_virtqueue {
 	struct vhost_dev *dev;
-	struct vhost_worker *worker;
+	struct vhost_worker __rcu *worker;
 
 	/* The actual ring of buffers. */
 	struct mutex mutex;
diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h
index 96dc146c2d15c..f5c48b61ab622 100644
--- a/include/uapi/linux/vhost.h
+++ b/include/uapi/linux/vhost.h
@@ -90,9 +90,7 @@
 #define VHOST_SET_VRING_ENDIAN _IOW(VHOST_VIRTIO, 0x13, struct vhost_vring_state)
 #define VHOST_GET_VRING_ENDIAN _IOW(VHOST_VIRTIO, 0x14, struct vhost_vring_state)
 /* Attach a vhost_worker created with VHOST_NEW_WORKER to one of the device's
- * virtqueues. This must be done before VHOST_SET_VRING_KICK and the driver
- * specific ioctl to activate the virtqueue (VHOST_SCSI_SET_ENDPOINT,
- * VHOST_NET_SET_BACKEND, VHOST_VSOCK_SET_RUNNING) has been run.
+ * virtqueues.
  *
  * This will replace the virtqueue's existing worker. If the replaced worker
  * is no longer attached to any virtqueues, it can be freed with