summaryrefslogtreecommitdiffstats
path: root/fs
diff options
context:
space:
mode:
Diffstat (limited to 'fs')
-rw-r--r--fs/io-wq.c68
-rw-r--r--fs/io-wq.h7
-rw-r--r--fs/io_uring.c36
3 files changed, 81 insertions, 30 deletions
diff --git a/fs/io-wq.c b/fs/io-wq.c
index 54e270ae12ab..7ccedc82a703 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -56,7 +56,8 @@ struct io_worker {
struct rcu_head rcu;
struct mm_struct *mm;
- const struct cred *creds;
+ const struct cred *cur_creds;
+ const struct cred *saved_creds;
struct files_struct *restore_files;
};
@@ -109,8 +110,6 @@ struct io_wq {
struct task_struct *manager;
struct user_struct *user;
- const struct cred *creds;
- struct mm_struct *mm;
refcount_t refs;
struct completion done;
@@ -137,9 +136,9 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
{
bool dropped_lock = false;
- if (worker->creds) {
- revert_creds(worker->creds);
- worker->creds = NULL;
+ if (worker->saved_creds) {
+ revert_creds(worker->saved_creds);
+ worker->cur_creds = worker->saved_creds = NULL;
}
if (current->files != worker->restore_files) {
@@ -398,6 +397,43 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
return NULL;
}
+static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work)
+{
+ if (worker->mm) {
+ unuse_mm(worker->mm);
+ mmput(worker->mm);
+ worker->mm = NULL;
+ }
+ if (!work->mm) {
+ set_fs(KERNEL_DS);
+ return;
+ }
+ if (mmget_not_zero(work->mm)) {
+ use_mm(work->mm);
+ if (!worker->mm)
+ set_fs(USER_DS);
+ worker->mm = work->mm;
+ /* hang on to this mm */
+ work->mm = NULL;
+ return;
+ }
+
+ /* failed grabbing mm, ensure work gets cancelled */
+ work->flags |= IO_WQ_WORK_CANCEL;
+}
+
+static void io_wq_switch_creds(struct io_worker *worker,
+ struct io_wq_work *work)
+{
+ const struct cred *old_creds = override_creds(work->creds);
+
+ worker->cur_creds = work->creds;
+ if (worker->saved_creds)
+ put_cred(old_creds); /* creds set by previous switch */
+ else
+ worker->saved_creds = old_creds;
+}
+
static void io_worker_handle_work(struct io_worker *worker)
__releases(wqe->lock)
{
@@ -446,18 +482,10 @@ next:
current->files = work->files;
task_unlock(current);
}
- if ((work->flags & IO_WQ_WORK_NEEDS_USER) && !worker->mm &&
- wq->mm) {
- if (mmget_not_zero(wq->mm)) {
- use_mm(wq->mm);
- set_fs(USER_DS);
- worker->mm = wq->mm;
- } else {
- work->flags |= IO_WQ_WORK_CANCEL;
- }
- }
- if (!worker->creds)
- worker->creds = override_creds(wq->creds);
+ if (work->mm != worker->mm)
+ io_wq_switch_mm(worker, work);
+ if (worker->cur_creds != work->creds)
+ io_wq_switch_creds(worker, work);
/*
* OK to set IO_WQ_WORK_CANCEL even for uncancellable work,
* the worker function will do the right thing.
@@ -1037,7 +1065,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
/* caller must already hold a reference to this */
wq->user = data->user;
- wq->creds = data->creds;
for_each_node(node) {
struct io_wqe *wqe;
@@ -1064,9 +1091,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
init_completion(&wq->done);
- /* caller must have already done mmgrab() on this mm */
- wq->mm = data->mm;
-
wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager");
if (!IS_ERR(wq->manager)) {
wake_up_process(wq->manager);
diff --git a/fs/io-wq.h b/fs/io-wq.h
index 1cd039af8813..167316ad447e 100644
--- a/fs/io-wq.h
+++ b/fs/io-wq.h
@@ -7,7 +7,6 @@ enum {
IO_WQ_WORK_CANCEL = 1,
IO_WQ_WORK_HAS_MM = 2,
IO_WQ_WORK_HASHED = 4,
- IO_WQ_WORK_NEEDS_USER = 8,
IO_WQ_WORK_NEEDS_FILES = 16,
IO_WQ_WORK_UNBOUND = 32,
IO_WQ_WORK_INTERNAL = 64,
@@ -74,6 +73,8 @@ struct io_wq_work {
};
void (*func)(struct io_wq_work **);
struct files_struct *files;
+ struct mm_struct *mm;
+ const struct cred *creds;
unsigned flags;
};
@@ -83,15 +84,15 @@ struct io_wq_work {
(work)->func = _func; \
(work)->flags = 0; \
(work)->files = NULL; \
+ (work)->mm = NULL; \
+ (work)->creds = NULL; \
} while (0) \
typedef void (get_work_fn)(struct io_wq_work *);
typedef void (put_work_fn)(struct io_wq_work *);
struct io_wq_data {
- struct mm_struct *mm;
struct user_struct *user;
- const struct cred *creds;
get_work_fn *get_work;
put_work_fn *put_work;
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 1dd20305c664..0ea36911745d 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -875,6 +875,29 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
}
}
+static inline void io_req_work_grab_env(struct io_kiocb *req,
+ const struct io_op_def *def)
+{
+ if (!req->work.mm && def->needs_mm) {
+ mmgrab(current->mm);
+ req->work.mm = current->mm;
+ }
+ if (!req->work.creds)
+ req->work.creds = get_current_cred();
+}
+
+static inline void io_req_work_drop_env(struct io_kiocb *req)
+{
+ if (req->work.mm) {
+ mmdrop(req->work.mm);
+ req->work.mm = NULL;
+ }
+ if (req->work.creds) {
+ put_cred(req->work.creds);
+ req->work.creds = NULL;
+ }
+}
+
static inline bool io_prep_async_work(struct io_kiocb *req,
struct io_kiocb **link)
{
@@ -888,8 +911,8 @@ static inline bool io_prep_async_work(struct io_kiocb *req,
if (def->unbound_nonreg_file)
req->work.flags |= IO_WQ_WORK_UNBOUND;
}
- if (def->needs_mm)
- req->work.flags |= IO_WQ_WORK_NEEDS_USER;
+
+ io_req_work_grab_env(req, def);
*link = io_prep_linked_timeout(req);
return do_hashed;
@@ -1180,6 +1203,8 @@ static void __io_req_aux_free(struct io_kiocb *req)
else
fput(req->file);
}
+
+ io_req_work_drop_env(req);
}
static void __io_free_req(struct io_kiocb *req)
@@ -3963,6 +3988,8 @@ static int io_req_defer_prep(struct io_kiocb *req,
{
ssize_t ret = 0;
+ io_req_work_grab_env(req, &io_op_defs[req->opcode]);
+
switch (req->opcode) {
case IORING_OP_NOP:
break;
@@ -5725,9 +5752,7 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
goto err;
}
- data.mm = ctx->sqo_mm;
data.user = ctx->user;
- data.creds = ctx->creds;
data.get_work = io_get_work;
data.put_work = io_put_work;
@@ -6535,7 +6560,8 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p)
goto err;
p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP |
- IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS;
+ IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS |
+ IORING_FEAT_CUR_PERSONALITY;
trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags);
return ret;
err: