summaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/scsi.c4
-rw-r--r--drivers/vhost/vdpa.c84
-rw-r--r--drivers/vhost/vringh.c6
3 files changed, 71 insertions, 23 deletions
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index f22fce549862..6ff8a5096691 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -220,6 +220,7 @@ struct vhost_scsi_tmf {
struct list_head queue_entry;
struct se_cmd se_cmd;
+ u8 scsi_resp;
struct vhost_scsi_inflight *inflight;
struct iovec resp_iov;
int in_iovs;
@@ -426,6 +427,7 @@ static void vhost_scsi_queue_tm_rsp(struct se_cmd *se_cmd)
struct vhost_scsi_tmf *tmf = container_of(se_cmd, struct vhost_scsi_tmf,
se_cmd);
+ tmf->scsi_resp = se_cmd->se_tmr_req->response;
transport_generic_free_cmd(&tmf->se_cmd, 0);
}
@@ -1183,7 +1185,7 @@ static void vhost_scsi_tmf_resp_work(struct vhost_work *work)
vwork);
int resp_code;
- if (tmf->se_cmd.se_tmr_req->response == TMR_FUNCTION_COMPLETE)
+ if (tmf->scsi_resp == TMR_FUNCTION_COMPLETE)
resp_code = VIRTIO_SCSI_S_FUNCTION_SUCCEEDED;
else
resp_code = VIRTIO_SCSI_S_FUNCTION_REJECTED;
diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c
index 2754f3069738..29ed4173f04e 100644
--- a/drivers/vhost/vdpa.c
+++ b/drivers/vhost/vdpa.c
@@ -348,7 +348,9 @@ static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
.last = v->range.last,
};
- return copy_to_user(argp, &range, sizeof(range));
+ if (copy_to_user(argp, &range, sizeof(range)))
+ return -EFAULT;
+ return 0;
}
static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
@@ -577,6 +579,8 @@ static int vhost_vdpa_map(struct vhost_vdpa *v,
if (r)
vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
+ else
+ atomic64_add(size >> PAGE_SHIFT, &dev->mm->pinned_vm);
return r;
}
@@ -608,8 +612,9 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
unsigned int gup_flags = FOLL_LONGTERM;
unsigned long npages, cur_base, map_pfn, last_pfn = 0;
- unsigned long locked, lock_limit, pinned, i;
+ unsigned long lock_limit, sz2pin, nchunks, i;
u64 iova = msg->iova;
+ long pinned;
int ret = 0;
if (msg->iova < v->range.first ||
@@ -620,6 +625,7 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
msg->iova + msg->size - 1))
return -EEXIST;
+ /* Limit the use of memory for bookkeeping */
page_list = (struct page **) __get_free_page(GFP_KERNEL);
if (!page_list)
return -ENOMEM;
@@ -628,52 +634,75 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
gup_flags |= FOLL_WRITE;
npages = PAGE_ALIGN(msg->size + (iova & ~PAGE_MASK)) >> PAGE_SHIFT;
- if (!npages)
- return -EINVAL;
+ if (!npages) {
+ ret = -EINVAL;
+ goto free;
+ }
mmap_read_lock(dev->mm);
- locked = atomic64_add_return(npages, &dev->mm->pinned_vm);
lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
-
- if (locked > lock_limit) {
+ if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
ret = -ENOMEM;
- goto out;
+ goto unlock;
}
cur_base = msg->uaddr & PAGE_MASK;
iova &= PAGE_MASK;
+ nchunks = 0;
while (npages) {
- pinned = min_t(unsigned long, npages, list_size);
- ret = pin_user_pages(cur_base, pinned,
- gup_flags, page_list, NULL);
- if (ret != pinned)
+ sz2pin = min_t(unsigned long, npages, list_size);
+ pinned = pin_user_pages(cur_base, sz2pin,
+ gup_flags, page_list, NULL);
+ if (sz2pin != pinned) {
+ if (pinned < 0) {
+ ret = pinned;
+ } else {
+ unpin_user_pages(page_list, pinned);
+ ret = -ENOMEM;
+ }
goto out;
+ }
+ nchunks++;
if (!last_pfn)
map_pfn = page_to_pfn(page_list[0]);
- for (i = 0; i < ret; i++) {
+ for (i = 0; i < pinned; i++) {
unsigned long this_pfn = page_to_pfn(page_list[i]);
u64 csize;
if (last_pfn && (this_pfn != last_pfn + 1)) {
/* Pin a contiguous chunk of memory */
csize = (last_pfn - map_pfn + 1) << PAGE_SHIFT;
- if (vhost_vdpa_map(v, iova, csize,
- map_pfn << PAGE_SHIFT,
- msg->perm))
+ ret = vhost_vdpa_map(v, iova, csize,
+ map_pfn << PAGE_SHIFT,
+ msg->perm);
+ if (ret) {
+ /*
+ * Unpin the pages that are left unmapped
+ * from this point on in the current
+ * page_list. The remaining outstanding
+ * ones which may stride across several
+ * chunks will be covered in the common
+ * error path subsequently.
+ */
+ unpin_user_pages(&page_list[i],
+ pinned - i);
goto out;
+ }
+
map_pfn = this_pfn;
iova += csize;
+ nchunks = 0;
}
last_pfn = this_pfn;
}
- cur_base += ret << PAGE_SHIFT;
- npages -= ret;
+ cur_base += pinned << PAGE_SHIFT;
+ npages -= pinned;
}
/* Pin the rest chunk */
@@ -681,10 +710,27 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
map_pfn << PAGE_SHIFT, msg->perm);
out:
if (ret) {
+ if (nchunks) {
+ unsigned long pfn;
+
+ /*
+ * Unpin the outstanding pages which are yet to be
+ * mapped but haven't due to vdpa_map() or
+ * pin_user_pages() failure.
+ *
+ * Mapped pages are accounted in vdpa_map(), hence
+ * the corresponding unpinning will be handled by
+ * vdpa_unmap().
+ */
+ WARN_ON(!last_pfn);
+ for (pfn = map_pfn; pfn <= last_pfn; pfn++)
+ unpin_user_page(pfn_to_page(pfn));
+ }
vhost_vdpa_unmap(v, msg->iova, msg->size);
- atomic64_sub(npages, &dev->mm->pinned_vm);
}
+unlock:
mmap_read_unlock(dev->mm);
+free:
free_page((unsigned long)page_list);
return ret;
}
diff --git a/drivers/vhost/vringh.c b/drivers/vhost/vringh.c
index 8bd8b403f087..b7403ba8e7f7 100644
--- a/drivers/vhost/vringh.c
+++ b/drivers/vhost/vringh.c
@@ -730,7 +730,7 @@ EXPORT_SYMBOL(vringh_iov_pull_user);
/**
* vringh_iov_push_user - copy bytes into vring_iov.
* @wiov: the wiov as passed to vringh_getdesc_user() (updated as we consume)
- * @dst: the place to copy.
+ * @src: the place to copy from.
* @len: the maximum length to copy.
*
* Returns the bytes copied <= len or a negative errno.
@@ -976,7 +976,7 @@ EXPORT_SYMBOL(vringh_iov_pull_kern);
/**
* vringh_iov_push_kern - copy bytes into vring_iov.
* @wiov: the wiov as passed to vringh_getdesc_kern() (updated as we consume)
- * @dst: the place to copy.
+ * @src: the place to copy from.
* @len: the maximum length to copy.
*
* Returns the bytes copied <= len or a negative errno.
@@ -1333,7 +1333,7 @@ EXPORT_SYMBOL(vringh_iov_pull_iotlb);
* vringh_iov_push_iotlb - copy bytes into vring_iov.
* @vrh: the vring.
* @wiov: the wiov as passed to vringh_getdesc_iotlb() (updated as we consume)
- * @dst: the place to copy.
+ * @src: the place to copy from.
* @len: the maximum length to copy.
*
* Returns the bytes copied <= len or a negative errno.