Move sanity and compatibility tests from the attach_dev callbacks to the new test_dev callback functions. The IOMMU core makes sure an attach_dev call must be invoked after a successful test_dev call.
Correct the errno upon malloc failure. Also, drop the function prototype of iommu_sva_set_dev_pasid() from the header and make it static, as only pasid.c uses it. Signed-off-by: Nicolin Chen <[email protected]> --- drivers/iommu/amd/amd_iommu.h | 3 --- drivers/iommu/amd/iommu.c | 27 +++++++++++++++++++-------- drivers/iommu/amd/pasid.c | 29 +++++++++++++++++++---------- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/drivers/iommu/amd/amd_iommu.h b/drivers/iommu/amd/amd_iommu.h index 9b4b589a54b57..f99fa4da34996 100644 --- a/drivers/iommu/amd/amd_iommu.h +++ b/drivers/iommu/amd/amd_iommu.h @@ -52,9 +52,6 @@ struct protection_domain *protection_domain_alloc(void); struct iommu_domain *amd_iommu_domain_alloc_sva(struct device *dev, struct mm_struct *mm); void amd_iommu_domain_free(struct iommu_domain *dom); -int iommu_sva_set_dev_pasid(struct iommu_domain *domain, - struct device *dev, ioasid_t pasid, - struct iommu_domain *old); void amd_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid, struct iommu_domain *domain); diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c index e16ad510c8c8a..dc0406427dcf8 100644 --- a/drivers/iommu/amd/iommu.c +++ b/drivers/iommu/amd/iommu.c @@ -70,6 +70,8 @@ int amd_iommu_max_glx_val = -1; */ DEFINE_IDA(pdom_ids); +static int amd_iommu_test_device(struct iommu_domain *dom, struct device *dev, + ioasid_t pasid, struct iommu_domain *old); static int amd_iommu_attach_device(struct iommu_domain *dom, struct device *dev, struct iommu_domain *old); @@ -2670,6 +2672,7 @@ static struct iommu_domain blocked_domain = { static struct protection_domain identity_domain; static const struct iommu_domain_ops identity_domain_ops = { + .test_dev = amd_iommu_test_device, .attach_dev = amd_iommu_attach_device, }; @@ -2686,12 +2689,26 @@ void amd_iommu_init_identity_domain(void) protection_domain_init(&identity_domain); } +static int amd_iommu_test_device(struct iommu_domain *dom, struct device *dev, + ioasid_t pasid, struct iommu_domain *old) +{ + struct amd_iommu *iommu = get_amd_iommu_from_dev(dev); + + /* + * Restrict to devices with compatible IOMMU hardware support + * when enforcement of dirty tracking is enabled. + */ + if (dom->dirty_ops && !amd_iommu_hd_support(iommu)) + return -EINVAL; + + return 0; +} + static int amd_iommu_attach_device(struct iommu_domain *dom, struct device *dev, struct iommu_domain *old) { struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev); struct protection_domain *domain = to_pdomain(dom); - struct amd_iommu *iommu = get_amd_iommu_from_dev(dev); int ret; /* @@ -2703,13 +2720,6 @@ static int amd_iommu_attach_device(struct iommu_domain *dom, struct device *dev, dev_data->defer_attach = false; - /* - * Restrict to devices with compatible IOMMU hardware support - * when enforcement of dirty tracking is enabled. - */ - if (dom->dirty_ops && !amd_iommu_hd_support(iommu)) - return -EINVAL; - if (dev_data->domain) detach_device(dev); @@ -3047,6 +3057,7 @@ const struct iommu_ops amd_iommu_ops = { .def_domain_type = amd_iommu_def_domain_type, .page_response = amd_iommu_page_response, .default_domain_ops = &(const struct iommu_domain_ops) { + .test_dev = amd_iommu_test_device, .attach_dev = amd_iommu_attach_device, .map_pages = amd_iommu_map_pages, .unmap_pages = amd_iommu_unmap_pages, diff --git a/drivers/iommu/amd/pasid.c b/drivers/iommu/amd/pasid.c index 77c8e9a91cbca..474494a66dd40 100644 --- a/drivers/iommu/amd/pasid.c +++ b/drivers/iommu/amd/pasid.c @@ -99,31 +99,39 @@ static const struct mmu_notifier_ops sva_mn = { .release = sva_mn_release, }; -int iommu_sva_set_dev_pasid(struct iommu_domain *domain, - struct device *dev, ioasid_t pasid, - struct iommu_domain *old) +static int iommu_sva_test_dev(struct iommu_domain *domain, struct device *dev, + ioasid_t pasid, struct iommu_domain *old) { - struct pdom_dev_data *pdom_dev_data; - struct protection_domain *sva_pdom = to_pdomain(domain); struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev); - unsigned long flags; - int ret = -EINVAL; if (old) return -EOPNOTSUPP; /* PASID zero is used for requests from the I/O device without PASID */ if (!is_pasid_valid(dev_data, pasid)) - return ret; + return -EINVAL; /* Make sure PASID is enabled */ if (!is_pasid_enabled(dev_data)) - return ret; + return -EINVAL; + + return 0; +} + +static int iommu_sva_set_dev_pasid(struct iommu_domain *domain, + struct device *dev, ioasid_t pasid, + struct iommu_domain *old) +{ + struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev); + struct protection_domain *sva_pdom = to_pdomain(domain); + struct pdom_dev_data *pdom_dev_data; + unsigned long flags; + int ret; /* Add PASID to protection domain pasid list */ pdom_dev_data = kzalloc(sizeof(*pdom_dev_data), GFP_KERNEL); if (pdom_dev_data == NULL) - return ret; + return -ENOMEM; pdom_dev_data->pasid = pasid; pdom_dev_data->dev_data = dev_data; @@ -175,6 +183,7 @@ static void iommu_sva_domain_free(struct iommu_domain *domain) } static const struct iommu_domain_ops amd_sva_domain_ops = { + .test_dev = iommu_sva_test_dev, .set_dev_pasid = iommu_sva_set_dev_pasid, .free = iommu_sva_domain_free }; -- 2.43.0

