diff options
Diffstat (limited to 'fs/userfaultfd.c')
-rw-r--r-- | fs/userfaultfd.c | 79 |
1 files changed, 46 insertions, 33 deletions
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 0756d97b0666..1f2ddaaf3c03 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -50,7 +50,7 @@ struct userfaultfd_ctx { }; struct userfaultfd_wait_queue { - unsigned long address; + struct uffd_msg msg; wait_queue_t wq; bool pending; struct userfaultfd_ctx *ctx; @@ -77,7 +77,8 @@ static int userfaultfd_wake_function(wait_queue_t *wq, unsigned mode, /* len == 0 means wake all */ start = range->start; len = range->len; - if (len && (start > uwq->address || start + len <= uwq->address)) + if (len && (start > uwq->msg.arg.pagefault.address || + start + len <= uwq->msg.arg.pagefault.address)) goto out; ret = wake_up_state(wq->private, mode); if (ret) @@ -135,28 +136,43 @@ static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx) } } -static inline unsigned long userfault_address(unsigned long address, - unsigned int flags, - unsigned long reason) +static inline void msg_init(struct uffd_msg *msg) { - BUILD_BUG_ON(PAGE_SHIFT < UFFD_BITS); - address &= PAGE_MASK; + BUILD_BUG_ON(sizeof(struct uffd_msg) != 32); + /* + * Must use memset to zero out the paddings or kernel data is + * leaked to userland. + */ + memset(msg, 0, sizeof(struct uffd_msg)); +} + +static inline struct uffd_msg userfault_msg(unsigned long address, + unsigned int flags, + unsigned long reason) +{ + struct uffd_msg msg; + msg_init(&msg); + msg.event = UFFD_EVENT_PAGEFAULT; + msg.arg.pagefault.address = address; if (flags & FAULT_FLAG_WRITE) /* - * Encode "write" fault information in the LSB of the - * address read by userland, without depending on - * FAULT_FLAG_WRITE kernel internal value. + * If UFFD_FEATURE_PAGEFAULT_FLAG_WRITE was set in the + * uffdio_api.features and UFFD_PAGEFAULT_FLAG_WRITE + * was not set in a UFFD_EVENT_PAGEFAULT, it means it + * was a read fault, otherwise if set it means it's + * a write fault. */ - address |= UFFD_BIT_WRITE; + msg.arg.pagefault.flags |= UFFD_PAGEFAULT_FLAG_WRITE; if (reason & VM_UFFD_WP) /* - * Encode "reason" fault information as bit number 1 - * in the address read by userland. If bit number 1 is - * clear it means the reason is a VM_FAULT_MISSING - * fault. + * If UFFD_FEATURE_PAGEFAULT_FLAG_WP was set in the + * uffdio_api.features and UFFD_PAGEFAULT_FLAG_WP was + * not set in a UFFD_EVENT_PAGEFAULT, it means it was + * a missing fault, otherwise if set it means it's a + * write protect fault. */ - address |= UFFD_BIT_WP; - return address; + msg.arg.pagefault.flags |= UFFD_PAGEFAULT_FLAG_WP; + return msg; } /* @@ -242,7 +258,7 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, init_waitqueue_func_entry(&uwq.wq, userfaultfd_wake_function); uwq.wq.private = current; - uwq.address = userfault_address(address, flags, reason); + uwq.msg = userfault_msg(address, flags, reason); uwq.pending = true; uwq.ctx = ctx; @@ -398,7 +414,7 @@ static unsigned int userfaultfd_poll(struct file *file, poll_table *wait) } static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait, - __u64 *addr) + struct uffd_msg *msg) { ssize_t ret; DECLARE_WAITQUEUE(wait, current); @@ -416,8 +432,8 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait, * disappear from under us. */ uwq->pending = false; - /* careful to always initialize addr if ret == 0 */ - *addr = uwq->address; + /* careful to always initialize msg if ret == 0 */ + *msg = uwq->msg; spin_unlock(&ctx->fault_wqh.lock); ret = 0; break; @@ -447,8 +463,7 @@ static ssize_t userfaultfd_read(struct file *file, char __user *buf, { struct userfaultfd_ctx *ctx = file->private_data; ssize_t _ret, ret = 0; - /* careful to always initialize addr if ret == 0 */ - __u64 uninitialized_var(addr); + struct uffd_msg msg; int no_wait = file->f_flags & O_NONBLOCK; if (ctx->state == UFFD_STATE_WAIT_API) @@ -456,16 +471,16 @@ static ssize_t userfaultfd_read(struct file *file, char __user *buf, BUG_ON(ctx->state != UFFD_STATE_RUNNING); for (;;) { - if (count < sizeof(addr)) + if (count < sizeof(msg)) return ret ? ret : -EINVAL; - _ret = userfaultfd_ctx_read(ctx, no_wait, &addr); + _ret = userfaultfd_ctx_read(ctx, no_wait, &msg); if (_ret < 0) return ret ? ret : _ret; - if (put_user(addr, (__u64 __user *) buf)) + if (copy_to_user((__u64 __user *) buf, &msg, sizeof(msg))) return ret ? ret : -EFAULT; - ret += sizeof(addr); - buf += sizeof(addr); - count -= sizeof(addr); + ret += sizeof(msg); + buf += sizeof(msg); + count -= sizeof(msg); /* * Allow to read more than one fault at time but only * block if waiting for the very first one. @@ -873,17 +888,15 @@ static int userfaultfd_api(struct userfaultfd_ctx *ctx, if (ctx->state != UFFD_STATE_WAIT_API) goto out; ret = -EFAULT; - if (copy_from_user(&uffdio_api, buf, sizeof(__u64))) + if (copy_from_user(&uffdio_api, buf, sizeof(uffdio_api))) goto out; - if (uffdio_api.api != UFFD_API) { - /* careful not to leak info, we only read the first 8 bytes */ + if (uffdio_api.api != UFFD_API || uffdio_api.features) { memset(&uffdio_api, 0, sizeof(uffdio_api)); if (copy_to_user(buf, &uffdio_api, sizeof(uffdio_api))) goto out; ret = -EINVAL; goto out; } - /* careful not to leak info, we only read the first 8 bytes */ uffdio_api.features = UFFD_API_FEATURES; uffdio_api.ioctls = UFFD_API_IOCTLS; ret = -EFAULT; |