diff options
Diffstat (limited to 'fs')
-rw-r--r-- | fs/io-wq.c | 68 | ||||
-rw-r--r-- | fs/io-wq.h | 7 | ||||
-rw-r--r-- | fs/io_uring.c | 36 |
3 files changed, 81 insertions, 30 deletions
diff --git a/fs/io-wq.c b/fs/io-wq.c index 54e270ae12ab..7ccedc82a703 100644 --- a/fs/io-wq.c +++ b/fs/io-wq.c @@ -56,7 +56,8 @@ struct io_worker { struct rcu_head rcu; struct mm_struct *mm; - const struct cred *creds; + const struct cred *cur_creds; + const struct cred *saved_creds; struct files_struct *restore_files; }; @@ -109,8 +110,6 @@ struct io_wq { struct task_struct *manager; struct user_struct *user; - const struct cred *creds; - struct mm_struct *mm; refcount_t refs; struct completion done; @@ -137,9 +136,9 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker) { bool dropped_lock = false; - if (worker->creds) { - revert_creds(worker->creds); - worker->creds = NULL; + if (worker->saved_creds) { + revert_creds(worker->saved_creds); + worker->cur_creds = worker->saved_creds = NULL; } if (current->files != worker->restore_files) { @@ -398,6 +397,43 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash) return NULL; } +static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work) +{ + if (worker->mm) { + unuse_mm(worker->mm); + mmput(worker->mm); + worker->mm = NULL; + } + if (!work->mm) { + set_fs(KERNEL_DS); + return; + } + if (mmget_not_zero(work->mm)) { + use_mm(work->mm); + if (!worker->mm) + set_fs(USER_DS); + worker->mm = work->mm; + /* hang on to this mm */ + work->mm = NULL; + return; + } + + /* failed grabbing mm, ensure work gets cancelled */ + work->flags |= IO_WQ_WORK_CANCEL; +} + +static void io_wq_switch_creds(struct io_worker *worker, + struct io_wq_work *work) +{ + const struct cred *old_creds = override_creds(work->creds); + + worker->cur_creds = work->creds; + if (worker->saved_creds) + put_cred(old_creds); /* creds set by previous switch */ + else + worker->saved_creds = old_creds; +} + static void io_worker_handle_work(struct io_worker *worker) __releases(wqe->lock) { @@ -446,18 +482,10 @@ next: current->files = work->files; task_unlock(current); } - if ((work->flags & IO_WQ_WORK_NEEDS_USER) && !worker->mm && - wq->mm) { - if (mmget_not_zero(wq->mm)) { - use_mm(wq->mm); - set_fs(USER_DS); - worker->mm = wq->mm; - } else { - work->flags |= IO_WQ_WORK_CANCEL; - } - } - if (!worker->creds) - worker->creds = override_creds(wq->creds); + if (work->mm != worker->mm) + io_wq_switch_mm(worker, work); + if (worker->cur_creds != work->creds) + io_wq_switch_creds(worker, work); /* * OK to set IO_WQ_WORK_CANCEL even for uncancellable work, * the worker function will do the right thing. @@ -1037,7 +1065,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) /* caller must already hold a reference to this */ wq->user = data->user; - wq->creds = data->creds; for_each_node(node) { struct io_wqe *wqe; @@ -1064,9 +1091,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) init_completion(&wq->done); - /* caller must have already done mmgrab() on this mm */ - wq->mm = data->mm; - wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager"); if (!IS_ERR(wq->manager)) { wake_up_process(wq->manager); diff --git a/fs/io-wq.h b/fs/io-wq.h index 1cd039af8813..167316ad447e 100644 --- a/fs/io-wq.h +++ b/fs/io-wq.h @@ -7,7 +7,6 @@ enum { IO_WQ_WORK_CANCEL = 1, IO_WQ_WORK_HAS_MM = 2, IO_WQ_WORK_HASHED = 4, - IO_WQ_WORK_NEEDS_USER = 8, IO_WQ_WORK_NEEDS_FILES = 16, IO_WQ_WORK_UNBOUND = 32, IO_WQ_WORK_INTERNAL = 64, @@ -74,6 +73,8 @@ struct io_wq_work { }; void (*func)(struct io_wq_work **); struct files_struct *files; + struct mm_struct *mm; + const struct cred *creds; unsigned flags; }; @@ -83,15 +84,15 @@ struct io_wq_work { (work)->func = _func; \ (work)->flags = 0; \ (work)->files = NULL; \ + (work)->mm = NULL; \ + (work)->creds = NULL; \ } while (0) \ typedef void (get_work_fn)(struct io_wq_work *); typedef void (put_work_fn)(struct io_wq_work *); struct io_wq_data { - struct mm_struct *mm; struct user_struct *user; - const struct cred *creds; get_work_fn *get_work; put_work_fn *put_work; diff --git a/fs/io_uring.c b/fs/io_uring.c index 1dd20305c664..0ea36911745d 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -875,6 +875,29 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx) } } +static inline void io_req_work_grab_env(struct io_kiocb *req, + const struct io_op_def *def) +{ + if (!req->work.mm && def->needs_mm) { + mmgrab(current->mm); + req->work.mm = current->mm; + } + if (!req->work.creds) + req->work.creds = get_current_cred(); +} + +static inline void io_req_work_drop_env(struct io_kiocb *req) +{ + if (req->work.mm) { + mmdrop(req->work.mm); + req->work.mm = NULL; + } + if (req->work.creds) { + put_cred(req->work.creds); + req->work.creds = NULL; + } +} + static inline bool io_prep_async_work(struct io_kiocb *req, struct io_kiocb **link) { @@ -888,8 +911,8 @@ static inline bool io_prep_async_work(struct io_kiocb *req, if (def->unbound_nonreg_file) req->work.flags |= IO_WQ_WORK_UNBOUND; } - if (def->needs_mm) - req->work.flags |= IO_WQ_WORK_NEEDS_USER; + + io_req_work_grab_env(req, def); *link = io_prep_linked_timeout(req); return do_hashed; @@ -1180,6 +1203,8 @@ static void __io_req_aux_free(struct io_kiocb *req) else fput(req->file); } + + io_req_work_drop_env(req); } static void __io_free_req(struct io_kiocb *req) @@ -3963,6 +3988,8 @@ static int io_req_defer_prep(struct io_kiocb *req, { ssize_t ret = 0; + io_req_work_grab_env(req, &io_op_defs[req->opcode]); + switch (req->opcode) { case IORING_OP_NOP: break; @@ -5725,9 +5752,7 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx, goto err; } - data.mm = ctx->sqo_mm; data.user = ctx->user; - data.creds = ctx->creds; data.get_work = io_get_work; data.put_work = io_put_work; @@ -6535,7 +6560,8 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p) goto err; p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP | - IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS; + IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS | + IORING_FEAT_CUR_PERSONALITY; trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags); return ret; err: |