summaryrefslogtreecommitdiffstats
path: root/io_uring/net.c
diff options
context:
space:
mode:
Diffstat (limited to 'io_uring/net.c')
-rw-r--r--io_uring/net.c180
1 files changed, 161 insertions, 19 deletions
diff --git a/io_uring/net.c b/io_uring/net.c
index 5bc3440a8290..616d5f04cc74 100644
--- a/io_uring/net.c
+++ b/io_uring/net.c
@@ -325,6 +325,21 @@ int io_send(struct io_kiocb *req, unsigned int issue_flags)
return IOU_OK;
}
+static bool io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
+{
+ unsigned long hdr;
+
+ if (check_add_overflow(sizeof(struct io_uring_recvmsg_out),
+ (unsigned long)iomsg->namelen, &hdr))
+ return true;
+ if (check_add_overflow(hdr, iomsg->controllen, &hdr))
+ return true;
+ if (hdr > INT_MAX)
+ return true;
+
+ return false;
+}
+
static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
struct io_async_msghdr *iomsg)
{
@@ -352,6 +367,13 @@ static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
sr->len = iomsg->fast_iov[0].iov_len;
iomsg->free_iov = NULL;
}
+
+ if (req->flags & REQ_F_APOLL_MULTISHOT) {
+ iomsg->namelen = msg.msg_namelen;
+ iomsg->controllen = msg.msg_controllen;
+ if (io_recvmsg_multishot_overflow(iomsg))
+ return -EOVERFLOW;
+ }
} else {
iomsg->free_iov = iomsg->fast_iov;
ret = __import_iovec(READ, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
@@ -399,6 +421,13 @@ static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
sr->len = clen;
iomsg->free_iov = NULL;
}
+
+ if (req->flags & REQ_F_APOLL_MULTISHOT) {
+ iomsg->namelen = msg.msg_namelen;
+ iomsg->controllen = msg.msg_controllen;
+ if (io_recvmsg_multishot_overflow(iomsg))
+ return -EOVERFLOW;
+ }
} else {
iomsg->free_iov = iomsg->fast_iov;
ret = __import_iovec(READ, (struct iovec __user *)uiov, msg.msg_iovlen,
@@ -455,8 +484,6 @@ int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
if (sr->msg_flags & MSG_ERRQUEUE)
req->flags |= REQ_F_CLEAR_POLLIN;
if (sr->flags & IORING_RECV_MULTISHOT) {
- if (req->opcode == IORING_OP_RECVMSG)
- return -EINVAL;
if (!(req->flags & REQ_F_BUFFER_SELECT))
return -EINVAL;
if (sr->msg_flags & MSG_WAITALL)
@@ -483,12 +510,13 @@ static inline void io_recv_prep_retry(struct io_kiocb *req)
}
/*
- * Finishes io_recv
+ * Finishes io_recv and io_recvmsg.
*
* Returns true if it is actually finished, or false if it should run
* again (for multishot).
*/
-static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int cflags)
+static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
+ unsigned int cflags, bool mshot_finished)
{
if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
io_req_set_res(req, *ret, cflags);
@@ -496,7 +524,7 @@ static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int c
return true;
}
- if (*ret > 0) {
+ if (!mshot_finished) {
if (io_post_aux_cqe(req->ctx, req->cqe.user_data, *ret,
cflags | IORING_CQE_F_MORE, false)) {
io_recv_prep_retry(req);
@@ -518,6 +546,90 @@ static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int c
return true;
}
+static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg,
+ struct io_sr_msg *sr, void __user **buf,
+ size_t *len)
+{
+ unsigned long ubuf = (unsigned long) *buf;
+ unsigned long hdr;
+
+ hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
+ kmsg->controllen;
+ if (*len < hdr)
+ return -EFAULT;
+
+ if (kmsg->controllen) {
+ unsigned long control = ubuf + hdr - kmsg->controllen;
+
+ kmsg->msg.msg_control_user = (void *) control;
+ kmsg->msg.msg_controllen = kmsg->controllen;
+ }
+
+ sr->buf = *buf; /* stash for later copy */
+ *buf = (void *) (ubuf + hdr);
+ kmsg->payloadlen = *len = *len - hdr;
+ return 0;
+}
+
+struct io_recvmsg_multishot_hdr {
+ struct io_uring_recvmsg_out msg;
+ struct sockaddr_storage addr;
+};
+
+static int io_recvmsg_multishot(struct socket *sock, struct io_sr_msg *io,
+ struct io_async_msghdr *kmsg,
+ unsigned int flags, bool *finished)
+{
+ int err;
+ int copy_len;
+ struct io_recvmsg_multishot_hdr hdr;
+
+ if (kmsg->namelen)
+ kmsg->msg.msg_name = &hdr.addr;
+ kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
+ kmsg->msg.msg_namelen = 0;
+
+ if (sock->file->f_flags & O_NONBLOCK)
+ flags |= MSG_DONTWAIT;
+
+ err = sock_recvmsg(sock, &kmsg->msg, flags);
+ *finished = err <= 0;
+ if (err < 0)
+ return err;
+
+ hdr.msg = (struct io_uring_recvmsg_out) {
+ .controllen = kmsg->controllen - kmsg->msg.msg_controllen,
+ .flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
+ };
+
+ hdr.msg.payloadlen = err;
+ if (err > kmsg->payloadlen)
+ err = kmsg->payloadlen;
+
+ copy_len = sizeof(struct io_uring_recvmsg_out);
+ if (kmsg->msg.msg_namelen > kmsg->namelen)
+ copy_len += kmsg->namelen;
+ else
+ copy_len += kmsg->msg.msg_namelen;
+
+ /*
+ * "fromlen shall refer to the value before truncation.."
+ * 1003.1g
+ */
+ hdr.msg.namelen = kmsg->msg.msg_namelen;
+
+ /* ensure that there is no gap between hdr and sockaddr_storage */
+ BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
+ sizeof(struct io_uring_recvmsg_out));
+ if (copy_to_user(io->buf, &hdr, copy_len)) {
+ *finished = true;
+ return -EFAULT;
+ }
+
+ return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
+ kmsg->controllen + err;
+}
+
int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
{
struct io_sr_msg *sr = io_kiocb_to_cmd(req);
@@ -527,6 +639,7 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
unsigned flags;
int ret, min_ret = 0;
bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
+ bool mshot_finished = true;
sock = sock_from_file(req->file);
if (unlikely(!sock))
@@ -545,16 +658,27 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
(sr->flags & IORING_RECVSEND_POLL_FIRST))
return io_setup_async_msg(req, kmsg, issue_flags);
+retry_multishot:
if (io_do_buffer_select(req)) {
void __user *buf;
+ size_t len = sr->len;
- buf = io_buffer_select(req, &sr->len, issue_flags);
+ buf = io_buffer_select(req, &len, issue_flags);
if (!buf)
return -ENOBUFS;
+
+ if (req->flags & REQ_F_APOLL_MULTISHOT) {
+ ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
+ if (ret) {
+ io_kbuf_recycle(req, issue_flags);
+ return ret;
+ }
+ }
+
kmsg->fast_iov[0].iov_base = buf;
- kmsg->fast_iov[0].iov_len = sr->len;
+ kmsg->fast_iov[0].iov_len = len;
iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->fast_iov, 1,
- sr->len);
+ len);
}
flags = sr->msg_flags;
@@ -564,10 +688,23 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
min_ret = iov_iter_count(&kmsg->msg.msg_iter);
kmsg->msg.msg_get_inq = 1;
- ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg, kmsg->uaddr, flags);
+ if (req->flags & REQ_F_APOLL_MULTISHOT)
+ ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
+ &mshot_finished);
+ else
+ ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg,
+ kmsg->uaddr, flags);
+
if (ret < min_ret) {
- if (ret == -EAGAIN && force_nonblock)
- return io_setup_async_msg(req, kmsg, issue_flags);
+ if (ret == -EAGAIN && force_nonblock) {
+ ret = io_setup_async_msg(req, kmsg, issue_flags);
+ if (ret == -EAGAIN && (req->flags & IO_APOLL_MULTI_POLLED) ==
+ IO_APOLL_MULTI_POLLED) {
+ io_kbuf_recycle(req, issue_flags);
+ return IOU_ISSUE_SKIP_COMPLETE;
+ }
+ return ret;
+ }
if (ret == -ERESTARTSYS)
ret = -EINTR;
if (ret > 0 && io_net_retry(sock, flags)) {
@@ -580,11 +717,6 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
req_set_fail(req);
}
- /* fast path, check for non-NULL to avoid function call */
- if (kmsg->free_iov)
- kfree(kmsg->free_iov);
- io_netmsg_recycle(req, issue_flags);
- req->flags &= ~REQ_F_NEED_CLEANUP;
if (ret > 0)
ret += sr->done_io;
else if (sr->done_io)
@@ -596,8 +728,18 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
if (kmsg->msg.msg_inq)
cflags |= IORING_CQE_F_SOCK_NONEMPTY;
- io_req_set_res(req, ret, cflags);
- return IOU_OK;
+ if (!io_recv_finish(req, &ret, cflags, mshot_finished))
+ goto retry_multishot;
+
+ if (mshot_finished) {
+ io_netmsg_recycle(req, issue_flags);
+ /* fast path, check for non-NULL to avoid function call */
+ if (kmsg->free_iov)
+ kfree(kmsg->free_iov);
+ req->flags &= ~REQ_F_NEED_CLEANUP;
+ }
+
+ return ret;
}
int io_recv(struct io_kiocb *req, unsigned int issue_flags)
@@ -684,7 +826,7 @@ out_free:
if (msg.msg_inq)
cflags |= IORING_CQE_F_SOCK_NONEMPTY;
- if (!io_recv_finish(req, &ret, cflags))
+ if (!io_recv_finish(req, &ret, cflags, ret <= 0))
goto retry_multishot;
return ret;