diff options
author | Thomas Markwalder <tmark@isc.org> | 2023-01-04 15:18:15 +0100 |
---|---|---|
committer | Thomas Markwalder <tmark@isc.org> | 2023-01-11 15:46:18 +0100 |
commit | a0f1a3665d8cb92e69c76a1cf294600fd8f562c9 (patch) | |
tree | 1d522cc5719c66b1f0470c4d63e7f38e169b6968 | |
parent | [#969] Addressed remaining review comments (diff) | |
download | kea-a0f1a3665d8cb92e69c76a1cf294600fd8f562c9.tar.xz kea-a0f1a3665d8cb92e69c76a1cf294600fd8f562c9.zip |
[#2634] Added TLS unit tests to lib tcp
src/lib/tcp/tests/tls_listener_unittests.cc - new file
src/lib/tcp/tcp_connection.cc
Added missing log argument
src/lib/tcp/tests/Makefile.am
Added tls_listner_unittests.cc
src/lib/tcp/tests/mt_tcp_listener_mgr_unittests.cc
Udpated client ctor calls
src/lib/tcp/tests/tcp_test_client.h
Added support for TLS to TcpTestClient
src/lib/tcp/tests/tcp_listener_unittests.cc
Modified to use generalized TcpTestClient
-rw-r--r-- | src/lib/tcp/tcp_connection.cc | 4 | ||||
-rw-r--r-- | src/lib/tcp/tests/Makefile.am | 17 | ||||
-rw-r--r-- | src/lib/tcp/tests/mt_tcp_listener_mgr_unittests.cc | 2 | ||||
-rw-r--r-- | src/lib/tcp/tests/tcp_listener_unittests.cc | 50 | ||||
-rw-r--r-- | src/lib/tcp/tests/tcp_test_client.h | 462 | ||||
-rw-r--r-- | src/lib/tcp/tests/tls_listener_unittests.cc | 497 |
6 files changed, 774 insertions, 258 deletions
diff --git a/src/lib/tcp/tcp_connection.cc b/src/lib/tcp/tcp_connection.cc index f910aff5da..e805afcfcb 100644 --- a/src/lib/tcp/tcp_connection.cc +++ b/src/lib/tcp/tcp_connection.cc @@ -350,8 +350,8 @@ TcpConnection::handshakeCallback(const boost::system::error_code& ec) { } else { LOG_DEBUG(tcp_logger, isc::log::DBGLVL_TRACE_DETAIL, TLS_REQUEST_RECEIVE_START) - .arg(getRemoteEndpointAddressAsText()); - + .arg(getRemoteEndpointAddressAsText()) + .arg(static_cast<unsigned>(idle_timeout_/1000)); doRead(); } } diff --git a/src/lib/tcp/tests/Makefile.am b/src/lib/tcp/tests/Makefile.am index 4380f99779..41ac0fcb5c 100644 --- a/src/lib/tcp/tests/Makefile.am +++ b/src/lib/tcp/tests/Makefile.am @@ -24,17 +24,12 @@ run_unittests_SOURCES += tcp_test_client.h tcp_test_listener.h run_unittests_SOURCES += tcp_listener_unittests.cc run_unittests_SOURCES += mt_tcp_listener_mgr_unittests.cc -# @todo These tests will be needed. These are here as a reminder. -#if HAVE_OPENSSL -#run_unittests_SOURCES += tls_unittest.cc -#run_unittests_SOURCES += tls_acceptor_unittest.cc -#run_unittests_SOURCES += tls_socket_unittest.cc -#endif -#if HAVE_BOTAN_BOOST -#run_unittests_SOURCES += tls_unittest.cc -#run_unittests_SOURCES += tls_acceptor_unittest.cc -#run_unittests_SOURCES += tls_socket_unittest.cc -#endif +if HAVE_OPENSSL +run_unittests_SOURCES += tls_listener_unittests.cc +endif +if HAVE_BOTAN_BOOST +run_unittests_SOURCES += tls_listener_unittests.cc +endif run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES) diff --git a/src/lib/tcp/tests/mt_tcp_listener_mgr_unittests.cc b/src/lib/tcp/tests/mt_tcp_listener_mgr_unittests.cc index b5d805f382..dd2112ab31 100644 --- a/src/lib/tcp/tests/mt_tcp_listener_mgr_unittests.cc +++ b/src/lib/tcp/tests/mt_tcp_listener_mgr_unittests.cc @@ -151,6 +151,7 @@ public: // Instantiate the client. TcpTestClientPtr client(new TcpTestClient(io_service_, std::bind(&MtTcpListenerMgrTest::clientDone, this), + TlsContextPtr(), SERVER_ADDRESS, SERVER_PORT)); // Add it to the list of clients. clients_.push_back(client); @@ -179,6 +180,7 @@ public: // Create a new client. TcpTestClientPtr client(new TcpTestClient(io_service_, std::bind(&MtTcpListenerMgrTest::clientDone, this), + TlsContextPtr(), SERVER_ADDRESS, SERVER_PORT)); // Construct the "thread" command post including the argument, diff --git a/src/lib/tcp/tests/tcp_listener_unittests.cc b/src/lib/tcp/tests/tcp_listener_unittests.cc index c6e3235358..f2b268f868 100644 --- a/src/lib/tcp/tests/tcp_listener_unittests.cc +++ b/src/lib/tcp/tests/tcp_listener_unittests.cc @@ -9,7 +9,6 @@ #include <asiolink/interval_timer.h> #include <asiolink/io_service.h> #include <tcp_test_listener.h> -#include <tcp_test_client.h> #include <gtest/gtest.h> @@ -82,15 +81,16 @@ public: } } - /// @brief Connect to the endpoint. + /// @brief Create a new client. /// /// This method creates TcpTestClient instance and retains it in /// the clients_ list. - TcpTestClientPtr connectClient() { + /// @param tls_context TLS context to assign to the client. + TcpTestClientPtr createClient(TlsContextPtr tls_context = TlsContextPtr()) { TcpTestClientPtr client(new TcpTestClient(io_service_, - std::bind(&TcpListenerTest::clientDone, this))); + std::bind(&TcpListenerTest::clientDone, this), + tls_context)); clients_.push_back(client); - client->connect(); return (client); } @@ -100,18 +100,24 @@ public: /// the clients_ list. /// /// @param request String containing the request to be sent. - void startRequest(const std::string& request) { - TcpTestClientPtr client(new TcpTestClient(io_service_, - std::bind(&TcpListenerTest::clientDone, this))); - clients_.push_back(client); - clients_.back()->startRequest(request); + /// @param tls_context TLS context to assign to the client. + void startRequest(const std::string& request, + TlsContextPtr tls_context = TlsContextPtr()) { + TcpTestClientPtr client = createClient(tls_context); + client->startRequest(request); } - void startRequests(const std::list<std::string>& requests) { - TcpTestClientPtr client(new TcpTestClient(io_service_, - std::bind(&TcpListenerTest::clientDone, this))); - clients_.push_back(client); - clients_.back()->startRequests(requests); + /// @brief Connect to the endpoint and send a list of requests. + /// + /// This method creates a TcpTestClient instance and initiates a + /// series of requests. + /// + /// @param request String containing the request to be sent. + /// @param tls_context TLS context to assign to the client. + void startRequests(const std::list<std::string>& requests, + TlsContextPtr tls_context = TlsContextPtr()) { + TcpTestClientPtr client = createClient(tls_context); + client->startRequests(requests); } /// @brief Callback function invoke upon test timeout. @@ -275,16 +281,16 @@ TEST_F(TcpListenerTest, idleTimeoutTest) { ASSERT_NO_THROW(listener.start()); ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); - ASSERT_NO_THROW(connectClient()); - ASSERT_EQ(1, clients_.size()); - TcpTestClientPtr client = *clients_.begin(); - ASSERT_TRUE(client); - - // Tell the client expecting reading to fail with an EOF. - ASSERT_NO_THROW(client->waitForEof()); + // Start a client with an empty request. Empty requests tell the client to read + // without sending anything and expect the read to fail when the listener idle + // times out the socket. + ASSERT_NO_THROW(startRequest("")); // Run until idle timer expires. ASSERT_NO_THROW(runIOService()); + + ASSERT_EQ(1, clients_.size()); + TcpTestClientPtr client = *clients_.begin(); EXPECT_FALSE(client->receiveDone()); EXPECT_TRUE(client->expectedEof()); diff --git a/src/lib/tcp/tests/tcp_test_client.h b/src/lib/tcp/tests/tcp_test_client.h index 338ae5a1ae..51f967ea98 100644 --- a/src/lib/tcp/tests/tcp_test_client.h +++ b/src/lib/tcp/tests/tcp_test_client.h @@ -1,4 +1,4 @@ -// Copyright (C) 2022 Internet Systems Consortium, Inc. ("ISC") +// Copyright (C) 2022-2023 Internet Systems Consortium, Inc. ("ISC") // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,9 @@ #define TCP_TEST_CLIENT_H #include <cc/data.h> +#include <asiolink/tcp_socket.h> +#include <asiolink/tls_socket.h> +#include <asiolink/testutils/test_tls.h> #include <tcp/tcp_connection.h> #include <tcp/tcp_stream_msg.h> #include <boost/asio/read.hpp> @@ -15,29 +18,88 @@ #include <boost/asio/ip/tcp.hpp> #include <gtest/gtest.h> -using namespace boost::asio::ip; -using namespace isc::asiolink; -using namespace isc::tcp; - -/// @brief Entity which can connect to the TCP server endpoint. +/// @brief Entity which can connect to the TCP server endpoint with or +/// or without TLS. class TcpTestClient : public boost::noncopyable { + +private: + /// @brief Type of the function implementing a callback invoked by the + /// @c SocketCallback functor. + typedef std::function<void(boost::system::error_code ec, size_t length)> + SocketCallbackFunction; + + /// @brief Functor associated with the socket object. + /// + /// This functor calls a callback function specified in the constructor. + class SocketCallback { + public: + + /// @brief Constructor. + /// + /// @param socket_callback Callback to be invoked by the functor upon + /// an event associated with the socket. + SocketCallback(SocketCallbackFunction socket_callback) + : callback_(socket_callback) { + } + + /// @brief Operator called when event associated with a socket occurs. + /// + /// This operator returns immediately when received error code is + /// @c boost::system::error_code is equal to + /// @c boost::asio::error::operation_aborted, i.e. the callback is not + /// invoked. + /// + /// @param ec Error code. + /// @param length Data length. + void operator()(boost::system::error_code ec, size_t length = 0) { + if (ec.value() == boost::asio::error::operation_aborted) { + return; + } + + callback_(ec, length); + } + + private: + /// @brief Supplied callback. + SocketCallbackFunction callback_; + }; + public: /// @brief Constructor. /// /// This constructor creates new socket instance. It doesn't connect. Call - /// connect() to connect to the server. + /// start() to connect to the server. /// /// @param io_service IO service to be stopped on error or completion. + /// @param done_callback Function cient should invoke when it has finished + /// all its requests or failed. + /// @param tls_context /// @param server_address string containing the IP address of the server. /// @param port port number of the server. - explicit TcpTestClient(IOService& io_service, + explicit TcpTestClient(isc::asiolink::IOService& io_service, std::function<void()> done_callback, + isc::asiolink::TlsContextPtr tls_context = + isc::asiolink::TlsContextPtr(), const std::string& server_address = "127.0.0.1", uint16_t port = 18123) - : io_service_(io_service.get_io_service()), socket_(io_service_), buf_(), - response_(), done_callback_(done_callback), server_address_(server_address), - server_port_(port), receive_done_(false), expected_eof_(true) { + : io_service_(io_service.get_io_service()), + tls_context_(tls_context), + tcp_socket_(), tls_socket_(), + done_callback_(done_callback), + server_address_(server_address), server_port_(port), + buf_(), response_(), + receive_done_(false), expected_eof_(false), handshake_failed_(false) { + if (!tls_context_) { + tcp_socket_.reset(new isc::asiolink::TCPSocket<SocketCallback>(io_service)); + } else { + tls_socket_.reset(new isc::asiolink::TLSSocket<SocketCallback>(io_service, + tls_context)); + } + } + + bool useTls() { + return (tls_context_ != 0); } /// @brief Destructor. @@ -47,64 +109,62 @@ public: close(); } - /// @brief Connect to the listener. + /// @brief Connect to the listener and initiate request processing. /// - /// @param request request string to send. - void connect() { - tcp::endpoint endpoint(address::from_string(server_address_), server_port_); - socket_.async_connect(endpoint, - [this](const boost::system::error_code& ec) { - receive_done_ = false; - expected_eof_ = false; - if (ec) { - // One would expect that async_connect wouldn't return - // EINPROGRESS error code, but simply wait for the connection - // to get established before the handler is invoked. It turns out, - // however, that on some OSes the connect handler may receive this - // error code which doesn't necessarily indicate a problem. - // Making an attempt to write and read from this socket will - // typically succeed. So, we ignore this error. - if (ec.value() != boost::asio::error::in_progress) { - ADD_FAILURE() << "error occurred while connecting: " - << ec.message(); - done_callback_(); - return; + /// Upon successful connection, carry out the TLS handshake. If the handshake + /// completes successful start sending requests. + void start() { + isc::asiolink::TCPEndpoint endpoint(boost::asio::ip::address::from_string(server_address_), server_port_); + SocketCallback socket_cb( + [this](boost::system::error_code ec, size_t /*length */) { + receive_done_ = false; + expected_eof_ = false; + handshake_failed_ = false; + if (ec) { + // One would expect that open wouldn't return + // EINPROGRESS error code, but simply wait for the connection + // to get established before the handler is invoked. It turns out, + // however, that on some OSes the connect handler may receive this + // error code which doesn't necessarily indicate a problem. + // Making an attempt to write and read from this socket will + // typically succeed. So, we ignore this error. + if (ec.value() != boost::asio::error::in_progress) { + ADD_FAILURE() << "error occurred while connecting: " + << ec.message(); + done_callback_(); + } } + + if (useTls()) { + SocketCallback socket_cb( + [this](boost::system::error_code ec, size_t /*length */) { + if (ec) { + handshake_failed_ = true; + done_callback_(); + } else { + sendNextRequest(); + } + }); + + tls_socket_->handshake(socket_cb); + } else { + sendNextRequest(); } }); + + if (useTls()) { + tls_socket_->open(&endpoint, socket_cb); + } else { + tcp_socket_->open(&endpoint, socket_cb); + } } /// @brief Send request specified in textual format. /// /// @param request request in the textual format. void startRequest(const std::string& request) { - tcp::endpoint endpoint(address::from_string(server_address_), server_port_); - socket_.async_connect(endpoint, - [this, request](const boost::system::error_code& ec) { - receive_done_ = false; - expected_eof_ = false; - if (ec) { - // One would expect that async_connect wouldn't return - // EINPROGRESS error code, but simply wait for the connection - // to get established before the handler is invoked. It turns out, - // however, that on some OSes the connect handler may receive this - // error code which doesn't necessarily indicate a problem. - // Making an attempt to write and read from this socket will - // typically succeed. So, we ignore this error. - if (ec.value() != boost::asio::error::in_progress) { - ADD_FAILURE() << "error occurred while connecting: " - << ec.message(); - done_callback_(); - return; - } - } - - if (request.empty()) { - waitForEof(); - } else { - sendRequest(request); - } - }); + requests_to_send_.push_back(request); + start(); } /// @brief Send request specified in textual format. @@ -112,30 +172,7 @@ public: /// @param request request in the textual format. void startRequests(const std::list<std::string>& requests) { requests_to_send_ = requests; - - tcp::endpoint endpoint(address::from_string(server_address_), server_port_); - socket_.async_connect(endpoint, - [this](const boost::system::error_code& ec) { - receive_done_ = false; - expected_eof_ = false; - if (ec) { - // One would expect that async_connect wouldn't return - // EINPROGRESS error code, but simply wait for the connection - // to get established before the handler is invoked. It turns out, - // however, that on some OSes the connect handler may receive this - // error code which doesn't necessarily indicate a problem. - // Making an attempt to write and read from this socket will - // typically succeed. So, we ignore this error. - if (ec.value() != boost::asio::error::in_progress) { - ADD_FAILURE() << "error occurred while connecting: " - << ec.message(); - done_callback_(); - return; - } - } - - sendNextRequest(); - }); + start(); } /// @brief Sends the next request from the list of requests to send. @@ -144,7 +181,11 @@ public: if (!requests_to_send_.empty()) { std::string request = requests_to_send_.front(); requests_to_send_.pop_front(); - sendRequest(request); + if (request.empty()) { + waitForEof(); + } else { + sendRequest(request); + } } } @@ -156,7 +197,7 @@ public: void sendRequest(const std::string& request, const size_t send_length = 0) { // Prepend the length of the request. uint16_t size = static_cast<uint16_t>(request.size()); - WireData wire_request; + isc::tcp::WireData wire_request; if (!request.empty()) { wire_request.push_back(static_cast<uint8_t>((size & 0xff00U) >> 8)); wire_request.push_back(static_cast<uint8_t>(size & 0x00ffU)); @@ -168,7 +209,7 @@ public: /// @brief Wait for a server to close the connection. void waitForEof() { - stream_response_.reset(new TcpStreamRequest()); + stream_response_.reset(new isc::tcp::TcpStreamRequest()); receivePartialResponse(true); } @@ -177,7 +218,7 @@ public: /// @param request part of the request to be sent. /// @param send_length number of bytes to send. If not zero, can be used /// to truncate the amount of data sent. - void sendPartialRequest(WireData& wire_request, size_t send_length = 0) { + void sendPartialRequest(isc::tcp::WireData& wire_request, size_t send_length = 0) { if (!send_length) { send_length = wire_request.size(); } else { @@ -185,87 +226,105 @@ public: << "broken test, send_length exceeds wire size"; } - socket_.async_send(boost::asio::buffer(wire_request.data(), send_length), - [this, wire_request](const boost::system::error_code& ec, - std::size_t bytes_transferred) mutable { - if (ec) { - if (ec.value() == boost::asio::error::operation_aborted) { - return; + SocketCallback socket_cb( + [this, wire_request](boost::system::error_code ec, size_t bytes_transferred) mutable { + if (ec) { + if (ec.value() == boost::asio::error::operation_aborted) { + return; + + } else if ((ec.value() == boost::asio::error::try_again) || + (ec.value() == boost::asio::error::would_block)) { + // If we should try again make sure there is no garbage in the + // bytes_transferred. + bytes_transferred = 0; + } else { + ADD_FAILURE() << "error occurred while connecting: " + << ec.message(); + done_callback_(); + return; + } + } - } else if ((ec.value() == boost::asio::error::try_again) || - (ec.value() == boost::asio::error::would_block)) { - // If we should try again make sure there is no garbage in the - // bytes_transferred. - bytes_transferred = 0; + // Remove the part of the request which has been sent. + if (bytes_transferred > 0 && (wire_request.size() <= bytes_transferred)) { + wire_request.erase(wire_request.begin(), + (wire_request.begin() + bytes_transferred)); + } + // Continue sending request data if there are still some data to be + // sent. + if (!wire_request.empty()) { + sendPartialRequest(wire_request); } else { - ADD_FAILURE() << "error occurred while connecting: " - << ec.message(); - done_callback_(); - return; + // Request has been sent. Start receiving response. + receivePartialResponse(); } - } + }); - // Remove the part of the request which has been sent. - if (bytes_transferred > 0 && (wire_request.size() <= bytes_transferred)) { - wire_request.erase(wire_request.begin(), wire_request.begin() + bytes_transferred); - } - - // Continue sending request data if there are still some data to be - // sent. - if (!wire_request.empty()) { - sendPartialRequest(wire_request); - } else { - // Request has been sent. Start receiving response. - receivePartialResponse(); - } - }); + if (useTls()) { + tls_socket_->asyncSend(static_cast<const void *>(wire_request.data()), + send_length, socket_cb); + } else { + tcp_socket_->asyncSend(static_cast<const void *>(wire_request.data()), + send_length, socket_cb); + } } /// @brief Receive response from the server. void receivePartialResponse(bool expect_eof = false) { - socket_.async_read_some(boost::asio::buffer(buf_.data(), buf_.size()), - [this, expect_eof](const boost::system::error_code& ec, - std::size_t bytes_transferred) { - if (!stream_response_) { - stream_response_.reset(new TcpStreamRequest()); - } + SocketCallback socket_cb( + [this, expect_eof](const boost::system::error_code& ec, + std::size_t bytes_transferred) { + if (!stream_response_) { + stream_response_.reset(new isc::tcp::TcpStreamRequest()); + } - if (ec) { - // IO service stopped so simply return. - if (ec.value() == boost::asio::error::operation_aborted) { - return; - } else if ((ec.value() == boost::asio::error::try_again) || - (ec.value() == boost::asio::error::would_block)) { - // If we should try again, make sure that there is no garbage - // in the bytes_transferred. - bytes_transferred = 0; - } else if (ec.value() == boost::asio::error::eof && expect_eof) { - expected_eof_ = true; - done_callback_(); - return; - } else { - // Error occurred, bail... - ADD_FAILURE() << "client: " << this << " error occurred while receiving TCP" - " response from the server: " << ec.message(); - done_callback_(); - return; + if (ec) { + // IO service stopped so simply return. + if (ec.value() == boost::asio::error::operation_aborted) { + return; + } else if ((ec.value() == boost::asio::error::try_again) || + (ec.value() == boost::asio::error::would_block)) { + // If we should try again, make sure that there is no garbage + // in the bytes_transferred. + bytes_transferred = 0; + } else if (expect_eof && ((ec.value() == boost::asio::error::eof) || + (ec.value() == boost::asio::ssl::error::stream_truncated))) { + expected_eof_ = true; + done_callback_(); + return; + } else { + // Error occurred, bail... + ADD_FAILURE() << "client: " << this + << " error occurred while receiving TCP" + << " response from the server: " << ec.message(); + done_callback_(); + return; + } } - } - // Post received data to the current response. - if (bytes_transferred > 0) { - stream_response_->postBuffer(buf_.data(), bytes_transferred); - } + // Post received data to the current response. + if (bytes_transferred > 0) { + stream_response_->postBuffer(buf_.data(), bytes_transferred); + } - if (stream_response_->needData()) { - // Response is incomplete, keep reading. - receivePartialResponse(); - } else { - // Response is complete, process it. - responseReceived(); - } - }); + if (stream_response_->needData()) { + // Response is incomplete, keep reading. + receivePartialResponse(); + } else { + // Response is complete, process it. + responseReceived(); + } + }); + + isc::asiolink::TCPEndpoint from; + if (useTls()) { + tls_socket_->asyncReceive(static_cast<void*>(buf_.data()), buf_.size(), 0, + &from, socket_cb); + } else { + tcp_socket_->asyncReceive(static_cast<void*>(buf_.data()), buf_.size(), 0, + &from, socket_cb); + } } /// @brief Process a completed response received from the server. @@ -287,69 +346,13 @@ public: sendNextRequest(); } - /// @brief Checks if the TCP connection is still open. - /// - /// Tests the TCP connection by trying to read from the socket. - /// Unfortunately expected failure depends on a race between the read - /// and the other side close so to check if the connection is closed - /// please use @c isConnectionClosed instead. - /// - /// @return true if the TCP connection is open. - bool isConnectionAlive() { - // Remember the current non blocking setting. - const bool non_blocking_orig = socket_.non_blocking(); - // Set the socket to non blocking mode. We're going to test if the socket - // returns would_block status on the attempt to read from it. - socket_.non_blocking(true); - - // We need to provide a buffer for a call to read. - char data[2]; - boost::system::error_code ec; - boost::asio::read(socket_, boost::asio::buffer(data, sizeof(data)), ec); - - // Revert the original non_blocking flag on the socket. - socket_.non_blocking(non_blocking_orig); - - // If the connection is alive we'd typically get would_block status code. - // If there are any data that haven't been read we may also get success - // status. We're guessing that try_again may also be returned by some - // implementations in some situations. Any other error code indicates a - // problem with the connection so we assume that the connection has been - // closed. - return (!ec || (ec.value() == boost::asio::error::try_again) || - (ec.value() == boost::asio::error::would_block)); - } - - /// @brief Checks if the TCP connection is already closed. - /// - /// Tests the TCP connection by trying to read from the socket. - /// The read can block so this must be used to check if a connection - /// is alive so to check if the connection is alive please always - /// use @c isConnectionAlive. - /// - /// @return true if the TCP connection is closed. - bool isConnectionClosed() { - // Remember the current non blocking setting. - const bool non_blocking_orig = socket_.non_blocking(); - // Set the socket to blocking mode. We're going to test if the socket - // returns eof status on the attempt to read from it. - socket_.non_blocking(false); - - // We need to provide a buffer for a call to read. - char data[2]; - boost::system::error_code ec; - boost::asio::read(socket_, boost::asio::buffer(data, sizeof(data)), ec); - - // Revert the original non_blocking flag on the socket. - socket_.non_blocking(non_blocking_orig); - - // If the connection is closed we'd typically get eof status code. - return (ec.value() == boost::asio::error::eof); - } - /// @brief Close connection. void close() { - socket_.close(); + if (useTls()) { + tls_socket_->close(); + } else { + tcp_socket_->close(); + } } /// @brief Returns true if the receive completed without error. @@ -374,19 +377,23 @@ public: return (responses_received_); } + bool handshakeFailed() { + return(handshake_failed_); + } + private: /// @brief Holds reference to the IO service. boost::asio::io_service& io_service_; - /// @brief A socket used for the connection. - boost::asio::ip::tcp::socket socket_; + /// @brief TLS context. + isc::asiolink::TlsContextPtr tls_context_; - /// @brief Buffer into which response is written. - std::array<char, 8192> buf_; + /// @brief TCP socket used by this connection. + std::unique_ptr<isc::asiolink::TCPSocket<SocketCallback> > tcp_socket_; - /// @brief Response in the textual format. - std::string response_; + /// @brief TLS socket used by this connection. + std::unique_ptr<isc::asiolink::TLSSocket<SocketCallback> > tls_socket_; /// @brief Callback to invoke when the client has finished its work or /// failed. @@ -398,6 +405,12 @@ private: /// @brief IP port of the server. uint16_t server_port_; + /// @brief Buffer into which response is written. + std::array<char, 8192> buf_; + + /// @brief Response in the textual format. + std::string response_; + /// @brief Set to true when the receive has completed successfully. bool receive_done_; @@ -406,8 +419,11 @@ private: /// expected it to do. bool expected_eof_; + /// @brief Set to true if the TLS handshake failed. + bool handshake_failed_; + /// @brief Pointer to the server response currently being received. - TcpStreamRequestPtr stream_response_; + isc::tcp::TcpStreamRequestPtr stream_response_; /// @brief List of string requests to send. std::list<std::string> requests_to_send_; diff --git a/src/lib/tcp/tests/tls_listener_unittests.cc b/src/lib/tcp/tests/tls_listener_unittests.cc new file mode 100644 index 0000000000..bdfd1e74c7 --- /dev/null +++ b/src/lib/tcp/tests/tls_listener_unittests.cc @@ -0,0 +1,497 @@ +// Copyright (C) 2023 Internet Systems Consortium, Inc. ("ISC") +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include <config.h> +#include <asiolink/asio_wrapper.h> +#include <asiolink/interval_timer.h> +#include <asiolink/testutils/test_tls.h> +#include <asiolink/io_service.h> +#include <tcp_test_listener.h> +#include <tcp_test_client.h> + +#include <gtest/gtest.h> + +#include <sstream> + +using namespace boost::asio::ip; +using namespace isc::asiolink; +using namespace isc::asiolink::test; +using namespace isc::tcp; + +namespace ph = std::placeholders; + +namespace { + +/// @brief IP address to which service is bound. +const std::string SERVER_ADDRESS = "127.0.0.1"; + +/// @brief IPv6 address to whch service is bound. +const std::string IPV6_SERVER_ADDRESS = "::1"; + +/// @brief Port number to which service is bound. +const unsigned short SERVER_PORT = 18123; + +/// @brief Request Timeout used in most of the tests (ms). +const long REQUEST_TIMEOUT = 10000; + +/// @brief Connection idle timeout used in tests where idle connections +/// are tested (ms). +const long SHORT_REQUEST_TIMEOUT = 200; + +/// @brief Connection idle timeout used in most of the tests (ms). +const long IDLE_TIMEOUT = 10000; + +/// @brief Connection idle timeout used in tests where idle connections +/// are tested (ms). +const long SHORT_IDLE_TIMEOUT = 200; + +/// @brief Test timeout (ms). +const long TEST_TIMEOUT = 10000; + +/// @brief Test fixture class for @ref TcpListener that uses TLS. +class TlsListenerTest : public ::testing::Test { +public: + + /// @brief Constructor. + /// + /// Starts test timer which detects timeouts. + TlsListenerTest() + : io_service_(), test_timer_(io_service_), + run_io_service_timer_(io_service_), + clients_(), clients_done_(0) { + test_timer_.setup(std::bind(&TlsListenerTest::timeoutHandler, this, true), + TEST_TIMEOUT, + IntervalTimer::ONE_SHOT); + } + + /// @brief Destructor. + /// + /// Removes active clients. + virtual ~TlsListenerTest() { + for (auto client : clients_) { + client->close(); + } + } + + /// @brief Fetch the server TLS context. + TlsContextPtr serverContext() { + TlsContextPtr tls_context; + configServer(tls_context); + return(tls_context); + } + + /// @brief Fetch a client TLS context that works with the server context. + TlsContextPtr clientContext() { + TlsContextPtr tls_context; + configClient(tls_context); + return(tls_context); + } + + /// @brief Create a new client. + /// + /// This method creates TcpTestClient instance and retains it in + /// the clients_ list. + /// @param tls_context TLS context to assign to the client. + TcpTestClientPtr createClient(TlsContextPtr tls_context) { + TcpTestClientPtr client(new TcpTestClient(io_service_, + std::bind(&TlsListenerTest::clientDone, this), + tls_context)); + clients_.push_back(client); + return (client); + } + + /// @brief Connect to the endpoint and send a request. + /// + /// This method creates TcpTestClient instance and retains it in + /// the clients_ list. + /// + /// @param request String containing the request to be sent. + /// @param tls_context TLS context to assign to the client. + void startRequest(const std::string& request, TlsContextPtr tls_context) { + ASSERT_TRUE(tls_context); + TcpTestClientPtr client = createClient(tls_context); + client->startRequest(request); + } + + /// @brief Connect to the endpoint and send a list of requests. + /// + /// This method creates a TcpTestClient instance and initiates a + /// series of requests. + /// + /// @param request String containing the request to be sent. + /// @param tls_context TLS context to assign to the client. + void startRequests(const std::list<std::string>& requests, TlsContextPtr tls_context) { + ASSERT_TRUE(tls_context); + TcpTestClientPtr client = createClient(tls_context); + client->startRequests(requests); + } + + /// @brief Callback function invoke upon test timeout. + /// + /// It stops the IO service and reports test timeout. + /// + /// @param fail_on_timeout Specifies if test failure should be reported. + void timeoutHandler(const bool fail_on_timeout) { + if (fail_on_timeout) { + ADD_FAILURE() << "Timeout occurred while running the test!"; + } + io_service_.stop(); + } + + /// @brief Callback function each client invokes when done. + /// + /// It stops the IO service when all clients are done. + /// + /// @param fail_on_timeout Specifies if test failure should be reported. + void clientDone() { + ++clients_done_; + if (clients_done_ >= clients_.size()) { + // They're all done or dead. Stop the service. + io_service_.stop(); + } + } + + /// @brief Runs IO service with optional timeout. + /// + /// @param timeout Optional value specifying for how long the io service + /// should be ran. + void runIOService(long timeout = 0) { + io_service_.get_io_service().reset(); + + if (timeout > 0) { + run_io_service_timer_.setup(std::bind(&TlsListenerTest::timeoutHandler, + this, false), + timeout, + IntervalTimer::ONE_SHOT); + } + io_service_.run(); + io_service_.get_io_service().reset(); + io_service_.poll(); + } + + /// @brief Filter that denies every other connection. + /// + /// @param remote_endpoint_address ip address of the remote end of + /// a connection. + bool connectionFilter(const boost::asio::ip::tcp::endpoint& remote_endpoint) { + static size_t count = 0; + // If the address doesn't match, something hinky is going on, so + // we'll reject them all. If it does match, then cool, it works + // as expected. + if ((count++ % 2) || + (remote_endpoint.address().to_string() != SERVER_ADDRESS)) { + // Reject every other connection; + return (false); + } + + return (true); + } + + /// @brief IO service used in the tests. + IOService io_service_; + + /// @brief Asynchronous timer service to detect timeouts. + IntervalTimer test_timer_; + + /// @brief Asynchronous timer for running IO service for a specified amount + /// of time. + IntervalTimer run_io_service_timer_; + + /// @brief List of client connections. + std::list<TcpTestClientPtr> clients_; + + /// @brief Counts the number of clients that have reported as done. + size_t clients_done_; +}; + +// This test verifies that a connection can be established with a client +// with valid TLS credentials. +TEST_F(TlsListenerTest, listen) { + const std::string request = "I am done"; + + TcpTestListener listener(io_service_, + IOAddress(SERVER_ADDRESS), + SERVER_PORT, + serverContext(), + TcpListener::IdleTimeout(IDLE_TIMEOUT)); + + ASSERT_NO_THROW(listener.start()); + ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); + ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); + ASSERT_NO_THROW(startRequest(request, clientContext())); + ASSERT_NO_THROW(runIOService()); + ASSERT_EQ(1, clients_.size()); + TcpTestClientPtr client = *clients_.begin(); + ASSERT_TRUE(client); + EXPECT_TRUE(client->receiveDone()); + EXPECT_FALSE(client->expectedEof()); + + // Verify the audit trail for the connection. + // Sanity check to make sure we don't have more entries than we expect. + ASSERT_EQ(listener.audit_trail_->entries_.size(), 2); + + // Create the list of expected entries. + std::list<AuditEntry> expected_entries { + { 1, AuditEntry::INBOUND, "I am done" }, + { 1, AuditEntry::OUTBOUND, "good bye" } + }; + + // Verify the audit trail. + ASSERT_EQ(expected_entries, listener.audit_trail_->getConnectionTrail(1)); + + listener.stop(); + io_service_.poll(); +} + +// This test verifies that a connection is denied to a client +// with invalid TLS credentials. +TEST_F(TlsListenerTest, badClient) { + TcpTestListener listener(io_service_, + IOAddress(SERVER_ADDRESS), + SERVER_PORT, + serverContext(), + TcpListener::IdleTimeout(IDLE_TIMEOUT)); + + ASSERT_NO_THROW(listener.start()); + ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); + ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); + + TlsContextPtr bad_client_ctx; + configSelf(bad_client_ctx); + ASSERT_NO_THROW(startRequest("", bad_client_ctx)); + + ASSERT_NO_THROW(runIOService()); + + ASSERT_EQ(1, clients_.size()); + TcpTestClientPtr client = *clients_.begin(); + ASSERT_TRUE(client); + EXPECT_FALSE(client->receiveDone()); + // Handshake fails on the listener end which manifests itself in the client + // as an EOF rather than a failed handshake. + EXPECT_TRUE(client->expectedEof()); + EXPECT_FALSE(client->handshakeFailed()); +} + +// This test verifies that a TLS connection can receive a complete +// message that spans multiple socket reads. +TEST_F(TlsListenerTest, splitReads) { + const std::string request = "I am done"; + + // Read at most one byte at a time. + size_t read_max = 1; + TcpTestListener listener(io_service_, + IOAddress(SERVER_ADDRESS), + SERVER_PORT, + serverContext(), + TcpListener::IdleTimeout(IDLE_TIMEOUT), + 0, + read_max); + + ASSERT_NO_THROW(listener.start()); + ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); + ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); + ASSERT_NO_THROW(startRequest(request, clientContext())); + ASSERT_NO_THROW(runIOService()); + + // Fetch the client. + ASSERT_EQ(1, clients_.size()); + TcpTestClientPtr client = *clients_.begin(); + ASSERT_TRUE(client); + EXPECT_TRUE(client->receiveDone()); + EXPECT_FALSE(client->expectedEof()); + + listener.stop(); + io_service_.poll(); +} + +// This test verifies that a TLS connection can be established and used to +// transmit a streamed request and receive a streamed response. +TEST_F(TlsListenerTest, idleTimeoutTest) { + TcpTestListener listener(io_service_, + IOAddress(SERVER_ADDRESS), + SERVER_PORT, + serverContext(), + TcpListener::IdleTimeout(SHORT_IDLE_TIMEOUT)); + + ASSERT_NO_THROW(listener.start()); + ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); + ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); + // Start a client with an empty request. Empty requests tell the client to read + // without sending anything and expect the read to fail when the listener idle + // times out the socket. + ASSERT_NO_THROW(startRequest("", clientContext())); + + // Run until idle timer expires. + ASSERT_NO_THROW(runIOService()); + + ASSERT_EQ(1, clients_.size()); + TcpTestClientPtr client = *clients_.begin(); + EXPECT_FALSE(client->receiveDone()); + EXPECT_TRUE(client->expectedEof()); + + listener.stop(); + io_service_.poll(); +} + +// This test verifies that TLS connections with mulitple clients. +TEST_F(TlsListenerTest, multipleClientsListen) { + const std::string request = "I am done"; + + TcpTestListener listener(io_service_, + IOAddress(SERVER_ADDRESS), + SERVER_PORT, + serverContext(), + TcpListener::IdleTimeout(IDLE_TIMEOUT)); + + ASSERT_NO_THROW(listener.start()); + ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); + ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); + size_t num_clients = 5; + for (auto i = 0; i < num_clients; ++i) { + ASSERT_NO_THROW(startRequest(request, clientContext())); + } + + ASSERT_NO_THROW(runIOService()); + ASSERT_EQ(num_clients, clients_.size()); + + size_t connection_id = 1; + for (auto client : clients_) { + EXPECT_TRUE(client->receiveDone()); + EXPECT_FALSE(client->expectedEof()); + // Create the list of expected entries. + std::list<AuditEntry> expected_entries { + { connection_id, AuditEntry::INBOUND, "I am done" }, + { connection_id, AuditEntry::OUTBOUND, "good bye" } + }; + + // Fetch the entries for this connection. + auto entries = listener.audit_trail_->getConnectionTrail(connection_id); + ASSERT_EQ(expected_entries, entries); + ++connection_id; + } + + listener.stop(); + io_service_.poll(); +} + +// Verify that the listener handles multiple requests for multiple +// clients. +TEST_F(TlsListenerTest, multipleRequetsPerClients) { + std::list<std::string>requests{ "one", "two", "three", "I am done"}; + + TcpTestListener listener(io_service_, + IOAddress(SERVER_ADDRESS), + SERVER_PORT, + serverContext(), + TcpListener::IdleTimeout(IDLE_TIMEOUT)); + + ASSERT_NO_THROW(listener.start()); + ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); + ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); + size_t num_clients = 5; + for (auto i = 0; i < num_clients; ++i) { + ASSERT_NO_THROW(startRequests(requests, clientContext())); + } + + ASSERT_NO_THROW(runIOService()); + ASSERT_EQ(num_clients, clients_.size()); + + std::list<std::string>expected_responses{ "echo one", "echo two", + "echo three", "good bye"}; + size_t connection_id = 1; + for (auto client : clients_) { + EXPECT_TRUE(client->receiveDone()); + EXPECT_FALSE(client->expectedEof()); + EXPECT_EQ(expected_responses, client->getResponses()); + + // Verify the connection's audit trail. + // Create the list of expected entries. + std::list<AuditEntry> expected_entries { + { connection_id, AuditEntry::INBOUND, "one" }, + { connection_id, AuditEntry::OUTBOUND, "echo one" }, + { connection_id, AuditEntry::INBOUND, "two" }, + { connection_id, AuditEntry::OUTBOUND, "echo two" }, + { connection_id, AuditEntry::INBOUND, "three" }, + { connection_id, AuditEntry::OUTBOUND, "echo three" }, + { connection_id, AuditEntry::INBOUND, "I am done" }, + { connection_id, AuditEntry::OUTBOUND, "good bye" } + }; + + // Fetch the entries for this connection. + auto entries = listener.audit_trail_->getConnectionTrail(connection_id); + ASSERT_EQ(expected_entries, entries); + ++connection_id; + } + + listener.stop(); + io_service_.poll(); +} + +// Verify that connection filtering can eliminate specific connections. +TEST_F(TlsListenerTest, filterClientsTest) { + TcpTestListener listener(io_service_, + IOAddress(SERVER_ADDRESS), + SERVER_PORT, + serverContext(), + TcpListener::IdleTimeout(IDLE_TIMEOUT), + std::bind(&TlsListenerTest::connectionFilter, this, ph::_1)); + + ASSERT_NO_THROW(listener.start()); + ASSERT_EQ(SERVER_ADDRESS, listener.getLocalAddress().toText()); + ASSERT_EQ(SERVER_PORT, listener.getLocalPort()); + size_t num_clients = 5; + for (auto i = 0; i < num_clients; ++i) { + // Every other client sends nothing (i.e. waits for EOF) as + // we expect the filter to reject them. + if (i % 2 == 0) { + ASSERT_NO_THROW(startRequest("I am done", clientContext())); + } else { + ASSERT_NO_THROW(startRequest("", clientContext())); + } + } + + ASSERT_NO_THROW(runIOService()); + ASSERT_EQ(num_clients, clients_.size()); + + size_t i = 0; + for (auto client : clients_) { + if (i % 2 == 0) { + // These clients should have been accepted and received responses. + EXPECT_TRUE(client->receiveDone()); + EXPECT_FALSE(client->expectedEof()); + EXPECT_FALSE(client->handshakeFailed()); + + // Now verify the AuditTrail. + // Create the list of expected entries. + std::list<AuditEntry> expected_entries { + { i+1, AuditEntry::INBOUND, "I am done" }, + { i+1, AuditEntry::OUTBOUND, "good bye" } + }; + + auto entries = listener.audit_trail_->getConnectionTrail(i+1); + ASSERT_EQ(expected_entries, entries); + + } else { + // Connection filteriing closes the connection before the client + // initiates the handshake, causing the subsequent handshake attempt + // to fail. + EXPECT_FALSE(client->receiveDone()); + EXPECT_FALSE(client->expectedEof()); + EXPECT_TRUE(client->handshakeFailed()); + + // Verify connection recorded no audit entries. + auto entries = listener.audit_trail_->getConnectionTrail(i+1); + ASSERT_EQ(entries.size(), 0); + } + + ++i; + } + + listener.stop(); + io_service_.poll(); +} + +} |