diff options
Diffstat (limited to 'drivers/vhost')
-rw-r--r-- | drivers/vhost/scsi.c | 4 | ||||
-rw-r--r-- | drivers/vhost/vdpa.c | 84 | ||||
-rw-r--r-- | drivers/vhost/vringh.c | 6 |
3 files changed, 71 insertions, 23 deletions
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c index f22fce549862..6ff8a5096691 100644 --- a/drivers/vhost/scsi.c +++ b/drivers/vhost/scsi.c @@ -220,6 +220,7 @@ struct vhost_scsi_tmf { struct list_head queue_entry; struct se_cmd se_cmd; + u8 scsi_resp; struct vhost_scsi_inflight *inflight; struct iovec resp_iov; int in_iovs; @@ -426,6 +427,7 @@ static void vhost_scsi_queue_tm_rsp(struct se_cmd *se_cmd) struct vhost_scsi_tmf *tmf = container_of(se_cmd, struct vhost_scsi_tmf, se_cmd); + tmf->scsi_resp = se_cmd->se_tmr_req->response; transport_generic_free_cmd(&tmf->se_cmd, 0); } @@ -1183,7 +1185,7 @@ static void vhost_scsi_tmf_resp_work(struct vhost_work *work) vwork); int resp_code; - if (tmf->se_cmd.se_tmr_req->response == TMR_FUNCTION_COMPLETE) + if (tmf->scsi_resp == TMR_FUNCTION_COMPLETE) resp_code = VIRTIO_SCSI_S_FUNCTION_SUCCEEDED; else resp_code = VIRTIO_SCSI_S_FUNCTION_REJECTED; diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c index 2754f3069738..29ed4173f04e 100644 --- a/drivers/vhost/vdpa.c +++ b/drivers/vhost/vdpa.c @@ -348,7 +348,9 @@ static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp) .last = v->range.last, }; - return copy_to_user(argp, &range, sizeof(range)); + if (copy_to_user(argp, &range, sizeof(range))) + return -EFAULT; + return 0; } static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd, @@ -577,6 +579,8 @@ static int vhost_vdpa_map(struct vhost_vdpa *v, if (r) vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1); + else + atomic64_add(size >> PAGE_SHIFT, &dev->mm->pinned_vm); return r; } @@ -608,8 +612,9 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v, unsigned long list_size = PAGE_SIZE / sizeof(struct page *); unsigned int gup_flags = FOLL_LONGTERM; unsigned long npages, cur_base, map_pfn, last_pfn = 0; - unsigned long locked, lock_limit, pinned, i; + unsigned long lock_limit, sz2pin, nchunks, i; u64 iova = msg->iova; + long pinned; int ret = 0; if (msg->iova < v->range.first || @@ -620,6 +625,7 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v, msg->iova + msg->size - 1)) return -EEXIST; + /* Limit the use of memory for bookkeeping */ page_list = (struct page **) __get_free_page(GFP_KERNEL); if (!page_list) return -ENOMEM; @@ -628,52 +634,75 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v, gup_flags |= FOLL_WRITE; npages = PAGE_ALIGN(msg->size + (iova & ~PAGE_MASK)) >> PAGE_SHIFT; - if (!npages) - return -EINVAL; + if (!npages) { + ret = -EINVAL; + goto free; + } mmap_read_lock(dev->mm); - locked = atomic64_add_return(npages, &dev->mm->pinned_vm); lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT; - - if (locked > lock_limit) { + if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) { ret = -ENOMEM; - goto out; + goto unlock; } cur_base = msg->uaddr & PAGE_MASK; iova &= PAGE_MASK; + nchunks = 0; while (npages) { - pinned = min_t(unsigned long, npages, list_size); - ret = pin_user_pages(cur_base, pinned, - gup_flags, page_list, NULL); - if (ret != pinned) + sz2pin = min_t(unsigned long, npages, list_size); + pinned = pin_user_pages(cur_base, sz2pin, + gup_flags, page_list, NULL); + if (sz2pin != pinned) { + if (pinned < 0) { + ret = pinned; + } else { + unpin_user_pages(page_list, pinned); + ret = -ENOMEM; + } goto out; + } + nchunks++; if (!last_pfn) map_pfn = page_to_pfn(page_list[0]); - for (i = 0; i < ret; i++) { + for (i = 0; i < pinned; i++) { unsigned long this_pfn = page_to_pfn(page_list[i]); u64 csize; if (last_pfn && (this_pfn != last_pfn + 1)) { /* Pin a contiguous chunk of memory */ csize = (last_pfn - map_pfn + 1) << PAGE_SHIFT; - if (vhost_vdpa_map(v, iova, csize, - map_pfn << PAGE_SHIFT, - msg->perm)) + ret = vhost_vdpa_map(v, iova, csize, + map_pfn << PAGE_SHIFT, + msg->perm); + if (ret) { + /* + * Unpin the pages that are left unmapped + * from this point on in the current + * page_list. The remaining outstanding + * ones which may stride across several + * chunks will be covered in the common + * error path subsequently. + */ + unpin_user_pages(&page_list[i], + pinned - i); goto out; + } + map_pfn = this_pfn; iova += csize; + nchunks = 0; } last_pfn = this_pfn; } - cur_base += ret << PAGE_SHIFT; - npages -= ret; + cur_base += pinned << PAGE_SHIFT; + npages -= pinned; } /* Pin the rest chunk */ @@ -681,10 +710,27 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v, map_pfn << PAGE_SHIFT, msg->perm); out: if (ret) { + if (nchunks) { + unsigned long pfn; + + /* + * Unpin the outstanding pages which are yet to be + * mapped but haven't due to vdpa_map() or + * pin_user_pages() failure. + * + * Mapped pages are accounted in vdpa_map(), hence + * the corresponding unpinning will be handled by + * vdpa_unmap(). + */ + WARN_ON(!last_pfn); + for (pfn = map_pfn; pfn <= last_pfn; pfn++) + unpin_user_page(pfn_to_page(pfn)); + } vhost_vdpa_unmap(v, msg->iova, msg->size); - atomic64_sub(npages, &dev->mm->pinned_vm); } +unlock: mmap_read_unlock(dev->mm); +free: free_page((unsigned long)page_list); return ret; } diff --git a/drivers/vhost/vringh.c b/drivers/vhost/vringh.c index 8bd8b403f087..b7403ba8e7f7 100644 --- a/drivers/vhost/vringh.c +++ b/drivers/vhost/vringh.c @@ -730,7 +730,7 @@ EXPORT_SYMBOL(vringh_iov_pull_user); /** * vringh_iov_push_user - copy bytes into vring_iov. * @wiov: the wiov as passed to vringh_getdesc_user() (updated as we consume) - * @dst: the place to copy. + * @src: the place to copy from. * @len: the maximum length to copy. * * Returns the bytes copied <= len or a negative errno. @@ -976,7 +976,7 @@ EXPORT_SYMBOL(vringh_iov_pull_kern); /** * vringh_iov_push_kern - copy bytes into vring_iov. * @wiov: the wiov as passed to vringh_getdesc_kern() (updated as we consume) - * @dst: the place to copy. + * @src: the place to copy from. * @len: the maximum length to copy. * * Returns the bytes copied <= len or a negative errno. @@ -1333,7 +1333,7 @@ EXPORT_SYMBOL(vringh_iov_pull_iotlb); * vringh_iov_push_iotlb - copy bytes into vring_iov. * @vrh: the vring. * @wiov: the wiov as passed to vringh_getdesc_iotlb() (updated as we consume) - * @dst: the place to copy. + * @src: the place to copy from. * @len: the maximum length to copy. * * Returns the bytes copied <= len or a negative errno. |