diff --git a/mm/memory.c b/mm/memory.c
index 36b164ee9ffb0..44d11812a88f2 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -2933,10 +2933,9 @@ static gfp_t __get_fault_gfp_mask(struct vm_area_struct *vma)
  *
  * We do this without the lock held, so that it can sleep if it needs to.
  */
-static vm_fault_t do_page_mkwrite(struct vm_fault *vmf)
+static vm_fault_t do_page_mkwrite(struct vm_fault *vmf, struct folio *folio)
 {
 	vm_fault_t ret;
-	struct folio *folio = page_folio(vmf->page);
 	unsigned int old_flags = vmf->flags;
 
 	vmf->flags = FAULT_FLAG_WRITE|FAULT_FLAG_MKWRITE;
@@ -3298,7 +3297,7 @@ static vm_fault_t wp_page_shared(struct vm_fault *vmf, struct folio *folio)
 		vm_fault_t tmp;
 
 		pte_unmap_unlock(vmf->pte, vmf->ptl);
-		tmp = do_page_mkwrite(vmf);
+		tmp = do_page_mkwrite(vmf, folio);
 		if (unlikely(!tmp || (tmp &
 				      (VM_FAULT_ERROR | VM_FAULT_NOPAGE)))) {
 			folio_put(folio);
@@ -4621,7 +4620,7 @@ static vm_fault_t do_shared_fault(struct vm_fault *vmf)
 	 */
 	if (vma->vm_ops->page_mkwrite) {
 		folio_unlock(folio);
-		tmp = do_page_mkwrite(vmf);
+		tmp = do_page_mkwrite(vmf, folio);
 		if (unlikely(!tmp ||
 				(tmp & (VM_FAULT_ERROR | VM_FAULT_NOPAGE)))) {
 			folio_put(folio);