Move sanity and compatibility tests from the attach_dev callbacks to the
new test_dev callback functions. The IOMMU core makes sure an attach_dev
call must be invoked after a successful test_dev call.

Note the apple_dart_finalize_domain() has another caller than attach_dev
so it has to duplicate the pgsize sanity too.

Signed-off-by: Nicolin Chen <[email protected]>
---
 drivers/iommu/apple-dart.c | 50 +++++++++++++++++++++++++++++---------
 1 file changed, 39 insertions(+), 11 deletions(-)

diff --git a/drivers/iommu/apple-dart.c b/drivers/iommu/apple-dart.c
index b5848770ef482..fb63dcb7462a7 100644
--- a/drivers/iommu/apple-dart.c
+++ b/drivers/iommu/apple-dart.c
@@ -593,9 +593,6 @@ static int apple_dart_finalize_domain(struct 
apple_dart_domain *dart_domain,
        int ret = 0;
        int i, j;
 
-       if (dart->pgsize > PAGE_SIZE)
-               return -EINVAL;
-
        mutex_lock(&dart_domain->init_lock);
 
        if (dart_domain->finalized)
@@ -643,11 +640,6 @@ apple_dart_mod_streams(struct apple_dart_atomic_stream_map 
*domain_maps,
 {
        int i, j;
 
-       for (i = 0; i < MAX_DARTS_PER_DEVICE; ++i) {
-               if (domain_maps[i].dart != master_maps[i].dart)
-                       return -EINVAL;
-       }
-
        for (i = 0; i < MAX_DARTS_PER_DEVICE; ++i) {
                if (!domain_maps[i].dart)
                        break;
@@ -671,6 +663,29 @@ static int apple_dart_domain_add_streams(struct 
apple_dart_domain *domain,
                                      true);
 }
 
+static int apple_dart_test_dev_paging(struct iommu_domain *domain,
+                                     struct device *dev, ioasid_t pasid,
+                                     struct iommu_domain *old)
+{
+       struct apple_dart_domain *dart_domain = to_dart_domain(domain);
+       struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev);
+       struct apple_dart *dart = cfg->stream_maps[0].dart;
+
+       if (dart->pgsize > PAGE_SIZE)
+               return -EINVAL;
+       if (dart_domain->finalized) {
+               int i;
+
+               for (i = 0; i < MAX_DARTS_PER_DEVICE; ++i) {
+                       if (dart_domain->stream_maps[i].dart !=
+                           cfg->stream_maps[i].dart)
+                               return -EINVAL;
+               }
+       }
+
+       return 0;
+}
+
 static int apple_dart_attach_dev_paging(struct iommu_domain *domain,
                                        struct device *dev,
                                        struct iommu_domain *old)
@@ -693,6 +708,17 @@ static int apple_dart_attach_dev_paging(struct 
iommu_domain *domain,
        return 0;
 }
 
+static int apple_dart_test_dev_identity(struct iommu_domain *domain,
+                                       struct device *dev, ioasid_t pasid,
+                                       struct iommu_domain *old)
+{
+       struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev);
+
+       if (!cfg->supports_bypass)
+               return -EINVAL;
+       return 0;
+}
+
 static int apple_dart_attach_dev_identity(struct iommu_domain *domain,
                                          struct device *dev,
                                          struct iommu_domain *old)
@@ -701,15 +727,13 @@ static int apple_dart_attach_dev_identity(struct 
iommu_domain *domain,
        struct apple_dart_stream_map *stream_map;
        int i;
 
-       if (!cfg->supports_bypass)
-               return -EINVAL;
-
        for_each_stream_map(i, cfg, stream_map)
                apple_dart_hw_enable_bypass(stream_map);
        return 0;
 }
 
 static const struct iommu_domain_ops apple_dart_identity_ops = {
+       .test_dev = apple_dart_test_dev_identity,
        .attach_dev = apple_dart_attach_dev_identity,
 };
 
@@ -776,8 +800,11 @@ static struct iommu_domain 
*apple_dart_domain_alloc_paging(struct device *dev)
 
        if (dev) {
                struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev);
+               struct apple_dart *dart = cfg->stream_maps[0].dart;
                int ret;
 
+               if (dart->pgsize > PAGE_SIZE)
+                       return ERR_PTR(-EINVAL);
                ret = apple_dart_finalize_domain(dart_domain, cfg);
                if (ret) {
                        kfree(dart_domain);
@@ -1010,6 +1037,7 @@ static const struct iommu_ops apple_dart_iommu_ops = {
        .get_resv_regions = apple_dart_get_resv_regions,
        .owner = THIS_MODULE,
        .default_domain_ops = &(const struct iommu_domain_ops) {
+               .test_dev       = apple_dart_test_dev_paging,
                .attach_dev     = apple_dart_attach_dev_paging,
                .map_pages      = apple_dart_map_pages,
                .unmap_pages    = apple_dart_unmap_pages,
-- 
2.43.0


Reply via email to