summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--drivers/vfio/vfio_iommu_type1.c77
1 files changed, 42 insertions, 35 deletions
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 8a2be4e40f22..98231d10890c 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -370,6 +370,9 @@ static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start,
struct vfio_dma *split;
int ret;
+ if (!*size)
+ return 0;
+
/*
* Existing dma region is completely covered, unmap all. This is
* the likely case since userspace tends to map and unmap buffers
@@ -411,7 +414,9 @@ static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start,
dma->vaddr += overlap;
dma->size -= overlap;
vfio_insert_dma(iommu, dma);
- }
+ } else
+ kfree(dma);
+
*size = overlap;
return 0;
}
@@ -425,48 +430,41 @@ static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start,
if (ret)
return ret;
- /*
- * We may have unmapped the entire vfio_dma if the user is
- * trying to unmap a sub-region of what was originally
- * mapped. If anything left, we can resize in place since
- * iova is unchanged.
- */
- if (overlap < dma->size)
- dma->size -= overlap;
- else
- vfio_remove_dma(iommu, dma);
-
+ dma->size -= overlap;
*size = overlap;
return 0;
}
/* Split existing */
+ split = kzalloc(sizeof(*split), GFP_KERNEL);
+ if (!split)
+ return -ENOMEM;
+
offset = start - dma->iova;
ret = vfio_unmap_unpin(iommu, dma, start, size);
if (ret)
return ret;
- WARN_ON(!*size);
+ if (!*size) {
+ kfree(split);
+ return -EINVAL;
+ }
+
tmp = dma->size;
- /*
- * Resize the lower vfio_dma in place, insert new for remaining
- * upper segment.
- */
+ /* Resize the lower vfio_dma in place, before the below insert */
dma->size = offset;
- if (offset + *size < tmp) {
- split = kzalloc(sizeof(*split), GFP_KERNEL);
- if (!split)
- return -ENOMEM;
-
+ /* Insert new for remainder, assuming it didn't all get unmapped */
+ if (likely(offset + *size < tmp)) {
split->size = tmp - offset - *size;
split->iova = dma->iova + offset + *size;
split->vaddr = dma->vaddr + offset + *size;
split->prot = dma->prot;
vfio_insert_dma(iommu, split);
- }
+ } else
+ kfree(split);
return 0;
}
@@ -483,7 +481,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
if (unmap->iova & mask)
return -EINVAL;
- if (unmap->size & mask)
+ if (!unmap->size || unmap->size & mask)
return -EINVAL;
WARN_ON(mask & PAGE_MASK);
@@ -493,7 +491,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
size = unmap->size;
ret = vfio_remove_dma_overlap(iommu, unmap->iova, &size, dma);
- if (ret)
+ if (ret || !size)
break;
unmapped += size;
}
@@ -635,7 +633,6 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
if (tmp && tmp->prot == prot &&
tmp->vaddr + tmp->size == vaddr) {
tmp->size += size;
-
iova = tmp->iova;
size = tmp->size;
vaddr = tmp->vaddr;
@@ -643,19 +640,28 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
}
}
- /* Check if we abut a region above - nothing above ~0 + 1 */
+ /*
+ * Check if we abut a region above - nothing above ~0 + 1.
+ * If we abut above and below, remove and free. If only
+ * abut above, remove, modify, reinsert.
+ */
if (likely(iova + size)) {
struct vfio_dma *tmp;
-
tmp = vfio_find_dma(iommu, iova + size, 1);
if (tmp && tmp->prot == prot &&
tmp->vaddr == vaddr + size) {
vfio_remove_dma(iommu, tmp);
- if (dma)
+ if (dma) {
dma->size += tmp->size;
- else
+ kfree(tmp);
+ } else {
size += tmp->size;
- kfree(tmp);
+ tmp->size = size;
+ tmp->iova = iova;
+ tmp->vaddr = vaddr;
+ vfio_insert_dma(iommu, tmp);
+ dma = tmp;
+ }
}
}
@@ -681,11 +687,10 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
iova = map->iova;
size = map->size;
while ((tmp = vfio_find_dma(iommu, iova, size))) {
- if (vfio_remove_dma_overlap(iommu, iova, &size, tmp)) {
- pr_warn("%s: Error rolling back failed map\n",
- __func__);
+ int r = vfio_remove_dma_overlap(iommu, iova,
+ &size, tmp);
+ if (WARN_ON(r || !size))
break;
- }
}
}
@@ -813,6 +818,8 @@ static void vfio_iommu_type1_release(void *iommu_data)
struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
size_t size = dma->size;
vfio_remove_dma_overlap(iommu, dma->iova, &size, dma);
+ if (WARN_ON(!size))
+ break;
}
iommu_domain_free(iommu->domain);