Move sanity and compatibility tests from the attach_dev callback to the
new test_dev callback function. The IOMMU core makes sure an attach_dev
call must be invoked after a successful test_dev call.

Signed-off-by: Nicolin Chen <[email protected]>
---
 drivers/iommu/riscv/iommu.c | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c
index d9429097a2b51..6613ece2c9f41 100644
--- a/drivers/iommu/riscv/iommu.c
+++ b/drivers/iommu/riscv/iommu.c
@@ -1320,6 +1320,18 @@ static bool riscv_iommu_pt_supported(struct 
riscv_iommu_device *iommu, int pgd_m
        return false;
 }
 
+static int riscv_iommu_test_paging_domain(struct iommu_domain *iommu_domain,
+                                         struct device *dev, ioasid_t pasid,
+                                         struct iommu_domain *old)
+{
+       struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
+       struct riscv_iommu_device *iommu = dev_to_iommu(dev);
+
+       if (!riscv_iommu_pt_supported(iommu, domain->pgd_mode))
+               return -ENODEV;
+       return 0;
+}
+
 static int riscv_iommu_attach_paging_domain(struct iommu_domain *iommu_domain,
                                            struct device *dev,
                                            struct iommu_domain *old)
@@ -1329,9 +1341,6 @@ static int riscv_iommu_attach_paging_domain(struct 
iommu_domain *iommu_domain,
        struct riscv_iommu_info *info = dev_iommu_priv_get(dev);
        u64 fsc, ta;
 
-       if (!riscv_iommu_pt_supported(iommu, domain->pgd_mode))
-               return -ENODEV;
-
        fsc = FIELD_PREP(RISCV_IOMMU_PC_FSC_MODE, domain->pgd_mode) |
              FIELD_PREP(RISCV_IOMMU_PC_FSC_PPN, virt_to_pfn(domain->pgd_root));
        ta = FIELD_PREP(RISCV_IOMMU_PC_TA_PSCID, domain->pscid) |
@@ -1348,6 +1357,7 @@ static int riscv_iommu_attach_paging_domain(struct 
iommu_domain *iommu_domain,
 }
 
 static const struct iommu_domain_ops riscv_iommu_paging_domain_ops = {
+       .test_dev = riscv_iommu_test_paging_domain,
        .attach_dev = riscv_iommu_attach_paging_domain,
        .free = riscv_iommu_free_paging_domain,
        .map_pages = riscv_iommu_map_pages,
-- 
2.43.0


Reply via email to