summaryrefslogtreecommitdiffstats
path: root/drivers/iommu/iommu-sva-lib.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu/iommu-sva-lib.c')
-rw-r--r--drivers/iommu/iommu-sva-lib.c39
1 files changed, 12 insertions, 27 deletions
diff --git a/drivers/iommu/iommu-sva-lib.c b/drivers/iommu/iommu-sva-lib.c
index bd41405d34e9..106506143896 100644
--- a/drivers/iommu/iommu-sva-lib.c
+++ b/drivers/iommu/iommu-sva-lib.c
@@ -18,8 +18,7 @@ static DECLARE_IOASID_SET(iommu_sva_pasid);
*
* Try to allocate a PASID for this mm, or take a reference to the existing one
* provided it fits within the [@min, @max] range. On success the PASID is
- * available in mm->pasid, and must be released with iommu_sva_free_pasid().
- * @min must be greater than 0, because 0 indicates an unused mm->pasid.
+ * available in mm->pasid and will be available for the lifetime of the mm.
*
* Returns 0 on success and < 0 on error.
*/
@@ -33,38 +32,24 @@ int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
return -EINVAL;
mutex_lock(&iommu_sva_lock);
- if (mm->pasid) {
- if (mm->pasid >= min && mm->pasid <= max)
- ioasid_get(mm->pasid);
- else
+ /* Is a PASID already associated with this mm? */
+ if (pasid_valid(mm->pasid)) {
+ if (mm->pasid < min || mm->pasid >= max)
ret = -EOVERFLOW;
- } else {
- pasid = ioasid_alloc(&iommu_sva_pasid, min, max, mm);
- if (pasid == INVALID_IOASID)
- ret = -ENOMEM;
- else
- mm->pasid = pasid;
+ goto out;
}
+
+ pasid = ioasid_alloc(&iommu_sva_pasid, min, max, mm);
+ if (!pasid_valid(pasid))
+ ret = -ENOMEM;
+ else
+ mm_pasid_set(mm, pasid);
+out:
mutex_unlock(&iommu_sva_lock);
return ret;
}
EXPORT_SYMBOL_GPL(iommu_sva_alloc_pasid);
-/**
- * iommu_sva_free_pasid - Release the mm's PASID
- * @mm: the mm
- *
- * Drop one reference to a PASID allocated with iommu_sva_alloc_pasid()
- */
-void iommu_sva_free_pasid(struct mm_struct *mm)
-{
- mutex_lock(&iommu_sva_lock);
- if (ioasid_put(mm->pasid))
- mm->pasid = 0;
- mutex_unlock(&iommu_sva_lock);
-}
-EXPORT_SYMBOL_GPL(iommu_sva_free_pasid);
-
/* ioasid_find getter() requires a void * argument */
static bool __mmget_not_zero(void *mm)
{