diff --git a/drivers/vfio/pci/mlx5/cmd.c b/drivers/vfio/pci/mlx5/cmd.c
index f6293da033cca..993749818d90f 100644
--- a/drivers/vfio/pci/mlx5/cmd.c
+++ b/drivers/vfio/pci/mlx5/cmd.c
@@ -598,9 +598,11 @@ int mlx5vf_cmd_load_vhca_state(struct mlx5vf_pci_core_device *mvdev,
 	if (mvdev->mdev_detach)
 		return -ENOTCONN;
 
-	err = mlx5vf_dma_data_buffer(buf);
-	if (err)
-		return err;
+	if (!buf->dmaed) {
+		err = mlx5vf_dma_data_buffer(buf);
+		if (err)
+			return err;
+	}
 
 	MLX5_SET(load_vhca_state_in, in, opcode,
 		 MLX5_CMD_OP_LOAD_VHCA_STATE);
@@ -644,6 +646,11 @@ void mlx5fv_cmd_clean_migf_resources(struct mlx5_vf_migration_file *migf)
 		migf->buf = NULL;
 	}
 
+	if (migf->buf_header) {
+		mlx5vf_free_data_buffer(migf->buf_header);
+		migf->buf_header = NULL;
+	}
+
 	list_splice(&migf->avail_list, &migf->buf_list);
 
 	while ((entry = list_first_entry_or_null(&migf->buf_list,
diff --git a/drivers/vfio/pci/mlx5/cmd.h b/drivers/vfio/pci/mlx5/cmd.h
index d048f23977dd1..7729eac8c78c1 100644
--- a/drivers/vfio/pci/mlx5/cmd.h
+++ b/drivers/vfio/pci/mlx5/cmd.h
@@ -22,6 +22,14 @@ enum mlx5_vf_migf_state {
 	MLX5_MIGF_STATE_COMPLETE,
 };
 
+enum mlx5_vf_load_state {
+	MLX5_VF_LOAD_STATE_READ_IMAGE_NO_HEADER,
+	MLX5_VF_LOAD_STATE_READ_HEADER,
+	MLX5_VF_LOAD_STATE_PREP_IMAGE,
+	MLX5_VF_LOAD_STATE_READ_IMAGE,
+	MLX5_VF_LOAD_STATE_LOAD_IMAGE,
+};
+
 struct mlx5_vf_migration_header {
 	__le64 image_size;
 	/* For future use in case we may need to change the kernel protocol */
@@ -60,9 +68,11 @@ struct mlx5_vf_migration_file {
 	struct mutex lock;
 	enum mlx5_vf_migf_state state;
 
+	enum mlx5_vf_load_state load_state;
 	u32 pdn;
 	loff_t max_pos;
 	struct mlx5_vhca_data_buffer *buf;
+	struct mlx5_vhca_data_buffer *buf_header;
 	spinlock_t list_lock;
 	struct list_head buf_list;
 	struct list_head avail_list;
diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
index 44b1543c751c9..5a669b73994a7 100644
--- a/drivers/vfio/pci/mlx5/main.c
+++ b/drivers/vfio/pci/mlx5/main.c
@@ -518,13 +518,162 @@ mlx5vf_pci_save_device_data(struct mlx5vf_pci_core_device *mvdev, bool track)
 	return ERR_PTR(ret);
 }
 
+static int
+mlx5vf_append_page_to_mig_buf(struct mlx5_vhca_data_buffer *vhca_buf,
+			      const char __user **buf, size_t *len,
+			      loff_t *pos, ssize_t *done)
+{
+	unsigned long offset;
+	size_t page_offset;
+	struct page *page;
+	size_t page_len;
+	u8 *to_buff;
+	int ret;
+
+	offset = *pos - vhca_buf->start_pos;
+	page_offset = offset % PAGE_SIZE;
+
+	page = mlx5vf_get_migration_page(vhca_buf, offset - page_offset);
+	if (!page)
+		return -EINVAL;
+	page_len = min_t(size_t, *len, PAGE_SIZE - page_offset);
+	to_buff = kmap_local_page(page);
+	ret = copy_from_user(to_buff + page_offset, *buf, page_len);
+	kunmap_local(to_buff);
+	if (ret)
+		return -EFAULT;
+
+	*pos += page_len;
+	*done += page_len;
+	*buf += page_len;
+	*len -= page_len;
+	vhca_buf->length += page_len;
+	return 0;
+}
+
+static int
+mlx5vf_resume_read_image_no_header(struct mlx5_vhca_data_buffer *vhca_buf,
+				   loff_t requested_length,
+				   const char __user **buf, size_t *len,
+				   loff_t *pos, ssize_t *done)
+{
+	int ret;
+
+	if (requested_length > MAX_MIGRATION_SIZE)
+		return -ENOMEM;
+
+	if (vhca_buf->allocated_length < requested_length) {
+		ret = mlx5vf_add_migration_pages(
+			vhca_buf,
+			DIV_ROUND_UP(requested_length - vhca_buf->allocated_length,
+				     PAGE_SIZE));
+		if (ret)
+			return ret;
+	}
+
+	while (*len) {
+		ret = mlx5vf_append_page_to_mig_buf(vhca_buf, buf, len, pos,
+						    done);
+		if (ret)
+			return ret;
+	}
+
+	return 0;
+}
+
+static ssize_t
+mlx5vf_resume_read_image(struct mlx5_vf_migration_file *migf,
+			 struct mlx5_vhca_data_buffer *vhca_buf,
+			 size_t image_size, const char __user **buf,
+			 size_t *len, loff_t *pos, ssize_t *done,
+			 bool *has_work)
+{
+	size_t copy_len, to_copy;
+	int ret;
+
+	to_copy = min_t(size_t, *len, image_size - vhca_buf->length);
+	copy_len = to_copy;
+	while (to_copy) {
+		ret = mlx5vf_append_page_to_mig_buf(vhca_buf, buf, &to_copy, pos,
+						    done);
+		if (ret)
+			return ret;
+	}
+
+	*len -= copy_len;
+	if (vhca_buf->length == image_size) {
+		migf->load_state = MLX5_VF_LOAD_STATE_LOAD_IMAGE;
+		migf->max_pos += image_size;
+		*has_work = true;
+	}
+
+	return 0;
+}
+
+static int
+mlx5vf_resume_read_header(struct mlx5_vf_migration_file *migf,
+			  struct mlx5_vhca_data_buffer *vhca_buf,
+			  const char __user **buf,
+			  size_t *len, loff_t *pos,
+			  ssize_t *done, bool *has_work)
+{
+	struct page *page;
+	size_t copy_len;
+	u8 *to_buff;
+	int ret;
+
+	copy_len = min_t(size_t, *len,
+		sizeof(struct mlx5_vf_migration_header) - vhca_buf->length);
+	page = mlx5vf_get_migration_page(vhca_buf, 0);
+	if (!page)
+		return -EINVAL;
+	to_buff = kmap_local_page(page);
+	ret = copy_from_user(to_buff + vhca_buf->length, *buf, copy_len);
+	if (ret) {
+		ret = -EFAULT;
+		goto end;
+	}
+
+	*buf += copy_len;
+	*pos += copy_len;
+	*done += copy_len;
+	*len -= copy_len;
+	vhca_buf->length += copy_len;
+	if (vhca_buf->length == sizeof(struct mlx5_vf_migration_header)) {
+		u64 flags;
+
+		vhca_buf->header_image_size = le64_to_cpup((__le64 *)to_buff);
+		if (vhca_buf->header_image_size > MAX_MIGRATION_SIZE) {
+			ret = -ENOMEM;
+			goto end;
+		}
+
+		flags = le64_to_cpup((__le64 *)(to_buff +
+			    offsetof(struct mlx5_vf_migration_header, flags)));
+		if (flags) {
+			ret = -EOPNOTSUPP;
+			goto end;
+		}
+
+		migf->load_state = MLX5_VF_LOAD_STATE_PREP_IMAGE;
+		migf->max_pos += vhca_buf->length;
+		*has_work = true;
+	}
+end:
+	kunmap_local(to_buff);
+	return ret;
+}
+
 static ssize_t mlx5vf_resume_write(struct file *filp, const char __user *buf,
 				   size_t len, loff_t *pos)
 {
 	struct mlx5_vf_migration_file *migf = filp->private_data;
 	struct mlx5_vhca_data_buffer *vhca_buf = migf->buf;
+	struct mlx5_vhca_data_buffer *vhca_buf_header = migf->buf_header;
 	loff_t requested_length;
+	bool has_work = false;
 	ssize_t done = 0;
+	int ret = 0;
 
 	if (pos)
 		return -ESPIPE;
@@ -534,56 +683,83 @@ static ssize_t mlx5vf_resume_write(struct file *filp, const char __user *buf,
 	    check_add_overflow((loff_t)len, *pos, &requested_length))
 		return -EINVAL;
 
-	if (requested_length > MAX_MIGRATION_SIZE)
-		return -ENOMEM;
-
+	mutex_lock(&migf->mvdev->state_mutex);
 	mutex_lock(&migf->lock);
 	if (migf->state == MLX5_MIGF_STATE_ERROR) {
-		done = -ENODEV;
+		ret = -ENODEV;
 		goto out_unlock;
 	}
 
-	if (vhca_buf->allocated_length < requested_length) {
-		done = mlx5vf_add_migration_pages(
-			vhca_buf,
-			DIV_ROUND_UP(requested_length - vhca_buf->allocated_length,
-				     PAGE_SIZE));
-		if (done)
-			goto out_unlock;
-	}
+	while (len || has_work) {
+		has_work = false;
+		switch (migf->load_state) {
+		case MLX5_VF_LOAD_STATE_READ_HEADER:
+			ret = mlx5vf_resume_read_header(migf, vhca_buf_header,
+							&buf, &len, pos,
+							&done, &has_work);
+			if (ret)
+				goto out_unlock;
+			break;
+		case MLX5_VF_LOAD_STATE_PREP_IMAGE:
+		{
+			u64 size = vhca_buf_header->header_image_size;
+
+			if (vhca_buf->allocated_length < size) {
+				mlx5vf_free_data_buffer(vhca_buf);
+
+				migf->buf = mlx5vf_alloc_data_buffer(migf,
+							size, DMA_TO_DEVICE);
+				if (IS_ERR(migf->buf)) {
+					ret = PTR_ERR(migf->buf);
+					migf->buf = NULL;
+					goto out_unlock;
+				}
 
-	while (len) {
-		size_t page_offset;
-		struct page *page;
-		size_t page_len;
-		u8 *to_buff;
-		int ret;
+				vhca_buf = migf->buf;
+			}
 
-		page_offset = (*pos) % PAGE_SIZE;
-		page = mlx5vf_get_migration_page(vhca_buf, *pos - page_offset);
-		if (!page) {
-			if (done == 0)
-				done = -EINVAL;
-			goto out_unlock;
+			vhca_buf->start_pos = migf->max_pos;
+			migf->load_state = MLX5_VF_LOAD_STATE_READ_IMAGE;
+			break;
 		}
+		case MLX5_VF_LOAD_STATE_READ_IMAGE_NO_HEADER:
+			ret = mlx5vf_resume_read_image_no_header(vhca_buf,
+						requested_length,
+						&buf, &len, pos, &done);
+			if (ret)
+				goto out_unlock;
+			break;
+		case MLX5_VF_LOAD_STATE_READ_IMAGE:
+			ret = mlx5vf_resume_read_image(migf, vhca_buf,
+						vhca_buf_header->header_image_size,
+						&buf, &len, pos, &done, &has_work);
+			if (ret)
+				goto out_unlock;
+			break;
+		case MLX5_VF_LOAD_STATE_LOAD_IMAGE:
+			ret = mlx5vf_cmd_load_vhca_state(migf->mvdev, migf, vhca_buf);
+			if (ret)
+				goto out_unlock;
+			migf->load_state = MLX5_VF_LOAD_STATE_READ_HEADER;
 
-		page_len = min_t(size_t, len, PAGE_SIZE - page_offset);
-		to_buff = kmap_local_page(page);
-		ret = copy_from_user(to_buff + page_offset, buf, page_len);
-		kunmap_local(to_buff);
-		if (ret) {
-			done = -EFAULT;
-			goto out_unlock;
+			/* prep header buf for next image */
+			vhca_buf_header->length = 0;
+			vhca_buf_header->header_image_size = 0;
+			/* prep data buf for next image */
+			vhca_buf->length = 0;
+
+			break;
+		default:
+			break;
 		}
-		*pos += page_len;
-		len -= page_len;
-		done += page_len;
-		buf += page_len;
-		vhca_buf->length += page_len;
 	}
+
 out_unlock:
+	if (ret)
+		migf->state = MLX5_MIGF_STATE_ERROR;
 	mutex_unlock(&migf->lock);
-	return done;
+	mlx5vf_state_mutex_unlock(migf->mvdev);
+	return ret ? ret : done;
 }
 
 static const struct file_operations mlx5vf_resume_fops = {
@@ -623,12 +799,29 @@ mlx5vf_pci_resume_device_data(struct mlx5vf_pci_core_device *mvdev)
 	}
 
 	migf->buf = buf;
+	if (MLX5VF_PRE_COPY_SUPP(mvdev)) {
+		buf = mlx5vf_alloc_data_buffer(migf,
+			sizeof(struct mlx5_vf_migration_header), DMA_NONE);
+		if (IS_ERR(buf)) {
+			ret = PTR_ERR(buf);
+			goto out_buf;
+		}
+
+		migf->buf_header = buf;
+		migf->load_state = MLX5_VF_LOAD_STATE_READ_HEADER;
+	} else {
+		/* Initial state will be to read the image */
+		migf->load_state = MLX5_VF_LOAD_STATE_READ_IMAGE_NO_HEADER;
+	}
+
 	stream_open(migf->filp->f_inode, migf->filp);
 	mutex_init(&migf->lock);
 	INIT_LIST_HEAD(&migf->buf_list);
 	INIT_LIST_HEAD(&migf->avail_list);
 	spin_lock_init(&migf->list_lock);
 	return migf;
+out_buf:
+	mlx5vf_free_data_buffer(buf);
 out_pd:
 	mlx5vf_cmd_dealloc_pd(migf);
 out_free:
@@ -728,11 +921,13 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
 	}
 
 	if (cur == VFIO_DEVICE_STATE_RESUMING && new == VFIO_DEVICE_STATE_STOP) {
-		ret = mlx5vf_cmd_load_vhca_state(mvdev,
-						 mvdev->resuming_migf,
-						 mvdev->resuming_migf->buf);
-		if (ret)
-			return ERR_PTR(ret);
+		if (!MLX5VF_PRE_COPY_SUPP(mvdev)) {
+			ret = mlx5vf_cmd_load_vhca_state(mvdev,
+							 mvdev->resuming_migf,
+							 mvdev->resuming_migf->buf);
+			if (ret)
+				return ERR_PTR(ret);
+		}
 		mlx5vf_disable_fds(mvdev);
 		return NULL;
 	}