diff --git a/arch/arm64/include/asm/kvm_nested.h b/arch/arm64/include/asm/kvm_nested.h index 233e65522716..f616e25e204f 100644 --- a/arch/arm64/include/asm/kvm_nested.h +++ b/arch/arm64/include/asm/kvm_nested.h @@ -186,7 +186,7 @@ static inline bool kvm_supported_tlbi_s1e2_op(struct kvm_vcpu *vpcu, u32 instr) return true; } -int kvm_init_nv_sysregs(struct kvm *kvm); +int kvm_init_nv_sysregs(struct kvm_vcpu *vcpu); #ifdef CONFIG_ARM64_PTR_AUTH bool kvm_auth_eretax(struct kvm_vcpu *vcpu, u64 *elr); diff --git a/arch/arm64/kvm/nested.c b/arch/arm64/kvm/nested.c index 9b36218b48de..dd6480cf90ea 100644 --- a/arch/arm64/kvm/nested.c +++ b/arch/arm64/kvm/nested.c @@ -963,14 +963,15 @@ static __always_inline void set_sysreg_masks(struct kvm *kvm, int sr, u64 res0, kvm->arch.sysreg_masks->mask[i].res1 = res1; } -int kvm_init_nv_sysregs(struct kvm *kvm) +int kvm_init_nv_sysregs(struct kvm_vcpu *vcpu) { + struct kvm *kvm = vcpu->kvm; u64 res0, res1; lockdep_assert_held(&kvm->arch.config_lock); if (kvm->arch.sysreg_masks) - return 0; + goto out; kvm->arch.sysreg_masks = kzalloc(sizeof(*(kvm->arch.sysreg_masks)), GFP_KERNEL_ACCOUNT); @@ -1271,6 +1272,10 @@ int kvm_init_nv_sysregs(struct kvm *kvm) res0 |= MDCR_EL2_EnSTEPOP; set_sysreg_masks(kvm, MDCR_EL2, res0, res1); +out: + for (enum vcpu_sysreg sr = __SANITISED_REG_START__; sr < NR_SYS_REGS; sr++) + (void)__vcpu_sys_reg(vcpu, sr); + return 0; } diff --git a/arch/arm64/kvm/sys_regs.c b/arch/arm64/kvm/sys_regs.c index 83c6b4a07ef5..8671d46f53e5 100644 --- a/arch/arm64/kvm/sys_regs.c +++ b/arch/arm64/kvm/sys_regs.c @@ -4396,6 +4396,9 @@ void kvm_reset_sys_regs(struct kvm_vcpu *vcpu) reset_vcpu_ftr_id_reg(vcpu, r); else r->reset(vcpu, r); + + if (r->reg >= __SANITISED_REG_START__ && r->reg < NR_SYS_REGS) + (void)__vcpu_sys_reg(vcpu, r->reg); } set_bit(KVM_ARCH_FLAG_ID_REGS_INITIALIZED, &kvm->arch.flags); @@ -4999,7 +5002,7 @@ int kvm_finalize_sys_regs(struct kvm_vcpu *vcpu) } if (vcpu_has_nv(vcpu)) { - int ret = kvm_init_nv_sysregs(kvm); + int ret = kvm_init_nv_sysregs(vcpu); if (ret) return ret; }