diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c index 8dc7ed678e3fb..c6f4852a8a0c0 100644 --- a/drivers/iommu/iommufd/device.c +++ b/drivers/iommu/iommufd/device.c @@ -22,7 +22,8 @@ void iommufd_device_destroy(struct iommufd_object *obj) iommu_device_release_dma_owner(idev->dev); iommu_group_put(idev->group); - iommufd_ctx_put(idev->ictx); + if (!iommufd_selftest_is_mock_dev(idev->dev)) + iommufd_ctx_put(idev->ictx); } /** @@ -69,7 +70,8 @@ struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx, goto out_release_owner; } idev->ictx = ictx; - iommufd_ctx_get(ictx); + if (!iommufd_selftest_is_mock_dev(dev)) + iommufd_ctx_get(ictx); idev->dev = dev; idev->enforce_cache_coherency = device_iommu_capable(dev, IOMMU_CAP_ENFORCE_CACHE_COHERENCY); @@ -151,7 +153,8 @@ static int iommufd_device_setup_msi(struct iommufd_device *idev, * operation from the device (eg a simple DMA) cannot trigger an * interrupt outside this iommufd context. */ - if (!iommu_group_has_isolated_msi(idev->group)) { + if (!iommufd_selftest_is_mock_dev(idev->dev) && + !iommu_group_has_isolated_msi(idev->group)) { if (!allow_unsafe_interrupts) return -EPERM; @@ -706,34 +709,3 @@ int iommufd_access_rw(struct iommufd_access *access, unsigned long iova, return rc; } EXPORT_SYMBOL_NS_GPL(iommufd_access_rw, IOMMUFD); - -#ifdef CONFIG_IOMMUFD_TEST -/* - * Creating a real iommufd_device is too hard, bypass creating a iommufd_device - * and go directly to attaching a domain. - */ -struct iommufd_hw_pagetable * -iommufd_device_selftest_attach(struct iommufd_ctx *ictx, - struct iommufd_ioas *ioas, - struct device *mock_dev) -{ - struct iommufd_device tmp_idev = { .dev = mock_dev }; - struct iommufd_hw_pagetable *hwpt; - - mutex_lock(&ioas->mutex); - hwpt = iommufd_hw_pagetable_alloc(ictx, ioas, &tmp_idev, false); - mutex_unlock(&ioas->mutex); - if (IS_ERR(hwpt)) - return hwpt; - - refcount_inc(&hwpt->obj.users); - iommufd_object_finalize(ictx, &hwpt->obj); - return hwpt; -} - -void iommufd_device_selftest_detach(struct iommufd_ctx *ictx, - struct iommufd_hw_pagetable *hwpt) -{ - refcount_dec(&hwpt->obj.users); -} -#endif diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h index 331664e917b77..d523ef12890e1 100644 --- a/drivers/iommu/iommufd/iommufd_private.h +++ b/drivers/iommu/iommufd/iommufd_private.h @@ -297,12 +297,6 @@ void iopt_remove_access(struct io_pagetable *iopt, void iommufd_access_destroy_object(struct iommufd_object *obj); #ifdef CONFIG_IOMMUFD_TEST -struct iommufd_hw_pagetable * -iommufd_device_selftest_attach(struct iommufd_ctx *ictx, - struct iommufd_ioas *ioas, - struct device *mock_dev); -void iommufd_device_selftest_detach(struct iommufd_ctx *ictx, - struct iommufd_hw_pagetable *hwpt); int iommufd_test(struct iommufd_ucmd *ucmd); void iommufd_selftest_destroy(struct iommufd_object *obj); extern size_t iommufd_test_memory_limit; @@ -311,6 +305,7 @@ void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd, bool iommufd_should_fail(void); void __init iommufd_test_init(void); void iommufd_test_exit(void); +bool iommufd_selftest_is_mock_dev(struct device *dev); #else static inline void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd, unsigned int ioas_id, @@ -327,5 +322,9 @@ static inline void __init iommufd_test_init(void) static inline void iommufd_test_exit(void) { } +static inline bool iommufd_selftest_is_mock_dev(struct device *dev) +{ + return false; +} #endif #endif diff --git a/drivers/iommu/iommufd/selftest.c b/drivers/iommu/iommufd/selftest.c index e05b41059630a..17cb7b95eb275 100644 --- a/drivers/iommu/iommufd/selftest.c +++ b/drivers/iommu/iommufd/selftest.c @@ -91,23 +91,50 @@ enum selftest_obj_type { TYPE_IDEV, }; +struct mock_dev { + struct device dev; +}; + struct selftest_obj { struct iommufd_object obj; enum selftest_obj_type type; union { struct { - struct iommufd_hw_pagetable *hwpt; + struct iommufd_device *idev; struct iommufd_ctx *ictx; - struct device mock_dev; + struct mock_dev *mock_dev; } idev; }; }; +static void mock_domain_blocking_free(struct iommu_domain *domain) +{ +} + +static int mock_domain_nop_attach(struct iommu_domain *domain, + struct device *dev) +{ + return 0; +} + +static const struct iommu_domain_ops mock_blocking_ops = { + .free = mock_domain_blocking_free, + .attach_dev = mock_domain_nop_attach, +}; + +static struct iommu_domain mock_blocking_domain = { + .type = IOMMU_DOMAIN_BLOCKED, + .ops = &mock_blocking_ops, +}; + static struct iommu_domain *mock_domain_alloc(unsigned int iommu_domain_type) { struct mock_iommu_domain *mock; + if (iommu_domain_type == IOMMU_DOMAIN_BLOCKED) + return &mock_blocking_domain; + if (WARN_ON(iommu_domain_type != IOMMU_DOMAIN_UNMANAGED)) return NULL; @@ -236,19 +263,39 @@ static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain, return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE; } +static bool mock_domain_capable(struct device *dev, enum iommu_cap cap) +{ + return cap == IOMMU_CAP_CACHE_COHERENCY; +} + +static void mock_domain_set_plaform_dma_ops(struct device *dev) +{ + /* + * mock doesn't setup default domains because we can't hook into the + * normal probe path + */ +} + static const struct iommu_ops mock_ops = { .owner = THIS_MODULE, .pgsize_bitmap = MOCK_IO_PAGE_SIZE, .domain_alloc = mock_domain_alloc, + .capable = mock_domain_capable, + .set_platform_dma_ops = mock_domain_set_plaform_dma_ops, .default_domain_ops = &(struct iommu_domain_ops){ .free = mock_domain_free, + .attach_dev = mock_domain_nop_attach, .map_pages = mock_domain_map_pages, .unmap_pages = mock_domain_unmap_pages, .iova_to_phys = mock_domain_iova_to_phys, }, }; +struct iommu_device mock_iommu_device = { + .ops = &mock_ops, +}; + static inline struct iommufd_hw_pagetable * get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, struct mock_iommu_domain **mock) @@ -269,48 +316,142 @@ get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, return hwpt; } +static struct bus_type iommufd_mock_bus_type = { + .name = "iommufd_mock", + .iommu_ops = &mock_ops, +}; + +static void mock_dev_release(struct device *dev) +{ + struct mock_dev *mdev = container_of(dev, struct mock_dev, dev); + + kfree(mdev); +} + +static struct mock_dev *mock_dev_create(void) +{ + struct iommu_group *iommu_group; + struct dev_iommu *dev_iommu; + struct mock_dev *mdev; + int rc; + + mdev = kzalloc(sizeof(*mdev), GFP_KERNEL); + if (!mdev) + return ERR_PTR(-ENOMEM); + + device_initialize(&mdev->dev); + mdev->dev.release = mock_dev_release; + mdev->dev.bus = &iommufd_mock_bus_type; + + iommu_group = iommu_group_alloc(); + if (IS_ERR(iommu_group)) { + rc = PTR_ERR(iommu_group); + goto err_put; + } + + rc = dev_set_name(&mdev->dev, "iommufd_mock%u", + iommu_group_id(iommu_group)); + if (rc) + goto err_group; + + /* + * The iommu core has no way to associate a single device with an iommu + * driver (heck currently it can't even support two iommu_drivers + * registering). Hack it together with an open coded dev_iommu_get(). + * Notice that the normal notifier triggered iommu release process also + * does not work here because this bus is not in iommu_buses. + */ + mdev->dev.iommu = kzalloc(sizeof(*dev_iommu), GFP_KERNEL); + if (!mdev->dev.iommu) { + rc = -ENOMEM; + goto err_group; + } + mutex_init(&mdev->dev.iommu->lock); + mdev->dev.iommu->iommu_dev = &mock_iommu_device; + + rc = device_add(&mdev->dev); + if (rc) + goto err_dev_iommu; + + rc = iommu_group_add_device(iommu_group, &mdev->dev); + if (rc) + goto err_del; + iommu_group_put(iommu_group); + return mdev; + +err_del: + device_del(&mdev->dev); +err_dev_iommu: + kfree(mdev->dev.iommu); + mdev->dev.iommu = NULL; +err_group: + iommu_group_put(iommu_group); +err_put: + put_device(&mdev->dev); + return ERR_PTR(rc); +} + +static void mock_dev_destroy(struct mock_dev *mdev) +{ + iommu_group_remove_device(&mdev->dev); + device_del(&mdev->dev); + kfree(mdev->dev.iommu); + mdev->dev.iommu = NULL; + put_device(&mdev->dev); +} + +bool iommufd_selftest_is_mock_dev(struct device *dev) +{ + return dev->release == mock_dev_release; +} + /* Create an hw_pagetable with the mock domain so we can test the domain ops */ static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd, struct iommu_test_cmd *cmd) { - static struct bus_type mock_bus = { .iommu_ops = &mock_ops }; - struct iommufd_hw_pagetable *hwpt; + struct iommufd_device *idev; struct selftest_obj *sobj; - struct iommufd_ioas *ioas; + u32 pt_id = cmd->id; + u32 idev_id; int rc; - ioas = iommufd_get_ioas(ucmd, cmd->id); - if (IS_ERR(ioas)) - return PTR_ERR(ioas); - sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST); - if (IS_ERR(sobj)) { - rc = PTR_ERR(sobj); - goto out_ioas; - } + if (IS_ERR(sobj)) + return PTR_ERR(sobj); + sobj->idev.ictx = ucmd->ictx; sobj->type = TYPE_IDEV; - sobj->idev.mock_dev.bus = &mock_bus; - hwpt = iommufd_device_selftest_attach(ucmd->ictx, ioas, - &sobj->idev.mock_dev); - if (IS_ERR(hwpt)) { - rc = PTR_ERR(hwpt); + sobj->idev.mock_dev = mock_dev_create(); + if (IS_ERR(sobj->idev.mock_dev)) { + rc = PTR_ERR(sobj->idev.mock_dev); goto out_sobj; } - sobj->idev.hwpt = hwpt; - /* Userspace must destroy both of these IDs to destroy the object */ - cmd->mock_domain.out_hwpt_id = hwpt->obj.id; + idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev, + &idev_id); + if (IS_ERR(idev)) { + rc = PTR_ERR(idev); + goto out_mdev; + } + sobj->idev.idev = idev; + + rc = iommufd_device_attach(idev, &pt_id); + if (rc) + goto out_unbind; + + /* Userspace must destroy the device_id to destroy the object */ + cmd->mock_domain.out_hwpt_id = pt_id; cmd->mock_domain.out_stdev_id = sobj->obj.id; iommufd_object_finalize(ucmd->ictx, &sobj->obj); - iommufd_put_object(&ioas->obj); return iommufd_ucmd_respond(ucmd, sizeof(*cmd)); +out_unbind: + iommufd_device_unbind(idev); +out_mdev: + mock_dev_destroy(sobj->idev.mock_dev); out_sobj: iommufd_object_abort(ucmd->ictx, &sobj->obj); -out_ioas: - iommufd_put_object(&ioas->obj); return rc; } @@ -780,8 +921,9 @@ void iommufd_selftest_destroy(struct iommufd_object *obj) switch (sobj->type) { case TYPE_IDEV: - iommufd_device_selftest_detach(sobj->idev.ictx, - sobj->idev.hwpt); + iommufd_device_detach(sobj->idev.idev); + iommufd_device_unbind(sobj->idev.idev); + mock_dev_destroy(sobj->idev.mock_dev); break; } } @@ -845,9 +987,11 @@ void __init iommufd_test_init(void) { dbgfs_root = fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd); + WARN_ON(bus_register(&iommufd_mock_bus_type)); } void iommufd_test_exit(void) { debugfs_remove_recursive(dbgfs_root); + bus_unregister(&iommufd_mock_bus_type); } diff --git a/tools/testing/selftests/iommu/iommufd.c b/tools/testing/selftests/iommu/iommufd.c index 1d72c48157b79..fe20342abfb02 100644 --- a/tools/testing/selftests/iommu/iommufd.c +++ b/tools/testing/selftests/iommu/iommufd.c @@ -645,7 +645,6 @@ TEST_F(iommufd_ioas, access_pin) &check_map_cmd)); test_ioctl_destroy(mock_stdev_id); - test_ioctl_destroy(mock_hwpt_id); test_cmd_destroy_access_pages( access_cmd.id, access_cmd.access_pages.out_access_pages_id); @@ -1214,7 +1213,6 @@ TEST_F(iommufd_mock_domain, all_aligns_copy) 1); test_ioctl_destroy(mock_stdev_id); - test_ioctl_destroy(self->hwpt_ids[1]); self->hwpt_ids[1] = old_id; test_ioctl_ioas_unmap(iova, length); diff --git a/tools/testing/selftests/iommu/iommufd_fail_nth.c b/tools/testing/selftests/iommu/iommufd_fail_nth.c index e7d535680721b..d9afcb23810e1 100644 --- a/tools/testing/selftests/iommu/iommufd_fail_nth.c +++ b/tools/testing/selftests/iommu/iommufd_fail_nth.c @@ -323,8 +323,6 @@ TEST_FAIL_NTH(basic_fail_nth, map_domain) if (_test_ioctl_destroy(self->fd, stdev_id)) return -1; - if (_test_ioctl_destroy(self->fd, hwpt_id)) - return -1; if (_test_cmd_mock_domain(self->fd, ioas_id, &stdev_id, &hwpt_id)) return -1; @@ -365,13 +363,9 @@ TEST_FAIL_NTH(basic_fail_nth, map_two_domains) if (_test_ioctl_destroy(self->fd, stdev_id)) return -1; - if (_test_ioctl_destroy(self->fd, hwpt_id)) - return -1; if (_test_ioctl_destroy(self->fd, stdev_id2)) return -1; - if (_test_ioctl_destroy(self->fd, hwpt_id2)) - return -1; if (_test_cmd_mock_domain(self->fd, ioas_id, &stdev_id, &hwpt_id)) return -1; @@ -572,8 +566,6 @@ TEST_FAIL_NTH(basic_fail_nth, access_pin_domain) if (_test_ioctl_destroy(self->fd, stdev_id)) return -1; - if (_test_ioctl_destroy(self->fd, hwpt_id)) - return -1; return 0; }