diff options
author | Luca Boccassi <luca.boccassi@microsoft.com> | 2022-01-31 14:56:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-31 14:56:04 +0100 |
commit | 06d4d83fa79306b7f01531a509db62729ab5fe43 (patch) | |
tree | 9709f7daa55d45f6af63ce2973f844d08b980757 /src | |
parent | core: don't fail on EEXIST when creating mount point (diff) | |
parent | resolve: llmnr: fix never hit condition (diff) | |
download | systemd-06d4d83fa79306b7f01531a509db62729ab5fe43.tar.xz systemd-06d4d83fa79306b7f01531a509db62729ab5fe43.zip |
Merge pull request #22274 from yuwata/resolve-comment
resolve: cleanups for on_stream_io()
Diffstat (limited to 'src')
-rw-r--r-- | src/resolve/resolved-dns-stream.c | 84 | ||||
-rw-r--r-- | src/resolve/resolved-dns-stream.h | 16 | ||||
-rw-r--r-- | src/resolve/resolved-dns-stub.c | 12 | ||||
-rw-r--r-- | src/resolve/resolved-dns-transaction.c | 14 | ||||
-rw-r--r-- | src/resolve/resolved-llmnr.c | 14 | ||||
-rw-r--r-- | src/resolve/test-resolved-stream.c | 8 |
6 files changed, 81 insertions, 67 deletions
diff --git a/src/resolve/resolved-dns-stream.c b/src/resolve/resolved-dns-stream.c index 51ffa6b4b0..cf9d1a9d5e 100644 --- a/src/resolve/resolved-dns-stream.c +++ b/src/resolve/resolved-dns-stream.c @@ -281,6 +281,29 @@ static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) { return dns_stream_complete(s, ETIMEDOUT); } +static DnsPacket *dns_stream_take_read_packet(DnsStream *s) { + assert(s); + + /* Note, dns_stream_update() should be called after this is called. When this is called, the + * stream may be already full and the EPOLLIN flag is dropped from the stream IO event source. + * Even this makes a room to read in the stream, this does not call dns_stream_update(), hence + * EPOLLIN flag is not set automatically. So, to read further packets from the stream, + * dns_stream_update() must be called explicitly. Currently, this is only called from + * on_stream_io_impl(), and there dns_stream_update() is called. */ + + if (!s->read_packet) + return NULL; + + if (s->n_read < sizeof(s->read_size)) + return NULL; + + if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) + return NULL; + + s->n_read = 0; + return TAKE_PTR(s->read_packet); +} + static int on_stream_io_impl(DnsStream *s, uint32_t revents) { bool progressed = false; int r; @@ -411,31 +434,37 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) { s->n_read += ss; } - /* Are we done? If so, disable the event source for EPOLLIN */ - if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) { - /* If there's a packet handler - * installed, call that. Note that - * this is optional... */ - if (s->on_packet) { - r = s->on_packet(s); - if (r < 0) - return r; - } + /* Are we done? If so, call the packet handler and re-enable EPOLLIN for the + * event source if necessary. */ + _cleanup_(dns_packet_unrefp) DnsPacket *p = dns_stream_take_read_packet(s); + if (p) { + assert(s->on_packet); + r = s->on_packet(s, p); + if (r < 0) + return r; r = dns_stream_update_io(s); if (r < 0) return dns_stream_complete(s, -r); + + s->packet_received = true; } } } - /* Call "complete" callback if finished reading and writing one packet, and there's nothing else left - * to write. */ - if (s->type == DNS_STREAM_LLMNR_SEND && - (s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) && - ordered_set_isempty(s->write_queue) && - (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size)) - return dns_stream_complete(s, 0); + if (s->type == DNS_STREAM_LLMNR_SEND && s->packet_received) { + uint32_t events; + + /* Complete the stream if finished reading and writing one packet, and there's nothing + * else left to write. */ + + r = sd_event_source_get_io_events(s->io_event_source, &events); + if (r < 0) + return r; + + if (!FLAGS_SET(events, EPOLLOUT)) + return dns_stream_complete(s, 0); + } /* If we did something, let's restart the timeout event source */ if (progressed && s->timeout_event_source) { @@ -523,6 +552,8 @@ int dns_stream_new( DnsProtocol protocol, int fd, const union sockaddr_union *tfo_address, + int (on_packet)(DnsStream*, DnsPacket*), + int (complete)(DnsStream*, int), /* optional */ usec_t connect_timeout_usec) { _cleanup_(dns_stream_unrefp) DnsStream *s = NULL; @@ -535,6 +566,7 @@ int dns_stream_new( assert(protocol >= 0); assert(protocol < _DNS_PROTOCOL_MAX); assert(fd >= 0); + assert(on_packet); if (m->n_dns_streams[type] > DNS_STREAMS_MAX) return -EBUSY; @@ -576,6 +608,8 @@ int dns_stream_new( s->manager = m; s->fd = fd; + s->on_packet = on_packet; + s->complete = complete; if (tfo_address) { s->tfo_address = *tfo_address; @@ -602,22 +636,6 @@ int dns_stream_write_packet(DnsStream *s, DnsPacket *p) { return dns_stream_update_io(s); } -DnsPacket *dns_stream_take_read_packet(DnsStream *s) { - assert(s); - - if (!s->read_packet) - return NULL; - - if (s->n_read < sizeof(s->read_size)) - return NULL; - - if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) - return NULL; - - s->n_read = 0; - return TAKE_PTR(s->read_packet); -} - void dns_stream_detach(DnsStream *s) { assert(s); diff --git a/src/resolve/resolved-dns-stream.h b/src/resolve/resolved-dns-stream.h index 96b977f628..1c606365cd 100644 --- a/src/resolve/resolved-dns-stream.h +++ b/src/resolve/resolved-dns-stream.h @@ -60,6 +60,7 @@ struct DnsStream { int ifindex; uint32_t ttl; bool identified; + bool packet_received; /* At least one packet is received. Used by LLMNR. */ /* only when using TCP fast open */ union sockaddr_union tfo_address; @@ -78,7 +79,7 @@ struct DnsStream { size_t n_written, n_read; OrderedSet *write_queue; - int (*on_packet)(DnsStream *s); + int (*on_packet)(DnsStream *s, DnsPacket *p); int (*complete)(DnsStream *s, int error); LIST_HEAD(DnsTransaction, transactions); /* when used by the transaction logic */ @@ -93,7 +94,16 @@ struct DnsStream { LIST_FIELDS(DnsStream, streams); }; -int dns_stream_new(Manager *m, DnsStream **s, DnsStreamType type, DnsProtocol protocol, int fd, const union sockaddr_union *tfo_address, usec_t timeout); +int dns_stream_new( + Manager *m, + DnsStream **ret, + DnsStreamType type, + DnsProtocol protocol, + int fd, + const union sockaddr_union *tfo_address, + int (on_packet)(DnsStream*, DnsPacket*), + int (complete)(DnsStream*, int), /* optional */ + usec_t connect_timeout_usec); #if ENABLE_DNS_OVER_TLS int dns_stream_connect_tls(DnsStream *s, void *tls_session); #endif @@ -114,6 +124,4 @@ static inline bool DNS_STREAM_QUEUED(DnsStream *s) { return !!s->write_packet; } -DnsPacket *dns_stream_take_read_packet(DnsStream *s); - void dns_stream_detach(DnsStream *s); diff --git a/src/resolve/resolved-dns-stub.c b/src/resolve/resolved-dns-stub.c index 73590e3f9b..992ae19bbc 100644 --- a/src/resolve/resolved-dns-stub.c +++ b/src/resolve/resolved-dns-stub.c @@ -1044,12 +1044,9 @@ static int on_dns_stub_packet_extra(sd_event_source *s, int fd, uint32_t revents return on_dns_stub_packet_internal(s, fd, revents, l->manager, l); } -static int on_dns_stub_stream_packet(DnsStream *s) { - _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL; - +static int on_dns_stub_stream_packet(DnsStream *s, DnsPacket *p) { assert(s); - - p = dns_stream_take_read_packet(s); + assert(s->manager); assert(p); if (dns_packet_validate_query(p) > 0) { @@ -1074,15 +1071,14 @@ static int on_dns_stub_stream_internal(sd_event_source *s, int fd, uint32_t reve return -errno; } - r = dns_stream_new(m, &stream, DNS_STREAM_STUB, DNS_PROTOCOL_DNS, cfd, NULL, DNS_STREAM_STUB_TIMEOUT_USEC); + r = dns_stream_new(m, &stream, DNS_STREAM_STUB, DNS_PROTOCOL_DNS, cfd, NULL, + on_dns_stub_stream_packet, dns_stub_stream_complete, DNS_STREAM_STUB_TIMEOUT_USEC); if (r < 0) { safe_close(cfd); return r; } stream->stub_listener_extra = l; - stream->on_packet = on_dns_stub_stream_packet; - stream->complete = dns_stub_stream_complete; /* We let the reference to the stream dangle here, it will be dropped later by the complete callback. */ diff --git a/src/resolve/resolved-dns-transaction.c b/src/resolve/resolved-dns-transaction.c index 0cf9912712..f937f9f7b5 100644 --- a/src/resolve/resolved-dns-transaction.c +++ b/src/resolve/resolved-dns-transaction.c @@ -644,14 +644,12 @@ static int on_stream_complete(DnsStream *s, int error) { return 0; } -static int on_stream_packet(DnsStream *s) { - _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL; +static int on_stream_packet(DnsStream *s, DnsPacket *p) { DnsTransaction *t; assert(s); - - /* Take ownership of packet to be able to receive new packets */ - assert_se(p = dns_stream_take_read_packet(s)); + assert(s->manager); + assert(p); t = hashmap_get(s->manager->dns_transactions, UINT_TO_PTR(DNS_PACKET_ID(p))); if (t && t->stream == s) /* Validate that the stream we got this on actually is the stream the @@ -754,7 +752,8 @@ static int dns_transaction_emit_tcp(DnsTransaction *t) { if (fd < 0) return fd; - r = dns_stream_new(t->scope->manager, &s, type, t->scope->protocol, fd, &sa, stream_timeout_usec); + r = dns_stream_new(t->scope->manager, &s, type, t->scope->protocol, fd, &sa, + on_stream_packet, on_stream_complete, stream_timeout_usec); if (r < 0) return r; @@ -777,9 +776,6 @@ static int dns_transaction_emit_tcp(DnsTransaction *t) { t->server->stream = dns_stream_ref(s); } - s->complete = on_stream_complete; - s->on_packet = on_stream_packet; - /* The interface index is difficult to determine if we are * connecting to the local host, hence fill this in right away * instead of determining it from the socket */ diff --git a/src/resolve/resolved-llmnr.c b/src/resolve/resolved-llmnr.c index 32483006b1..76e42940f4 100644 --- a/src/resolve/resolved-llmnr.c +++ b/src/resolve/resolved-llmnr.c @@ -277,13 +277,11 @@ int manager_llmnr_ipv6_udp_fd(Manager *m) { return m->llmnr_ipv6_udp_fd = TAKE_FD(s); } -static int on_llmnr_stream_packet(DnsStream *s) { - _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL; +static int on_llmnr_stream_packet(DnsStream *s, DnsPacket *p) { DnsScope *scope; assert(s); - - p = dns_stream_take_read_packet(s); + assert(s->manager); assert(p); scope = manager_find_scope(s->manager, p); @@ -296,7 +294,6 @@ static int on_llmnr_stream_packet(DnsStream *s) { } else log_debug("Invalid LLMNR TCP packet, ignoring."); - dns_stream_unref(s); return 0; } @@ -313,15 +310,14 @@ static int on_llmnr_stream(sd_event_source *s, int fd, uint32_t revents, void *u return -errno; } - r = dns_stream_new(m, &stream, DNS_STREAM_LLMNR_RECV, DNS_PROTOCOL_LLMNR, cfd, NULL, DNS_STREAM_DEFAULT_TIMEOUT_USEC); + /* We don't configure a "complete" handler here, we rely on the default handler, thus freeing it */ + r = dns_stream_new(m, &stream, DNS_STREAM_LLMNR_RECV, DNS_PROTOCOL_LLMNR, cfd, NULL, + on_llmnr_stream_packet, NULL, DNS_STREAM_DEFAULT_TIMEOUT_USEC); if (r < 0) { safe_close(cfd); return r; } - stream->on_packet = on_llmnr_stream_packet; - /* We don't configure a "complete" handler here, we rely on the default handler than simply drops the - * reference to the stream, thus freeing it */ return 0; } diff --git a/src/resolve/test-resolved-stream.c b/src/resolve/test-resolved-stream.c index 50173389dd..f9428989f0 100644 --- a/src/resolve/test-resolved-stream.c +++ b/src/resolve/test-resolved-stream.c @@ -194,9 +194,9 @@ static const size_t MAX_RECEIVED_PACKETS = 2; static DnsPacket *received_packets[2] = {}; static size_t n_received_packets = 0; -static int on_stream_packet(DnsStream *stream) { +static int on_stream_packet(DnsStream *stream, DnsPacket *p) { assert_se(n_received_packets < MAX_RECEIVED_PACKETS); - assert_se(received_packets[n_received_packets++] = dns_stream_take_read_packet(stream)); + assert_se(received_packets[n_received_packets++] = dns_packet_ref(p)); return 0; } @@ -253,8 +253,8 @@ static void test_dns_stream(bool tls) { /* Initialize DNS stream */ assert_se(dns_stream_new(&manager, &stream, DNS_STREAM_LOOKUP, DNS_PROTOCOL_DNS, - TAKE_FD(clientfd), NULL, DNS_STREAM_DEFAULT_TIMEOUT_USEC) >= 0); - stream->on_packet = on_stream_packet; + TAKE_FD(clientfd), NULL, on_stream_packet, NULL, + DNS_STREAM_DEFAULT_TIMEOUT_USEC) >= 0); #if ENABLE_DNS_OVER_TLS if (tls) { DnsServer server = { |