diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index c4e0b97f82acc..9d617b9dc78f0 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -3730,7 +3730,9 @@ static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
 		vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
 	} else
 		BUG();
-	vcpu->arch.mmu->root_cr3 = vcpu->arch.mmu->get_cr3(vcpu);
+
+	/* root_cr3 is ignored for direct MMUs. */
+	vcpu->arch.mmu->root_cr3 = 0;
 
 	return 0;
 }
@@ -4272,8 +4274,8 @@ static bool cached_root_available(struct kvm_vcpu *vcpu, gpa_t new_cr3,
 	for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
 		swap(root, mmu->prev_roots[i]);
 
-		if (new_cr3 == root.cr3 && VALID_PAGE(root.hpa) &&
-		    page_header(root.hpa) != NULL &&
+		if ((new_role.direct || new_cr3 == root.cr3) &&
+		    VALID_PAGE(root.hpa) && page_header(root.hpa) &&
 		    new_role.word == page_header(root.hpa)->role.word)
 			break;
 	}