diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index f65b7e1a5758f..0b97aee6b735c 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -4777,34 +4777,6 @@ kvm_calc_cpu_role(struct kvm_vcpu *vcpu, const struct kvm_mmu_role_regs *regs)
 	return role;
 }
 
-static union kvm_mmu_role kvm_calc_mmu_role_common(struct kvm_vcpu *vcpu,
-						   const struct kvm_mmu_role_regs *regs)
-{
-	union kvm_mmu_role role = {0};
-
-	role.base.access = ACC_ALL;
-	if (____is_cr0_pg(regs)) {
-		role.ext.cr0_pg = 1;
-		role.base.efer_nx = ____is_efer_nx(regs);
-		role.base.cr0_wp = ____is_cr0_wp(regs);
-
-		role.ext.cr4_pae = ____is_cr4_pae(regs);
-		role.ext.cr4_smep = ____is_cr4_smep(regs);
-		role.ext.cr4_smap = ____is_cr4_smap(regs);
-		role.ext.cr4_pse = ____is_cr4_pse(regs);
-
-		/* PKEY and LA57 are active iff long mode is active. */
-		role.ext.cr4_pke = ____is_efer_lma(regs) && ____is_cr4_pke(regs);
-		role.ext.cr4_la57 = ____is_efer_lma(regs) && ____is_cr4_la57(regs);
-		role.ext.efer_lma = ____is_efer_lma(regs);
-	}
-	role.base.smm = is_smm(vcpu);
-	role.base.guest_mode = is_guest_mode(vcpu);
-	role.ext.valid = 1;
-
-	return role;
-}
-
 static inline int kvm_mmu_get_tdp_level(struct kvm_vcpu *vcpu)
 {
 	/* tdp_root_level is architecture forced level, use it if nonzero */
@@ -4820,14 +4792,20 @@ static inline int kvm_mmu_get_tdp_level(struct kvm_vcpu *vcpu)
 
 static union kvm_mmu_role
 kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu,
-				const struct kvm_mmu_role_regs *regs)
+				union kvm_mmu_role cpu_role)
 {
-	union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, regs);
+	union kvm_mmu_role role = {0};
 
+	role.base.access = ACC_ALL;
+	role.base.cr0_wp = true;
+	role.base.efer_nx = true;
+	role.base.smm = cpu_role.base.smm;
+	role.base.guest_mode = cpu_role.base.guest_mode;
 	role.base.ad_disabled = (shadow_accessed_mask == 0);
 	role.base.level = kvm_mmu_get_tdp_level(vcpu);
 	role.base.direct = true;
 	role.base.has_4_byte_gpte = false;
+	role.ext.valid = true;
 
 	return role;
 }
@@ -4837,8 +4815,7 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu,
 {
 	struct kvm_mmu *context = &vcpu->arch.root_mmu;
 	union kvm_mmu_role cpu_role = kvm_calc_cpu_role(vcpu, regs);
-	union kvm_mmu_role mmu_role =
-		kvm_calc_tdp_mmu_root_page_role(vcpu, regs);
+	union kvm_mmu_role mmu_role = kvm_calc_tdp_mmu_root_page_role(vcpu, cpu_role);
 
 	if (cpu_role.as_u64 == context->cpu_role.as_u64 &&
 	    mmu_role.as_u64 == context->mmu_role.as_u64)