Move sanity and compatibility tests from the attach_dev callbacks to the new test_dev callback functions. The IOMMU core makes sure an attach_dev call must be invoked after a successful test_dev call.
Note the apple_dart_finalize_domain() has another caller than attach_dev so it has to duplicate the pgsize sanity too. Signed-off-by: Nicolin Chen <[email protected]> --- drivers/iommu/apple-dart.c | 50 +++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/drivers/iommu/apple-dart.c b/drivers/iommu/apple-dart.c index b5848770ef482..fb63dcb7462a7 100644 --- a/drivers/iommu/apple-dart.c +++ b/drivers/iommu/apple-dart.c @@ -593,9 +593,6 @@ static int apple_dart_finalize_domain(struct apple_dart_domain *dart_domain, int ret = 0; int i, j; - if (dart->pgsize > PAGE_SIZE) - return -EINVAL; - mutex_lock(&dart_domain->init_lock); if (dart_domain->finalized) @@ -643,11 +640,6 @@ apple_dart_mod_streams(struct apple_dart_atomic_stream_map *domain_maps, { int i, j; - for (i = 0; i < MAX_DARTS_PER_DEVICE; ++i) { - if (domain_maps[i].dart != master_maps[i].dart) - return -EINVAL; - } - for (i = 0; i < MAX_DARTS_PER_DEVICE; ++i) { if (!domain_maps[i].dart) break; @@ -671,6 +663,29 @@ static int apple_dart_domain_add_streams(struct apple_dart_domain *domain, true); } +static int apple_dart_test_dev_paging(struct iommu_domain *domain, + struct device *dev, ioasid_t pasid, + struct iommu_domain *old) +{ + struct apple_dart_domain *dart_domain = to_dart_domain(domain); + struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev); + struct apple_dart *dart = cfg->stream_maps[0].dart; + + if (dart->pgsize > PAGE_SIZE) + return -EINVAL; + if (dart_domain->finalized) { + int i; + + for (i = 0; i < MAX_DARTS_PER_DEVICE; ++i) { + if (dart_domain->stream_maps[i].dart != + cfg->stream_maps[i].dart) + return -EINVAL; + } + } + + return 0; +} + static int apple_dart_attach_dev_paging(struct iommu_domain *domain, struct device *dev, struct iommu_domain *old) @@ -693,6 +708,17 @@ static int apple_dart_attach_dev_paging(struct iommu_domain *domain, return 0; } +static int apple_dart_test_dev_identity(struct iommu_domain *domain, + struct device *dev, ioasid_t pasid, + struct iommu_domain *old) +{ + struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev); + + if (!cfg->supports_bypass) + return -EINVAL; + return 0; +} + static int apple_dart_attach_dev_identity(struct iommu_domain *domain, struct device *dev, struct iommu_domain *old) @@ -701,15 +727,13 @@ static int apple_dart_attach_dev_identity(struct iommu_domain *domain, struct apple_dart_stream_map *stream_map; int i; - if (!cfg->supports_bypass) - return -EINVAL; - for_each_stream_map(i, cfg, stream_map) apple_dart_hw_enable_bypass(stream_map); return 0; } static const struct iommu_domain_ops apple_dart_identity_ops = { + .test_dev = apple_dart_test_dev_identity, .attach_dev = apple_dart_attach_dev_identity, }; @@ -776,8 +800,11 @@ static struct iommu_domain *apple_dart_domain_alloc_paging(struct device *dev) if (dev) { struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev); + struct apple_dart *dart = cfg->stream_maps[0].dart; int ret; + if (dart->pgsize > PAGE_SIZE) + return ERR_PTR(-EINVAL); ret = apple_dart_finalize_domain(dart_domain, cfg); if (ret) { kfree(dart_domain); @@ -1010,6 +1037,7 @@ static const struct iommu_ops apple_dart_iommu_ops = { .get_resv_regions = apple_dart_get_resv_regions, .owner = THIS_MODULE, .default_domain_ops = &(const struct iommu_domain_ops) { + .test_dev = apple_dart_test_dev_paging, .attach_dev = apple_dart_attach_dev_paging, .map_pages = apple_dart_map_pages, .unmap_pages = apple_dart_unmap_pages, -- 2.43.0

