summaryrefslogtreecommitdiffstats
path: root/drivers/iommu
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu')
-rw-r--r--drivers/iommu/intel/svm.c115
1 files changed, 65 insertions, 50 deletions
diff --git a/drivers/iommu/intel/svm.c b/drivers/iommu/intel/svm.c
index 65d2327dcd0d..c104a50a625c 100644
--- a/drivers/iommu/intel/svm.c
+++ b/drivers/iommu/intel/svm.c
@@ -228,13 +228,57 @@ static LIST_HEAD(global_svm_list);
list_for_each_entry((sdev), &(svm)->devs, list) \
if ((d) != (sdev)->dev) {} else
+static int pasid_to_svm_sdev(struct device *dev, unsigned int pasid,
+ struct intel_svm **rsvm,
+ struct intel_svm_dev **rsdev)
+{
+ struct intel_svm_dev *d, *sdev = NULL;
+ struct intel_svm *svm;
+
+ /* The caller should hold the pasid_mutex lock */
+ if (WARN_ON(!mutex_is_locked(&pasid_mutex)))
+ return -EINVAL;
+
+ if (pasid == INVALID_IOASID || pasid >= PASID_MAX)
+ return -EINVAL;
+
+ svm = ioasid_find(NULL, pasid, NULL);
+ if (IS_ERR(svm))
+ return PTR_ERR(svm);
+
+ if (!svm)
+ goto out;
+
+ /*
+ * If we found svm for the PASID, there must be at least one device
+ * bond.
+ */
+ if (WARN_ON(list_empty(&svm->devs)))
+ return -EINVAL;
+
+ rcu_read_lock();
+ list_for_each_entry_rcu(d, &svm->devs, list) {
+ if (d->dev == dev) {
+ sdev = d;
+ break;
+ }
+ }
+ rcu_read_unlock();
+
+out:
+ *rsvm = svm;
+ *rsdev = sdev;
+
+ return 0;
+}
+
int intel_svm_bind_gpasid(struct iommu_domain *domain, struct device *dev,
struct iommu_gpasid_bind_data *data)
{
struct intel_iommu *iommu = device_to_iommu(dev, NULL, NULL);
+ struct intel_svm_dev *sdev = NULL;
struct dmar_domain *dmar_domain;
- struct intel_svm_dev *sdev;
- struct intel_svm *svm;
+ struct intel_svm *svm = NULL;
int ret = 0;
if (WARN_ON(!iommu) || !data)
@@ -261,35 +305,23 @@ int intel_svm_bind_gpasid(struct iommu_domain *domain, struct device *dev,
dmar_domain = to_dmar_domain(domain);
mutex_lock(&pasid_mutex);
- svm = ioasid_find(NULL, data->hpasid, NULL);
- if (IS_ERR(svm)) {
- ret = PTR_ERR(svm);
+ ret = pasid_to_svm_sdev(dev, data->hpasid, &svm, &sdev);
+ if (ret)
goto out;
- }
-
- if (svm) {
- /*
- * If we found svm for the PASID, there must be at
- * least one device bond, otherwise svm should be freed.
- */
- if (WARN_ON(list_empty(&svm->devs))) {
- ret = -EINVAL;
- goto out;
- }
+ if (sdev) {
/*
* Do not allow multiple bindings of the same device-PASID since
* there is only one SL page tables per PASID. We may revisit
* once sharing PGD across domains are supported.
*/
- for_each_svm_dev(sdev, svm, dev) {
- dev_warn_ratelimited(dev,
- "Already bound with PASID %u\n",
- svm->pasid);
- ret = -EBUSY;
- goto out;
- }
- } else {
+ dev_warn_ratelimited(dev, "Already bound with PASID %u\n",
+ svm->pasid);
+ ret = -EBUSY;
+ goto out;
+ }
+
+ if (!svm) {
/* We come here when PASID has never been bond to a device. */
svm = kzalloc(sizeof(*svm), GFP_KERNEL);
if (!svm) {
@@ -372,25 +404,17 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
struct intel_iommu *iommu = device_to_iommu(dev, NULL, NULL);
struct intel_svm_dev *sdev;
struct intel_svm *svm;
- int ret = -EINVAL;
+ int ret;
if (WARN_ON(!iommu))
return -EINVAL;
mutex_lock(&pasid_mutex);
- svm = ioasid_find(NULL, pasid, NULL);
- if (!svm) {
- ret = -EINVAL;
- goto out;
- }
-
- if (IS_ERR(svm)) {
- ret = PTR_ERR(svm);
+ ret = pasid_to_svm_sdev(dev, pasid, &svm, &sdev);
+ if (ret)
goto out;
- }
- for_each_svm_dev(sdev, svm, dev) {
- ret = 0;
+ if (sdev) {
if (iommu_dev_feature_enabled(dev, IOMMU_DEV_FEAT_AUX))
sdev->users--;
if (!sdev->users) {
@@ -414,7 +438,6 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
kfree(svm);
}
}
- break;
}
out:
mutex_unlock(&pasid_mutex);
@@ -592,7 +615,7 @@ success:
if (sd)
*sd = sdev;
ret = 0;
- out:
+out:
return ret;
}
@@ -608,17 +631,11 @@ static int intel_svm_unbind_mm(struct device *dev, int pasid)
if (!iommu)
goto out;
- svm = ioasid_find(NULL, pasid, NULL);
- if (!svm)
- goto out;
-
- if (IS_ERR(svm)) {
- ret = PTR_ERR(svm);
+ ret = pasid_to_svm_sdev(dev, pasid, &svm, &sdev);
+ if (ret)
goto out;
- }
- for_each_svm_dev(sdev, svm, dev) {
- ret = 0;
+ if (sdev) {
sdev->users--;
if (!sdev->users) {
list_del_rcu(&sdev->list);
@@ -647,10 +664,8 @@ static int intel_svm_unbind_mm(struct device *dev, int pasid)
kfree(svm);
}
}
- break;
}
- out:
-
+out:
return ret;
}