diff options
author | Jelte Jansen <jelte@isc.org> | 2010-06-09 11:52:34 +0200 |
---|---|---|
committer | Jelte Jansen <jelte@isc.org> | 2010-06-09 11:52:34 +0200 |
commit | 93227d328b0f86ad9c86c95f506587f758a333d9 (patch) | |
tree | 3d9e5e3879beca46ef1683253eac6b7a78cb9738 /src | |
parent | copyright statements had wrong year (diff) | |
parent | Generate a unique session ID by using socket.gethostname() instead of socket.... (diff) | |
download | kea-93227d328b0f86ad9c86c95f506587f758a333d9.tar.xz kea-93227d328b0f86ad9c86c95f506587f758a333d9.zip |
merge to sync with trunk and make later merge back easier
updated additions in tests for wrapper api
also independently came up with the fix attached in ticket #224
git-svn-id: svn://bind10.isc.org/svn/bind10/experiments/python-binding@2097 e5f2f494-b856-4b98-b285-d166d9295462
Diffstat (limited to 'src')
89 files changed, 2772 insertions, 1602 deletions
diff --git a/src/bin/auth/Makefile.am b/src/bin/auth/Makefile.am index 0f57faaf34..6414dcac29 100644 --- a/src/bin/auth/Makefile.am +++ b/src/bin/auth/Makefile.am @@ -2,13 +2,13 @@ SUBDIRS = . tests AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib AM_CPPFLAGS += -I$(top_srcdir)/src/lib/dns -I$(top_builddir)/src/lib/dns -if GCC_WERROR_OK -AM_CPPFLAGS += -Werror -endif +AM_CPPFLAGS += -I$(top_builddir)/src/lib/cc + +AM_CXXFLAGS = $(B10_CXXFLAGS) pkglibexecdir = $(libexecdir)/@PACKAGE@ -CLEANFILES = *.gcno *.gcda auth.spec +CLEANFILES = *.gcno *.gcda auth.spec spec_config.h man_MANS = b10-auth.8 EXTRA_DIST = $(man_MANS) b10-auth.xml @@ -23,6 +23,23 @@ endif auth.spec: auth.spec.pre $(SED) -e "s|@@LOCALSTATEDIR@@|$(localstatedir)|" auth.spec.pre >$@ +spec_config.h: spec_config.h.pre + $(SED) -e "s|@@LOCALSTATEDIR@@|$(localstatedir)|" spec_config.h.pre >$@ + +# This is a wrapper library solely used for b10-auth. The ASIO header files +# have some code fragments that would hit gcc's unused-parameter warning, +# which would make the build fail with -Werror (our default setting). +# We don't want to lower the warning level for our own code just for ASIO, +# so as a workaround we extract the ASIO related code into a separate library, +# only for which we accept the unused-parameter warning. +lib_LIBRARIES = libasio_link.a +libasio_link_a_SOURCES = asio_link.cc asio_link.h +# Note: the ordering matters: -Wno-... must follow -Wextra (defined in +# B10_CXXFLAGS) +libasio_link_a_CXXFLAGS = $(AM_CXXFLAGS) -Wno-unused-parameter +libasio_link_a_CPPFLAGS = $(AM_CPPFLAGS) + +BUILT_SOURCES = spec_config.h pkglibexec_PROGRAMS = b10-auth b10_auth_SOURCES = auth_srv.cc auth_srv.h b10_auth_SOURCES += common.h @@ -32,12 +49,9 @@ b10_auth_LDADD += $(top_builddir)/src/lib/dns/.libs/libdns.a b10_auth_LDADD += $(top_builddir)/src/lib/config/.libs/libcfgclient.a b10_auth_LDADD += $(top_builddir)/src/lib/cc/libcc.a b10_auth_LDADD += $(top_builddir)/src/lib/exceptions/.libs/libexceptions.a +b10_auth_LDADD += $(top_builddir)/src/bin/auth/libasio_link.a b10_auth_LDADD += $(SQLITE_LIBS) -if HAVE_BOOST_SYSTEM b10_auth_LDADD += $(top_builddir)/src/lib/xfr/.libs/libxfr.a -endif -b10_auth_LDFLAGS = $(AM_LDFLAGS) $(BOOST_LDFLAGS) -b10_auth_LDADD += $(BOOST_SYSTEM_LIB) # TODO: config.h.in is wrong because doesn't honor pkgdatadir # and can't use @datadir@ because doesn't expand default ${prefix} diff --git a/src/bin/auth/asio_link.cc b/src/bin/auth/asio_link.cc new file mode 100644 index 0000000000..332c92d519 --- /dev/null +++ b/src/bin/auth/asio_link.cc @@ -0,0 +1,413 @@ +// Copyright (C) 2010 Internet Systems Consortium, Inc. ("ISC") +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +// $Id$ + +#include <config.h> + +#include <asio.hpp> +#include <boost/bind.hpp> + +#include <dns/buffer.h> +#include <dns/message.h> +#include <dns/messagerenderer.h> + +#include <xfr/xfrout_client.h> + +#include <asio_link.h> + +#include "spec_config.h" // for XFROUT. should not be here. +#include "auth_srv.h" + +using namespace asio; +using ip::udp; +using ip::tcp; + +using namespace std; +using namespace isc::dns; +using namespace isc::xfr; + +namespace { +// As a short term workaround, we have XFROUT specific code. We should soon +// refactor the code with some abstraction so that we can separate this level +// details from the (AS)IO module. + +// This was contained in an ifdef USE_XFROUT, but we should really check +// live if we do xfrout +//TODO. The sample way for checking axfr query, the code should be merged to auth server class +bool +check_axfr_query(char* const msg_data, const uint16_t msg_len) { + if (msg_len < 15) { + return false; + } + + const uint16_t query_type = *(uint16_t *)(msg_data + (msg_len - 4)); + if ( query_type == 0xFC00) { + return true; + } + + return false; +} + +//TODO. Send the xfr query to xfrout module, the code should be merged to auth server class +//BIGGERTODO: stop using hardcoded install-path locations! +void +dispatch_axfr_query(const int tcp_sock, char const axfr_query[], + const uint16_t query_len) +{ + string path(UNIX_SOCKET_FILE); + if (getenv("B10_FROM_BUILD")) { + path = string(getenv("B10_FROM_BUILD")) + "/auth_xfrout_conn"; + } + XfroutClient xfr_client(path); + try { + xfr_client.connect(); + xfr_client.sendXfroutRequestInfo(tcp_sock, (uint8_t *)axfr_query, + query_len); + xfr_client.disconnect(); + } + catch (const exception & err) { + //if (verbose_mode) + cerr << "error handle xfr query " << UNIX_SOCKET_FILE << ":" << err.what() << endl; + } +} +} + +namespace asio_link { +// +// Helper classes for asynchronous I/O using asio +// +class TCPClient { +public: + TCPClient(AuthSrv* auth_server, io_service& io_service) : + auth_server_(auth_server), + socket_(io_service), + response_buffer_(0), + responselen_buffer_(TCP_MESSAGE_LENGTHSIZE), + response_renderer_(response_buffer_), + dns_message_(Message::PARSE) + {} + + void start() { + // Check for queued configuration commands + if (auth_server_->configSession()->hasQueuedMsgs()) { + auth_server_->configSession()->checkCommand(); + } + async_read(socket_, asio::buffer(data_, TCP_MESSAGE_LENGTHSIZE), + boost::bind(&TCPClient::headerRead, this, + placeholders::error, + placeholders::bytes_transferred)); + } + + tcp::socket& getSocket() { return (socket_); } + + void headerRead(const asio::error_code& error, + size_t bytes_transferred) + { + if (!error) { + InputBuffer dnsbuffer(data_, bytes_transferred); + + uint16_t msglen = dnsbuffer.readUint16(); + async_read(socket_, asio::buffer(data_, msglen), + + boost::bind(&TCPClient::requestRead, this, + placeholders::error, + placeholders::bytes_transferred)); + } else { + delete this; + } + } + + void requestRead(const asio::error_code& error, + size_t bytes_transferred) + { + if (!error) { + InputBuffer dnsbuffer(data_, bytes_transferred); + if (check_axfr_query(data_, bytes_transferred)) { + dispatch_axfr_query(socket_.native(), data_, bytes_transferred); + // start to get new query ? + start(); + } else { + if (auth_server_->processMessage(dnsbuffer, dns_message_, + response_renderer_, false)) { + responselen_buffer_.writeUint16( + response_buffer_.getLength()); + async_write(socket_, + asio::buffer( + responselen_buffer_.getData(), + responselen_buffer_.getLength()), + boost::bind(&TCPClient::responseWrite, this, + placeholders::error)); + } else { + delete this; + } + } + } else { + delete this; + } + } + + void responseWrite(const asio::error_code& error) { + if (!error) { + async_write(socket_, + asio::buffer(response_buffer_.getData(), + response_buffer_.getLength()), + boost::bind(&TCPClient::handleWrite, this, + placeholders::error)); + } else { + delete this; + } + } + + void handleWrite(const asio::error_code& error) { + if (!error) { + start(); // handle next request, if any. + } else { + delete this; + } + } + +private: + AuthSrv* auth_server_; + tcp::socket socket_; + OutputBuffer response_buffer_; + OutputBuffer responselen_buffer_; + MessageRenderer response_renderer_; + Message dns_message_; + enum { MAX_LENGTH = 65535 }; + static const size_t TCP_MESSAGE_LENGTHSIZE = 2; + char data_[MAX_LENGTH]; +}; + +class TCPServer { +public: + TCPServer(AuthSrv* auth_server, io_service& io_service, + int af, short port) : + auth_server_(auth_server), io_service_(io_service), + acceptor_(io_service_), listening_(new TCPClient(auth_server_, + io_service_)) + { + tcp::endpoint endpoint(af == AF_INET6 ? tcp::v6() : tcp::v4(), port); + acceptor_.open(endpoint.protocol()); + // Set v6-only (we use a different instantiation for v4, + // otherwise asio will bind to both v4 and v6 + if (af == AF_INET6) { + acceptor_.set_option(ip::v6_only(true)); + } + acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.bind(endpoint); + acceptor_.listen(); + acceptor_.async_accept(listening_->getSocket(), + boost::bind(&TCPServer::handleAccept, this, + listening_, placeholders::error)); + } + + ~TCPServer() { delete listening_; } + + void handleAccept(TCPClient* new_client, + const asio::error_code& error) + { + if (!error) { + assert(new_client == listening_); + new_client->start(); + listening_ = new TCPClient(auth_server_, io_service_); + acceptor_.async_accept(listening_->getSocket(), + boost::bind(&TCPServer::handleAccept, + this, listening_, + placeholders::error)); + } else { + delete new_client; + } + } + +private: + AuthSrv* auth_server_; + io_service& io_service_; + tcp::acceptor acceptor_; + TCPClient* listening_; +}; + +class UDPServer { +public: + UDPServer(AuthSrv* auth_server, io_service& io_service, + int af, short port) : + auth_server_(auth_server), + io_service_(io_service), + socket_(io_service, af == AF_INET6 ? udp::v6() : udp::v4()), + response_buffer_(0), + response_renderer_(response_buffer_), + dns_message_(Message::PARSE) + { + // Set v6-only (we use a different instantiation for v4, + // otherwise asio will bind to both v4 and v6 + if (af == AF_INET6) { + socket_.set_option(asio::ip::v6_only(true)); + socket_.bind(udp::endpoint(udp::v6(), port)); + } else { + socket_.bind(udp::endpoint(udp::v4(), port)); + } + startReceive(); + } + + void handleRequest(const asio::error_code& error, + size_t bytes_recvd) + { + // Check for queued configuration commands + if (auth_server_->configSession()->hasQueuedMsgs()) { + auth_server_->configSession()->checkCommand(); + } + if (!error && bytes_recvd > 0) { + InputBuffer request_buffer(data_, bytes_recvd); + + dns_message_.clear(Message::PARSE); + response_renderer_.clear(); + if (auth_server_->processMessage(request_buffer, dns_message_, + response_renderer_, true)) { + socket_.async_send_to( + asio::buffer(response_buffer_.getData(), + response_buffer_.getLength()), + sender_endpoint_, + boost::bind(&UDPServer::sendCompleted, + this, + placeholders::error, + placeholders::bytes_transferred)); + } else { + startReceive(); + } + } else { + startReceive(); + } + } + + void sendCompleted(const asio::error_code& error UNUSED_PARAM, + size_t bytes_sent UNUSED_PARAM) + { + // Even if error occurred there's nothing to do. Simply handle + // the next request. + startReceive(); + } +private: + void startReceive() { + socket_.async_receive_from( + asio::buffer(data_, MAX_LENGTH), sender_endpoint_, + boost::bind(&UDPServer::handleRequest, this, + placeholders::error, + placeholders::bytes_transferred)); + } + +private: + AuthSrv* auth_server_; + io_service& io_service_; + udp::socket socket_; + OutputBuffer response_buffer_; + MessageRenderer response_renderer_; + Message dns_message_; + udp::endpoint sender_endpoint_; + enum { MAX_LENGTH = 4096 }; + char data_[MAX_LENGTH]; +}; + +// This is a helper structure just to make the construction of IOServiceImpl +// exception safe. If the constructor of {UDP/TCP}Server throws an exception, +// the destructor of this class will automatically perform the necessary +// cleanup. +struct ServerSet { + ServerSet() : udp4_server(NULL), udp6_server(NULL), + tcp4_server(NULL), tcp6_server(NULL) + {} + ~ServerSet() { + delete udp4_server; + delete udp6_server; + delete tcp4_server; + delete tcp6_server; + } + UDPServer* udp4_server; + UDPServer* udp6_server; + TCPServer* tcp4_server; + TCPServer* tcp6_server; +}; + +class IOServiceImpl { +public: + IOServiceImpl(AuthSrv* auth_server, const char* port, + const bool use_ipv4, const bool use_ipv6); + ~IOServiceImpl(); + asio::io_service io_service_; + AuthSrv* auth_server_; + UDPServer* udp4_server_; + UDPServer* udp6_server_; + TCPServer* tcp4_server_; + TCPServer* tcp6_server_; +}; + +IOServiceImpl::IOServiceImpl(AuthSrv* auth_server, const char* const port, + const bool use_ipv4, const bool use_ipv6) : + auth_server_(auth_server), udp4_server_(NULL), udp6_server_(NULL), + tcp4_server_(NULL), tcp6_server_(NULL) +{ + ServerSet servers; + short portnum = atoi(port); + + if (use_ipv4) { + servers.udp4_server = new UDPServer(auth_server, io_service_, + AF_INET, portnum); + servers.tcp4_server = new TCPServer(auth_server, io_service_, + AF_INET, portnum); + } + if (use_ipv6) { + servers.udp6_server = new UDPServer(auth_server, io_service_, + AF_INET6, portnum); + servers.tcp6_server = new TCPServer(auth_server, io_service_, + AF_INET6, portnum); + } + + // Now we don't have to worry about exception, and need to make sure that + // the server objects won't be accidentally cleaned up. + servers.udp4_server = NULL; + servers.udp6_server = NULL; + servers.tcp4_server = NULL; + servers.tcp6_server = NULL; +} + +IOServiceImpl::~IOServiceImpl() { + delete udp4_server_; + delete udp6_server_; + delete tcp4_server_; + delete tcp6_server_; +} + +IOService::IOService(AuthSrv* auth_server, const char* const port, + const bool use_ipv4, const bool use_ipv6) { + impl_ = new IOServiceImpl(auth_server, port, use_ipv4, use_ipv6); +} + +IOService::~IOService() { + delete impl_; +} + +void +IOService::run() { + impl_->io_service_.run(); +} + +void +IOService::stop() { + impl_->io_service_.stop(); +} + +asio::io_service& +IOService::get_io_service() { + return impl_->io_service_; +} +} diff --git a/src/bin/auth/asio_link.h b/src/bin/auth/asio_link.h new file mode 100644 index 0000000000..b5c9153f83 --- /dev/null +++ b/src/bin/auth/asio_link.h @@ -0,0 +1,41 @@ +// Copyright (C) 2010 Internet Systems Consortium, Inc. ("ISC") +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +// $Id$ + +#ifndef __ASIO_LINK_H +#define __ASIO_LINK_H 1 + +class AuthSrv; + +namespace asio_link { +struct IOServiceImpl; + +class IOService { +public: + IOService(AuthSrv* auth_server, const char* port, + const bool use_ipv4, const bool use_ipv6); + ~IOService(); + void run(); + void stop(); + asio::io_service& get_io_service(); +private: + IOServiceImpl* impl_; +}; +} // asio_link +#endif // __ASIO_LINK_H + +// Local Variables: +// mode: c++ +// End: diff --git a/src/bin/auth/auth.spec.pre.in b/src/bin/auth/auth.spec.pre.in index 4bde11af50..98d7005974 100644 --- a/src/bin/auth/auth.spec.pre.in +++ b/src/bin/auth/auth.spec.pre.in @@ -1,6 +1,7 @@ { "module_spec": { "module_name": "Auth", + "module_description": "Authoritative service", "config_data": [ { "item_name": "database_file", "item_type": "string", diff --git a/src/bin/auth/auth_srv.cc b/src/bin/auth/auth_srv.cc index bdb863160e..dee60ef3b2 100644 --- a/src/bin/auth/auth_srv.cc +++ b/src/bin/auth/auth_srv.cc @@ -139,7 +139,7 @@ makeErrorMessage(Message& message, MessageRenderer& renderer, message.toWire(renderer); if (verbose_mode) { - cerr << "sending an error response (" << + cerr << "[b10-auth] sending an error response (" << boost::lexical_cast<string>(renderer.getLength()) << " bytes):\n" << message.toText() << endl; } @@ -179,7 +179,7 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message, // Ignore all responses. if (message.getHeaderFlag(MessageFlag::QR())) { if (impl_->verbose_mode_) { - cerr << "received unexpected response, ignoring" << endl; + cerr << "[b10-auth] received unexpected response, ignoring" << endl; } return (false); } @@ -192,7 +192,7 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message, message.fromWire(request_buffer); } catch (const DNSProtocolError& error) { if (impl_->verbose_mode_) { - cerr << "returning " << error.getRcode().toText() << ": " + cerr << "[b10-auth] returning " << error.getRcode().toText() << ": " << error.what() << endl; } makeErrorMessage(message, response_renderer, error.getRcode(), @@ -200,7 +200,7 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message, return (true); } catch (const Exception& ex) { if (impl_->verbose_mode_) { - cerr << "returning SERVFAIL: " << ex.what() << endl; + cerr << "[b10-auth] returning SERVFAIL: " << ex.what() << endl; } makeErrorMessage(message, response_renderer, Rcode::SERVFAIL(), impl_->verbose_mode_); @@ -208,7 +208,7 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message, } // other exceptions will be handled at a higher layer. if (impl_->verbose_mode_) { - cerr << "[AuthSrv] received a message:\n" << message.toText() << endl; + cerr << "[b10-auth] received a message:\n" << message.toText() << endl; } // Perform further protocol-level validation. @@ -216,7 +216,7 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message, // In this implementation, we only support normal queries if (message.getOpcode() != Opcode::QUERY()) { if (impl_->verbose_mode_) { - cerr << "unsupported opcode" << endl; + cerr << "[b10-auth] unsupported opcode" << endl; } makeErrorMessage(message, response_renderer, Rcode::NOTIMP(), impl_->verbose_mode_); @@ -243,7 +243,7 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message, impl_->data_sources_.doQuery(query); } catch (const Exception& ex) { if (impl_->verbose_mode_) { - cerr << "Internal error, returning SERVFAIL: " << ex.what() << endl; + cerr << "[b10-auth] Internal error, returning SERVFAIL: " << ex.what() << endl; } makeErrorMessage(message, response_renderer, Rcode::SERVFAIL(), impl_->verbose_mode_); @@ -253,7 +253,7 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message, response_renderer.setLengthLimit(udp_buffer ? remote_bufsize : 65535); message.toWire(response_renderer); if (impl_->verbose_mode_) { - cerr << "sending a response (" << + cerr << "[b10-auth] sending a response (" << boost::lexical_cast<string>(response_renderer.getLength()) << " bytes):\n" << message.toText() << endl; } @@ -281,7 +281,7 @@ AuthSrvImpl::setDbFile(const isc::data::ElementPtr config) { } if (verbose_mode_) { - cerr << "[AuthSrv] Data source database file: " << db_file_ << endl; + cerr << "[b10-auth] Data source database file: " << db_file_ << endl; } // create SQL data source @@ -313,7 +313,7 @@ AuthSrv::updateConfig(isc::data::ElementPtr new_config) { return answer; } catch (const isc::Exception& error) { if (impl_->verbose_mode_) { - cerr << "[AuthSrv] error: " << error.what() << endl; + cerr << "[b10-auth] error: " << error.what() << endl; } return isc::config::createAnswer(1, error.what()); } diff --git a/src/bin/auth/main.cc b/src/bin/auth/main.cc index 7ca6390028..2e7073699f 100644 --- a/src/bin/auth/main.cc +++ b/src/bin/auth/main.cc @@ -28,10 +28,6 @@ #include <iostream> #include <boost/foreach.hpp> -#ifdef HAVE_BOOST_SYSTEM -#include <boost/bind.hpp> -#include <boost/asio.hpp> -#endif // HAVE_BOOST_SYSTEM #include <exceptions/exceptions.h> @@ -43,26 +39,12 @@ #include <cc/data.h> #include <config/ccsession.h> -#if defined(HAVE_BOOST_SYSTEM) -#define USE_XFROUT -#include <xfr/xfrout_client.h> -#endif - #include "spec_config.h" #include "common.h" #include "auth_srv.h" +#include "asio_link.h" using namespace std; -#ifdef USE_XFROUT -using namespace isc::xfr; -#endif - -#ifdef HAVE_BOOST_SYSTEM -using namespace boost::asio; -using ip::udp; -using ip::tcp; -#endif // HAVE_BOOST_SYSTEM - using namespace isc::data; using namespace isc::cc; using namespace isc::config; @@ -79,13 +61,8 @@ const char* DNSPORT = "5300"; * todo: turn this around, and put handlers in the authserver * class itself? */ AuthSrv *auth_server; -#ifdef HAVE_BOOST_SYSTEM -// TODO: this should be a property of AuthSrv, and AuthSrv needs -// a stop() method (so the shutdown command can be handled) -boost::asio::io_service io_service_; -#else -bool running; -#endif // HAVE_BOOST_SYSTEM + +asio_link::IOService* io_service; ElementPtr my_config_handler(ElementPtr new_config) { @@ -101,605 +78,12 @@ my_command_handler(const string& command, const ElementPtr args) { /* let's add that message to our answer as well */ answer->get("result")->add(args); } else if (command == "shutdown") { -#ifdef HAVE_BOOST_SYSTEM - io_service_.stop(); -#else - running = false; -#endif // HAVE_BOOST_SYSTEM + io_service->stop(); } return answer; } -#ifdef USE_XFROUT -//TODO. The sample way for checking axfr query, the code should be merged to auth server class -static bool -check_axfr_query(char *msg_data, uint16_t msg_len) -{ - if (msg_len < 15) - return false; - - uint16_t query_type = *(uint16_t *)(msg_data + (msg_len - 4)); - if ( query_type == 0xFC00) - return true; - - return false; -} - -//TODO. Send the xfr query to xfrout module, the code should be merged to auth server class -static void -dispatch_axfr_query(int tcp_sock, char axfr_query[], uint16_t query_len) -{ - std::string path; - if (getenv("B10_FROM_SOURCE")) { - path = string(getenv("B10_FROM_SOURCE")) + - "/auth_xfrout_conn"; - } else { - path = string(UNIX_SOCKET_FILE); - } - (void)tcp_sock; - (void)axfr_query; - (void)query_len; - XfroutClient xfr_client(path); - try { - xfr_client.connect(); - xfr_client.sendXfroutRequestInfo(tcp_sock, (uint8_t *)axfr_query, query_len); - xfr_client.disconnect(); - } - catch (const std::exception & err) { - //if (verbose_mode) - cerr << "error handle xfr query:" << err.what() << endl; - } -} -#endif - -#ifdef HAVE_BOOST_SYSTEM -// -// Helper classes for asynchronous I/O using boost::asio -// -class TCPClient { -public: - TCPClient(io_service& io_service) : - socket_(io_service), - response_buffer_(0), - responselen_buffer_(TCP_MESSAGE_LENGTHSIZE), - response_renderer_(response_buffer_), - dns_message_(Message::PARSE) - {} - - void start() { - async_read(socket_, boost::asio::buffer(data_, TCP_MESSAGE_LENGTHSIZE), - boost::bind(&TCPClient::headerRead, this, - placeholders::error, - placeholders::bytes_transferred)); - } - - tcp::socket& getSocket() { return (socket_); } - - void headerRead(const boost::system::error_code& error, - size_t bytes_transferred) - { - if (!error) { - InputBuffer dnsbuffer(data_, bytes_transferred); - - uint16_t msglen = dnsbuffer.readUint16(); - async_read(socket_, boost::asio::buffer(data_, msglen), - - boost::bind(&TCPClient::requestRead, this, - placeholders::error, - placeholders::bytes_transferred)); - } else { - delete this; - } - } - - void requestRead(const boost::system::error_code& error, - size_t bytes_transferred) - { - if (!error) { - InputBuffer dnsbuffer(data_, bytes_transferred); -#ifdef USE_XFROUT - if (check_axfr_query(data_, bytes_transferred)) { - dispatch_axfr_query(socket_.native(), data_, bytes_transferred); - // start to get new query ? - start(); - } else { -#endif - if (auth_server->processMessage(dnsbuffer, dns_message_, - response_renderer_, false)) { - responselen_buffer_.writeUint16(response_buffer_.getLength()); - async_write(socket_, - boost::asio::buffer( - responselen_buffer_.getData(), - responselen_buffer_.getLength()), - boost::bind(&TCPClient::responseWrite, this, - placeholders::error)); - } else { - delete this; - } -#ifdef USE_XFROUT - } -#endif - } else { - delete this; - } - } - - void responseWrite(const boost::system::error_code& error) { - if (!error) { - async_write(socket_, - boost::asio::buffer(response_buffer_.getData(), - response_buffer_.getLength()), - boost::bind(&TCPClient::handleWrite, this, - placeholders::error)); - } else { - delete this; - } - } - - void handleWrite(const boost::system::error_code& error) { - if (!error) { - start(); // handle next request, if any. - } else { - delete this; - } - } - -private: - tcp::socket socket_; - OutputBuffer response_buffer_; - OutputBuffer responselen_buffer_; - MessageRenderer response_renderer_; - Message dns_message_; - enum { MAX_LENGTH = 65535 }; - static const size_t TCP_MESSAGE_LENGTHSIZE = 2; - char data_[MAX_LENGTH]; -}; - -class TCPServer { -public: - TCPServer(io_service& io_service, int af, short port) : - io_service_(io_service), acceptor_(io_service_), - listening_(new TCPClient(io_service_)) - { - tcp::endpoint endpoint(af == AF_INET6 ? tcp::v6() : tcp::v4(), port); - acceptor_.open(endpoint.protocol()); - // Set v6-only (we use a different instantiation for v4, - // otherwise asio will bind to both v4 and v6 - if (af == AF_INET6) { - acceptor_.set_option(ip::v6_only(true)); - } - acceptor_.set_option(tcp::acceptor::reuse_address(true)); - acceptor_.bind(endpoint); - acceptor_.listen(); - acceptor_.async_accept(listening_->getSocket(), - boost::bind(&TCPServer::handleAccept, this, - listening_, placeholders::error)); - } - - ~TCPServer() { delete listening_; } - - void handleAccept(TCPClient* new_client, - const boost::system::error_code& error) - { - if (!error) { - assert(new_client == listening_); - new_client->start(); - listening_ = new TCPClient(io_service_); - acceptor_.async_accept(listening_->getSocket(), - boost::bind(&TCPServer::handleAccept, - this, listening_, - placeholders::error)); - } else { - delete new_client; - } - } - -private: - io_service& io_service_; - tcp::acceptor acceptor_; - TCPClient* listening_; -}; - -class UDPServer { -public: - UDPServer(io_service& io_service, int af, short port) : - io_service_(io_service), - socket_(io_service, af == AF_INET6 ? udp::v6() : udp::v4()), - response_buffer_(0), - response_renderer_(response_buffer_), - dns_message_(Message::PARSE) - { - // Set v6-only (we use a different instantiation for v4, - // otherwise asio will bind to both v4 and v6 - if (af == AF_INET6) { - socket_.set_option(boost::asio::ip::v6_only(true)); - socket_.bind(udp::endpoint(udp::v6(), port)); - } else { - socket_.bind(udp::endpoint(udp::v4(), port)); - } - startReceive(); - } - - void handleRequest(const boost::system::error_code& error, - size_t bytes_recvd) - { - if (!error && bytes_recvd > 0) { - InputBuffer request_buffer(data_, bytes_recvd); - - dns_message_.clear(Message::PARSE); - response_renderer_.clear(); - if (auth_server->processMessage(request_buffer, dns_message_, - response_renderer_, true)) { - socket_.async_send_to( - boost::asio::buffer(response_buffer_.getData(), - response_buffer_.getLength()), - sender_endpoint_, - boost::bind(&UDPServer::sendCompleted, - this, - placeholders::error, - placeholders::bytes_transferred)); - } else { - startReceive(); - } - } else { - startReceive(); - } - } - - void sendCompleted(const boost::system::error_code& error UNUSED_PARAM, - size_t bytes_sent UNUSED_PARAM) - { - // Even if error occurred there's nothing to do. Simply handle - // the next request. - startReceive(); - } -private: - void startReceive() { - socket_.async_receive_from( - boost::asio::buffer(data_, MAX_LENGTH), sender_endpoint_, - boost::bind(&UDPServer::handleRequest, this, - placeholders::error, - placeholders::bytes_transferred)); - } - -private: - io_service& io_service_; - udp::socket socket_; - OutputBuffer response_buffer_; - MessageRenderer response_renderer_; - Message dns_message_; - udp::endpoint sender_endpoint_; - enum { MAX_LENGTH = 4096 }; - char data_[MAX_LENGTH]; -}; - -struct ServerSet { - ServerSet() : udp4_server(NULL), udp6_server(NULL), - tcp4_server(NULL), tcp6_server(NULL) - {} - ~ServerSet() { - delete udp4_server; - delete udp6_server; - delete tcp4_server; - delete tcp6_server; - } - UDPServer* udp4_server; - UDPServer* udp6_server; - TCPServer* tcp4_server; - TCPServer* tcp6_server; -}; - -void -run_server(const char* port, const bool use_ipv4, const bool use_ipv6, - AuthSrv* srv UNUSED_PARAM) -{ - ServerSet servers; - short portnum = atoi(port); - - if (use_ipv4) { - servers.udp4_server = new UDPServer(io_service_, AF_INET, portnum); - servers.tcp4_server = new TCPServer(io_service_, AF_INET, portnum); - } - if (use_ipv6) { - servers.udp6_server = new UDPServer(io_service_, AF_INET6, portnum); - servers.tcp6_server = new TCPServer(io_service_, AF_INET6, portnum); - } - - cout << "Server started." << endl; - io_service_.run(); -} -#else // !HAVE_BOOST_SYSTEM -struct SocketSet { - SocketSet() : ups4(-1), tps4(-1), ups6(-1), tps6(-1) {} - ~SocketSet() { - if (ups4 >= 0) { - close(ups4); - } - if (tps4 >= 0) { - close(tps4); - } - if (ups6 >= 0) { - close(ups6); - } - if (tps4 >= 0) { - close(tps6); - } - } - int ups4, tps4, ups6, tps6; -}; - -int -getUDPSocket(int af, const char* port) { - struct addrinfo hints, *res; - - memset(&hints, 0, sizeof(hints)); - hints.ai_family = af; - hints.ai_socktype = SOCK_DGRAM; - hints.ai_flags = AI_PASSIVE; - hints.ai_protocol = IPPROTO_UDP; - - int error = getaddrinfo(NULL, port, &hints, &res); - if (error != 0) { - isc_throw(FatalError, "getaddrinfo failed: " << gai_strerror(error)); - } - - int s = socket(res->ai_family, res->ai_socktype, res->ai_protocol); - if (s < 0) { - isc_throw(FatalError, "failed to open socket"); - } - - if (af == AF_INET6) { - int on = 1; - if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) { - cerr << "couldn't set IPV6_V6ONLY socket option" << endl; - // proceed anyway - } - } - - if (bind(s, res->ai_addr, res->ai_addrlen) < 0) { - isc_throw(FatalError, "binding socket failure"); - } - - return (s); -} - -int -getTCPSocket(int af, const char* port) { - struct addrinfo hints, *res; - - memset(&hints, 0, sizeof(hints)); - hints.ai_family = af; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = AI_PASSIVE; - hints.ai_protocol = IPPROTO_TCP; - - int error = getaddrinfo(NULL, port, &hints, &res); - if (error != 0) { - isc_throw(FatalError, "getaddrinfo failed: " << gai_strerror(error)); - } - - int s = socket(res->ai_family, res->ai_socktype, res->ai_protocol); - if (s < 0) { - isc_throw(FatalError, "failed to open socket"); - } - - int on = 1; - if (af == AF_INET6) { - if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) { - cerr << "couldn't set IPV6_V6ONLY socket option" << endl; - } - // proceed anyway - } - - if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) { - cerr << "couldn't set SO_REUSEADDR socket option" << endl; - } - - if (bind(s, res->ai_addr, res->ai_addrlen) < 0) { - isc_throw(FatalError, "binding socket failure"); - } - - if (listen(s, 100) < 0) { - isc_throw(FatalError, "failed to listen on a TCP socket"); - } - return (s); -} - -void -processMessageUDP(const int fd, Message& dns_message, - MessageRenderer& response_renderer) -{ - struct sockaddr_storage ss; - socklen_t sa_len = sizeof(ss); - struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss); - char recvbuf[4096]; - int cc; - - dns_message.clear(Message::PARSE); - response_renderer.clear(); - if ((cc = recvfrom(fd, recvbuf, sizeof(recvbuf), 0, sa, &sa_len)) > 0) { - InputBuffer buffer(recvbuf, cc); - if (auth_server->processMessage(buffer, dns_message, response_renderer, - true)) { - cc = sendto(fd, response_renderer.getData(), - response_renderer.getLength(), 0, sa, sa_len); - if (cc != response_renderer.getLength()) { - cerr << "UDP send error" << endl; - } - } - } else if (verbose_mode) { - cerr << "UDP receive error" << endl; - } -} - -// XXX: this function does not handle partial reads or partial writes, -// and is VERY UNSAFE - will probably be removed or rewritten -void -processMessageTCP(const int fd, Message& dns_message, - MessageRenderer& response_renderer) -{ - struct sockaddr_storage ss; - socklen_t sa_len = sizeof(ss); - struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss); - char sizebuf[2]; - int cc; - - int ts = accept(fd, sa, &sa_len); - if (ts < 0) { - if (verbose_mode) { - cerr << "[XX] TCP accept failure:" << endl; - return; - } - } - - if (verbose_mode) { - cerr << "[XX] process TCP" << endl; - } - cc = recv(ts, sizebuf, 2, 0); - if (cc < 0) { - if (verbose_mode) { - cerr << "[XX] TCP recv failure:" << endl; - } - close(ts); - return; - } - if (verbose_mode) { - cerr << "[XX] got: " << cc << endl; - } - uint16_t size, size_n; - memcpy(&size_n, sizebuf, 2); - size = ntohs(size_n); - if (verbose_mode) { - cerr << "[XX] got: " << size << endl; - } - - vector<char> message_buffer; - message_buffer.reserve(size); - cc = 0; - while (cc < size) { - if (verbose_mode) { - cerr << "[XX] cc now: " << cc << " of " << size << endl; - } - const int cc0 = recv(ts, &message_buffer[0] + cc, size - cc, 0); - if (cc0 < 0) { - if (verbose_mode) { - cerr << "TCP receive error" << endl; - close(ts); - return; - } - } - if (cc0 == 0) { - // client closed connection - close(ts); - return; - } - cc += cc0; - } - - InputBuffer buffer(&message_buffer[0], size); - dns_message.clear(Message::PARSE); - response_renderer.clear(); - if (auth_server->processMessage(buffer, dns_message, response_renderer, - false)) { - size = response_renderer.getLength(); - size_n = htons(size); - if (send(ts, &size_n, 2, 0) == 2) { - cc = send(ts, response_renderer.getData(), - response_renderer.getLength(), 0); - if (cc == -1) { - if (verbose_mode) { - cerr << "[AuthSrv] error in sending TCP response message" << - endl; - } - } else { - if (verbose_mode) { - cerr << "[XX] sent TCP response: " << cc << " bytes" - << endl; - } - } - } else { - if (verbose_mode) { - cerr << "TCP send error" << endl; - } - } - } - - // TODO: we don't check for more queries on the stream atm - close(ts); -} - -void -run_server(const char* port, const bool use_ipv4, const bool use_ipv6, - AuthSrv* srv) -{ - SocketSet socket_set; - fd_set fds_base; - int nfds = -1; - - FD_ZERO(&fds_base); - if (use_ipv4) { - socket_set.ups4 = getUDPSocket(AF_INET, port); - FD_SET(socket_set.ups4, &fds_base); - nfds = max(nfds, socket_set.ups4); - socket_set.tps4 = getTCPSocket(AF_INET, port); - FD_SET(socket_set.tps4, &fds_base); - nfds = max(nfds, socket_set.tps4); - } - if (use_ipv6) { - socket_set.ups6 = getUDPSocket(AF_INET6, port); - FD_SET(socket_set.ups6, &fds_base); - nfds = max(nfds, socket_set.ups6); - socket_set.tps6 = getTCPSocket(AF_INET6, port); - FD_SET(socket_set.tps6, &fds_base); - nfds = max(nfds, socket_set.tps6); - } - ++nfds; - - cout << "Server started." << endl; - - if (srv->configSession() == NULL) { - isc_throw(FatalError, "Config session not initalized"); - } - - int ss = srv->configSession()->getSocket(); - Message dns_message(Message::PARSE); - OutputBuffer resonse_buffer(0); - MessageRenderer response_renderer(resonse_buffer); - - running = true; - while (running) { - fd_set fds = fds_base; - FD_SET(ss, &fds); - ++nfds; - - int n = select(nfds, &fds, NULL, NULL, NULL); - if (n < 0) { - if (errno != EINTR) { - isc_throw(FatalError, "select error"); - } - continue; - } - - if (socket_set.ups4 >= 0 && FD_ISSET(socket_set.ups4, &fds)) { - processMessageUDP(socket_set.ups4, dns_message, response_renderer); - } - if (socket_set.ups6 >= 0 && FD_ISSET(socket_set.ups6, &fds)) { - processMessageUDP(socket_set.ups6, dns_message, response_renderer); - } - if (socket_set.tps4 >= 0 && FD_ISSET(socket_set.tps4, &fds)) { - processMessageTCP(socket_set.tps4, dns_message, response_renderer); - } - if (socket_set.tps6 >= 0 && FD_ISSET(socket_set.tps6, &fds)) { - processMessageTCP(socket_set.tps6, dns_message, response_renderer); - } - if (FD_ISSET(ss, &fds)) { - srv->configSession()->checkCommand(); - } - } -} -#endif // HAVE_BOOST_SYSTEM - void usage() { cerr << "Usage: b10-auth [-p port] [-4|-6]" << endl; @@ -743,7 +127,7 @@ main(int argc, char* argv[]) { } if (!use_ipv4 && !use_ipv6) { - cerr << "-4 and -6 can't coexist" << endl; + cerr << "[b10-auth] Error: -4 and -6 can't coexist" << endl; usage(); } @@ -751,8 +135,8 @@ main(int argc, char* argv[]) { int ret = 0; try { string specfile; - if (getenv("B10_FROM_SOURCE")) { - specfile = string(getenv("B10_FROM_SOURCE")) + + if (getenv("B10_FROM_BUILD")) { + specfile = string(getenv("B10_FROM_BUILD")) + "/src/bin/auth/auth.spec"; } else { specfile = string(AUTH_SPECFILE_LOCATION); @@ -761,22 +145,23 @@ main(int argc, char* argv[]) { auth_server = new AuthSrv; auth_server->setVerbose(verbose_mode); -#ifdef HAVE_BOOST_SYSTEM - ModuleCCSession cs(specfile, io_service_, my_config_handler, - my_command_handler); -#else - ModuleCCSession cs(specfile, my_config_handler, my_command_handler); -#endif + io_service = new asio_link::IOService(auth_server, port, use_ipv4, + use_ipv6); + + ModuleCCSession cs(specfile, io_service->get_io_service(), my_config_handler, my_command_handler); auth_server->setConfigSession(&cs); auth_server->updateConfig(ElementPtr()); - run_server(port, use_ipv4, use_ipv6, auth_server); + + cout << "[b10-auth] Server started." << endl; + io_service->run(); } catch (const std::exception& ex) { - cerr << ex.what() << endl; + cerr << "[b10-auth] " << ex.what() << endl; ret = 1; } + delete io_service; delete auth_server; return (ret); } diff --git a/src/bin/auth/spec_config.h.in b/src/bin/auth/spec_config.h.pre.in index da9d025cdc..52581ddbc6 100644 --- a/src/bin/auth/spec_config.h.in +++ b/src/bin/auth/spec_config.h.pre.in @@ -1,16 +1,16 @@ -// Copyright (C) 2009 Internet Systems Consortium, Inc. ("ISC") -// -// Permission to use, copy, modify, and/or distribute this software for any -// purpose with or without fee is hereby granted, provided that the above -// copyright notice and this permission notice appear in all copies. -// -// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH -// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY -// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, -// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM -// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE -// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR -// PERFORMANCE OF THIS SOFTWARE. - -#define AUTH_SPECFILE_LOCATION "@prefix@/share/@PACKAGE@/auth.spec" -#define UNIX_SOCKET_FILE "@prefix@/var/auth_xfrout_conn" +// Copyright (C) 2009 Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#define AUTH_SPECFILE_LOCATION "@prefix@/share/@PACKAGE@/auth.spec"
+#define UNIX_SOCKET_FILE "@@LOCALSTATEDIR@@/auth_xfrout_conn"
diff --git a/src/bin/auth/tests/Makefile.am b/src/bin/auth/tests/Makefile.am index f89803fac8..ed9deb552f 100644 --- a/src/bin/auth/tests/Makefile.am +++ b/src/bin/auth/tests/Makefile.am @@ -1,7 +1,10 @@ AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib AM_CPPFLAGS += -I$(top_builddir)/src/lib/dns -I$(top_srcdir)/src/bin +AM_CPPFLAGS += -I$(top_builddir)/src/lib/cc AM_CPPFLAGS += -DTEST_DATA_DIR=\"$(srcdir)/testdata\" +AM_CXXFLAGS = $(B10_CXXFLAGS) + CLEANFILES = *.gcno *.gcda TESTS = @@ -21,10 +24,6 @@ run_unittests_LDADD += $(top_builddir)/src/lib/dns/.libs/libdns.a run_unittests_LDADD += $(top_builddir)/src/lib/config/.libs/libcfgclient.a run_unittests_LDADD += $(top_builddir)/src/lib/cc/libcc.a run_unittests_LDADD += $(top_builddir)/src/lib/exceptions/.libs/libexceptions.a -if HAVE_BOOST_SYSTEM -run_unittests_LDFLAGS += $(BOOST_LDFLAGS) -run_unittests_LDADD += $(BOOST_SYSTEM_LIB) -endif endif noinst_PROGRAMS = $(TESTS) diff --git a/src/bin/bind10/TODO b/src/bin/bind10/TODO index d99b5504e2..eb0abcd3dc 100644 --- a/src/bin/bind10/TODO +++ b/src/bin/bind10/TODO @@ -1,4 +1,5 @@ -- Read msgq configuration from configuration manager +- Read msgq configuration from configuration manager (Trac #213) + https://bind10.isc.org/ticket/213 - Provide more administrator options: - Get process list - Get information on a process (returns list of times started & stopped, diff --git a/src/bin/bind10/bind10.py.in b/src/bin/bind10/bind10.py.in index d76bb05aef..040f14bce1 100644 --- a/src/bin/bind10/bind10.py.in +++ b/src/bin/bind10/bind10.py.in @@ -49,8 +49,6 @@ else: DATAROOTDIR = "@datarootdir@" SPECFILE_LOCATION = "@datadir@/@PACKAGE@/bob.spec".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX) -# TODO: start up statistics thingy - import subprocess import signal import re @@ -63,10 +61,7 @@ from optparse import OptionParser, OptionValueError import isc.cc # This is the version that gets displayed to the user. -__version__ = "v20100310" - -# Nothing at all to do with the 1990-12-10 article here: -# http://www.subgenius.com/subg-digest/v2/0056.html +__version__ = "v20100531" class RestartSchedule: """ @@ -116,7 +111,18 @@ to avoid being restarted at exactly 10 seconds.""" class ProcessInfo: """Information about a process""" - dev_null = open("/dev/null", "w") + dev_null = open(os.devnull, "w") + + def __init__(self, name, args, env={}, dev_null_stdout=False, + dev_null_stderr=False): + self.name = name + self.args = args + self.env = env + self.dev_null_stdout = dev_null_stdout + self.dev_null_stderr = dev_null_stderr + self.restart_schedule = RestartSchedule() + self._spawn() + def _spawn(self): if self.dev_null_stdout: @@ -143,73 +149,53 @@ class ProcessInfo: self.pid = self.process.pid self.restart_schedule.set_run_start_time() - def __init__(self, name, args, env={}, dev_null_stdout=False, - dev_null_stderr=False): - self.name = name - self.args = args - self.env = env - self.dev_null_stdout = dev_null_stdout - self.dev_null_stderr = dev_null_stderr - self.restart_schedule = RestartSchedule() - self._spawn() - def respawn(self): self._spawn() class BoB: """Boss of BIND class.""" - def __init__(self, c_channel_port=9912, auth_port=5300, verbose=False): + + def __init__(self, msgq_socket_file=None, auth_port=5300, verbose=False): """Initialize the Boss of BIND. This is a singleton (only one can run). - The c_channel_port specifies the TCP/IP port that the msgq - process listens on. If verbose is True, then the boss reports - what it is doing. + The msgq_socket_file specifies the UNIX domain socket file + that the msgq process listens on. + If verbose is True, then the boss reports what it is doing. """ self.verbose = verbose - self.c_channel_port = c_channel_port + self.msgq_socket_file = msgq_socket_file self.auth_port = auth_port self.cc_session = None self.ccs = None self.processes = {} self.dead_processes = {} self.runnable = False - - os.environ['ISC_MSGQ_PORT'] = str(self.c_channel_port) def config_handler(self, new_config): if self.verbose: - print("[bind10] handling new config:") - print(new_config) + sys.stdout.write("[bind10] handling new config:\n") + sys.stdout.write(new_config + "\n") answer = isc.config.ccsession.create_answer(0) return answer # TODO def command_handler(self, command, args): if self.verbose: - print("[bind10] Boss got command:") - print(command) + sys.stdout.write("[bind10] Boss got command:\n") + sys.stdout.write(command + "\n") answer = isc.config.ccsession.create_answer(1, "command not implemented") if type(command) != str: answer = isc.config.ccsession.create_answer(1, "bad command") else: cmd = command if cmd == "shutdown": - print("[bind10] got shutdown command") + sys.stdout.write("[bind10] got shutdown command\n") self.runnable = False answer = isc.config.ccsession.create_answer(0) - elif cmd == "print_message": - if args: - print(args) - answer = isc.config.ccsession.create_answer(0, args) - elif cmd == "print_settings": - print("[bind10] Full Config:") - full_config = self.ccs.get_full_config() - for item in full_config: - print(item + ": " + str(full_config[item])) - answer = isc.config.ccsession.create_answer(0) else: - answer = isc.config.ccsession.create_answer(1, "Unknown command") + answer = isc.config.ccsession.create_answer(1, + "Unknown command") return answer def startup(self): @@ -220,20 +206,23 @@ class BoB: """ # try to connect to the c-channel daemon, # to see if it is already running - c_channel_env = { "ISC_MSGQ_PORT": str(self.c_channel_port), } + c_channel_env = {} + if self.msgq_socket_file is not None: + c_channel_env["BIND10_MSGQ_SOCKET_FILE"] = self.msgq_socket_file if self.verbose: - sys.stdout.write("Checking for already running b10-msgq\n") + sys.stdout.write("[bind10] Checking for already running b10-msgq\n") # try to connect, and if we can't wait a short while try: - self.cc_session = isc.cc.Session(self.c_channel_port) - return "b10-msgq already running, cannot start" + self.cc_session = isc.cc.Session(self.msgq_socket_file) + return "b10-msgq already running, or socket file not cleaned , cannot start" except isc.cc.session.SessionError: + # this is the case we want, where the msgq is not running pass # start the c-channel daemon if self.verbose: - sys.stdout.write("Starting b10-msgq using port %d\n" % - self.c_channel_port) + if self.msgq_socket_file: + sys.stdout.write("[bind10] Starting b10-msgq\n") try: c_channel = ProcessInfo("b10-msgq", ["b10-msgq"], c_channel_env, True, not self.verbose) @@ -241,7 +230,7 @@ class BoB: return "Unable to start b10-msgq; " + str(e) self.processes[c_channel.pid] = c_channel if self.verbose: - sys.stdout.write("Started b10-msgq (PID %d)\n" % c_channel.pid) + sys.stdout.write("[bind10] Started b10-msgq (PID %d)\n" % c_channel.pid) # now connect to the c-channel cc_connect_start = time.time() @@ -252,63 +241,64 @@ class BoB: return "Unable to connect to c-channel after 5 seconds" # try to connect, and if we can't wait a short while try: - self.cc_session = isc.cc.Session(self.c_channel_port) + self.cc_session = isc.cc.Session(self.msgq_socket_file) except isc.cc.session.SessionError: time.sleep(0.1) - #self.cc_session.group_subscribe("Boss", "boss") # start the configuration manager if self.verbose: sys.stdout.write("[bind10] Starting b10-cfgmgr\n") try: bind_cfgd = ProcessInfo("b10-cfgmgr", ["b10-cfgmgr"], - { 'ISC_MSGQ_PORT': str(self.c_channel_port)}) + c_channel_env) except Exception as e: c_channel.process.kill() return "Unable to start b10-cfgmgr; " + str(e) self.processes[bind_cfgd.pid] = bind_cfgd if self.verbose: - sys.stdout.write("[bind10] Started b10-cfgmgr (PID %d)\n" % bind_cfgd.pid) + sys.stdout.write("[bind10] Started b10-cfgmgr (PID %d)\n" % + bind_cfgd.pid) - # TODO: once this interface is done, replace self.cc_session - # by this one # sleep until b10-cfgmgr is fully up and running, this is a good place # to have a (short) timeout on synchronized groupsend/receive # TODO: replace the sleep by a listen for ConfigManager started # message time.sleep(1) if self.verbose: - print("[bind10] starting ccsession") - self.ccs = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler) + sys.stdout.write("[bind10] starting ccsession\n") + self.ccs = isc.config.ModuleCCSession(SPECFILE_LOCATION, + self.config_handler, self.command_handler) self.ccs.start() if self.verbose: - print("[bind10] ccsession started") + sys.stdout.write("[bind10] ccsession started\n") - # start the xfrout before auth-server, to make sure every xfr-query can be - # processed properly. + # start the xfrout before auth-server, to make sure every xfr-query can + # be processed properly. + xfrout_args = ['b10-xfrout'] if self.verbose: - sys.stdout.write("Starting b10-xfrout\n") + sys.stdout.write("[bind10] Starting b10-xfrout\n") + xfrout_args += ['-v'] try: - xfrout = ProcessInfo("b10-xfrout", ["b10-xfrout"], - { 'ISC_MSGQ_PORT': str(self.c_channel_port)}) + xfrout = ProcessInfo("b10-xfrout", xfrout_args, + c_channel_env ) except Exception as e: c_channel.process.kill() bind_cfgd.process.kill() return "Unable to start b10-xfrout; " + str(e) self.processes[xfrout.pid] = xfrout if self.verbose: - sys.stdout.write("Started b10-xfrout (PID %d)\n" % xfrout.pid) + sys.stdout.write("[bind10] Started b10-xfrout (PID %d)\n" % xfrout.pid) # start b10-auth # XXX: this must be read from the configuration manager in the future authargs = ['b10-auth', '-p', str(self.auth_port)] if self.verbose: - sys.stdout.write("Starting b10-auth using port %d\n" % + sys.stdout.write("[bind10] Starting b10-auth using port %d\n" % self.auth_port) authargs += ['-v'] try: auth = ProcessInfo("b10-auth", authargs, - { 'ISC_MSGQ_PORT': str(self.c_channel_port)}) + c_channel_env) except Exception as e: c_channel.process.kill() bind_cfgd.process.kill() @@ -316,14 +306,16 @@ class BoB: return "Unable to start b10-auth; " + str(e) self.processes[auth.pid] = auth if self.verbose: - sys.stdout.write("Started b10-auth (PID %d)\n" % auth.pid) + sys.stdout.write("[bind10] Started b10-auth (PID %d)\n" % auth.pid) # start b10-xfrin + xfrin_args = ['b10-xfrin'] if self.verbose: - sys.stdout.write("Starting b10-xfrin\n") + sys.stdout.write("[bind10] Starting b10-xfrin\n") + xfrin_args += ['-v'] try: - xfrind = ProcessInfo("b10-xfrin", ['b10-xfrin'], - { 'ISC_MSGQ_PORT': str(self.c_channel_port)}) + xfrind = ProcessInfo("b10-xfrin", xfrin_args, + c_channel_env) except Exception as e: c_channel.process.kill() bind_cfgd.process.kill() @@ -332,15 +324,17 @@ class BoB: return "Unable to start b10-xfrin; " + str(e) self.processes[xfrind.pid] = xfrind if self.verbose: - sys.stdout.write("Started b10-xfrin (PID %d)\n" % xfrind.pid) + sys.stdout.write("[bind10] Started b10-xfrin (PID %d)\n" % xfrind.pid) # start the b10-cmdctl # XXX: we hardcode port 8080 + cmdctl_args = ['b10-cmdctl'] if self.verbose: - sys.stdout.write("Starting b10-cmdctl on port 8080\n") + sys.stdout.write("[bind10] Starting b10-cmdctl on port 8080\n") + cmdctl_args += ['-v'] try: - cmd_ctrld = ProcessInfo("b10-cmdctl", ['b10-cmdctl'], - { 'ISC_MSGQ_PORT': str(self.c_channel_port)}) + cmd_ctrld = ProcessInfo("b10-cmdctl", cmdctl_args, + c_channel_env) except Exception as e: c_channel.process.kill() bind_cfgd.process.kill() @@ -350,7 +344,7 @@ class BoB: return "Unable to start b10-cmdctl; " + str(e) self.processes[cmd_ctrld.pid] = cmd_ctrld if self.verbose: - sys.stdout.write("Started b10-cmdctl (PID %d)\n" % cmd_ctrld.pid) + sys.stdout.write("[bind10] Started b10-cmdctl (PID %d)\n" % cmd_ctrld.pid) self.runnable = True @@ -373,7 +367,7 @@ class BoB: def shutdown(self): """Stop the BoB instance.""" if self.verbose: - sys.stdout.write("Stopping the server.\n") + sys.stdout.write("[bind10] Stopping the server.\n") # first try using the BIND 10 request to stop try: self.stop_all_processes() @@ -386,7 +380,7 @@ class BoB: processes_to_stop = list(self.processes.values()) for proc_info in processes_to_stop: if self.verbose: - sys.stdout.write("Sending SIGTERM to %s (PID %d).\n" % + sys.stdout.write("[bind10] Sending SIGTERM to %s (PID %d).\n" % (proc_info.name, proc_info.pid)) try: proc_info.process.terminate() @@ -402,7 +396,7 @@ class BoB: processes_to_stop = list(self.processes.values()) for proc_info in processes_to_stop: if self.verbose: - sys.stdout.write("Sending SIGKILL to %s (PID %d).\n" % + sys.stdout.write("[bind10] Sending SIGKILL to %s (PID %d).\n" % (proc_info.name, proc_info.pid)) try: proc_info.process.kill() @@ -411,7 +405,7 @@ class BoB: # finally exited) pass if self.verbose: - sys.stdout.write("All processes ended, server done.\n") + sys.stdout.write("[bind10] All processes ended, server done.\n") def reap_children(self): """Check to see if any of our child processes have exited, @@ -430,70 +424,15 @@ class BoB: proc_info.restart_schedule.set_run_stop_time() self.dead_processes[proc_info.pid] = proc_info if self.verbose: - sys.stdout.write("Process %s (PID %d) died.\n" % + sys.stdout.write("[bind10] Process %s (PID %d) died.\n" % (proc_info.name, proc_info.pid)) if proc_info.name == "b10-msgq": if self.verbose and self.runnable: sys.stdout.write( - "The b10-msgq process died, shutting down.\n") + "[bind10] The b10-msgq process died, shutting down.\n") self.runnable = False else: - sys.stdout.write("Unknown child pid %d exited.\n" % pid) - - # 'old' command style, uncommented for now - # move the handling below move to command_handler please - #def recv_and_process_cc_msg(self): - #"""Receive and process the next message on the c-channel, - #if any.""" - #self.ccs.checkCommand() - #msg, envelope = self.cc_session.group_recvmsg(False) - #print(msg) - #if msg is None: - # return - #if not ((type(msg) is dict) and (type(envelope) is dict)): - # if self.verbose: - # sys.stdout.write("Non-dictionary message\n") - # return - #if not "command" in msg: - # if self.verbose: - # if "msg" in envelope: - # del envelope['msg'] - # sys.stdout.write("Unknown message received\n") - # sys.stdout.write(pprint.pformat(envelope) + "\n") - # sys.stdout.write(pprint.pformat(msg) + "\n") - # return - - #cmd = msg['command'] - #if not (type(cmd) is list): - # if self.verbose: - # sys.stdout.write("Non-list command\n") - # return - # - # done checking and extracting... time to execute the command - #if cmd[0] == "shutdown": - # if self.verbose: - # sys.stdout.write("shutdown command received\n") - # self.runnable = False - # # XXX: reply here? - #elif cmd[0] == "getProcessList": - # if self.verbose: - # sys.stdout.write("getProcessList command received\n") - # live_processes = [ ] - # for proc_info in processes: - # live_processes.append({ "name": proc_info.name, - # "args": proc_info.args, - # "pid": proc_info.pid, }) - # dead_processes = [ ] - # for proc_info in dead_processes: - # dead_processes.append({ "name": proc_info.name, - # "args": proc_info.args, }) - # cc.group_reply(envelope, { "response": cmd, - # "sent": msg["sent"], - # "live_processes": live_processes, - # "dead_processes": dead_processes, }) - #else: - # if self.verbose: - # sys.stdout.write("Unknown command %s\n" % str(cmd)) + sys.stdout.write("[bind10] Unknown child pid %d exited.\n" % pid) def restart_processes(self): """Restart any dead processes.""" @@ -507,10 +446,6 @@ class BoB: for proc_info in self.dead_processes.values(): restart_time = proc_info.restart_schedule.get_restart_time(now) if restart_time > now: -# if self.verbose: -# sys.stdout.write("Dead %s process waiting %.1f seconds "\ -# "for resurrection\n" % -# (proc_info.name, (restart_time-now))) if (next_restart is None) or (next_restart > restart_time): next_restart = restart_time still_dead[proc_info.pid] = proc_info @@ -582,9 +517,9 @@ def main(): parser.add_option("-p", "--port", dest="auth_port", type="string", action="callback", callback=check_port, default="5300", help="port the b10-auth daemon will use (default 5300)") - parser.add_option("-m", "--msgq-port", dest="msgq_port", type="string", - action="callback", callback=check_port, default="9912", - help="port the b10-msgq daemon will use (default 9912)") + parser.add_option("-m", "--msgq-socket-file", dest="msgq_socket_file", + type="string", default=None, + help="UNIX domain socket file the b10-msgq daemon will use") (options, args) = parser.parse_args() # Announce startup. @@ -607,11 +542,11 @@ def main(): signal.signal(signal.SIGTERM, fatal_signal) # Go bob! - boss_of_bind = BoB(int(options.msgq_port), int(options.auth_port), + boss_of_bind = BoB(options.msgq_socket_file, int(options.auth_port), options.verbose) startup_result = boss_of_bind.startup() if startup_result: - sys.stderr.write("Error on startup: %s\n" % startup_result) + sys.stderr.write("[bind10] Error on startup: %s\n" % startup_result) sys.exit(1) # In our main loop, we check for dead processes or messages @@ -637,7 +572,7 @@ def main(): if err.args[0] == errno.EINTR: (rlist, wlist, xlist) = ([], [], []) else: - sys.stderr.write("Error with select(); %s\n" % err) + sys.stderr.write("[bind10] Error with select(); %s\n" % err) break for fd in rlist + xlist: diff --git a/src/bin/bind10/bob.spec b/src/bin/bind10/bob.spec index 796bea043f..b890487559 100644 --- a/src/bin/bind10/bob.spec +++ b/src/bin/bind10/bob.spec @@ -1,37 +1,11 @@ { "module_spec": { "module_name": "Boss", + "module_description": "Master process", "config_data": [ - { - "item_name": "example_string", - "item_type": "string", - "item_optional": False, - "item_default": "Just an example string configuration value" - }, - { - "item_name": "example_int", - "item_type": "integer", - "item_optional": False, - "item_default": 1 - } ], "commands": [ { - "command_name": "print_message", - "command_description": "Print the given message to stdout", - "command_args": [ { - "item_name": "message", - "item_type": "string", - "item_optional": False, - "item_default": "" - } ] - }, - { - "command_name": "print_settings", - "command_description": "Print some_string and some_int to stdout", - "command_args": [] - }, - { "command_name": "shutdown", "command_description": "Shut down BIND 10", "command_args": [] diff --git a/src/bin/bind10/run_bind10.sh.in b/src/bin/bind10/run_bind10.sh.in index 69e8c56615..3544d6fad5 100644 --- a/src/bin/bind10/run_bind10.sh.in +++ b/src/bin/bind10/run_bind10.sh.in @@ -24,10 +24,19 @@ PATH=@abs_top_builddir@/src/bin/msgq:@abs_top_builddir@/src/bin/auth:@abs_top_bu export PATH PYTHONPATH=@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/lib/dns/python/.libs:@abs_top_builddir@/src/lib/xfr/.libs +#PYTHONPATH=@abs_top_srcdir@/src/lib/python:@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/lib/dns/.libs:@abs_top_builddir@/src/lib/xfr/.libs export PYTHONPATH B10_FROM_SOURCE=@abs_top_srcdir@ export B10_FROM_SOURCE +# TODO: We need to do this feature based (ie. no general from_source) +# But right now we need a second one because some spec files are +# generated and hence end up under builddir +B10_FROM_BUILD=@abs_top_builddir@ +export B10_FROM_BUILD + +BIND10_MSGQ_SOCKET_FILE=@abs_top_builddir@/msgq_socket +export BIND10_MSGQ_SOCKET_FILE cd ${BIND10_PATH} exec ${PYTHON_EXEC} -O bind10 $* diff --git a/src/bin/bind10/tests/Makefile.am b/src/bin/bind10/tests/Makefile.am index c003e069ee..d13993a3ee 100644 --- a/src/bin/bind10/tests/Makefile.am +++ b/src/bin/bind10/tests/Makefile.am @@ -7,6 +7,6 @@ PYCOVERAGE = $(PYTHON) check-local: for pytest in $(PYTESTS) ; do \ echo Running test: $$pytest ; \ - env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/bin/bind10 \ + env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_builddir)/src/bin/bind10 \ $(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \ done diff --git a/src/bin/bind10/tests/bind10_test.py b/src/bin/bind10/tests/bind10_test.py index 5d2ffb46d9..3f305283e4 100644 --- a/src/bin/bind10/tests/bind10_test.py +++ b/src/bin/bind10/tests/bind10_test.py @@ -75,16 +75,16 @@ class TestBoB(unittest.TestCase): def test_init(self): bob = BoB() self.assertEqual(bob.verbose, False) - self.assertEqual(bob.c_channel_port, 9912) + self.assertEqual(bob.msgq_socket_file, None) self.assertEqual(bob.cc_session, None) self.assertEqual(bob.processes, {}) self.assertEqual(bob.dead_processes, {}) self.assertEqual(bob.runnable, False) - def test_init_alternate_port(self): - bob = BoB(2199) + def test_init_alternate_socket(self): + bob = BoB("alt_socket_file") self.assertEqual(bob.verbose, False) - self.assertEqual(bob.c_channel_port, 2199) + self.assertEqual(bob.msgq_socket_file, "alt_socket_file") self.assertEqual(bob.cc_session, None) self.assertEqual(bob.processes, {}) self.assertEqual(bob.dead_processes, {}) diff --git a/src/bin/bindctl/bindcmd.py b/src/bin/bindctl/bindcmd.py index e13d673d95..138024f05f 100644 --- a/src/bin/bindctl/bindcmd.py +++ b/src/bin/bindctl/bindcmd.py @@ -49,7 +49,7 @@ try: except ImportError: my_readline = sys.stdin.readline - +CONFIG_MODULE_NAME = 'config' CONST_BINDCTL_HELP = """ usage: <module name> <command name> [param1 = value1 [, param2 = value2]] Type Tab character to get the hint of module/command/parameters. @@ -87,8 +87,8 @@ class BindCmdInterpreter(Cmd): '''Generate one session id for the connection. ''' rand = os.urandom(16) now = time.time() - ip = socket.gethostbyname(socket.gethostname()) - session_id = sha1(("%s%s%s" %(rand, now, ip)).encode()) + session_id = sha1(("%s%s%s" %(rand, now, + socket.gethostname())).encode()) digest = session_id.hexdigest() return digest @@ -180,13 +180,9 @@ class BindCmdInterpreter(Cmd): def _update_commands(self): - '''Get the commands of all modules. ''' - cmd_spec = self.send_GET('/command_spec') - if not cmd_spec: - return - - for module_name in cmd_spec.keys(): - self._prepare_module_commands(module_name, cmd_spec[module_name]) + '''Update the commands of all modules. ''' + for module_name in self.config_data.get_config_item_list(): + self._prepare_module_commands(self.config_data.get_module_spec(module_name)) def send_GET(self, url, body = None): '''Send GET request to cmdctl, session id is send with the name @@ -222,17 +218,18 @@ class BindCmdInterpreter(Cmd): self.prompt = self.location + self.prompt_end return stop - def _prepare_module_commands(self, module_name, module_commands): + def _prepare_module_commands(self, module_spec): '''Prepare the module commands''' - module = ModuleInfo(name = module_name, - desc = "same here") - for command in module_commands: + module = ModuleInfo(name = module_spec.get_module_name(), + desc = module_spec.get_module_description()) + for command in module_spec.get_commands_spec(): cmd = CommandInfo(name = command["command_name"], desc = command["command_description"]) for arg in command["command_args"]: param = ParamInfo(name = arg["item_name"], type = arg["item_type"], - optional = bool(arg["item_optional"])) + optional = bool(arg["item_optional"]), + param_spec = arg) if ("item_default" in arg): param.default = arg["item_default"] cmd.add_param(param) @@ -305,12 +302,20 @@ class BindCmdInterpreter(Cmd): if not name in params and not param_nr in params: raise CmdMissParamSyntaxError(cmd.module, cmd.command, name) param_nr += 1 + + # Convert parameter value according parameter spec file. + # Ignore check for commands belongs to module 'config' + if cmd.module != CONFIG_MODULE_NAME: + for param_name in cmd.params: + param_spec = command_info.get_param_with_name(param_name).param_spec + cmd.params[param_name] = isc.config.config_data.convert_type(param_spec, cmd.params[param_name]) + def _handle_cmd(self, cmd): '''Handle a command entered by the user''' if cmd.command == "help" or ("help" in cmd.params.keys()): self._handle_help(cmd) - elif cmd.module == "config": + elif cmd.module == CONFIG_MODULE_NAME: self.apply_config_cmd(cmd) else: self.apply_cmd(cmd) @@ -361,7 +366,7 @@ class BindCmdInterpreter(Cmd): else: hints = self._get_param_startswith(cmd.module, cmd.command, text) - if cmd.module == "config": + if cmd.module == CONFIG_MODULE_NAME: # grm text has been stripped of slashes... my_text = self.location + "/" + cur_line.rpartition(" ")[2] list = self.config_data.get_config_item_list(my_text.rpartition("/")[0], True) @@ -439,6 +444,9 @@ class BindCmdInterpreter(Cmd): except BindCtlException as e: print("Error! ", e) self._print_correct_usage(e) + except isc.cc.data.DataTypeError as e: + print("Error! ", e) + self._print_correct_usage(e) def _print_correct_usage(self, ept): diff --git a/src/bin/bindctl/bindctl-source.py.in b/src/bin/bindctl/bindctl-source.py.in index 4e522a30b9..bbe3b50273 100644 --- a/src/bin/bindctl/bindctl-source.py.in +++ b/src/bin/bindctl/bindctl-source.py.in @@ -29,7 +29,7 @@ __version__ = 'Bindctl' def prepare_config_commands(tool): '''Prepare fixed commands for local configuration editing''' - module = ModuleInfo(name = "config", desc = "Configuration commands") + module = ModuleInfo(name = CONFIG_MODULE_NAME, desc = "Configuration commands") cmd = CommandInfo(name = "show", desc = "Show configuration") param = ParamInfo(name = "identifier", type = "string", optional=True) cmd.add_param(param) diff --git a/src/bin/bindctl/cmdparse.py b/src/bin/bindctl/cmdparse.py index 6e021b2c0e..e911cd2200 100644 --- a/src/bin/bindctl/cmdparse.py +++ b/src/bin/bindctl/cmdparse.py @@ -24,8 +24,8 @@ except ImportError: from bindctl.mycollections import OrderedDict param_name_str = "^\s*(?P<param_name>[\w]+)\s*=\s*" -param_value_str = "(?P<param_value>[\w\./-]+)" -param_value_with_quota_str = "[\"\'](?P<param_value>[\w\., /-]+)[\"\']" +param_value_str = "(?P<param_value>[\w\.:/-]+)" +param_value_with_quota_str = "[\"\'](?P<param_value>[\w\.:, /-]+)[\"\']" next_params_str = "(?P<blank>\s*)(?P<comma>,?)(?P<next_params>.*)$" PARAM_WITH_QUOTA_PATTERN = re.compile(param_name_str + @@ -116,10 +116,3 @@ class BindCmdParse: if not groups.group('blank') and \ not groups.group('comma'): raise CmdParamFormatError(self.module, self.command) - - - - - - - diff --git a/src/bin/bindctl/exception.py b/src/bin/bindctl/exception.py index bfb38426d7..1409f69526 100644 --- a/src/bin/bindctl/exception.py +++ b/src/bin/bindctl/exception.py @@ -62,7 +62,7 @@ class CmdParamFormatError(CmdFormatError): self.command = command def __str__(self): - return "Parameter format error, it should like 'key = value'" + return "Parameter format error, it should be 'key = value'" # Begin define the exception for syntax @@ -115,5 +115,3 @@ class CmdMissParamSyntaxError(CmdSyntaxError): def __str__(self): return str("Parameter '%s' is missed for command '%s' of module '%s'" % (self.param, self.command, self.module)) - - diff --git a/src/bin/bindctl/moduleinfo.py b/src/bin/bindctl/moduleinfo.py index 40326c7c03..015ef16cf6 100644 --- a/src/bin/bindctl/moduleinfo.py +++ b/src/bin/bindctl/moduleinfo.py @@ -33,17 +33,21 @@ PARAM_NODE_NAME = 'param' class ParamInfo: """One parameter of one command. - Each command parameter has four attributes: - parameter name, parameter type, parameter value, and parameter description + Each command parameter has five attributes: + parameter name, parameter type, parameter value, + parameter description and paramter's spec(got from + module spec file). """ def __init__(self, name, desc = '', type = STRING_TYPE, - optional = False, value = '', default_value = ''): + optional = False, value = '', default_value = '', + param_spec = None): self.name = name self.type = type self.value = value self.default_value = default_value self.desc = desc self.is_optional = optional + self.param_spec = param_spec def __str__(self): return str("\t%s <type: %s> \t(%s)" % (self.name, self.type, self.desc)) diff --git a/src/bin/bindctl/run_bindctl.sh.in b/src/bin/bindctl/run_bindctl.sh.in index e1e878753a..aa570224d1 100644 --- a/src/bin/bindctl/run_bindctl.sh.in +++ b/src/bin/bindctl/run_bindctl.sh.in @@ -18,9 +18,9 @@ PYTHON_EXEC=${PYTHON_EXEC:-@PYTHON@} export PYTHON_EXEC -BINDCTL_PATH=@abs_top_srcdir@/src/bin/bindctl +BINDCTL_PATH=@abs_top_builddir@/src/bin/bindctl -PYTHONPATH=@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/bin +PYTHONPATH=@abs_top_srcdir@/src/bin:@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/bin:@abs_top_srcdir@/src/lib/python export PYTHONPATH B10_FROM_SOURCE=@abs_top_srcdir@ diff --git a/src/bin/bindctl/tests/Makefile.am b/src/bin/bindctl/tests/Makefile.am index ecc931a544..e0ce57692f 100644 --- a/src/bin/bindctl/tests/Makefile.am +++ b/src/bin/bindctl/tests/Makefile.am @@ -7,6 +7,6 @@ PYCOVERAGE = $(PYTHON) check-local: for pytest in $(PYTESTS) ; do \ echo Running test: $$pytest ; \ - env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_srcdir)/src/bin \ + env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_srcdir)/src/bin \ $(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \ done diff --git a/src/bin/bindctl/tests/bindctl_test.py b/src/bin/bindctl/tests/bindctl_test.py index 9d4f6bdfe9..90cf6e4bc8 100644 --- a/src/bin/bindctl/tests/bindctl_test.py +++ b/src/bin/bindctl/tests/bindctl_test.py @@ -15,6 +15,7 @@ import unittest +import isc.cc.data from bindctl import cmdparse from bindctl import bindcmd from bindctl.moduleinfo import * @@ -92,14 +93,21 @@ class TestCmdSyntax(unittest.TestCase): """Create one bindcmd""" tool = bindcmd.BindCmdInterpreter() - zone_file_param = ParamInfo(name = "zone_file") - zone_name = ParamInfo(name = 'zone_name') + string_spec = { 'item_type' : 'string', + 'item_optional' : False, + 'item_default' : ''} + int_spec = { 'item_type' : 'integer', + 'item_optional' : False, + 'item_default' : 10} + zone_file_param = ParamInfo(name = "zone_file", param_spec = string_spec) + zone_name = ParamInfo(name = 'zone_name', param_spec = string_spec) load_cmd = CommandInfo(name = "load") load_cmd.add_param(zone_file_param) load_cmd.add_param(zone_name) - param_master = ParamInfo(name = "master", optional = True) - param_allow_update = ParamInfo(name = "allow_update", optional = True) + param_master = ParamInfo(name = "master", optional = True, param_spec = string_spec) + param_master = ParamInfo(name = "port", optional = True, param_spec = int_spec) + param_allow_update = ParamInfo(name = "allow_update", optional = True, param_spec = string_spec) set_cmd = CommandInfo(name = "set") set_cmd.add_param(param_master) set_cmd.add_param(param_allow_update) @@ -138,6 +146,7 @@ class TestCmdSyntax(unittest.TestCase): self.no_assert_raise("zone help help='dd' ") self.no_assert_raise("zone set allow_update='1.1.1.1' zone_name='cn'") self.no_assert_raise("zone set zone_name='cn'") + self.my_assert_raise(isc.cc.data.DataTypeError, "zone set zone_name ='cn', port='cn'") self.no_assert_raise("zone reload_all") diff --git a/src/bin/cfgmgr/b10-cfgmgr.py.in b/src/bin/cfgmgr/b10-cfgmgr.py.in index b09a793037..563bbcdf7e 100644 --- a/src/bin/cfgmgr/b10-cfgmgr.py.in +++ b/src/bin/cfgmgr/b10-cfgmgr.py.in @@ -50,6 +50,6 @@ if __name__ == "__main__": print("[b10-cfgmgr] Error creating config manager, " "is the command channel daemon running?") except KeyboardInterrupt as kie: - print("Got ctrl-c, exit") + print("[b10-cfgmgr] Interrupted, exiting") if cm: cm.write_config() diff --git a/src/bin/cmdctl/cmdctl.py.in b/src/bin/cmdctl/cmdctl.py.in index 8e845f69c2..4e5a73351d 100644 --- a/src/bin/cmdctl/cmdctl.py.in +++ b/src/bin/cmdctl/cmdctl.py.in @@ -219,8 +219,7 @@ class CommandControl(): self._verbose = verbose self.cc = isc.cc.Session() self.cc.group_subscribe('Cmd-Ctrld') - self.command_spec = self.get_cmd_specification() - self.config_spec = self.get_data_specification() + self.module_spec = self.get_module_specification() self.config_data = self.get_config_data() def _parse_command_result(self, rcode, reply): @@ -229,10 +228,6 @@ class CommandControl(): return {} return reply - def get_cmd_specification(self): - rcode, reply = self.send_command('ConfigManager', isc.config.ccsession.COMMAND_GET_COMMANDS_SPEC) - return self._parse_command_result(rcode, reply) - def get_config_data(self): '''Get config data for all modules from configmanager ''' rcode, reply = self.send_command('ConfigManager', isc.config.ccsession.COMMAND_GET_CONFIG) @@ -244,7 +239,7 @@ class CommandControl(): if module_name == 'ConfigManager' and command_name == isc.config.ccsession.COMMAND_SET_CONFIG: self.config_data = self.get_config_data() - def get_data_specification(self): + def get_module_specification(self): rcode, reply = self.send_command('ConfigManager', isc.config.ccsession.COMMAND_GET_MODULE_SPEC) return self._parse_command_result(rcode, reply) @@ -253,10 +248,8 @@ class CommandControl(): (message, env) = self.cc.group_recvmsg(True) command, arg = isc.config.ccsession.parse_command(message) while command: - if command == isc.config.ccsession.COMMAND_COMMANDS_UPDATE: - self.command_spec[arg[0]] = arg[1] - elif command == isc.config.ccsession.COMMAND_SPECIFICATION_UPDATE: - self.config_spec[arg[0]] = arg[1] + if command == isc.config.ccsession.COMMAND_MODULE_SPECIFICATION_UPDATE: + self.module_spec[arg[0]] = arg[1] elif command == isc.config.ccsession.COMMAND_SHUTDOWN: return False (message, env) = self.cc.group_recvmsg(True) @@ -270,27 +263,20 @@ class CommandControl(): Return rcode, dict. rcode = 0: dict is the correct returned value. rcode > 0: dict is : { 'error' : 'error reason' } - - TODO. add check for parameters. ''' # core module ConfigManager does not have a specification file if module_name == 'ConfigManager': return self.send_command(module_name, command_name, params) - if module_name not in self.command_spec.keys(): + if module_name not in self.module_spec.keys(): return 1, {'error' : 'unknown module'} - - cmd_valid = False - commands = self.command_spec[module_name] - for cmd in commands: - if cmd['command_name'] == command_name: - cmd_valid = True - break - - if not cmd_valid: - return 1, {'error' : 'unknown command'} - + + spec_obj = isc.config.module_spec.ModuleSpec(self.module_spec[module_name], False) + errors = [] + if not spec_obj.validate_command(command_name, params, errors): + return 1, {'error': errors[0]} + return self.send_command(module_name, command_name, params) def send_command(self, module_name, command_name, params = None): @@ -326,7 +312,7 @@ class CommandControl(): return 1, {'error': errstr} def log_info(self, msg): - sys.stdout.write(msg) + sys.stdout.write("[b10-cmdctl] %s\n" % str(msg)) class SecureHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer): '''Make the server address can be reused.''' @@ -383,12 +369,10 @@ class SecureHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer): '''Currently only support the following three url GET request ''' rcode, reply = http.client.NO_CONTENT, [] if not module: - if id == 'command_spec': - rcode, reply = http.client.OK, self.cmdctrl.command_spec - elif id == 'config_data': + if id == 'config_data': rcode, reply = http.client.OK, self.cmdctrl.config_data - elif id == 'config_spec': - rcode, reply = http.client.OK, self.cmdctrl.config_spec + elif id == 'module_spec': + rcode, reply = http.client.OK, self.cmdctrl.module_spec return rcode, reply @@ -416,7 +400,7 @@ class SecureHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer): return self.cmdctrl.send_command_with_check(module_name, command_name, params) def log_info(self, msg): - sys.stdout.write(msg) + sys.stdout.write("[b10-cmdctl] %s\n" % str(msg)) httpd = None diff --git a/src/bin/cmdctl/cmdctl.spec b/src/bin/cmdctl/cmdctl.spec index e4379736e6..4b28a2a9ef 100644 --- a/src/bin/cmdctl/cmdctl.spec +++ b/src/bin/cmdctl/cmdctl.spec @@ -1,6 +1,7 @@ { "module_spec": { "module_name": "Cmdctl", + "module_description": "Interface for command and control", "config_data": [ { "item_name": "key_file", diff --git a/src/bin/cmdctl/tests/Makefile.am b/src/bin/cmdctl/tests/Makefile.am index 0e120df491..79e8827648 100644 --- a/src/bin/cmdctl/tests/Makefile.am +++ b/src/bin/cmdctl/tests/Makefile.am @@ -7,6 +7,6 @@ PYCOVERAGE = $(PYTHON) check-local: for pytest in $(PYTESTS) ; do \ echo Running test: $$pytest ; \ - env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/bin/cmdctl \ + env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_builddir)/src/bin/cmdctl \ $(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \ done diff --git a/src/bin/cmdctl/tests/cmdctl_test.py b/src/bin/cmdctl/tests/cmdctl_test.py index 363a2caedd..93f9c75df2 100644 --- a/src/bin/cmdctl/tests/cmdctl_test.py +++ b/src/bin/cmdctl/tests/cmdctl_test.py @@ -51,9 +51,8 @@ class MySecureHTTPServer(SecureHTTPServer): class MyCommandControl(CommandControl): def __init__(self): - self.command_spec = {} - self.config_spec = {} self.config_data = {} + self.module_spec = {} def send_command(self, mod, cmd, param): return 0, {} @@ -66,6 +65,11 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): self.handler.server.user_sessions = {} self.handler.server.user_infos = {} self.handler.headers = {} + self.handler.rfile = open("check.tmp", 'w+b') + + def tearDown(self): + self.handler.rfile.close() + os.remove('check.tmp') def test_parse_request_path(self): self.handler.path = '' @@ -120,7 +124,7 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): def test_do_GET_3(self): self.handler.headers['cookie'] = 12346 self.handler.server.user_sessions[12346] = time.time() + 1000000 - path_vec = ['command_spec', 'config_data', 'config_spec'] + path_vec = ['config_data', 'module_spec'] for path in path_vec: self.handler.path = '/' + path self.handler.do_GET() @@ -145,7 +149,6 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): self.assertEqual(msg, ['invalid username or password']) def test_check_user_name_and_pwd_1(self): - self.handler.rfile = open("check.tmp", 'w+b') user_info = {'username':'root', 'password':'abc123'} len = self.handler.rfile.write(json.dumps(user_info).encode()) self.handler.headers['Content-Length'] = len @@ -155,11 +158,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): ret, msg = self.handler._check_user_name_and_pwd() self.assertTrue(ret == False) self.assertEqual(msg, ['password doesn\'t match']) - self.handler.rfile.close() - os.remove('check.tmp') def test_check_user_name_and_pwd_2(self): - self.handler.rfile = open("check.tmp", 'w+b') user_info = {'username':'root', 'password':'abc123'} len = self.handler.rfile.write(json.dumps(user_info).encode()) self.handler.headers['Content-Length'] = len - 1 @@ -168,11 +168,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): ret, msg = self.handler._check_user_name_and_pwd() self.assertTrue(ret == False) self.assertEqual(msg, ['invalid username or password']) - self.handler.rfile.close() - os.remove('check.tmp') def test_check_user_name_and_pwd_3(self): - self.handler.rfile = open("check.tmp", 'w+b') user_info = {'usernae':'root', 'password':'abc123'} len = self.handler.rfile.write(json.dumps(user_info).encode()) self.handler.headers['Content-Length'] = len @@ -181,11 +178,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): ret, msg = self.handler._check_user_name_and_pwd() self.assertTrue(ret == False) self.assertEqual(msg, ['need user name']) - self.handler.rfile.close() - os.remove('check.tmp') def test_check_user_name_and_pwd_4(self): - self.handler.rfile = open("check.tmp", 'w+b') user_info = {'username':'root', 'pssword':'abc123'} len = self.handler.rfile.write(json.dumps(user_info).encode()) self.handler.headers['Content-Length'] = len @@ -195,11 +189,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): ret, msg = self.handler._check_user_name_and_pwd() self.assertTrue(ret == False) self.assertEqual(msg, ['need password']) - self.handler.rfile.close() - os.remove('check.tmp') def test_check_user_name_and_pwd_5(self): - self.handler.rfile = open("check.tmp", 'w+b') user_info = {'username':'root', 'password':'abc123'} len = self.handler.rfile.write(json.dumps(user_info).encode()) self.handler.headers['Content-Length'] = len @@ -208,8 +199,6 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): ret, msg = self.handler._check_user_name_and_pwd() self.assertTrue(ret == False) self.assertEqual(msg, ['user doesn\'t exist']) - self.handler.rfile.close() - os.remove('check.tmp') def test_do_POST(self): self.handler.headers = {} @@ -235,20 +224,45 @@ class TestSecureHTTPRequestHandler(unittest.TestCase): rcode, reply = self.handler._handle_post_request() self.assertEqual(http.client.BAD_REQUEST, rcode) + def _gen_module_spec(self): + spec = { 'commands': [ + { 'command_name' :'command', + 'command_args': [ { + 'item_name' : 'param1', + 'item_type' : 'integer', + 'item_optional' : False, + 'item_default' : 0 + } ], + 'command_description' : 'cmd description' + } + ] + } + + return spec + def test_handle_post_request_2(self): - self.handler.rfile = open("check.tmp", 'w+b') - params = {123:'param data'} + params = {'param1':123} len = self.handler.rfile.write(json.dumps(params).encode()) self.handler.headers['Content-Length'] = len - self.handler.rfile.seek(0, 0) - self.handler.rfile.close() - os.remove('check.tmp') + self.handler.rfile.seek(0, 0) self.handler.path = '/module/command' - self.handler.server.cmdctrl.command_spec = {} - self.handler.server.cmdctrl.command_spec['module'] = [{'command_name':'command'}, {'command_name': ['data1']} ] + self.handler.server.cmdctrl.module_spec = {} + self.handler.server.cmdctrl.module_spec['module'] = self._gen_module_spec() rcode, reply = self.handler._handle_post_request() self.assertEqual(http.client.OK, rcode) + def test_handle_post_request_3(self): + params = {'param1':'abc'} + len = self.handler.rfile.write(json.dumps(params).encode()) + self.handler.headers['Content-Length'] = len + + self.handler.rfile.seek(0, 0) + self.handler.path = '/module/command' + self.handler.server.cmdctrl.module_spec = {} + self.handler.server.cmdctrl.module_spec['module'] = self._gen_module_spec() + rcode, reply = self.handler._handle_post_request() + self.assertEqual(http.client.BAD_REQUEST, rcode) + if __name__== "__main__": unittest.main() diff --git a/src/bin/host/Makefile.am b/src/bin/host/Makefile.am index a1b09f0c35..39e4b8928b 100644 --- a/src/bin/host/Makefile.am +++ b/src/bin/host/Makefile.am @@ -1,6 +1,8 @@ AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib AM_CPPFLAGS += -I$(top_srcdir)/src/lib/dns -I$(top_builddir)/src/lib/dns +AM_CXXFLAGS = $(B10_CXXFLAGS) + CLEANFILES = *.gcno *.gcda bin_PROGRAMS = host diff --git a/src/bin/loadzone/b10-loadzone.py.in b/src/bin/loadzone/b10-loadzone.py.in index ebb8ec1274..98525985d6 100644 --- a/src/bin/loadzone/b10-loadzone.py.in +++ b/src/bin/loadzone/b10-loadzone.py.in @@ -77,4 +77,4 @@ def main(): exit(1) if __name__ == "__main__": - main(datasrc + main() diff --git a/src/bin/msgq/msgq.py.in b/src/bin/msgq/msgq.py.in index cef69f1082..df6995b400 100644 --- a/src/bin/msgq/msgq.py.in +++ b/src/bin/msgq/msgq.py.in @@ -47,11 +47,11 @@ class SubscriptionManager: """Add a subscription.""" target = ( group, instance ) if target in self.subscriptions: - print("Appending to existing target") + print("[b10-msgq] Appending to existing target") if socket not in self.subscriptions[target]: self.subscriptions[target].append(socket) else: - print("Creating new target") + print("[b10-msgq] Creating new target") self.subscriptions[target] = [ socket ] def unsubscribe(self, group, instance, socket): @@ -86,25 +86,33 @@ class SubscriptionManager: class MsgQ: """Message Queue class.""" - def __init__(self, port=0, verbose=False): + # did we find a better way to do this? + SOCKET_FILE = os.path.join("@localstatedir@", + "@PACKAGE_NAME@", + "msgq_socket").replace("${prefix}", + "@prefix@") + + def __init__(self, socket_file=None, verbose=False): """Initialize the MsgQ master. - The port specifies the TCP/IP port that the msgq - process listens on. If verbose is True, then the MsgQ reports + The socket_file specifies the path to the UNIX domain socket + that the msgq process listens on. If it is None, the + environment variable BIND10_MSGQ_SOCKET_FILE is used. If that + is not set, it will default to + @localstatedir@/@PACKAGE_NAME@/msg_socket. + If verbose is True, then the MsgQ reports what it is doing. """ - if port == 0: - if 'ISC_MSGQ_PORT' in os.environ: - port = int(os.environ["ISC_MSGQ_PORT"]) - else: - port = 9912 - - - print(port) + if socket_file is None: + if "BIND10_MSGQ_SOCKET_FILE" in os.environ: + self.socket_file = os.environ["BIND10_MSGQ_SOCKET_FILE"] + else: + self.socket_file = self.SOCKET_FILE + else: + self.socket_file = socket_file self.verbose = verbose - self.c_channel_port = port self.poller = None self.kqueue = None self.runnable = False @@ -131,10 +139,19 @@ class MsgQ: def setup_listener(self): """Set up the listener socket. Internal function.""" - self.listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.listen_socket.bind(("127.0.0.1", self.c_channel_port)) - self.listen_socket.listen(1024) + self.listen_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + if os.path.exists(self.socket_file): + os.remove(self.socket_file) + try: + self.listen_socket.bind(self.socket_file) + self.listen_socket.listen(1024) + except Exception as e: + # remove the file again if something goes wrong + # (note this is a catch-all, but we reraise it) + if os.path.exists(self.socket_file): + os.remove(self.socket_file) + raise e if self.poller: self.poller.register(self.listen_socket, select.POLLIN) @@ -142,20 +159,25 @@ class MsgQ: self.add_kqueue_socket(self.listen_socket) def setup(self): - """Configure listener socket, polling, etc.""" + """Configure listener socket, polling, etc. + Raises a socket.error if the socket_file cannot be + created. + """ self.setup_poller() self.setup_listener() if self.verbose: - sys.stdout.write("Listening\n") + sys.stdout.write("[b10-msgq] Listening\n") self.runnable = True def process_accept(self): """Process an accept on the listening socket.""" newsocket, ipaddr = self.listen_socket.accept() - sys.stderr.write("Connection\n") + # TODO: When we have logging, we might want + # to add a debug message here that a new connection + # was made self.sockets[newsocket.fileno()] = newsocket lname = self.newlname() self.lnames[lname] = newsocket @@ -169,9 +191,9 @@ class MsgQ: """Process a read on a socket.""" sock = self.sockets[fd] if sock == None: - sys.stderr.write("Got read on Strange Socket fd %d\n" % fd) + sys.stderr.write("[b10-msgq] Got read on Strange Socket fd %d\n" % fd) return -# sys.stderr.write("Got read on fd %d\n" %fd) +# sys.stderr.write("[b10-msgq] Got read on fd %d\n" %fd) self.process_packet(fd, sock) def kill_socket(self, fd, sock): @@ -183,7 +205,7 @@ class MsgQ: del self.lnames[lname] sock.close() self.sockets[fd] = None - sys.stderr.write("Closing socket fd %d\n" % fd) + sys.stderr.write("[b10-msgq] Closing socket fd %d\n" % fd) def getbytes(self, fd, sock, length): """Get exactly the requested bytes, or raise an exception if @@ -223,14 +245,14 @@ class MsgQ: routing, data = self.read_packet(fd, sock) except MsgQReceiveError as err: self.kill_socket(fd, sock) - sys.stderr.write("Receive error: %s\n" % err) + sys.stderr.write("[b10-msgq] Receive error: %s\n" % err) return try: routingmsg = isc.cc.message.from_wire(routing) except DecodeError as err: self.kill_socket(fd, sock) - sys.stderr.write("Routing decode error: %s\n" % err) + sys.stderr.write("[b10-msgq] Routing decode error: %s\n" % err) return # sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n") @@ -241,8 +263,9 @@ class MsgQ: def process_command(self, fd, sock, routing, data): """Process a single command. This will split out into one of the other functions.""" - print("[XX] got command: ") - print(routing) + # TODO: A print statement got removed here (one that prints the + # routing envelope). When we have logging with multiple levels, + # we might want to re-add that on a high debug verbosity. cmd = routing["type"] if cmd == 'send': self.process_command_send(sock, routing, data) @@ -253,7 +276,7 @@ class MsgQ: elif cmd == 'getlname': self.process_command_getlname(sock, routing, data) else: - sys.stderr.write("Invalid command: %s\n" % cmd) + sys.stderr.write("[b10-msgq] Invalid command: %s\n" % cmd) def preparemsg(self, env, msg = None): if type(env) == dict: @@ -338,7 +361,7 @@ class MsgQ: if err.args[0] == errno.EINTR: events = [] else: - sys.stderr.write("Error with poll(): %s\n" % err) + sys.stderr.write("[b10-msgq] Error with poll(): %s\n" % err) break for (fd, event) in events: if fd == self.listen_socket.fileno(): @@ -364,8 +387,10 @@ class MsgQ: def shutdown(self): """Stop the MsgQ master.""" if self.verbose: - sys.stdout.write("Stopping the server.\n") + sys.stdout.write("[b10-msgq] Stopping the server.\n") self.listen_socket.close() + if os.path.exists(self.socket_file): + os.remove(self.socket_file) # can signal handling and calling a destructor be done without a # global variable? @@ -389,22 +414,22 @@ if __name__ == "__main__": parser = OptionParser(version=__version__) parser.add_option("-v", "--verbose", dest="verbose", action="store_true", help="display more about what is going on") - parser.add_option("-m", "--msgq-port", dest="msgq_port", type="string", - action="callback", callback=check_port, default="0", - help="port the msgq daemon will use") + parser.add_option("-s", "--socket-file", dest="msgq_socket_file", + type="string", default=None, + help="UNIX domain socket file the msgq daemon will use") (options, args) = parser.parse_args() signal.signal(signal.SIGTERM, signal_handler) # Announce startup. if options.verbose: - sys.stdout.write("MsgQ %s\n" % __version__) + sys.stdout.write("[b10-msgq] MsgQ %s\n" % __version__) - msgq = MsgQ(int(options.msgq_port), options.verbose) + msgq = MsgQ(options.msgq_socket_file, options.verbose) setup_result = msgq.setup() if setup_result: - sys.stderr.write("Error on startup: %s\n" % setup_result) + sys.stderr.write("[b10-msgq] Error on startup: %s\n" % setup_result) sys.exit(1) try: diff --git a/src/bin/msgq/tests/Makefile.am b/src/bin/msgq/tests/Makefile.am index a97a422909..9abf774edd 100644 --- a/src/bin/msgq/tests/Makefile.am +++ b/src/bin/msgq/tests/Makefile.am @@ -7,7 +7,7 @@ PYCOVERAGE = $(PYTHON) check-local: for pytest in $(PYTESTS) ; do \ echo Running test: $$pytest ; \ - env PYTHONPATH=$(abs_top_builddir)/src/bin/msgq:$(abs_top_srcdir)/src/lib/python \ + env PYTHONPATH=$(abs_top_builddir)/src/bin/msgq:$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python \ $(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \ done diff --git a/src/bin/msgq/tests/msgq_test.py b/src/bin/msgq/tests/msgq_test.py index 3e91a608d3..b0d0fefe21 100644 --- a/src/bin/msgq/tests/msgq_test.py +++ b/src/bin/msgq/tests/msgq_test.py @@ -1,6 +1,8 @@ from msgq import SubscriptionManager, MsgQ import unittest +import os +import socket # # Currently only the subscription part is implemented... I'd have to mock @@ -58,5 +60,46 @@ class TestSubscriptionManager(unittest.TestCase): self.sm.subscribe('g1', '*', 's2') self.assertEqual(self.sm.find_sub("g1", "i1"), [ 's1' ]) + def test_open_socket_parameter(self): + self.assertFalse(os.path.exists("./my_socket_file")) + msgq = MsgQ("./my_socket_file"); + msgq.setup() + self.assertTrue(os.path.exists("./my_socket_file")) + msgq.shutdown(); + self.assertFalse(os.path.exists("./my_socket_file")) + + def test_open_socket_environment_variable(self): + self.assertFalse(os.path.exists("my_socket_file")) + os.environ["BIND10_MSGQ_SOCKET_FILE"] = "./my_socket_file" + msgq = MsgQ(); + msgq.setup() + self.assertTrue(os.path.exists("./my_socket_file")) + msgq.shutdown(); + self.assertFalse(os.path.exists("./my_socket_file")) + + def test_open_socket_default(self): + env_var = None + if "BIND10_MSGQ_SOCKET_FILE" in os.environ: + env_var = os.environ["BIND10_MSGQ_SOCKET_FILE"] + del os.environ["BIND10_MSGQ_SOCKET_FILE"] + socket_file = MsgQ.SOCKET_FILE + self.assertFalse(os.path.exists(socket_file)) + msgq = MsgQ(); + try: + msgq.setup() + self.assertTrue(os.path.exists(socket_file)) + msgq.shutdown(); + self.assertFalse(os.path.exists(socket_file)) + except socket.error: + # ok, the install path doesn't exist at all, + # so we can't check any further + pass + if env_var is not None: + os.environ["BIND10_MSGQ_SOCKET_FILE"] = env_var + + def test_open_socket_bad(self): + msgq = MsgQ("/does/not/exist") + self.assertRaises(socket.error, msgq.setup) + if __name__ == '__main__': unittest.main() diff --git a/src/bin/xfrin/TODO b/src/bin/xfrin/TODO index 05cb5a0c31..6395d53ec0 100644 --- a/src/bin/xfrin/TODO +++ b/src/bin/xfrin/TODO @@ -1 +1,65 @@ -1. When xfrin's config data is changed, new config data should be applied.
\ No newline at end of file +1. When xfrin's config data is changed, new config data should be applied. +2. mutex on recorder is not sufficient. race can happen if two xfrin requests + occur at the same time. (but testing it would be very difficult) +3. It wouldn't support IPv6 because of the following line: + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + [FIXED in r1851] +4. Xfrin.retransfer and refresh share most of the code. should be unified. + [FIXED in r1861] +5. class IN is hardcoded. bad. + query_question = question(name(self._zone_name), rr_class.IN(), query_type) + [FIXED in r1889] + Note: we still hardcode it as the fixed default value for + retransfer/refresh commands. + we should fix this so that this is specifiable, so this TODO item is + still open. +6. QID 0 should be allowed: + query_id = random.randint(1, 0xFFFF) + [FIXED in r1880] +7. what if xfrin fails after opening a new DB? looks like garbage + (intermediate) data remains in the DB file, although it's more about + the data source implementation. check it, and fix it if it's the case. +8. Xfrin.command_handler() ignores unknown commands. should return an error. + [FIXED in r1882] +9. XfrinConnection can leak sockets. (same problem as that Jelte mentioned + on xfrout?) + [FIXED in r1908] +10. The following line of _check_soa_serial() is incorrect. + soa_reply = self._get_request_response(int(data_size)) + Unpack the data and convert it in the host by order. + [FIXED in r1866] +11. if do_xfrin fails it should probably return a non "OK" value. + (it's currently ignored anyway, though) + [FIXED in r1887] +12. XfrinConnection should probably define handle_close(). Also, the + following part should be revised because this can also happen when the + master closes the connection. + if self._recv_time_out: + raise XfrinException('receive data from socket time out.') +13. according to the source code xfrin cannot quickly terminate on shutdown + if some of the xfr connections stall. on a related note, the use of + threading.Event() is questionable: since no threads wait() on the event, + it actually just works as a global flag shared by all threads. + this implementation should be refactored so that a shutdown command is + propagate to all threads immediately, whether it's via a builtin mechanism + of the threading module or not (it's probably "not", see below). +14. the current use of asyncore seems to be thread unsafe because it + relies on a global channel map (which is the implicit default). + each thread should probably use its own map: + asyncore.dispatcher.__init__(self, map=sock_map) + # where sock_map is thread specific and is passed to + # XfrinConnection.__init__(). +15. but in the first place, it's not clear why we need asyncore. + since each thread is responsible for a single xfr connection, + socket operations can safely block (with timeouts). this should + be easily implemented using the bear socket module, and the code + would look like more straightforward by avoiding complicated logic + for asynchrony. in fact, that simplicity should be a major + advantage with thread over event-driven (the model asyncore + implements), so this mixture of two models seems awkward to me. +16. having said all that, asyncore may still be necessary to address + item #13: we'd need an explicit communication channel (e.g. a + pipe) between the parent thread and xfr connection thread, through + which a shutdown notification would be sent to the child. With + this approach each thread needs to watch at least two channels, + and then it would need some asynchronous communication mechanism. diff --git a/src/bin/xfrin/tests/xfrin_test.in b/src/bin/xfrin/tests/xfrin_test.in index d8b3323d36..37bfb13548 100644 --- a/src/bin/xfrin/tests/xfrin_test.in +++ b/src/bin/xfrin/tests/xfrin_test.in @@ -19,9 +19,8 @@ PYTHON_EXEC=${PYTHON_EXEC:-@PYTHON@} export PYTHON_EXEC TEST_PATH=@abs_top_srcdir@/src/bin/xfrin/tests -PYTHONPATH=@abs_top_srcdir@/src/bin/xfrin:@abs_top_srcdir@/src/lib/python +PYTHONPATH=@abs_top_srcdir@/src/bin/xfrin:@abs_top_srcdir@/src/lib/python:@abs_top_builddir@/src/lib/dns/.libs export PYTHONPATH cd ${TEST_PATH} exec ${PYTHON_EXEC} -O xfrin_test.py $* - diff --git a/src/bin/xfrin/tests/xfrin_test.py b/src/bin/xfrin/tests/xfrin_test.py index aea0a46a94..0707465f2b 100644 --- a/src/bin/xfrin/tests/xfrin_test.py +++ b/src/bin/xfrin/tests/xfrin_test.py @@ -13,99 +13,533 @@ # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# $Id$ import unittest import socket from xfrin import * -# An axfr response of the simple zone "example.com(without soa record at the end)." -axfr_response1 = b'\x84\x00\x00\x01\x00\x06\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01\xc0\x0c\x00\x06\x00\x01\x00\x00\x0e\x10\x00$\x05dns01\xc0\x0c\x05admin\xc0\x0c\x00\x00\x04\xd2\x00\x00\x0e\x10\x00\x00\x07\x08\x00$\xea\x00\x00\x00\x1c \xc0\x0c\x00\x02\x00\x01\x00\x00\x0e\x10\x00\x02\xc0)\xc0)\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\xa8\x02\x02\x04sql1\xc0\x0c\x00\x02\x00\x01\x00\x00\x0e\x10\x00\x02\xc0)\x04sql2\xc0\x0c\x00\x02\x00\x01\x00\x00\x0e\x10\x00\x02\xc0)\x03ns1\x07subzone\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\xa8\x03\x01' +# +# Commonly used (mostly constant) test parameters +# +TEST_ZONE_NAME = "example.com" +TEST_RRCLASS = RRClass.IN() +TEST_DB_FILE = 'db_file' +TEST_MASTER_IPV4_ADDRESS = '127.0.0.1' +TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM, + socket.IPPROTO_TCP, '', + (TEST_MASTER_IPV4_ADDRESS, 53)) +TEST_MASTER_IPV6_ADDRESS = '::1' +TEST_MASTER_IPV6_ADDRINFO = (socket.AF_INET6, socket.SOCK_STREAM, + socket.IPPROTO_TCP, '', + (TEST_MASTER_IPV6_ADDRESS, 53)) +# XXX: This should be a non priviledge port that is unlikely to be used. +# If some other process uses this port test will fail. +TEST_MASTER_PORT = '53535' + +soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS, + 'master.example.com. admin.example.com ' + + '1234 3600 1800 2419200 7200') +soa_rrset = RRset(Name(TEST_ZONE_NAME), TEST_RRCLASS, RRType.SOA(), + RRTTL(3600)) +soa_rrset.add_rdata(soa_rdata) +example_axfr_question = Question(Name(TEST_ZONE_NAME), TEST_RRCLASS, + RRType.AXFR()) +example_soa_question = Question(Name(TEST_ZONE_NAME), TEST_RRCLASS, + RRType.SOA()) +default_questions = [example_axfr_question] +default_answers = [soa_rrset] -# The second axfr response with only the end soa record. -axfr_response2 = b'\x84\x00\x00\x00\x00\x01\x00\x00\x00\x00\x07example\x03com\x00\x00\x06\x00\x01\x00\x00\x0e\x10\x00$\x05dns01\xc0\x0c\x05admin\xc0\x0c\x00\x00\x04\xd2\x00\x00\x0e\x10\x00\x00\x07\x08\x00$\xea\x00\x00\x00\x1c ' +class XfrinTestException(Exception): + pass -DB_FILE = 'db_file' -# Rewrite the class for unittest. -class MyXfrin(Xfrin): - def __init__(self): +class MockXfrin(Xfrin): + # This is a class attribute of a callable object that specifies a non + # default behavior triggered in _cc_check_command(). Specific test methods + # are expected to explicitly set this attribute before creating a + # MockXfrin object (when it needs a non default behavior). + # See the TestMain class. + check_command_hook = None + + def _cc_setup(self): pass + + def _cc_check_command(self): + self._shutdown_event.set() + if MockXfrin.check_command_hook: + MockXfrin.check_command_hook() -class MyXfrinConnection(XfrinConnection): - query_data = b'' - eply_data = b'' +class MockXfrinConnection(XfrinConnection): + def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event, + master_addr): + super().__init__(sock_map, zone_name, rrclass, db_file, shutdown_event, + master_addr) + self.query_data = b'' + self.reply_data = b'' + self.force_time_out = False + self.force_close = False + self.qlen = None + self.qid = None + self.response_generator = None - def _handle_xfrin_response(self): - for rr in super()._handle_xfrin_response(): - pass + def _asyncore_loop(self): + if self.force_close: + self.handle_close() + elif not self.force_time_out: + self.handle_read() + + def connect_to_master(self): + return True - def _get_request_response(self, size): - ret = self.reply_data[:size] + def recv(self, size): + data = self.reply_data[:size] self.reply_data = self.reply_data[size:] - if (len(ret) < size): - raise XfrinException('cannot get reply data') - return ret + if len(data) < size: + raise XfrinTestException('cannot get reply data') + return data def send(self, data): + if self.qlen != None and len(self.query_data) >= self.qlen: + # This is a new query. reset the internal state. + self.qlen = None + self.qid = None + self.query_data = b'' self.query_data += data + + # when the outgoing data is sufficiently large to contain the length + # and the QID fields (4 octets or more), extract these fields. + # The length will be reset the internal query data to support multiple + # queries in a single test. + # The QID will be used to construct a matching response. + if len(self.query_data) >= 4 and self.qid == None: + self.qlen = socket.htons(struct.unpack('H', + self.query_data[0:2])[0]) + self.qid = socket.htons(struct.unpack('H', self.query_data[2:4])[0]) + # if the response generator method is specified, invoke it now. + if self.response_generator != None: + self.response_generator() return len(data) - def create_response_data(self, data): - reply_data = self.query_data[2:4] + data - size = socket.htons(len(reply_data)) - reply_data = struct.pack('H', size) + reply_data - return reply_data + def create_response_data(self, response = True, bad_qid = False, + rcode = Rcode.NOERROR(), + questions = default_questions, + answers = default_answers): + resp = Message(Message.RENDER) + qid = self.qid + if bad_qid: + qid += 1 + resp.set_qid(qid) + resp.set_opcode(Opcode.QUERY()) + resp.set_rcode(rcode) + if response: + resp.set_header_flag(MessageFlag.QR()) + [resp.add_question(q) for q in questions] + [resp.add_rrset(Section.ANSWER(), a) for a in answers] + renderer = MessageRenderer() + resp.to_wire(renderer) + reply_data = struct.pack('H', socket.htons(renderer.get_length())) + reply_data += renderer.get_data() + + return reply_data class TestXfrinConnection(unittest.TestCase): def setUp(self): - self.conn = MyXfrinConnection('example.com.', DB_FILE, threading.Event(), '1.1.1.1') + if os.path.exists(TEST_DB_FILE): + os.remove(TEST_DB_FILE) + self.sock_map = {} + self.conn = MockXfrinConnection(self.sock_map, 'example.com.', + TEST_RRCLASS, TEST_DB_FILE, + threading.Event(), + TEST_MASTER_IPV4_ADDRINFO) + self.axfr_after_soa = False + self.soa_response_params = { + 'questions': [example_soa_question], + 'bad_qid': False, + 'response': True, + 'rcode': Rcode.NOERROR(), + 'axfr_after_soa': self._create_normal_response_data + } + + def tearDown(self): + self.conn.close() + if os.path.exists(TEST_DB_FILE): + os.remove(TEST_DB_FILE) + + def test_close(self): + # we shouldn't be using the global asyncore map. + self.assertEqual(len(asyncore.socket_map), 0) + # there should be exactly one entry in our local map + self.assertEqual(len(self.sock_map), 1) + # once closing the dispatch the map should become empty + self.conn.close() + self.assertEqual(len(self.sock_map), 0) + + def test_init_ip6(self): + # This test simply creates a new XfrinConnection object with an + # IPv6 address, tries to bind it to an IPv6 wildcard address/port + # to confirm an AF_INET6 socket has been created. A naive application + # tends to assume it's IPv4 only and hardcode AF_INET. This test + # uncovers such a bug. + c = MockXfrinConnection({}, 'example.com.', TEST_RRCLASS, TEST_DB_FILE, + threading.Event(), + TEST_MASTER_IPV6_ADDRINFO) + c.bind(('::', 0)) + c.close() + + def test_init_chclass(self): + c = XfrinConnection({}, 'example.com.', RRClass.CH(), TEST_DB_FILE, + threading.Event(), TEST_MASTER_IPV4_ADDRINFO) + axfrmsg = c._create_query(RRType.AXFR()) + self.assertEqual(axfrmsg.get_question()[0].get_class(), + RRClass.CH()) + c.close() def test_response_with_invalid_msg(self): - self.conn.data_exchange = b'aaaxxxx' - self.assertRaises(Exception, self.conn._handle_xfrin_response) + self.conn.reply_data = b'aaaxxxx' + self.assertRaises(XfrinTestException, self._handle_xfrin_response) def test_response_without_end_soa(self): self.conn._send_query(RRType.AXFR()) - self.conn.reply_data = self.conn.create_response_data(axfr_response1) - self.assertRaises(XfrinException, self.conn._handle_xfrin_response) + self.conn.reply_data = self.conn.create_response_data() + self.assertRaises(XfrinTestException, self._handle_xfrin_response) + + def test_response_bad_qid(self): + self.conn._send_query(RRType.AXFR()) + self.conn.reply_data = self.conn.create_response_data(bad_qid = True) + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_non_response(self): + self.conn._send_query(RRType.AXFR()) + self.conn.reply_data = self.conn.create_response_data(response = False) + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_error_code(self): + self.conn._send_query(RRType.AXFR()) + self.conn.reply_data = self.conn.create_response_data( + rcode=Rcode.SERVFAIL()) + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_multi_question(self): + self.conn._send_query(RRType.AXFR()) + self.conn.reply_data = self.conn.create_response_data( + questions=[example_axfr_question, example_axfr_question]) + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_empty_answer(self): + self.conn._send_query(RRType.AXFR()) + self.conn.reply_data = self.conn.create_response_data(answers=[]) + # Should an empty answer trigger an exception? Even though it's very + # unusual it's not necessarily invalid. Need to revisit. + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_non_response(self): + self.conn._send_query(RRType.AXFR()) + self.conn.reply_data = self.conn.create_response_data(response = False) + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_soacheck(self): + # we need to defer the creation until we know the QID, which is + # determined in _check_soa_serial(), so we use response_generator. + self.conn.response_generator = self._create_soa_response_data + self.assertEqual(self.conn._check_soa_serial(), XFRIN_OK) + + def test_soacheck_with_bad_response(self): + self.conn.response_generator = self._create_broken_response_data + self.assertRaises(MessageTooShort, self.conn._check_soa_serial) + + def test_soacheck_badqid(self): + self.soa_response_params['bad_qid'] = True + self.conn.response_generator = self._create_soa_response_data + self.assertRaises(XfrinException, self.conn._check_soa_serial) + + def test_soacheck_non_response(self): + self.soa_response_params['response'] = False + self.conn.response_generator = self._create_soa_response_data + self.assertRaises(XfrinException, self.conn._check_soa_serial) + + def test_soacheck_error_code(self): + self.soa_response_params['rcode'] = Rcode.SERVFAIL() + self.conn.response_generator = self._create_soa_response_data + self.assertRaises(XfrinException, self.conn._check_soa_serial) + + def test_response_shutdown(self): + self.conn.response_generator = self._create_normal_response_data + self.conn._shutdown_event.set() + self.conn._send_query(RRType.AXFR()) + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_timeout(self): + self.conn.response_generator = self._create_normal_response_data + self.conn.force_time_out = True + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_remote_close(self): + self.conn.response_generator = self._create_normal_response_data + self.conn.force_close = True + self.assertRaises(XfrinException, self._handle_xfrin_response) + + def test_response_bad_message(self): + self.conn.response_generator = self._create_broken_response_data + self.conn._send_query(RRType.AXFR()) + self.assertRaises(Exception, self._handle_xfrin_response) def test_response(self): + # normal case. + self.conn.response_generator = self._create_normal_response_data self.conn._send_query(RRType.AXFR()) - self.conn.reply_data = self.conn.create_response_data(axfr_response1) - self.conn.reply_data += self.conn.create_response_data(axfr_response2) - self.conn._handle_xfrin_response() + # two SOAs, and only these have been transfered. the 2nd SOA is just + # a marker, so only 1 RR has been provided in the iteration. + self.assertEqual(self._handle_xfrin_response(), 1) + + def test_do_xfrin(self): + self.conn.response_generator = self._create_normal_response_data + self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK) + + def test_do_xfrin_empty_response(self): + # skipping the creation of response data, so the transfer will fail. + self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL) + + def test_do_xfrin_bad_response(self): + self.conn.response_generator = self._create_broken_response_data + self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL) + + def test_do_xfrin_dberror(self): + # DB file is under a non existent directory, so its creation will fail, + # which will make the transfer fail. + self.conn._db_file = "not_existent/" + TEST_DB_FILE + self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL) + + def test_do_soacheck_and_xfrin(self): + self.conn.response_generator = self._create_soa_response_data + self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK) + + def test_do_soacheck_broken_response(self): + self.conn.response_generator = self._create_broken_response_data + # XXX: TODO: this test failed here, should xfr not raise an + # exception but simply drop and return FAIL? + #self.assertEqual(self.conn.do_xfrin(True), XFRIN_FAIL) + self.assertRaises(MessageTooShort, self.conn.do_xfrin, True) + + def test_do_soacheck_badqid(self): + # the QID mismatch would internally trigger a XfrinException exception, + # and covers part of the code that other tests can't. + self.soa_response_params['bad_qid'] = True + self.conn.response_generator = self._create_soa_response_data + self.assertEqual(self.conn.do_xfrin(True), XFRIN_FAIL) + + def _handle_xfrin_response(self): + # This helper methods iterates over all RRs (excluding the ending SOA) + # transferred, and simply returns the number of RRs. The return value + # may be used an assertion value for test cases. + rrs = 0 + for rr in self.conn._handle_xfrin_response(): + rrs += 1 + return rrs + + def _create_normal_response_data(self): + # This helper method creates a simple sequence of DNS messages that + # forms a valid XFR transaction. It consists of two messages, each + # containing just a single SOA RR. + self.conn.reply_data = self.conn.create_response_data() + self.conn.reply_data += self.conn.create_response_data() + + def _create_soa_response_data(self): + # This helper method creates a DNS message that is supposed to be + # used a valid response to SOA queries prior to XFR. + # If axfr_after_soa is True, it resets the response_generator so that + # a valid XFR messages will follow. + self.conn.reply_data = self.conn.create_response_data( + bad_qid=self.soa_response_params['bad_qid'], + response=self.soa_response_params['response'], + rcode=self.soa_response_params['rcode'], + questions=self.soa_response_params['questions']) + if self.soa_response_params['axfr_after_soa'] != None: + self.conn.response_generator = self.soa_response_params['axfr_after_soa'] + + def _create_broken_response_data(self): + # This helper method creates a bogus "DNS message" that only contains + # 4 octets of data. The DNS message parser will raise an exception. + bogus_data = b'xxxx' + self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data))) + self.conn.reply_data += bogus_data + +class TestXfrinRecorder(unittest.TestCase): + def setUp(self): + self.recorder = XfrinRecorder() + + def test_increment(self): + self.assertEqual(self.recorder.count(), 0) + self.recorder.increment(TEST_ZONE_NAME) + self.assertEqual(self.recorder.count(), 1) + # duplicate "increment" should probably be rejected. but it's not + # checked at this moment + self.recorder.increment(TEST_ZONE_NAME) + self.assertEqual(self.recorder.count(), 2) + + def test_decrement(self): + self.assertEqual(self.recorder.count(), 0) + self.recorder.increment(TEST_ZONE_NAME) + self.assertEqual(self.recorder.count(), 1) + self.recorder.decrement(TEST_ZONE_NAME) + self.assertEqual(self.recorder.count(), 0) + + def test_decrement_from_empty(self): + self.assertEqual(self.recorder.count(), 0) + self.recorder.decrement(TEST_ZONE_NAME) + self.assertEqual(self.recorder.count(), 0) + + def test_inprogress(self): + self.assertEqual(self.recorder.count(), 0) + self.recorder.increment(TEST_ZONE_NAME) + self.assertEqual(self.recorder.xfrin_in_progress(TEST_ZONE_NAME), True) + self.recorder.decrement(TEST_ZONE_NAME) + self.assertEqual(self.recorder.xfrin_in_progress(TEST_ZONE_NAME), False) class TestXfrin(unittest.TestCase): + def setUp(self): + self.xfr = MockXfrin() + self.args = {} + self.args['zone_name'] = TEST_ZONE_NAME + self.args['port'] = TEST_MASTER_PORT + self.args['master'] = TEST_MASTER_IPV4_ADDRESS + self.args['db_file'] = TEST_DB_FILE + + def tearDown(self): + self.xfr.shutdown() + + def _do_parse(self): + return self.xfr._parse_cmd_params(self.args) + def test_parse_cmd_params(self): - xfr = MyXfrin() - args = {} - args['zone_name'] = 'sd.cn.' - args['port'] = '12345' - args['master'] = '218.241.108.122' - args['db_file'] = '/home/tt' - - name, master, port, db_file = xfr._parse_cmd_params(args) - self.assertEqual(port, 12345) - self.assertEqual(name, 'sd.cn.') - self.assertEqual(master, '218.241.108.122') - self.assertEqual(db_file, '/home/tt') - - def test_parse_cmd_params_1(self): - xfr = MyXfrin() - args = {} - args['port'] = '12345' - args['master'] = '218.241.108.122' - args['db_file'] = '/home/tt' - - self.assertRaises(XfrinException, xfr._parse_cmd_params, args) - self.assertRaises(XfrinException, xfr._parse_cmd_params, {'zone_name':'ds.cn.', 'master':'3.3.3'}) - self.assertRaises(XfrinException, xfr._parse_cmd_params, {'zone_name':'ds.cn.'}) - self.assertRaises(XfrinException, xfr._parse_cmd_params, {'master':'ds.cn.'}) + name, master_addrinfo, db_file = self._do_parse() + self.assertEqual(master_addrinfo[4][1], int(TEST_MASTER_PORT)) + self.assertEqual(name, TEST_ZONE_NAME) + self.assertEqual(master_addrinfo[4][0], TEST_MASTER_IPV4_ADDRESS) + self.assertEqual(db_file, TEST_DB_FILE) + + def test_parse_cmd_params_default_port(self): + del self.args['port'] + master_addrinfo = self._do_parse()[1] + self.assertEqual(master_addrinfo[4][1], 53) + + def test_parse_cmd_params_ip6master(self): + self.args['master'] = TEST_MASTER_IPV6_ADDRESS + master_addrinfo = self._do_parse()[1] + self.assertEqual(master_addrinfo[4][0], TEST_MASTER_IPV6_ADDRESS) + + def test_parse_cmd_params_nozone(self): + # zone name is mandatory. + del self.args['zone_name'] + self.assertRaises(XfrinException, self._do_parse) + + def test_parse_cmd_params_nomaster(self): + # master address is mandatory. + del self.args['master'] + self.assertRaises(XfrinException, self._do_parse) + + def test_parse_cmd_params_bad_ip4(self): + self.args['master'] = '3.3.3.3.3' + self.assertRaises(XfrinException, self._do_parse) + + def test_parse_cmd_params_bad_ip6(self): + self.args['master'] = '1::1::1' + self.assertRaises(XfrinException, self._do_parse) + + def test_parse_cmd_params_bad_port(self): + self.args['port'] = '-1' + self.assertRaises(XfrinException, self._do_parse) + + self.args['port'] = '65536' + self.assertRaises(XfrinException, self._do_parse) + + self.args['port'] = 'http' + self.assertRaises(XfrinException, self._do_parse) + + def test_command_handler_shutdown(self): + self.assertEqual(self.xfr.command_handler("shutdown", + None)['result'][0], 0) + # shutdown command doesn't expect an argument, but accepts it if any. + self.assertEqual(self.xfr.command_handler("shutdown", + "unused")['result'][0], 0) + + def test_command_handler_retransfer(self): + self.assertEqual(self.xfr.command_handler("retransfer", + self.args)['result'][0], 0) + + def test_command_handler_retransfer_badcommand(self): + self.args['master'] = 'invalid' + self.assertEqual(self.xfr.command_handler("retransfer", + self.args)['result'][0], 1) + + def test_command_handler_retransfer_quota(self): + for i in range(self.xfr._max_transfers_in - 1): + self.xfr.recorder.increment(str(i) + TEST_ZONE_NAME) + # there can be one more outstanding transfer. + self.assertEqual(self.xfr.command_handler("retransfer", + self.args)['result'][0], 0) + # make sure the # xfrs would excceed the quota + self.xfr.recorder.increment(str(self.xfr._max_transfers_in) + TEST_ZONE_NAME) + # this one should fail + self.assertEqual(self.xfr.command_handler("retransfer", + self.args)['result'][0], 1) + + def test_command_handler_retransfer_inprogress(self): + self.xfr.recorder.increment(TEST_ZONE_NAME) + self.assertEqual(self.xfr.command_handler("retransfer", + self.args)['result'][0], 1) + + def test_command_handler_retransfer_nomodule(self): + dns_module = sys.modules['libdns_python'] # this must exist + del sys.modules['libdns_python'] + self.assertEqual(self.xfr.command_handler("retransfer", + self.args)['result'][0], 1) + # sys.modules is global, so we must recover it + sys.modules['libdns_python'] = dns_module + + def test_command_handler_refresh(self): + # at this level, refresh is no different than retransfer. + # just confirm the successful case with a different family of address. + self.args['master'] = TEST_MASTER_IPV6_ADDRESS + self.assertEqual(self.xfr.command_handler("refresh", + self.args)['result'][0], 0) + + def test_command_handler_unknown(self): + self.assertEqual(self.xfr.command_handler("xxx", None)['result'][0], 1) + +def raise_interrupt(): + raise KeyboardInterrupt() + +def raise_ccerror(): + raise isc.cc.session.SessionError('test error') + +def raise_excpetion(): + raise Exception('test exception') + +class TestMain(unittest.TestCase): + def setUp(self): + MockXfrin.check_command_hook = None + + def tearDown(self): + MockXfrin.check_command_hook = None + + def test_startup(self): + main(MockXfrin, False) + + def test_startup_interrupt(self): + MockXfrin.check_command_hook = raise_interrupt + main(MockXfrin, False) + + def test_startup_ccerror(self): + MockXfrin.check_command_hook = raise_ccerror + main(MockXfrin, False) + + def test_startup_generalerror(self): + MockXfrin.check_command_hook = raise_excpetion + main(MockXfrin, False) if __name__== "__main__": try: unittest.main() - os.remove(DB_FILE) except KeyboardInterrupt as e: print(e) - diff --git a/src/bin/xfrin/xfrin.py.in b/src/bin/xfrin/xfrin.py.in index de4be7953f..4e2fccc0b6 100644 --- a/src/bin/xfrin/xfrin.py.in +++ b/src/bin/xfrin/xfrin.py.in @@ -35,11 +35,11 @@ except ImportError as e: # must keep running, so we warn about it and move forward. sys.stderr.write('[b10-xfrin] failed to import DNS module: %s\n' % str(e)) -# If B10_FROM_SOURCE is set in the environment, we use data files +# If B10_FROM_BUILD is set in the environment, we use data files # from a directory relative to that, otherwise we use the ones # installed on the system -if "B10_FROM_SOURCE" in os.environ: - SPECFILE_PATH = os.environ["B10_FROM_SOURCE"] + "/src/bin/xfrin" +if "B10_FROM_BUILD" in os.environ: + SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrin" else: PREFIX = "@prefix@" DATAROOTDIR = "@datarootdir@" @@ -50,11 +50,12 @@ SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec" __version__ = 'BIND10' # define xfrin rcode XFRIN_OK = 0 +XFRIN_FAIL = 1 + +DEFAULT_MASTER_PORT = '53' def log_error(msg): - sys.stderr.write("[b10-xfrin] ") - sys.stderr.write(str(msg)) - sys.stderr.write('\n') + sys.stderr.write("[b10-xfrin] %s\n" % str(msg)) class XfrinException(Exception): pass @@ -62,50 +63,49 @@ class XfrinException(Exception): class XfrinConnection(asyncore.dispatcher): '''Do xfrin in this class. ''' - def __init__(self, - zone_name, db_file, shutdown_event, master_addr, - port = 53, verbose = False, idle_timeout = 60): + def __init__(self, + sock_map, zone_name, rrclass, db_file, shutdown_event, + master_addrinfo, verbose = False, idle_timeout = 60): ''' idle_timeout: max idle time for read data from socket. db_file: specify the data source file. check_soa: when it's true, check soa first before sending xfr query ''' - asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + asyncore.dispatcher.__init__(self, map=sock_map) + self.create_socket(master_addrinfo[0], master_addrinfo[1]) self._zone_name = zone_name + self._sock_map = sock_map + self._rrclass = rrclass self._db_file = db_file self._soa_rr_count = 0 self._idle_timeout = idle_timeout self.setblocking(1) self._shutdown_event = shutdown_event self._verbose = verbose - self._master_addr = master_addr - self._port = port + self._master_address = master_addrinfo[4] def connect_to_master(self): '''Connect to master in TCP.''' try: - self.connect((self._master_addr, self._port)) + self.connect(self._master_address) return True except socket.error as e: - self.log_msg('Failed to connect:(%s:%d), %s' % (self._master_addr, self._port, str(e))) + self.log_msg('Failed to connect:(%s), %s' % (self._master_address, + str(e))) return False def _create_query(self, query_type): '''Create dns query message. ''' msg = Message(Message.RENDER) - query_id = random.randint(1, 0xFFFF) + query_id = random.randint(0, 0xFFFF) self._query_id = query_id msg.set_qid(query_id) msg.set_opcode(Opcode.QUERY()) msg.set_rcode(Rcode.NOERROR()) - query_question = Question(Name(self._zone_name), RRClass("IN"), query_type) - try: - msg.add_question(query_question) - except Exception as err: - raise err + query_question = Question(Name(self._zone_name), self._rrclass, query_type) + msg.add_question(query_question) return msg def _send_data(self, data): @@ -125,6 +125,13 @@ class XfrinConnection(asyncore.dispatcher): self._send_data(header_len) self._send_data(render.get_data()) + + def _asyncore_loop(self): + ''' + This method is a trivial wrapper for asyncore.loop(). It's extracted from + _get_request_response so that we can test the rest of the code without + involving actual communication with a remote server.''' + asyncore.loop(self._idle_timeout, map=self._sock_map, count=1) def _get_request_response(self, size): recv_size = 0 @@ -132,7 +139,7 @@ class XfrinConnection(asyncore.dispatcher): while recv_size < size: self._recv_time_out = True self._need_recv_size = size - recv_size - asyncore.loop(self._idle_timeout, count = 1) + self._asyncore_loop() if self._recv_time_out: raise XfrinException('receive data from socket time out.') @@ -149,10 +156,19 @@ class XfrinConnection(asyncore.dispatcher): ''' self._send_query(RRType("SOA")) - data_size = self._get_request_response(2) - soa_reply = self._get_request_response(int(data_size)) - #TODO, need select soa record from data source then compare the two - #serial, current just return OK, since this function hasn't been used now + data_len = self._get_request_response(2) + msg_len = socket.htons(struct.unpack('H', data_len)[0]) + soa_response = self._get_request_response(msg_len) + msg = Message(Message.PARSE) + msg.from_wire(soa_response) + + # perform some minimal level validation. It's an open issue how + # strict we should be (see the comment in _check_response_header()) + self._check_response_header(msg) + + # TODO, need select soa record from data source then compare the two + # serial, current just return OK, since this function hasn't been used + # now. return XFRIN_OK def do_xfrin(self, check_soa, ixfr_first = False): @@ -161,6 +177,7 @@ class XfrinConnection(asyncore.dispatcher): try: ret = XFRIN_OK if check_soa: + logstr = 'SOA check for \'%s\' ' % self._zone_name ret = self._check_soa_serial() logstr = 'transfer of \'%s\': AXFR ' % self._zone_name @@ -172,23 +189,41 @@ class XfrinConnection(asyncore.dispatcher): self._handle_xfrin_response) self.log_msg(logstr + 'succeeded') + ret = XFRIN_OK except XfrinException as e: self.log_msg(e) self.log_msg(logstr + 'failed') + ret = XFRIN_FAIL #TODO, recover data source. except isc.datasrc.sqlite3_ds.Sqlite3DSError as e: self.log_msg(e) self.log_msg(logstr + 'failed') + ret = XFRIN_FAIL + except UserWarning as e: + # XXX: this is an exception from our C++ library via the + # Boost.Python binding. It would be better to have more more + # specific exceptions, but at this moment this is the finest + # granularity. + self.log_msg(e) + self.log_msg(logstr + 'failed') + ret = XFRIN_FAIL finally: self.close() return ret - - def _check_response_status(self, msg): - '''Check validation of xfr response. ''' - #TODO, check more? + def _check_response_header(self, msg): + '''Perform minimal validation on responses''' + + # It's not clear how strict we should be about response validation. + # BIND 9 ignores some cases where it would normally be considered a + # bogus response. For example, it accepts a response even if its + # opcode doesn't match that of the corresponding request. + # According to an original developer of BIND 9 some of the missing + # checks are deliberate to be kind to old implementations that would + # cause interoperability trouble with stricter checks. + msg_rcode = msg.get_rcode() if msg_rcode != Rcode.NOERROR(): raise XfrinException('error response: %s' % msg_rcode.to_text()) @@ -199,6 +234,11 @@ class XfrinConnection(asyncore.dispatcher): if msg.get_qid() != self._query_id: raise XfrinException('bad query id') + def _check_response_status(self, msg): + '''Check validation of xfr response. ''' + + self._check_response_header(msg) + if msg.get_rr_count(Section.ANSWER()) == 0: raise XfrinException('answer section is empty') @@ -270,24 +310,22 @@ class XfrinConnection(asyncore.dispatcher): def log_msg(self, msg): if self._verbose: - sys.stdout.write('[b10-xfrin] ') - sys.stdout.write(str(msg)) - sys.stdout.write('\n') + sys.stdout.write('[b10-xfrin] %s\n' % str(msg)) -def process_xfrin(xfrin_recorder, zone_name, db_file, - shutdown_event, master_addr, port, check_soa, verbose): - port = int(port) +def process_xfrin(xfrin_recorder, zone_name, rrclass, db_file, + shutdown_event, master_addrinfo, check_soa, verbose): xfrin_recorder.increment(zone_name) - conn = XfrinConnection(zone_name, db_file, shutdown_event, - master_addr, port, verbose) + sock_map = {} + conn = XfrinConnection(sock_map, zone_name, rrclass, db_file, + shutdown_event, master_addrinfo, verbose) if conn.connect_to_master(): conn.do_xfrin(check_soa) xfrin_recorder.decrement(zone_name) -class XfrinRecorder(): +class XfrinRecorder: def __init__(self): self._lock = threading.Lock() self._zones = [] @@ -315,15 +353,31 @@ class XfrinRecorder(): self._lock.release() return ret -class Xfrin(): +class Xfrin: def __init__(self, verbose = False): - self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler) - self._cc.start() + self._cc_setup() self._max_transfers_in = 10 self.recorder = XfrinRecorder() self._shutdown_event = threading.Event() self._verbose = verbose + def _cc_setup(self): + ''' +This method is used only as part of initialization, but is implemented +separately for convenience of unit tests; by letting the test code override +this method we can test most of this class without requiring a command channel. +''' + self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, + self.config_handler, + self.command_handler) + self._cc.start() + + def _cc_check_command(self): + ''' +This is a straightforward wrapper for cc.check_command, but provided as +a separate method for the convenience of unit tests. +''' + self._cc.check_command() def config_handler(self, new_config): # TODO, process new config data @@ -343,20 +397,20 @@ class Xfrin(): def command_handler(self, command, args): answer = create_answer(0) - cmd = command try: - if cmd == 'shutdown': + if command == 'shutdown': self._shutdown_event.set() - - elif cmd == 'retransfer': - zone_name, master, port, db_file = self._parse_cmd_params(args) - ret = self.xfrin_start(zone_name, db_file, master, port, False) - answer = create_answer(ret[0], ret[1]) - - elif cmd == 'refresh': - zone_name, master, port, db_file = self._parse_cmd_params(args) - ret = self.xfrin_start(zone_name, db_file, master, port) + elif command == 'retransfer' or command == 'refresh': + # The default RR class is IN. We should fix this so that + # the class is passed in the command arg (where we specify + # the default) + rrclass = RRClass.IN() + zone_name, master_addr, db_file = self._parse_cmd_params(args) + ret = self.xfrin_start(zone_name, rrclass, db_file, master_addr, + False if command == 'retransfer' else True) answer = create_answer(ret[0], ret[1]) + else: + answer = create_answer(1, 'unknown command: ' + command) except XfrinException as err: answer = create_answer(1, str(err)) @@ -372,28 +426,23 @@ class Xfrin(): if not master: raise XfrinException('master address should be provided') - check_addr(master) - port = 53 port_str = args.get('port') - if port_str: - port = int(port_str) - check_port(port) + if not port_str: + port_str = DEFAULT_MASTER_PORT + master_addrinfo = check_addr_port(master, port_str) db_file = args.get('db_file') if not db_file: #TODO, the db file path should be got in auth server's configuration db_file = '@@LOCALSTATEDIR@@/@PACKAGE@/zone.sqlite3' - return (zone_name, master, port, db_file) - + return (zone_name, master_addrinfo, db_file) def startup(self): while not self._shutdown_event.is_set(): - self._cc.check_command() - + self._cc_check_command() - def xfrin_start(self, zone_name, db_file, master_addr, - port = 53, + def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo, check_soa = True): if "libdns_python" not in sys.modules: return (1, "xfrin failed, can't load dns message python library: 'libdns_python'") @@ -407,12 +456,12 @@ class Xfrin(): xfrin_thread = threading.Thread(target = process_xfrin, args = (self.recorder, - zone_name, + zone_name, rrclass, db_file, self._shutdown_event, - master_addr, - port, check_soa, self._verbose)) - + master_addrinfo, check_soa, + self._verbose)) + xfrin_thread.start() return (0, 'zone xfrin is started') @@ -428,34 +477,53 @@ def set_signal_handler(): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) -def check_port(value): - if (value < 0) or (value > 65535): - raise XfrinException('requires a port number (0-65535)') - -def check_addr(ipstr): - ip_family = socket.AF_INET - if (ipstr.find(':') != -1): - ip_family = socket.AF_INET6 - +def check_addr_port(addrstr, portstr): + # XXX: Linux (glibc)'s getaddrinfo incorrectly accepts numeric port + # string larger than 65535. So we need to explicit validate it separately. try: - socket.inet_pton(ip_family, ipstr) - except: - raise XfrinException("%s invalid ip address" % ipstr) + portnum = int(portstr) + if portnum < 0 or portnum > 65535: + raise ValueError("invalid port number (out of range): " + portstr) + except ValueError as err: + raise XfrinException("failed to resolve master address/port=%s/%s: %s" % + (addrstr, portstr, str(err))) + try: + addrinfo = socket.getaddrinfo(addrstr, portstr, socket.AF_UNSPEC, + socket.SOCK_STREAM, socket.IPPROTO_TCP, + socket.AI_NUMERICHOST| + socket.AI_NUMERICSERV) + except socket.gaierror as err: + raise XfrinException("failed to resolve master address/port=%s/%s: %s" % + (addrstr, portstr, str(err))) + if len(addrinfo) != 1: + # with the parameters above the result must be uniquely determined. + errmsg = "unexpected result for address/port resolution for %s:%s" + raise XfrinException(errmsg % (addrstr, portstr)) + return addrinfo[0] def set_cmd_options(parser): parser.add_option("-v", "--verbose", dest="verbose", action="store_true", help="display more about what is going on") - -if __name__ == '__main__': +def main(xfrin_class, use_signal = True): + """The main loop of the Xfrin daemon. + + @param xfrin_class: A class of the Xfrin object. This is normally Xfrin, + but can be a subclass of it for customization. + @param use_signal: True if this process should catch signals. This is + normally True, but may be disabled when this function is called in a + testing context.""" + global xfrind + try: parser = OptionParser(version = __version__) set_cmd_options(parser) (options, args) = parser.parse_args() - set_signal_handler() - xfrind = Xfrin(verbose = options.verbose) + if use_signal: + set_signal_handler() + xfrind = xfrin_class(verbose = options.verbose) xfrind.startup() except KeyboardInterrupt: log_error("exit b10-xfrin") @@ -467,3 +535,6 @@ if __name__ == '__main__': if xfrind: xfrind.shutdown() + +if __name__ == '__main__': + main(Xfrin) diff --git a/src/bin/xfrin/xfrin.spec.pre.in b/src/bin/xfrin/xfrin.spec.pre.in index 640ed29c63..97099ed823 100644 --- a/src/bin/xfrin/xfrin.spec.pre.in +++ b/src/bin/xfrin/xfrin.spec.pre.in @@ -1,6 +1,7 @@ { "module_spec": { "module_name": "Xfrin", + "module_description": "XFR in daemon", "config_data": [ { "item_name": "transfers_in", @@ -47,5 +48,3 @@ ] } } - - diff --git a/src/bin/xfrout/tests/xfrout_test.py b/src/bin/xfrout/tests/xfrout_test.py index de7bfa2579..725b31a5fe 100644 --- a/src/bin/xfrout/tests/xfrout_test.py +++ b/src/bin/xfrout/tests/xfrout_test.py @@ -276,6 +276,70 @@ class TestUnixSockServer(unittest.TestCase): self.unix.decrease_transfers_counter() self.assertEqual(count - 1, self.unix._transfers_counter) + def _remove_file(self, sock_file): + try: + os.remove(sock_file) + except OSError: + pass + + def test_sock_file_in_use_file_exist(self): + sock_file = 'temp.sock.file' + self._remove_file(sock_file) + self.assertFalse(self.unix._sock_file_in_use(sock_file)) + self.assertFalse(os.path.exists(sock_file)) + + def test_sock_file_in_use_file_not_exist(self): + self.assertFalse(self.unix._sock_file_in_use('temp.sock.file')) + + def _start_unix_sock_server(self, sock_file): + serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler) + serv_thread = threading.Thread(target=serv.serve_forever) + serv_thread.setDaemon(True) + serv_thread.start() + + def test_sock_file_in_use(self): + sock_file = 'temp.sock.file' + self._remove_file(sock_file) + self.assertFalse(self.unix._sock_file_in_use(sock_file)) + self._start_unix_sock_server(sock_file) + + old_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + self.assertTrue(self.unix._sock_file_in_use(sock_file)) + sys.stdout = old_stdout + + def test_remove_unused_sock_file_in_use(self): + sock_file = 'temp.sock.file' + self._remove_file(sock_file) + self.assertFalse(self.unix._sock_file_in_use(sock_file)) + self._start_unix_sock_server(sock_file) + old_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + try: + self.unix._remove_unused_sock_file(sock_file) + except SystemExit: + pass + else: + # This should never happen + self.assertTrue(False) + + sys.stdout = old_stdout + + def test_remove_unused_sock_file_dir(self): + import tempfile + dir_name = tempfile.mkdtemp() + old_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + try: + self.unix._remove_unused_sock_file(dir_name) + except SystemExit: + pass + else: + # This should never happen + self.assertTrue(False) + + sys.stdout = old_stdout + os.rmdir(dir_name) if __name__== "__main__": unittest.main() diff --git a/src/bin/xfrout/xfrout.py.in b/src/bin/xfrout/xfrout.py.in index c7e1b6c07b..4c91970a12 100644 --- a/src/bin/xfrout/xfrout.py.in +++ b/src/bin/xfrout/xfrout.py.in @@ -28,6 +28,7 @@ import os from isc.config.ccsession import * from isc.cc import SessionError import socket +import errno from optparse import OptionParser, OptionValueError try: from libxfr_python import * @@ -37,14 +38,14 @@ except ImportError as e: # must keep running, so we warn about it and move forward. sys.stderr.write('[b10-xfrout] failed to import DNS or XFR module: %s\n' % str(e)) -if "B10_FROM_SOURCE" in os.environ: - SPECFILE_PATH = os.environ["B10_FROM_SOURCE"] + "/src/bin/xfrout" +if "B10_FROM_BUILD" in os.environ: + SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout" UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE"] + "/auth_xfrout_conn" else: PREFIX = "@prefix@" DATAROOTDIR = "@datarootdir@" SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX) - UNIX_SOCKET_FILE = "@localstatedir@".replace("${prefix}", PREFIX) + "/auth_xfrout_conn" + UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/auth_xfrout_conn" SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec" MAX_TRANSFERS_OUT = 10 @@ -59,7 +60,13 @@ class XfroutSession(BaseRequestHandler): fd = recv_fd(self.request.fileno()) if fd < 0: - raise XfroutException("failed to receive the FD for XFR connection") + # This may happen when one xfrout process try to connect to + # xfrout unix socket server, to check whether there is another + # xfrout running. + print("[b10-xfrout] Failed to receive the FD for XFR connection, " + "maybe because another xfrout process was started.") + return + data_len = self.request.recv(2) msg_len = struct.unpack('!H', data_len)[0] msgdata = self.request.recv(msg_len) @@ -182,7 +189,7 @@ class XfroutSession(BaseRequestHandler): self.log_msg("transfer of '%s/IN': AXFR end" % zone_name) except TmpException as err: if verbose_mode: - sys.stderr.write(str(err)) + sys.stderr.write("[b10-xfrout] %s\n" % str(err)) self.server.decrease_transfers_counter() return @@ -282,18 +289,44 @@ class UnixSockServer(ThreadingUnixStreamServer): '''The unix domain socket server which accept xfr query sent from auth server.''' def __init__(self, sock_file, handle_class, shutdown_event, config_data): - try: - os.unlink(sock_file) - except: - pass - + self._remove_unused_sock_file(sock_file) self._sock_file = sock_file ThreadingUnixStreamServer.__init__(self, sock_file, handle_class) self._lock = threading.Lock() self._transfers_counter = 0 self._shutdown_event = shutdown_event self.update_config_data(config_data) - + + def _remove_unused_sock_file(self, sock_file): + '''Try to remove the socket file. If the file is being used + by one running xfrout process, exit from python. + If it's not a socket file or nobody is listening + , it will be removed. If it can't be removed, exit from python. ''' + if self._sock_file_in_use(sock_file): + print("[b10-xfrout] Fail to start xfrout process, unix socket" + " file '%s' is being used by another xfrout process" % sock_file) + sys.exit(0) + else: + if not os.path.exists(sock_file): + return + + try: + os.unlink(sock_file) + except OSError as err: + print('[b10-xfrout] Fail to remove file ' + sock_file, err) + sys.exit(0) + + def _sock_file_in_use(self, sock_file): + '''Check whether the socket file 'sock_file' exists and + is being used by one running xfrout process. If it is, + return True, or else return False. ''' + try: + sock = socket.socket(socket.AF_UNIX) + sock.connect(sock_file) + except socket.error as err: + return False + else: + return True def shutdown(self): ThreadingUnixStreamServer.shutdown(self) @@ -396,7 +429,7 @@ class XfroutServer: def command_handler(self, cmd, args): if cmd == "shutdown": if verbose_mode: - log_msg("Received shutdown command") + print("[b10-xfrout] Received shutdown command") self.shutdown() answer = create_answer(0) else: diff --git a/src/lib/cc/Makefile.am b/src/lib/cc/Makefile.am index 4cad3db7c5..fe61cd25f2 100644 --- a/src/lib/cc/Makefile.am +++ b/src/lib/cc/Makefile.am @@ -1,19 +1,36 @@ AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib +AM_CPPFLAGS += -I$(top_srcdir)/src/lib/dns -I$(top_builddir)/src/lib/dns + +AM_CXXFLAGS = $(B10_CXXFLAGS) +# ASIO header files used in session.cc will trigger "unused-parameter" +# error. Unfortunately there doesn't seem to be an easy way to selectively +# avoid the error. As a short term workaround we suppress this warning +# for the entire this module. See also src/bin/auth/Makefile.am. +AM_CXXFLAGS += -Wno-unused-parameter lib_LIBRARIES = libcc.a libcc_a_SOURCES = data.cc data.h session.cc session.h -CLEANFILES = *.gcno *.gcda +CLEANFILES = *.gcno *.gcda session_config.h + +session_config.h: session_config.h.pre + $(SED) -e "s|@@LOCALSTATEDIR@@|$(localstatedir)|" session_config.h.pre >$@ + +BUILT_SOURCES = session_config.h TESTS = if HAVE_GTEST TESTS += run_unittests -run_unittests_SOURCES = data_unittests.cc run_unittests.cc +# (TODO: these need to be completed and moved to tests/) +run_unittests_SOURCES = data_unittests.cc session_unittests.cc run_unittests.cc run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES) -run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS) +# TODO: remove PTHREAD_LDFLAGS (and from configure too) +run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS) $(PTHREAD_LDFLAGS) + run_unittests_LDADD = libcc.a $(GTEST_LDADD) run_unittests_LDADD += $(top_builddir)/src/lib/dns/.libs/libdns.a run_unittests_LDADD += $(top_builddir)/src/lib/exceptions/.libs/libexceptions.a + endif noinst_PROGRAMS = $(TESTS) diff --git a/src/lib/cc/data.cc b/src/lib/cc/data.cc index 0722127ffb..2d43e754ea 100644 --- a/src/lib/cc/data.cc +++ b/src/lib/cc/data.cc @@ -204,47 +204,27 @@ bool operator==(const isc::data::ElementPtr a, const isc::data::ElementPtr b) { // ElementPtr Element::create(const int i) { - try { - return ElementPtr(new IntElement(i)); - } catch (std::bad_alloc) { - return ElementPtr(); - } + return ElementPtr(new IntElement(i)); } ElementPtr Element::create(const double d) { - try { - return ElementPtr(new DoubleElement(d)); - } catch (std::bad_alloc) { - return ElementPtr(); - } + return ElementPtr(new DoubleElement(d)); } ElementPtr Element::create(const std::string& s) { - try { - return ElementPtr(new StringElement(s)); - } catch (std::bad_alloc) { - return ElementPtr(); - } + return ElementPtr(new StringElement(s)); } ElementPtr Element::create(const bool b) { - try { - return ElementPtr(new BoolElement(b)); - } catch (std::bad_alloc) { - return ElementPtr(); - } + return ElementPtr(new BoolElement(b)); } ElementPtr Element::create(const std::vector<ElementPtr>& v) { - try { - return ElementPtr(new ListElement(v)); - } catch (std::bad_alloc) { - return ElementPtr(); - } + return ElementPtr(new ListElement(v)); } ElementPtr @@ -255,11 +235,7 @@ Element::create(const std::map<std::string, ElementPtr>& m) { isc_throw(TypeError, "Map tag is too long"); } } - try { - return ElementPtr(new MapElement(m)); - } catch (std::bad_alloc) { - return ElementPtr(); - } + return ElementPtr(new MapElement(m)); } diff --git a/src/lib/cc/session.cc b/src/lib/cc/session.cc index 68da5a83ea..66bc2717ab 100644 --- a/src/lib/cc/session.cc +++ b/src/lib/cc/session.cc @@ -14,7 +14,8 @@ // $Id$ -#include "config.h" +#include <config.h> +#include "session_config.h" #include <stdint.h> @@ -23,11 +24,14 @@ #include <iostream> #include <sstream> -#ifdef HAVE_BOOST_SYSTEM +#include <sys/un.h> + #include <boost/bind.hpp> #include <boost/function.hpp> -#include <boost/asio.hpp> -#endif + +#include <asio.hpp> +#include <asio/error_code.hpp> +#include <asio/system_error.hpp> #include <exceptions/exceptions.h> @@ -38,12 +42,10 @@ using namespace std; using namespace isc::cc; using namespace isc::data; -#ifdef HAVE_BOOST_SYSTEM -// some of the boost::asio names conflict with socket API system calls -// (e.g. write(2)) so we don't import the entire boost::asio namespace. -using boost::asio::io_service; -using boost::asio::ip::tcp; -#endif +// some of the asio names conflict with socket API system calls +// (e.g. write(2)) so we don't import the entire asio namespace. +using asio::io_service; +using asio::ip::tcp; #include <sys/types.h> #include <sys/socket.h> @@ -54,27 +56,27 @@ namespace cc { class SessionImpl { public: - SessionImpl() : sequence_(-1) {} + SessionImpl() : sequence_(-1) { queue_ = Element::createFromString("[]"); } virtual ~SessionImpl() {} - virtual void establish() = 0; + virtual void establish(const char& socket_file) = 0; virtual int getSocket() = 0; virtual void disconnect() = 0; virtual void writeData(const void* data, size_t datalen) = 0; virtual size_t readDataLength() = 0; virtual void readData(void* data, size_t datalen) = 0; virtual void startRead(boost::function<void()> user_handler) = 0; - + int sequence_; // the next sequence number to use std::string lname_; + ElementPtr queue_; }; -#ifdef HAVE_BOOST_SYSTEM class ASIOSession : public SessionImpl { public: ASIOSession(io_service& io_service) : io_service_(io_service), socket_(io_service_), data_length_(0) {} - virtual void establish(); + virtual void establish(const char& socket_file); virtual void disconnect(); virtual int getSocket() { return (socket_.native()); } virtual void writeData(const void* data, size_t datalen); @@ -82,23 +84,28 @@ public: virtual void readData(void* data, size_t datalen); virtual void startRead(boost::function<void()> user_handler); private: - void internalRead(const boost::system::error_code& error, + void internalRead(const asio::error_code& error, size_t bytes_transferred); private: io_service& io_service_; - tcp::socket socket_; + asio::local::stream_protocol::socket socket_; uint32_t data_length_; boost::function<void()> user_handler_; - boost::system::error_code error_; + asio::error_code error_; }; + + void -ASIOSession::establish() { - socket_.connect(tcp::endpoint(boost::asio::ip::address_v4::loopback(), - 9912), error_); +ASIOSession::establish(const char& socket_file) { + try { + socket_.connect(asio::local::stream_protocol::endpoint(&socket_file), error_); + } catch (asio::system_error& se) { + isc_throw(SessionError, se.what()); + } if (error_) { - isc_throw(SessionError, "Unable to connect to message queue"); + isc_throw(SessionError, "Unable to connect to message queue."); } } @@ -111,9 +118,9 @@ ASIOSession::disconnect() { void ASIOSession::writeData(const void* data, size_t datalen) { try { - boost::asio::write(socket_, boost::asio::buffer(data, datalen)); - } catch (const boost::system::system_error& boost_ex) { - isc_throw(SessionError, "ASIO write failed: " << boost_ex.what()); + asio::write(socket_, asio::buffer(data, datalen)); + } catch (const asio::system_error& asio_ex) { + isc_throw(SessionError, "ASIO write failed: " << asio_ex.what()); } } @@ -136,11 +143,11 @@ ASIOSession::readDataLength() { void ASIOSession::readData(void* data, size_t datalen) { try { - boost::asio::read(socket_, boost::asio::buffer(data, datalen)); - } catch (const boost::system::system_error& boost_ex) { + asio::read(socket_, asio::buffer(data, datalen)); + } catch (const asio::system_error& asio_ex) { // to hide boost specific exceptions, we catch them explicitly // and convert it to SessionError. - isc_throw(SessionError, "ASIO read failed: " << boost_ex.what()); + isc_throw(SessionError, "ASIO read failed: " << asio_ex.what()); } } @@ -148,15 +155,15 @@ void ASIOSession::startRead(boost::function<void()> user_handler) { data_length_ = 0; user_handler_ = user_handler; - async_read(socket_, boost::asio::buffer(&data_length_, + async_read(socket_, asio::buffer(&data_length_, sizeof(data_length_)), boost::bind(&ASIOSession::internalRead, this, - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred)); + asio::placeholders::error, + asio::placeholders::bytes_transferred)); } void -ASIOSession::internalRead(const boost::system::error_code& error, +ASIOSession::internalRead(const asio::error_code& error, size_t bytes_transferred) { if (!error) { @@ -170,14 +177,13 @@ ASIOSession::internalRead(const boost::system::error_code& error, isc_throw(SessionError, "asynchronous read failed"); } } -#endif class SocketSession : public SessionImpl { public: SocketSession() : sock_(-1) {} virtual ~SocketSession() { disconnect(); } virtual int getSocket() { return (sock_); } - void establish(); + void establish(const char& socket_file); virtual void disconnect() { if (sock_ >= 0) { @@ -212,29 +218,25 @@ public: } void -SocketSession::establish() { - int s; - struct sockaddr_in sin; +SocketSession::establish(const char& socket_file) { + struct sockaddr_un s_un; +#ifdef HAVE_SA_LEN + s_un.sun_len = sizeof(struct sockaddr_un); +#endif + + if (strlen(&socket_file) >= sizeof(s_un.sun_path)) { + isc_throw(SessionError, "Unable to connect to message queue; " + "socket file path too long: " << socket_file); + } + s_un.sun_family = AF_UNIX; + strncpy(s_un.sun_path, &socket_file, sizeof(s_un.sun_path) - 1); - s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + int s = socket(AF_UNIX, SOCK_STREAM, 0); if (s < 0) { isc_throw(SessionError, "socket() failed"); } - - int port = atoi(getenv("ISC_MSGQ_PORT")); - if (port == 0) { - port = 9912; - } - - sin.sin_family = AF_INET; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = INADDR_ANY; - -#ifdef HAVE_SIN_LEN - sin.sin_len = sizeof(struct sockaddr_in); -#endif - if (connect(s, (struct sockaddr *)&sin, sizeof(sin)) < 0) { + if (connect(s, (struct sockaddr *)&s_un, sizeof(s_un)) < 0) { close(s); isc_throw(SessionError, "Unable to connect to message queue"); } @@ -270,10 +272,8 @@ SocketSession::readData(void* data, const size_t datalen) { Session::Session() : impl_(new SocketSession) {} -#ifdef HAVE_BOOST_SYSTEM Session::Session(io_service& io_service) : impl_(new ASIOSession(io_service)) {} -#endif Session::~Session() { delete impl_; @@ -295,8 +295,15 @@ Session::startRead(boost::function<void()> read_callback) { } void -Session::establish() { - impl_->establish(); +Session::establish(const char* socket_file) { + if (socket_file == NULL) { + socket_file = getenv("BIND10_MSGQ_SOCKET_FILE"); + } + if (socket_file == NULL) { + socket_file = BIND10_MSGQ_SOCKET_FILE; + } + + impl_->establish(*socket_file); // once established, encapsulate the implementation object so that we // can safely release the internal resource when exception happens @@ -314,7 +321,6 @@ Session::establish() { recvmsg(routing, msg, false); impl_->lname_ = msg->get("lname")->stringValue(); - cout << "My local name is: " << impl_->lname_ << endl; // At this point there's no risk of resource leak. session_holder.clear(); @@ -352,35 +358,35 @@ Session::sendmsg(ElementPtr& env, ElementPtr& msg) { } bool -Session::recvmsg(ElementPtr& msg, bool nonblock UNUSED_PARAM) { - size_t length = impl_->readDataLength(); - - unsigned short header_length_net; - impl_->readData(&header_length_net, sizeof(header_length_net)); - - unsigned short header_length = ntohs(header_length_net); - if (header_length != length) { - isc_throw(SessionError, "Length parameters invalid: total=" << length - << ", header=" << header_length); - } - - std::vector<char> buffer(length); - impl_->readData(&buffer[0], length); - - std::string wire = std::string(&buffer[0], length); - std::stringstream wire_stream; - wire_stream << wire; - - msg = Element::fromWire(wire_stream, length); - - return (true); - // XXXMLG handle non-block here, and return false for short reads +Session::recvmsg(ElementPtr& msg, bool nonblock, int seq) { + ElementPtr l_env; + return recvmsg(l_env, msg, nonblock, seq); } bool -Session::recvmsg(ElementPtr& env, ElementPtr& msg, bool nonblock UNUSED_PARAM) { +Session::recvmsg(ElementPtr& env, ElementPtr& msg, + bool nonblock, int seq) { size_t length = impl_->readDataLength(); - + ElementPtr l_env, l_msg; + if (hasQueuedMsgs()) { + ElementPtr q_el; + for (int i = 0; i < impl_->queue_->size(); i++) { + q_el = impl_->queue_->get(i); + if (( seq == -1 && + !q_el->get(0)->contains("reply") + ) || ( + q_el->get(0)->contains("reply") && + q_el->get(0)->get("reply")->intValue() == seq + ) + ) { + env = q_el->get(0); + msg = q_el->get(1); + impl_->queue_->remove(i); + return true; + } + } + } + unsigned short header_length_net; impl_->readData(&header_length_net, sizeof(header_length_net)); @@ -400,13 +406,28 @@ Session::recvmsg(ElementPtr& env, ElementPtr& msg, bool nonblock UNUSED_PARAM) { length - header_length); std::stringstream header_wire_stream; header_wire_stream << header_wire; - env = Element::fromWire(header_wire_stream, header_length); + l_env = Element::fromWire(header_wire_stream, header_length); std::stringstream body_wire_stream; body_wire_stream << body_wire; - msg = Element::fromWire(body_wire_stream, length - header_length); - - return (true); + l_msg = Element::fromWire(body_wire_stream, length - header_length); + if ((seq == -1 && + !l_env->contains("reply") + ) || ( + l_env->contains("reply") && + l_env->get("reply")->intValue() == seq + ) + ) { + env = l_env; + msg = l_msg; + return true; + } else { + ElementPtr q_el = Element::createFromString("[]"); + q_el->add(l_env); + q_el->add(l_msg); + impl_->queue_->add(q_el); + return recvmsg(env, msg, nonblock, seq); + } // XXXMLG handle non-block here, and return false for short reads } @@ -432,47 +453,55 @@ Session::unsubscribe(std::string group, std::string instance) { sendmsg(env); } -unsigned int +int Session::group_sendmsg(ElementPtr msg, std::string group, std::string instance, std::string to) { ElementPtr env = Element::create(std::map<std::string, ElementPtr>()); - + int nseq = ++impl_->sequence_; + env->set("type", Element::create("send")); env->set("from", Element::create(impl_->lname_)); env->set("to", Element::create(to)); env->set("group", Element::create(group)); env->set("instance", Element::create(instance)); - env->set("seq", Element::create(impl_->sequence_)); + env->set("seq", Element::create(nseq)); //env->set("msg", Element::create(msg->toWire())); sendmsg(env, msg); - - return (++impl_->sequence_); + return nseq; } bool Session::group_recvmsg(ElementPtr& envelope, ElementPtr& msg, - bool nonblock) + bool nonblock, int seq) { - return (recvmsg(envelope, msg, nonblock)); + return (recvmsg(envelope, msg, nonblock, seq)); } -unsigned int +int Session::reply(ElementPtr& envelope, ElementPtr& newmsg) { ElementPtr env = Element::create(std::map<std::string, ElementPtr>()); - + int nseq = ++impl_->sequence_; + env->set("type", Element::create("send")); env->set("from", Element::create(impl_->lname_)); env->set("to", Element::create(envelope->get("from")->stringValue())); env->set("group", Element::create(envelope->get("group")->stringValue())); env->set("instance", Element::create(envelope->get("instance")->stringValue())); - env->set("seq", Element::create(impl_->sequence_)); + env->set("seq", Element::create(nseq)); env->set("reply", Element::create(envelope->get("seq")->intValue())); sendmsg(env, newmsg); - return (++impl_->sequence_); + return nseq; } + +bool +Session::hasQueuedMsgs() +{ + return (impl_->queue_->size() > 0); +} + } } diff --git a/src/lib/cc/session.h b/src/lib/cc/session.h index 509cf35feb..ee3b387d67 100644 --- a/src/lib/cc/session.h +++ b/src/lib/cc/session.h @@ -24,12 +24,11 @@ #include <exceptions/exceptions.h> #include "data.h" +#include "session_config.h" -namespace boost { namespace asio { class io_service; } -} namespace isc { namespace cc { @@ -51,7 +50,7 @@ namespace isc { public: Session(); - Session(boost::asio::io_service& ioservice); + Session(asio::io_service& ioservice); ~Session(); // XXX: quick hack to allow the user to watch the socket directly. @@ -59,29 +58,33 @@ namespace isc { void startRead(boost::function<void()> read_callback); - void establish(); + void establish(const char* socket_file = NULL); void disconnect(); void sendmsg(isc::data::ElementPtr& msg); void sendmsg(isc::data::ElementPtr& env, isc::data::ElementPtr& msg); bool recvmsg(isc::data::ElementPtr& msg, - bool nonblock = true); + bool nonblock = true, + int seq = -1); bool recvmsg(isc::data::ElementPtr& env, isc::data::ElementPtr& msg, - bool nonblock = true); + bool nonblock = true, + int seq = -1); void subscribe(std::string group, std::string instance = "*"); void unsubscribe(std::string group, std::string instance = "*"); - unsigned int group_sendmsg(isc::data::ElementPtr msg, + int group_sendmsg(isc::data::ElementPtr msg, std::string group, std::string instance = "*", std::string to = "*"); bool group_recvmsg(isc::data::ElementPtr& envelope, isc::data::ElementPtr& msg, - bool nonblock = true); - unsigned int reply(isc::data::ElementPtr& envelope, + bool nonblock = true, + int seq = -1); + int reply(isc::data::ElementPtr& envelope, isc::data::ElementPtr& newmsg); + bool hasQueuedMsgs(); }; } // namespace cc } // namespace isc diff --git a/src/lib/cc/session_config.h.pre.in b/src/lib/cc/session_config.h.pre.in new file mode 100644 index 0000000000..96bcba0901 --- /dev/null +++ b/src/lib/cc/session_config.h.pre.in @@ -0,0 +1,2 @@ +#define BIND10_MSGQ_SOCKET_FILE "@@LOCALSTATEDIR@@/@PACKAGE@/msgq_socket" + diff --git a/src/lib/cc/session_unittests.cc b/src/lib/cc/session_unittests.cc new file mode 100644 index 0000000000..e67363efda --- /dev/null +++ b/src/lib/cc/session_unittests.cc @@ -0,0 +1,63 @@ +// Copyright (C) 2009 Internet Systems Consortium, Inc. ("ISC") +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +// $Id: data_unittests.cc 1899 2010-05-21 12:03:59Z jelte $ + +#include "config.h" +#include <gtest/gtest.h> +#include <session.h> + +#include <asio.hpp> +#include <exceptions/exceptions.h> + +using namespace isc::cc; + +TEST(AsioSession, establish) { + asio::io_service io_service_; + Session sess(io_service_); + + EXPECT_THROW( + sess.establish("/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + ), isc::cc::SessionError + ); + +} + +TEST(Session, establish) { + Session sess; + + EXPECT_THROW( + sess.establish("/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + "/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/aaaaaaaaaa/" + ), isc::cc::SessionError + ); + +} diff --git a/src/lib/config/Makefile.am b/src/lib/config/Makefile.am index 05b19db4e6..56416332da 100644 --- a/src/lib/config/Makefile.am +++ b/src/lib/config/Makefile.am @@ -1,4 +1,6 @@ -AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib -Wno-strict-aliasing +AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib +AM_CPPFLAGS += -I$(top_builddir)/src/lib/cc +AM_CXXFLAGS = $(B10_CXXFLAGS) -Wno-strict-aliasing lib_LTLIBRARIES = libcfgclient.la libcfgclient_la_SOURCES = config_data.h config_data.cc module_spec.h module_spec.cc ccsession.cc ccsession.h @@ -46,3 +48,6 @@ EXTRA_DIST += testdata/spec21.spec EXTRA_DIST += testdata/spec22.spec EXTRA_DIST += testdata/spec23.spec EXTRA_DIST += testdata/spec24.spec +EXTRA_DIST += testdata/spec25.spec +EXTRA_DIST += testdata/spec26.spec +EXTRA_DIST += testdata/spec27.spec diff --git a/src/lib/config/ccsession.cc b/src/lib/config/ccsession.cc index abe6ba42af..e0020ab2eb 100644 --- a/src/lib/config/ccsession.cc +++ b/src/lib/config/ccsession.cc @@ -32,9 +32,7 @@ #include <sstream> #include <cerrno> -#ifdef HAVE_BOOST_SYSTEM #include <boost/bind.hpp> -#endif #include <boost/foreach.hpp> #include <cc/data.h> @@ -188,7 +186,6 @@ ModuleCCSession::readModuleSpecification(const std::string& filename) { return module_spec; } -#ifdef HAVE_BOOST_SYSTEM void ModuleCCSession::startCheck() { // data available on the command channel. process it in the synchronous @@ -201,7 +198,7 @@ ModuleCCSession::startCheck() { ModuleCCSession::ModuleCCSession( std::string spec_file_name, - boost::asio::io_service& io_service, + asio::io_service& io_service, isc::data::ElementPtr(*config_handler)(isc::data::ElementPtr new_config), isc::data::ElementPtr(*command_handler)( const std::string& command, const isc::data::ElementPtr args) @@ -213,7 +210,6 @@ ModuleCCSession::ModuleCCSession( // register callback for asynchronous read session_.startRead(boost::bind(&ModuleCCSession::startCheck, this)); } -#endif ModuleCCSession::ModuleCCSession( std::string spec_file_name, @@ -248,8 +244,8 @@ ModuleCCSession::init( //session_.subscribe("statistics", "*"); // send the data specification ElementPtr spec_msg = createCommand("module_spec", module_specification_.getFullSpec()); - session_.group_sendmsg(spec_msg, "ConfigManager"); - session_.group_recvmsg(env, answer, false); + unsigned int seq = session_.group_sendmsg(spec_msg, "ConfigManager"); + session_.group_recvmsg(env, answer, false, seq); int rcode; ElementPtr err = parseAnswer(rcode, answer); if (rcode != 0) { @@ -260,8 +256,8 @@ ModuleCCSession::init( // get any stored configuration from the manager if (config_handler_) { ElementPtr cmd = Element::createFromString("{ \"command\": [\"get_config\", {\"module_name\":\"" + module_name_ + "\"} ] }"); - session_.group_sendmsg(cmd, "ConfigManager"); - session_.group_recvmsg(env, answer, false); + seq = session_.group_sendmsg(cmd, "ConfigManager"); + session_.group_recvmsg(env, answer, false, seq); ElementPtr new_config = parseAnswer(rcode, answer); if (rcode == 0) { handleConfigUpdate(new_config); @@ -310,6 +306,12 @@ ModuleCCSession::getSocket() return (session_.getSocket()); } +bool +ModuleCCSession::hasQueuedMsgs() +{ + return (session_.hasQueuedMsgs()); +} + int ModuleCCSession::checkCommand() { @@ -365,8 +367,8 @@ ModuleCCSession::addRemoteConfig(const std::string& spec_file_name) ElementPtr env, answer; int rcode; - session_.group_sendmsg(cmd, "ConfigManager"); - session_.group_recvmsg(env, answer, false); + unsigned int seq = session_.group_sendmsg(cmd, "ConfigManager"); + session_.group_recvmsg(env, answer, false, seq); ElementPtr new_config = parseAnswer(rcode, answer); if (rcode == 0) { rmod_config.setLocalConfig(new_config); diff --git a/src/lib/config/ccsession.h b/src/lib/config/ccsession.h index 4030364d21..08cfde99df 100644 --- a/src/lib/config/ccsession.h +++ b/src/lib/config/ccsession.h @@ -24,11 +24,9 @@ #include <cc/session.h> #include <cc/data.h> -namespace boost { namespace asio { class io_service; } -} namespace isc { namespace config { @@ -133,7 +131,7 @@ public: isc::data::ElementPtr(*command_handler)(const std::string& command, const isc::data::ElementPtr args) = NULL ) throw (isc::cc::SessionError); ModuleCCSession(std::string spec_file_name, - boost::asio::io_service& io_service, + asio::io_service& io_service, isc::data::ElementPtr(*config_handler)(isc::data::ElementPtr new_config) = NULL, isc::data::ElementPtr(*command_handler)(const std::string& command, const isc::data::ElementPtr args) = NULL ) throw (isc::cc::SessionError); @@ -150,6 +148,16 @@ public: int getSocket(); /** + * Optional optimization for checkCommand loop; returns true + * if there are unhandled queued messages in the cc session. + * (if either this is true or there is data on the socket found + * by the select() call on getSocket(), run checkCommand()) + * + * @return true if there are unhandled queued messages + */ + bool hasQueuedMsgs(); + + /** * Check if there is a command or config change on the command * session. If so, the appropriate handler is called if set. * If not set, a default answer is returned. diff --git a/src/lib/config/config_data.h b/src/lib/config/config_data.h index a6858863e0..359a6e154d 100644 --- a/src/lib/config/config_data.h +++ b/src/lib/config/config_data.h @@ -40,6 +40,7 @@ public: /// Constructs a ConfigData option with no specification and an /// empty configuration. ConfigData() { _config = Element::createFromString("{}"); }; + /// Constructs a ConfigData option with the given specification /// and an empty configuration. /// \param module_spec A ModuleSpec for the relevant module @@ -70,17 +71,21 @@ public: /// Returns the ModuleSpec associated with this ConfigData object const ModuleSpec getModuleSpec() { return _module_spec; }; + /// Set the ModuleSpec associated with this ConfigData object void setModuleSpec(ModuleSpec module_spec) { _module_spec = module_spec; }; + /// Set the local configuration (i.e. all non-default values) /// \param config An ElementPtr pointing to a MapElement containing /// *all* non-default configuration values. Existing values /// will be removed. void setLocalConfig(ElementPtr config) { _config = config; } + /// Returns the local (i.e. non-default) configuration. /// \returns An ElementPtr pointing to a MapElement containing all /// non-default configuration options. ElementPtr getLocalConfig() { return _config; } + /// Returns a list of all possible configuration options as specified /// by the ModuleSpec. /// \param identifier If given, show the items at the given identifier @@ -92,6 +97,7 @@ public: /// location (or all possible identifiers if identifier=="" /// and recurse==false) ElementPtr getItemList(const std::string& identifier = "", bool recurse = false); + /// Returns all current configuration settings (both non-default and default). /// \return An ElementPtr pointing to a MapElement containing /// string->value elements, where the string is the diff --git a/src/lib/config/module_spec.cc b/src/lib/config/module_spec.cc index dc38ca2f4d..c6bf51b16f 100644 --- a/src/lib/config/module_spec.cc +++ b/src/lib/config/module_spec.cc @@ -145,6 +145,7 @@ check_command_list(const ElementPtr& spec) { static void check_data_specification(const ElementPtr& spec) { check_leaf_item(spec, "module_name", Element::string, true); + check_leaf_item(spec, "module_description", Element::string, false); // config_data is not mandatory; module could just define // commands and have no config if (spec->contains("config_data")) { @@ -204,6 +205,16 @@ ModuleSpec::getModuleName() const return module_specification->get("module_name")->stringValue(); } +const std::string +ModuleSpec::getModuleDescription() const +{ + if (module_specification->contains("module_description")) { + return module_specification->get("module_description")->stringValue(); + } else { + return std::string(""); + } +} + bool ModuleSpec::validate_config(const ElementPtr data, const bool full) { diff --git a/src/lib/config/module_spec.h b/src/lib/config/module_spec.h index 908eef9120..948670ad3c 100644 --- a/src/lib/config/module_spec.h +++ b/src/lib/config/module_spec.h @@ -77,6 +77,10 @@ namespace isc { namespace config { /// Returns the module name as specified by the specification const std::string getModuleName() const; + /// Returns the module description as specified by the specification + /// returns an empty string if there is no description + const std::string getModuleDescription() const; + // returns true if the given element conforms to this data // configuration specification /// Validates the given configuration data for this specification. diff --git a/src/lib/config/testdata/spec25.spec b/src/lib/config/testdata/spec25.spec new file mode 100644 index 0000000000..6a174d5356 --- /dev/null +++ b/src/lib/config/testdata/spec25.spec @@ -0,0 +1,7 @@ +{ + "module_spec": { + "module_name": "Spec25", + "module_description": "Just an empty module" + } +} + diff --git a/src/lib/config/testdata/spec26.spec b/src/lib/config/testdata/spec26.spec new file mode 100644 index 0000000000..27f3c5b4de --- /dev/null +++ b/src/lib/config/testdata/spec26.spec @@ -0,0 +1,6 @@ +{ + "module_spec": { + "module_name": "Spec26", + "module_description": 1 + } +} diff --git a/src/lib/config/testdata/spec27.spec b/src/lib/config/testdata/spec27.spec new file mode 100644 index 0000000000..ee29d80093 --- /dev/null +++ b/src/lib/config/testdata/spec27.spec @@ -0,0 +1,121 @@ +{ + "module_spec": { + "module_name": "Spec27", + "commands": [ + { + 'command_name': 'cmd1', + "command_description": "command_for_unittest", + 'command_args': [ + { + "item_name": "value1", + "item_type": "integer", + "item_optional": False, + "item_default": 9 + }, + { "item_name": "value2", + "item_type": "real", + "item_optional": False, + "item_default": 9.9 + }, + { "item_name": "value3", + "item_type": "boolean", + "item_optional": False, + "item_default": False + }, + { "item_name": "value4", + "item_type": "string", + "item_optional": False, + "item_default": "default_string" + }, + { "item_name": "value5", + "item_type": "list", + "item_optional": False, + "item_default": [ "a", "b" ], + "list_item_spec": { + "item_name": "list_element", + "item_type": "integer", + "item_optional": False, + "item_default": 8 + } + }, + { "item_name": "value6", + "item_type": "map", + "item_optional": False, + "item_default": {}, + "map_item_spec": [ + { "item_name": "v61", + "item_type": "string", + "item_optional": False, + "item_default": "def" + }, + { "item_name": "v62", + "item_type": "boolean", + "item_optional": False, + "item_default": False + } + ] + }, + { "item_name": "value7", + "item_type": "list", + "item_optional": True, + "item_default": [ ], + "list_item_spec": { + "item_name": "list_element", + "item_type": "any", + "item_optional": True + } + }, + { "item_name": "value8", + "item_type": "list", + "item_optional": True, + "item_default": [ ], + "list_item_spec": { + "item_name": "list_element", + "item_type": "map", + "item_optional": True, + "item_default": { "a": "b" }, + "map_item_spec": [ + { "item_name": "a", + "item_type": "string", + "item_optional": True, + "item_default": "empty" + } + ] + } + }, + { "item_name": "value9", + "item_type": "map", + "item_optional": False, + "item_default": {}, + "map_item_spec": [ + { "item_name": "v91", + "item_type": "string", + "item_optional": False, + "item_default": "def" + }, + { "item_name": "v92", + "item_type": "map", + "item_optional": False, + "item_default": {}, + "map_item_spec": [ + { "item_name": "v92a", + "item_type": "string", + "item_optional": False, + "item_default": "Hello" + } , + { + "item_name": "v92b", + "item_type": "integer", + "item_optional": False, + "item_default": 47806 + } + ] + } + ] + } + ] + } + ] + } +} + diff --git a/src/lib/config/tests/Makefile.am b/src/lib/config/tests/Makefile.am index 1cfda66e5c..436d962937 100644 --- a/src/lib/config/tests/Makefile.am +++ b/src/lib/config/tests/Makefile.am @@ -1,5 +1,9 @@ AM_CPPFLAGS = -I$(top_srcdir)/src/lib +AM_CXXFLAGS = $(B10_CXXFLAGS) +# see src/lib/cc/Makefile.am for -Wno-unused-parameter +AM_CXXFLAGS += -Wno-unused-parameter + CLEANFILES = *.gcno *.gcda lib_LTLIBRARIES = libfake_session.la @@ -14,14 +18,9 @@ run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES) run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS) run_unittests_LDADD = $(GTEST_LDADD) run_unittests_LDADD += $(top_builddir)/src/lib/exceptions/libexceptions.la +run_unittests_LDADD += libfake_session.la run_unittests_LDADD += $(top_builddir)/src/lib/config/libcfgclient.la run_unittests_LDADD += $(top_builddir)/src/lib/cc/data.o -run_unittests_LDADD += libfake_session.la - -if HAVE_BOOST_SYSTEM -run_unittests_LDFLAGS += $(AM_LDFLAGS) $(BOOST_LDFLAGS) -run_unittests_LDADD += $(BOOST_SYSTEM_LIB) -endif endif diff --git a/src/lib/config/tests/fake_session.cc b/src/lib/config/tests/fake_session.cc index 2a83bb18c7..cad2e054d4 100644 --- a/src/lib/config/tests/fake_session.cc +++ b/src/lib/config/tests/fake_session.cc @@ -23,12 +23,6 @@ #include <iostream> #include <sstream> -#ifdef HAVE_BOOST_SYSTEM -#include <boost/bind.hpp> -#include <boost/function.hpp> -#include <boost/asio.hpp> -#endif - #include <boost/foreach.hpp> #include <exceptions/exceptions.h> @@ -40,13 +34,6 @@ using namespace std; using namespace isc::cc; using namespace isc::data; -#ifdef HAVE_BOOST_SYSTEM -// some of the boost::asio names conflict with socket API system calls -// (e.g. write(2)) so we don't import the entire boost::asio namespace. -using boost::asio::io_service; -using boost::asio::ip::tcp; -#endif - #include <sys/types.h> #include <sys/socket.h> #include <netinet/in.h> @@ -144,15 +131,18 @@ Session::Session() { } -#ifdef HAVE_BOOST_SYSTEM -Session::Session(io_service& io_service UNUSED_PARAM) +Session::Session(asio::io_service& io_service UNUSED_PARAM) { } -#endif Session::~Session() { } +bool +Session::connect() { + return true; +} + void Session::disconnect() { } @@ -167,7 +157,7 @@ Session::startRead(boost::function<void()> read_callback UNUSED_PARAM) { } void -Session::establish() { +Session::establish(const char* socket_file) { } // @@ -188,7 +178,7 @@ Session::sendmsg(ElementPtr& env, ElementPtr& msg) { } bool -Session::recvmsg(ElementPtr& msg, bool nonblock UNUSED_PARAM) { +Session::recvmsg(ElementPtr& msg, bool nonblock UNUSED_PARAM, int seq UNUSED_PARAM) { //cout << "[XX] client asks for message " << endl; if (initial_messages && initial_messages->getType() == Element::list && @@ -202,7 +192,7 @@ Session::recvmsg(ElementPtr& msg, bool nonblock UNUSED_PARAM) { } bool -Session::recvmsg(ElementPtr& env, ElementPtr& msg, bool nonblock UNUSED_PARAM) { +Session::recvmsg(ElementPtr& env, ElementPtr& msg, bool nonblock UNUSED_PARAM, int seq UNUSED_PARAM) { //cout << "[XX] client asks for message and env" << endl; env = ElementPtr(); if (initial_messages && @@ -269,9 +259,9 @@ Session::group_sendmsg(ElementPtr msg, std::string group, bool Session::group_recvmsg(ElementPtr& envelope, ElementPtr& msg, - bool nonblock) + bool nonblock, int seq) { - return (recvmsg(envelope, msg, nonblock)); + return (recvmsg(envelope, msg, nonblock, seq)); } unsigned int @@ -282,5 +272,10 @@ Session::reply(ElementPtr& envelope, ElementPtr& newmsg) { return 1; } +bool +Session::hasQueuedMsgs() { + return false; +} + } } diff --git a/src/lib/config/tests/fake_session.h b/src/lib/config/tests/fake_session.h index 18ee92ef0b..195c266178 100644 --- a/src/lib/config/tests/fake_session.h +++ b/src/lib/config/tests/fake_session.h @@ -25,11 +25,9 @@ #include <cc/data.h> -namespace boost { namespace asio { class io_service; } -} // global variables so tests can insert // update and check, before, during and after @@ -65,7 +63,7 @@ namespace isc { // public so tests can inspect them Session(); - Session(boost::asio::io_service& ioservice); + Session(asio::io_service& ioservice); ~Session(); // XXX: quick hack to allow the user to watch the socket directly. @@ -73,16 +71,17 @@ namespace isc { void startRead(boost::function<void()> read_callback); - void establish(); + void establish(const char* socket_file = NULL); + bool connect(); void disconnect(); void sendmsg(isc::data::ElementPtr& msg); void sendmsg(isc::data::ElementPtr& env, isc::data::ElementPtr& msg); bool recvmsg(isc::data::ElementPtr& msg, - bool nonblock = true); + bool nonblock = true, int seq = -1); bool recvmsg(isc::data::ElementPtr& env, isc::data::ElementPtr& msg, - bool nonblock = true); + bool nonblock = true, int seq = -1); void subscribe(std::string group, std::string instance = "*"); void unsubscribe(std::string group, @@ -93,9 +92,11 @@ namespace isc { std::string to = "*"); bool group_recvmsg(isc::data::ElementPtr& envelope, isc::data::ElementPtr& msg, - bool nonblock = true); + bool nonblock = true, + int seq = -1); unsigned int reply(isc::data::ElementPtr& envelope, isc::data::ElementPtr& newmsg); + bool hasQueuedMsgs(); }; } // namespace cc diff --git a/src/lib/config/tests/module_spec_unittests.cc b/src/lib/config/tests/module_spec_unittests.cc index fee7300b22..658c46cedb 100644 --- a/src/lib/config/tests/module_spec_unittests.cc +++ b/src/lib/config/tests/module_spec_unittests.cc @@ -60,6 +60,12 @@ TEST(ModuleSpec, ReadingSpecfiles) { dd = moduleSpecFromFile(specfile("spec2.spec")); EXPECT_EQ("[ {\"command_args\": [ {\"item_default\": \"\", \"item_name\": \"message\", \"item_optional\": False, \"item_type\": \"string\"} ], \"command_description\": \"Print the given message to stdout\", \"command_name\": \"print_message\"}, {\"command_args\": [ ], \"command_description\": \"Shut down BIND 10\", \"command_name\": \"shutdown\"} ]", dd.getCommandsSpec()->str()); EXPECT_EQ("Spec2", dd.getModuleName()); + EXPECT_EQ("", dd.getModuleDescription()); + + dd = moduleSpecFromFile(specfile("spec25.spec")); + EXPECT_EQ("Spec25", dd.getModuleName()); + EXPECT_EQ("Just an empty module", dd.getModuleDescription()); + EXPECT_THROW(moduleSpecFromFile(specfile("spec26.spec")), ModuleSpecError); std::ifstream file; file.open(specfile("spec1.spec").c_str()); diff --git a/src/lib/datasrc/Makefile.am b/src/lib/datasrc/Makefile.am index 2c00263df5..f61849cc91 100644 --- a/src/lib/datasrc/Makefile.am +++ b/src/lib/datasrc/Makefile.am @@ -4,6 +4,8 @@ AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib AM_CPPFLAGS += -I$(top_srcdir)/src/lib/dns -I$(top_builddir)/src/lib/dns AM_CPPFLAGS += $(SQLITE_CFLAGS) +AM_CXXFLAGS = $(B10_CXXFLAGS) + CLEANFILES = *.gcno *.gcda lib_LTLIBRARIES = libdatasrc.la diff --git a/src/lib/datasrc/tests/Makefile.am b/src/lib/datasrc/tests/Makefile.am index 62aa475ec6..382e4885ab 100644 --- a/src/lib/datasrc/tests/Makefile.am +++ b/src/lib/datasrc/tests/Makefile.am @@ -2,6 +2,8 @@ AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib AM_CPPFLAGS += -I$(top_builddir)/src/lib/dns -I$(top_srcdir)/src/lib/dns AM_CPPFLAGS += -DTEST_DATA_DIR=\"$(srcdir)/testdata\" +AM_CXXFLAGS = $(B10_CXXFLAGS) + CLEANFILES = *.gcno *.gcda TESTS = diff --git a/src/lib/datasrc/tests/sqlite3_unittest.cc b/src/lib/datasrc/tests/sqlite3_unittest.cc index 1e3f1f83d2..ed38a05e5f 100644 --- a/src/lib/datasrc/tests/sqlite3_unittest.cc +++ b/src/lib/datasrc/tests/sqlite3_unittest.cc @@ -376,8 +376,10 @@ TEST_F(Sqlite3DataSourceTest, reOpen) { NameMatch name_match(www_name); data_source.findClosestEnclosure(name_match, rrclass); - EXPECT_EQ(NULL, name_match.closestName()); - EXPECT_EQ(NULL, name_match.bestDataSrc()); + // XXX: some deviant compilers seem to fail to recognize a NULL as a + // pointer type. This explicit cast works around such compilers. + EXPECT_EQ(static_cast<void*>(NULL), name_match.closestName()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.bestDataSrc()); } TEST_F(Sqlite3DataSourceTest, openFail) { @@ -441,15 +443,15 @@ TEST_F(Sqlite3DataSourceTest, findClosestEnclosureAtDelegation) { TEST_F(Sqlite3DataSourceTest, findClosestEnclosureNoMatch) { NameMatch name_match(nomatch_name); data_source.findClosestEnclosure(name_match, rrclass); - EXPECT_EQ(NULL, name_match.closestName()); - EXPECT_EQ(NULL, name_match.bestDataSrc()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.closestName()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.bestDataSrc()); } TEST_F(Sqlite3DataSourceTest, findClosestClassMismatch) { NameMatch name_match(www_name); data_source.findClosestEnclosure(name_match, rrclass_notmatch); - EXPECT_EQ(NULL, name_match.closestName()); - EXPECT_EQ(NULL, name_match.bestDataSrc()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.closestName()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.bestDataSrc()); } // If the query class is ANY, the result should be the same as the case where diff --git a/src/lib/datasrc/tests/static_unittest.cc b/src/lib/datasrc/tests/static_unittest.cc index 86bb99ccbe..583775f579 100644 --- a/src/lib/datasrc/tests/static_unittest.cc +++ b/src/lib/datasrc/tests/static_unittest.cc @@ -214,8 +214,9 @@ TEST_F(StaticDataSourceTest, findClosestEnclosureForVersionClassAny) { TEST_F(StaticDataSourceTest, findClosestEnclosureForVersionClassMismatch) { NameMatch name_match(version_name); data_source.findClosestEnclosure(name_match, RRClass::IN()); - EXPECT_EQ(NULL, name_match.closestName()); - EXPECT_EQ(NULL, name_match.bestDataSrc()); + // XXX: see sqlite3_unittest.cc about the cast. + EXPECT_EQ(static_cast<void*>(NULL), name_match.closestName()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.bestDataSrc()); } TEST_F(StaticDataSourceTest, findClosestEnclosureForVersionPartial) { @@ -242,8 +243,8 @@ TEST_F(StaticDataSourceTest, findClosestEnclosureForAuthorsPartial) { TEST_F(StaticDataSourceTest, findClosestEnclosureNoMatch) { NameMatch name_match(nomatch_name); data_source.findClosestEnclosure(name_match, rrclass); - EXPECT_EQ(NULL, name_match.closestName()); - EXPECT_EQ(NULL, name_match.bestDataSrc()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.closestName()); + EXPECT_EQ(static_cast<void*>(NULL), name_match.bestDataSrc()); } TEST_F(StaticDataSourceTest, findRRsetVersionTXT) { diff --git a/src/lib/dns/Makefile.am b/src/lib/dns/Makefile.am index 58bc514409..8b4c2ec9ca 100644 --- a/src/lib/dns/Makefile.am +++ b/src/lib/dns/Makefile.am @@ -1,9 +1,7 @@ SUBDIRS = . tests python AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib -if GCC_WERROR_OK -AM_CPPFLAGS += -Werror -endif +AM_CXXFLAGS = $(B10_CXXFLAGS) CLEANFILES = *.gcno *.gcda CLEANFILES += rrclass.h rrtype.h rrparamregistry.cc rdataclass.h rdataclass.cc @@ -79,6 +77,7 @@ libdns_la_SOURCES += question.h question.cc libdns_la_SOURCES += sha1.h sha1.cc libdns_la_SOURCES += tsig.h tsig.cc + #if HAVE_BOOST_PYTHON ## This is a loadable module for python scripts, so we use the prefix "pyexec" ## to make sure the object files will be installed in the appropriate place @@ -86,10 +85,11 @@ libdns_la_SOURCES += tsig.h tsig.cc #pyexec_LTLIBRARIES = bind10_dns.la #bind10_dns_la_SOURCES = python_dns.cc #bind10_dns_la_CPPFLAGS = $(AM_CPPFLAGS) $(PYTHON_INCLUDES) +#bind10_dns_la_CXXFLAGS = $(AM_CXXFLAGS) $(B10_CXXFLAGS) #if GCC_WERROR_OK ## XXX: Boost.Python triggers strict aliasing violation, so if we use -Werror ## we need to suppress the warnings. -#bind10_dns_la_CPPFLAGS += -fno-strict-aliasing +#bind10_dns_la_CXXFLAGS += -fno-strict-aliasing #endif #bind10_dns_la_LDFLAGS = $(BOOST_LDFLAGS) $(PYTHON_LDFLAGS) ## Python prefers .so, while some OSes (specifically MacOS) use a different @@ -107,3 +107,29 @@ rrtype.h: rrtype-placeholder.h rrparamregistry.cc: rrparamregistry-placeholder.cc rrclass.h rrtype.h rrparamregistry.cc rdataclass.h rdataclass.cc: Makefile ./gen-rdatacode.py + +libdns_includedir = $(includedir)/dns +libdns_include_HEADERS = \ + buffer.h \ + dnssectime.h \ + exceptions.h \ + message.h \ + messagerenderer.h \ + name.h \ + question.h \ + rdata.h \ + rdataclass.h \ + rrclass.h \ + rrparamregistry.h \ + rrset.h \ + rrsetlist.h \ + rrttl.h \ + rrtype.h \ + tsig.h +# Purposely not installing these headers: +# base32.h # used only internally, and not actually DNS specific +# base64.h # used only internally, and not actually DNS specific +# hex.h # used only internally, and not actually DNS specific +# sha1.h # used only internally, and not actually DNS specific +# rrclass-placeholder.h +# rrtype-placeholder.h diff --git a/src/lib/dns/dnssectime.cc b/src/lib/dns/dnssectime.cc index 71e42cf80f..856130aeac 100644 --- a/src/lib/dns/dnssectime.cc +++ b/src/lib/dns/dnssectime.cc @@ -18,21 +18,12 @@ #include <iomanip> #include <iostream> #include <sstream> -#include <vector> #include <stdio.h> #include <time.h> #include <exceptions/exceptions.h> -#include <dns/base64.h> -#include <dns/buffer.h> -#include <dns/messagerenderer.h> -#include <dns/name.h> -#include <dns/rrtype.h> -#include <dns/rrttl.h> -#include <dns/rdata.h> -#include <dns/rdataclass.h> #include <dns/dnssectime.h> using namespace std; diff --git a/src/lib/dns/name.cc b/src/lib/dns/name.cc index 03eb38249d..e27255526b 100644 --- a/src/lib/dns/name.cc +++ b/src/lib/dns/name.cc @@ -666,7 +666,7 @@ Name::reverse() const { } Name -Name::split(unsigned int first, unsigned int n) const { +Name::split(const unsigned int first, const unsigned int n) const { if (n == 0 || n > labelcount_ || first > labelcount_ - n) { isc_throw(OutOfRange, "Name::split: invalid split range"); } @@ -702,6 +702,16 @@ Name::split(unsigned int first, unsigned int n) const { return (retname); } +Name +Name::split(const unsigned level) const { + if (level >= getLabelCount()) { + isc_throw(OutOfRange, "invalid level for name split (" << level + << ") for name " << *this); + } + + return (split(level, getLabelCount() - level)); +} + Name& Name::downcase() { unsigned int nlen = length_; diff --git a/src/lib/dns/name.h b/src/lib/dns/name.h index 20e3b2af56..e73d6b8b58 100644 --- a/src/lib/dns/name.h +++ b/src/lib/dns/name.h @@ -84,7 +84,7 @@ public: /// /// \brief A standard DNS module exception that is thrown if the name parser -/// finds the input (string or wire-format data) is incomplete. +/// finds the input (string or wire-format %data) is incomplete. /// /// An attempt of constructing a name from an empty string will trigger this /// exception. @@ -168,13 +168,13 @@ private: /// /// The \c Name class encapsulates DNS names. /// -/// It provides interfaces to construct a name from string or wire-format data, -/// transform a name into a string or wire-format data, compare two names, get +/// It provides interfaces to construct a name from string or wire-format %data, +/// transform a name into a string or wire-format %data, compare two names, get /// access to various properties of a name, etc. /// -/// Notes to developers: Internally, a name object maintains the name data +/// Notes to developers: Internally, a name object maintains the name %data /// in wire format as an instance of \c std::string. Since many string -/// implementations adopt copy-on-write data sharing, we expect this approach +/// implementations adopt copy-on-write %data sharing, we expect this approach /// will make copying a name less expensive in typical cases. If this is /// found to be a significant performance bottleneck later, we may reconsider /// the internal representation or perhaps the API. @@ -187,9 +187,9 @@ private: /// included. In the BIND9 DNS library from which this implementation is /// derived, the offsets are optional, probably due to performance /// considerations (in fact, offsets can always be calculated from the name -/// data, and in that sense are redundant). In our implementation, however, +/// %data, and in that sense are redundant). In our implementation, however, /// we always build and maintain the offsets. We believe we need more low -/// level, specialized data structure and interface where we really need to +/// level, specialized %data structure and interface where we really need to /// pursue performance, and would rather keep this generic API and /// implementation simpler. /// @@ -233,21 +233,21 @@ public: /// \param namestr A string representation of the name to be constructed. /// \param downcase Whether to convert upper case alphabets to lower case. explicit Name(const std::string& namestr, bool downcase = false); - /// Constructor from wire-format data. + /// Constructor from wire-format %data. /// /// The \c buffer parameter normally stores a complete DNS message /// containing the name to be constructed. The current read position of /// the buffer points to the head of the name. /// - /// The input data may or may not be compressed; if it's compressed, this + /// The input %data may or may not be compressed; if it's compressed, this /// method will automatically decompress it. /// - /// If the given data does not represent a valid DNS name, an exception + /// If the given %data does not represent a valid DNS name, an exception /// of class \c DNSMessageFORMERR will be thrown. /// In addition, if resource allocation for the new name fails, a /// corresponding standard exception will be thrown. /// - /// \param buffer A buffer storing the wire format data. + /// \param buffer A buffer storing the wire format %data. /// \param downcase Whether to convert upper case alphabets to lower case. explicit Name(InputBuffer& buffer, bool downcase = false); /// @@ -260,35 +260,35 @@ public: /// \name Getter Methods /// //@{ - /// \brief Provides one-byte name data in wire format at the specified + /// \brief Provides one-byte name %data in wire format at the specified /// position. /// /// This method returns the unsigned 8-bit value of wire-format \c Name - /// data at the given position from the head. + /// %data at the given position from the head. /// /// For example, if \c n is a \c Name object for "example.com", /// \c n.at(3) would return \c 'a', and \c n.at(7) would return \c 'e'. /// Note that \c n.at(0) would be 7 (decimal), the label length of - /// "example", instead of \c 'e', because it returns a data portion + /// "example", instead of \c 'e', because it returns a %data portion /// in wire-format. Likewise, \c n.at(8) would return 3 (decimal) /// instead of <code>'.'</code> /// /// This method would be useful for an application to examine the - /// wire-format name data without dumping the data into a buffer, - /// which would involve data copies and would be less efficient. + /// wire-format name %data without dumping the %data into a buffer, + /// which would involve %data copies and would be less efficient. /// One common usage of this method would be something like this: /// \code for (size_t i = 0; i < name.getLength(); ++i) { /// uint8_t c = name.at(i); /// // do something with c /// } \endcode /// - /// Parameter \c pos must be in the valid range of the name data, that is, + /// Parameter \c pos must be in the valid range of the name %data, that is, /// must be less than \c Name.getLength(). Otherwise, an exception of /// class \c OutOfRange will be thrown. /// This method never throws an exception in other ways. /// - /// \param pos The position in the wire format name data to be returned. - /// \return An unsigned 8-bit integer corresponding to the name data + /// \param pos The position in the wire format name %data to be returned. + /// \return An unsigned 8-bit integer corresponding to the name %data /// at the position of \c pos. uint8_t at(size_t pos) const { @@ -360,7 +360,7 @@ public: /// <code>buffer.getCapacity() - buffer.getLength() >= Name::MAX_WIRE</code> /// then this method should not throw an exception. /// - /// \param buffer An output buffer to store the wire data. + /// \param buffer An output buffer to store the wire %data. void toWire(OutputBuffer& buffer) const; //@} @@ -502,6 +502,72 @@ public: /// labels including and following the <code>first</code> label. Name split(unsigned int first, unsigned int n) const; + /// \brief Extract a specified super domain name of Name. + /// + /// This function constructs a new \c Name object that is a super domain + /// of \c this name. + /// The new name is \c level labels upper than \c this name. + /// For example, when \c name is www.example.com, + /// <code>name.split(1)</code> will return a \c Name object for example.com. + /// \c level can be 0, in which case this method returns a copy of + /// \c this name. + /// The possible maximum value for \c level is + /// <code>this->getLabelCount()-1</code>, in which case this method + /// returns a root name. + /// + /// One common expected usage of this method is to iterate over super + /// domains of a given name, label by label, as shown in the following + /// sample code: + /// \code // if name is www.example.com... + /// for (int i = 0; i < name.getLabelCount(); ++i) { + /// Name upper_name(name.split(i)); + /// // upper_name'll be www.example.com., example.com., com., and then . + /// } + /// \endcode + /// + /// \c level must be smaller than the number of labels of \c this name; + /// otherwise an exception of class \c OutOfRange will be thrown. + /// In addition, if resource allocation for the new name fails, a + /// corresponding standard exception will be thrown. + /// + /// Note to developers: probably as easily imagined, this method is a + /// simple wrapper to one usage of the other + /// <code>split(unsigned int, unsigned int) const</code> method and is + /// redundant in some sense. + /// We provide the "redundant" method for convenience, however, because + /// the expected usage shown above seems to be common, and the parameters + /// to the other \c split(unsigned int, unsigned int) const to implement + /// it may not be very intuitive. + /// + /// We are also aware that it is generally discouraged to add a public + /// member function that could be implemented using other member functions. + /// We considered making it a non member function, but we could not come + /// up with an intuitive function name to represent the specific service. + /// Some other BIND 10 developers argued, probably partly because of the + /// counter intuitive function name, a different signature of \c split + /// would be better to improve code readability. + /// While that may be a matter of personal preference, we accepted the + /// argument. One major goal of public APIs like this is wider acceptance + /// from internal/external developers, so unless there is a clear advantage + /// it would be better to respect the preference of the API users. + /// + /// Since this method doesn't have to be a member function in other way, + /// it is intentionally implemented only using public interfaces of the + /// \c Name class; it doesn't refer to private members of the class even if + /// it could. + /// This way we hope we can avoid damaging the class encapsulation, + /// which is a major drawback of public member functions. + /// As such if and when this "method" has to be extended, it should be + /// implemented without the privilege of being a member function unless + /// there is a very strong reason to do so. In particular a minor + /// performance advantage shouldn't justify that approach. + /// + /// \param level The number of labels to be removed from \c this name to + /// create the super domain name. + /// (0 <= \c level < <code>this->getLabelCount()</code>) + /// \return A new \c Name object to be created. + Name split(unsigned int level) const; + /// \brief Reverse the labels of a name /// /// This method reverses the labels of a name. For example, if diff --git a/src/lib/dns/rdata/generic/nsec_47.cc b/src/lib/dns/rdata/generic/nsec_47.cc index 2e52ae64bb..2cd94c94a9 100644 --- a/src/lib/dns/rdata/generic/nsec_47.cc +++ b/src/lib/dns/rdata/generic/nsec_47.cc @@ -112,7 +112,7 @@ NSEC::NSEC(InputBuffer& buffer, size_t rdata_len) { for (int i = 0; i < rdata_len; i += len) { if (i + 2 > rdata_len) { isc_throw(DNSMessageFORMERR, "NSEC RDATA from wire: " - "incomplete bit map filed"); + "incomplete bit map field"); } block = typebits[i]; len = typebits[i + 1]; @@ -182,7 +182,7 @@ NSEC::toText() const { assert(i + 2 <= impl_->typebits_.size()); const int block = impl_->typebits_.at(i); len = impl_->typebits_.at(i + 1); - assert(len >= 0 && len <= 32); + assert(len > 0 && len <= 32); i += 2; for (int j = 0; j < len; j++) { if (impl_->typebits_.at(i + j) == 0) { diff --git a/src/lib/dns/tests/Makefile.am b/src/lib/dns/tests/Makefile.am index 82c111feb9..29a174755b 100644 --- a/src/lib/dns/tests/Makefile.am +++ b/src/lib/dns/tests/Makefile.am @@ -1,6 +1,7 @@ AM_CPPFLAGS = -I$(top_builddir)/src/lib -I$(top_srcdir)/src/lib AM_CPPFLAGS += -I$(top_srcdir)/src/lib/dns -I$(top_builddir)/src/lib/dns AM_CPPFLAGS += -DTEST_DATA_DIR=\"$(srcdir)/testdata\" +AM_CXXFLAGS = $(B10_CXXFLAGS) CLEANFILES = *.gcno *.gcda diff --git a/src/lib/dns/tests/dnssectime_unittest.cc b/src/lib/dns/tests/dnssectime_unittest.cc index ae760102e4..68b4f8515c 100644 --- a/src/lib/dns/tests/dnssectime_unittest.cc +++ b/src/lib/dns/tests/dnssectime_unittest.cc @@ -64,7 +64,8 @@ TEST(DNSSECTimeTest, toText) { TEST(DNSSECTimeTest, overflow) { // Jan 1, Year 10,000. if (sizeof(time_t) > 4) { - EXPECT_THROW(timeToText(253402300800LL), InvalidTime); + EXPECT_THROW(timeToText(static_cast<time_t>(253402300800LL)), + InvalidTime); } } diff --git a/src/lib/dns/tests/name_unittest.cc b/src/lib/dns/tests/name_unittest.cc index 8a7cf1ce7e..f5dc84e113 100644 --- a/src/lib/dns/tests/name_unittest.cc +++ b/src/lib/dns/tests/name_unittest.cc @@ -500,6 +500,18 @@ TEST_F(NameTest, split) { OutOfRange); } +TEST_F(NameTest, split_for_suffix) { + EXPECT_PRED_FORMAT2(UnitTestUtil::matchName, example_name.split(1), + Name("example.com")); + EXPECT_PRED_FORMAT2(UnitTestUtil::matchName, example_name.split(0), + example_name); + EXPECT_PRED_FORMAT2(UnitTestUtil::matchName, example_name.split(3), + Name(".")); + + // Invalid case: the level must be less than the original label count. + EXPECT_THROW(example_name.split(4), OutOfRange); +} + TEST_F(NameTest, downcase) { // usual case: all-upper case name to all-lower case compareInWireFormat(example_name_upper.downcase(), example_name); diff --git a/src/lib/exceptions/Makefile.am b/src/lib/exceptions/Makefile.am index ea8136cd83..cc5158dd19 100644 --- a/src/lib/exceptions/Makefile.am +++ b/src/lib/exceptions/Makefile.am @@ -1,3 +1,4 @@ +AM_CXXFLAGS=$(B10_CXXFLAGS) lib_LTLIBRARIES = libexceptions.la libexceptions_la_SOURCES = exceptions.h exceptions.cc @@ -15,3 +16,6 @@ run_unittests_LDADD = .libs/libexceptions.a $(GTEST_LDADD) endif noinst_PROGRAMS = $(TESTS) + +libexceptions_includedir = $(includedir)/exceptions +libexceptions_include_HEADERS = exceptions.h diff --git a/src/lib/python/Makefile.am b/src/lib/python/Makefile.am index df8fa15425..f7eb3335d0 100644 --- a/src/lib/python/Makefile.am +++ b/src/lib/python/Makefile.am @@ -1 +1,12 @@ SUBDIRS = isc + +python_PYTHON = bind10_config.py + +# Explicitly define DIST_COMMON so ${python_PYTHON} is not included +# as we don't want the generated file included in distributed tarfile. +DIST_COMMON = $(srcdir)/Makefile.am $(srcdir)/Makefile.in bind10_config.py.in + +# When setting DIST_COMMON, then need to add the .in file too. +EXTRA_DIST = bind10_config.py.in + +CLEANFILES = bind10_config.pyc diff --git a/src/lib/python/bind10_config.py.in b/src/lib/python/bind10_config.py.in new file mode 100644 index 0000000000..3f2947d7cb --- /dev/null +++ b/src/lib/python/bind10_config.py.in @@ -0,0 +1,23 @@ +# Copyright (C) 2010 Internet Systems Consortium. +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM +# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL +# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT, +# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING +# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION +# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +# This is a base-level module intended to provide configure-time +# variables to python scripts and libraries. +import os + +BIND10_MSGQ_SOCKET_FILE = os.path.join("@localstatedir@", + "@PACKAGE_NAME@", + "msgq_socket").replace("${prefix}", + "@prefix@") diff --git a/src/lib/python/isc/cc/message.py b/src/lib/python/isc/cc/message.py index 07c884b1d0..800365ff32 100644 --- a/src/lib/python/isc/cc/message.py +++ b/src/lib/python/isc/cc/message.py @@ -26,6 +26,7 @@ _ITEM_LIST = 0x03 _ITEM_NULL = 0x04 _ITEM_BOOL = 0x05 _ITEM_INT = 0x06 +_ITEM_REAL = 0x07 _ITEM_UTF8 = 0x08 _ITEM_MASK = 0x0f @@ -77,6 +78,10 @@ def _pack_int(item): """Pack an integer and its type/length prefix.""" return (_encode_length_and_type(bytes(str(item), 'utf-8'), _ITEM_INT)) +def _pack_real(item): + """Pack an integer and its type/length prefix.""" + return (_encode_length_and_type(bytes(str(item), 'utf-8'), _ITEM_REAL)) + def _pack_array(item): """Pack a list (array) and its type/length prefix.""" return (_encode_length_and_type(_encode_array(item), _ITEM_LIST)) @@ -98,6 +103,8 @@ def _encode_item(item): return (_pack_bool(item)) elif type(item) == int: return (_pack_int(item)) + elif type(item) == float: + return (_pack_real(item)) elif type(item) == dict: return (_pack_hash(item)) elif type(item) == list: @@ -186,6 +193,8 @@ def _decode_item(data): value = _decode_bool(item) elif item_type == _ITEM_INT: value = _decode_int(item) + elif item_type == _ITEM_REAL: + value = _decode_real(item) elif item_type == _ITEM_UTF8: value = str(item, 'utf-8') elif item_type == _ITEM_HASH: @@ -205,6 +214,9 @@ def _decode_bool(data): def _decode_int(data): return int(str(data, 'utf-8')) +def _decode_real(data): + return float(str(data, 'utf-8')) + def _decode_hash(data): ret = {} while len(data) > 0: diff --git a/src/lib/python/isc/cc/session.py b/src/lib/python/isc/cc/session.py index 119a384563..30c7be6195 100644 --- a/src/lib/python/isc/cc/session.py +++ b/src/lib/python/isc/cc/session.py @@ -17,6 +17,8 @@ import sys import socket import struct import os +import threading +import bind10_config import isc.cc.message @@ -25,7 +27,7 @@ class NetworkError(Exception): pass class SessionError(Exception): pass class Session: - def __init__(self, port=0): + def __init__(self, socket_file=None): self._socket = None self._lname = None self._recvbuffer = bytearray() @@ -33,17 +35,20 @@ class Session: self._sequence = 1 self._closed = False self._queue = [] + self._lock = threading.RLock() - if port == 0: - if 'ISC_MSGQ_PORT' in os.environ: - port = int(os.environ["ISC_MSGQ_PORT"]) - else: - port = 9912 + if socket_file is None: + if "BIND10_MSGQ_SOCKET_FILE" in os.environ: + self.socket_file = os.environ["BIND10_MSGQ_SOCKET_FILE"] + else: + self.socket_file = bind10_config.BIND10_MSGQ_SOCKET_FILE + else: + self.socket_file = socket_file + try: - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.connect(tuple(['127.0.0.1', port])) - + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._socket.connect(self.socket_file) self.sendmsg({ "type": "getlname" }) env, msg = self.recvmsg(False) if not env: @@ -64,56 +69,54 @@ class Session: self._closed = True def sendmsg(self, env, msg = None): - XXmsg = msg - XXenv = env - if self._closed: - raise SessionError("Session has been closed.") - if type(env) == dict: - env = isc.cc.message.to_wire(env) - if type(msg) == dict: - msg = isc.cc.message.to_wire(msg) - self._socket.setblocking(1) - length = 2 + len(env); - if msg: - length += len(msg) - self._socket.send(struct.pack("!I", length)) - self._socket.send(struct.pack("!H", len(env))) - self._socket.send(env) - if msg: - self._socket.send(msg) + with self._lock: + if self._closed: + raise SessionError("Session has been closed.") + if type(env) == dict: + env = isc.cc.message.to_wire(env) + if type(msg) == dict: + msg = isc.cc.message.to_wire(msg) + self._socket.setblocking(1) + length = 2 + len(env); + if msg: + length += len(msg) + self._socket.send(struct.pack("!I", length)) + self._socket.send(struct.pack("!H", len(env))) + self._socket.send(env) + if msg: + self._socket.send(msg) def recvmsg(self, nonblock = True, seq = None): - #print("[XX] queue len: " + str(len(self._queue))) - if len(self._queue) > 0: - if seq == None: - #print("[XX] return first") - return self._queue.pop(0) - else: + with self._lock: + if len(self._queue) > 0: i = 0; - #print("[XX] check rest") for env, msg in self._queue: - if "reply" in env and seq == env["reply"]: + if seq != None and "reply" in env and seq == env["reply"]: + return self._queue.pop(i) + elif seq == None and "reply" not in env: return self._queue.pop(i) else: i = i + 1 - #print("[XX] not found") - if self._closed: - raise SessionError("Session has been closed.") - data = self._receive_full_buffer(nonblock) - if data and len(data) > 2: - header_length = struct.unpack('>H', data[0:2])[0] - data_length = len(data) - 2 - header_length - if data_length > 0: - env = isc.cc.message.from_wire(data[2:header_length+2]) - msg = isc.cc.message.from_wire(data[header_length + 2:]) - if seq == None or "reply" in env and seq == env["reply"]: - return env, msg + if self._closed: + raise SessionError("Session has been closed.") + data = self._receive_full_buffer(nonblock) + if data and len(data) > 2: + header_length = struct.unpack('>H', data[0:2])[0] + data_length = len(data) - 2 - header_length + if data_length > 0: + env = isc.cc.message.from_wire(data[2:header_length+2]) + msg = isc.cc.message.from_wire(data[header_length + 2:]) + if (seq == None and "reply" not in env) or (seq != None and "reply" in env and seq == env["reply"]): + return env, msg + else: + tmp = None + if "reply" in env: + tmp = env["reply"] + self._queue.append((env,msg)) + return self.recvmsg(nonblock, seq) else: - self._queue.append((env,msg)) - return self.recvmsg(nonblock, seq) - else: - return isc.cc.message.from_wire(data[2:header_length+2]), None - return None, None + return isc.cc.message.from_wire(data[2:header_length+2]), None + return None, None def _receive_full_buffer(self, nonblock): if nonblock: @@ -130,7 +133,6 @@ class Session: return None if data == "": # server closed connection raise ProtocolError("Read of 0 bytes: connection closed") - self._recvbuffer += data if len(self._recvbuffer) < 4: return None @@ -182,6 +184,9 @@ class Session: }, isc.cc.message.to_wire(msg)) return seq + def has_queued_msgs(self): + return len(self._queue) > 0 + def group_recvmsg(self, nonblock = True, seq = None): env, msg = self.recvmsg(nonblock, seq) if env == None: diff --git a/src/lib/python/isc/cc/tests/Makefile.am b/src/lib/python/isc/cc/tests/Makefile.am index c46ed3fe6f..0828520ba2 100644 --- a/src/lib/python/isc/cc/tests/Makefile.am +++ b/src/lib/python/isc/cc/tests/Makefile.am @@ -10,6 +10,6 @@ PYCOVERAGE = $(PYTHON) check-local: for pytest in $(PYTESTS) ; do \ echo Running test: $$pytest ; \ - env PYTHONPATH=$(abs_top_srcdir)/src/lib/python \ + env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python \ $(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \ done diff --git a/src/lib/python/isc/cc/tests/session_test.py b/src/lib/python/isc/cc/tests/session_test.py index 413d1d3911..284fe5d78f 100644 --- a/src/lib/python/isc/cc/tests/session_test.py +++ b/src/lib/python/isc/cc/tests/session_test.py @@ -98,6 +98,7 @@ class MySession(Session): self._sequence = 1 self._closed = False self._queue = [] + self._lock = threading.RLock() try: self._socket = MySocket(socket.AF_INET, socket.SOCK_STREAM) @@ -178,66 +179,78 @@ class testSession(unittest.TestCase): # sending message {'to': 'someone', 'reply': 1}, {"hello": "a"} #print("sending message {'to': 'someone', 'reply': 1}, {'hello': 'a'}") - # simply get the message without asking for a specific sequence number reply + # get no message without asking for a specific sequence number reply + self.assertFalse(sess.has_queued_msgs()) sess._socket.addrecv(b'\x00\x00\x00(\x00\x19Skan\x02to(\x07someone\x05reply&\x011Skan\x05hello(\x01a') env, msg = sess.recvmsg(False) - self.assertEqual({'to': 'someone', 'reply': 1}, env) - self.assertEqual({"hello": "a"}, msg) - - # simply get the message, asking for a specific sequence number reply - sess._socket.addrecv(b'\x00\x00\x00(\x00\x19Skan\x02to(\x07someone\x05reply&\x011Skan\x05hello(\x01a') + self.assertEqual(None, env) + self.assertTrue(sess.has_queued_msgs()) env, msg = sess.recvmsg(False, 1) self.assertEqual({'to': 'someone', 'reply': 1}, env) self.assertEqual({"hello": "a"}, msg) + self.assertFalse(sess.has_queued_msgs()) # ask for a differe sequence number reply (that doesn't exist) # then ask for the one that is there + self.assertFalse(sess.has_queued_msgs()) sess._socket.addrecv(b'\x00\x00\x00(\x00\x19Skan\x02to(\x07someone\x05reply&\x011Skan\x05hello(\x01a') env, msg = sess.recvmsg(False, 2) self.assertEqual(None, env) self.assertEqual(None, msg) + self.assertTrue(sess.has_queued_msgs()) env, msg = sess.recvmsg(False, 1) self.assertEqual({'to': 'someone', 'reply': 1}, env) self.assertEqual({"hello": "a"}, msg) + self.assertFalse(sess.has_queued_msgs()) # ask for a differe sequence number reply (that doesn't exist) # then ask for any message + self.assertFalse(sess.has_queued_msgs()) sess._socket.addrecv(b'\x00\x00\x00(\x00\x19Skan\x02to(\x07someone\x05reply&\x011Skan\x05hello(\x01a') env, msg = sess.recvmsg(False, 2) self.assertEqual(None, env) self.assertEqual(None, msg) - env, msg = sess.recvmsg(False) + self.assertTrue(sess.has_queued_msgs()) + env, msg = sess.recvmsg(False, 1) self.assertEqual({'to': 'someone', 'reply': 1}, env) self.assertEqual({"hello": "a"}, msg) + self.assertFalse(sess.has_queued_msgs()) #print("sending message {'to': 'someone', 'reply': 1}, {'hello': 'a'}") # ask for a differe sequence number reply (that doesn't exist) - # send a new message, ask for any message (get the first) + # send a new message, ask for specific message (get the first) # then ask for any message (get the second) + self.assertFalse(sess.has_queued_msgs()) sess._socket.addrecv(b'\x00\x00\x00(\x00\x19Skan\x02to(\x07someone\x05reply&\x011Skan\x05hello(\x01a') env, msg = sess.recvmsg(False, 2) self.assertEqual(None, env) self.assertEqual(None, msg) + self.assertTrue(sess.has_queued_msgs()) sess._socket.addrecv(b'\x00\x00\x00\x1f\x00\x10Skan\x02to(\x07someoneSkan\x05hello(\x01b') - env, msg = sess.recvmsg(False) - self.assertEqual({'to': 'someone', 'reply': 1}, env) + env, msg = sess.recvmsg(False, 1) + self.assertEqual({'to': 'someone', 'reply': 1 }, env) self.assertEqual({"hello": "a"}, msg) + self.assertFalse(sess.has_queued_msgs()) env, msg = sess.recvmsg(False) self.assertEqual({'to': 'someone'}, env) self.assertEqual({"hello": "b"}, msg) + self.assertFalse(sess.has_queued_msgs()) # send a message, then one with specific reply value # ask for that specific message (get the second) # then ask for any message (get the first) + self.assertFalse(sess.has_queued_msgs()) sess._socket.addrecv(b'\x00\x00\x00\x1f\x00\x10Skan\x02to(\x07someoneSkan\x05hello(\x01b') sess._socket.addrecv(b'\x00\x00\x00(\x00\x19Skan\x02to(\x07someone\x05reply&\x011Skan\x05hello(\x01a') env, msg = sess.recvmsg(False, 1) self.assertEqual({'to': 'someone', 'reply': 1}, env) self.assertEqual({"hello": "a"}, msg) + self.assertTrue(sess.has_queued_msgs()) env, msg = sess.recvmsg(False) self.assertEqual({'to': 'someone'}, env) self.assertEqual({"hello": "b"}, msg) + self.assertFalse(sess.has_queued_msgs()) def test_next_sequence(self): sess = MySession() diff --git a/src/lib/python/isc/config/ccsession.py b/src/lib/python/isc/config/ccsession.py index 5bce857378..9128286ed6 100644 --- a/src/lib/python/isc/config/ccsession.py +++ b/src/lib/python/isc/config/ccsession.py @@ -81,8 +81,7 @@ def create_answer(rcode, arg = None): # 'fixed' commands """Fixed names for command and configuration messages""" COMMAND_CONFIG_UPDATE = "config_update" -COMMAND_COMMANDS_UPDATE = "commands_update" -COMMAND_SPECIFICATION_UPDATE = "specification_update" +COMMAND_MODULE_SPECIFICATION_UPDATE = "module_specification_update" COMMAND_GET_COMMANDS_SPEC = "get_commands_spec" COMMAND_GET_CONFIG = "get_config" @@ -314,16 +313,9 @@ class UIModuleCCSession(MultiConfigData): # this step should be unnecessary but is the current way cmdctl returns stuff # so changes are needed there to make this clean (we need a command to simply get the # full specs for everything, including commands etc, not separate gets for that) - specs = self._conn.send_GET('/config_spec') - commands = self._conn.send_GET('/commands') + specs = self._conn.send_GET('/module_spec') for module in specs.keys(): - cur_spec = { 'module_name': module } - if module in specs and specs[module]: - cur_spec['config_data'] = specs[module] - if module in commands and commands[module]: - cur_spec['commands'] = commands[module] - - self.set_specification(isc.config.ModuleSpec(cur_spec)) + self.set_specification(isc.config.ModuleSpec(specs[module])) def request_current_config(self): """Requests the current configuration from the configuration diff --git a/src/lib/python/isc/config/cfgmgr.py b/src/lib/python/isc/config/cfgmgr.py index e7dce0fd4e..78c9193aca 100644 --- a/src/lib/python/isc/config/cfgmgr.py +++ b/src/lib/python/isc/config/cfgmgr.py @@ -25,7 +25,9 @@ import ast import pprint import os import copy +import tempfile from isc.cc import data +from isc.config import ccsession class ConfigManagerDataReadError(Exception): """This exception is thrown when there is an error while reading @@ -84,24 +86,35 @@ class ConfigManagerData: """Writes the current configuration data to a file. If output_file_name is not specified, the file used in read_from_file is used.""" + filename = None try: - tmp_filename = self.db_filename + ".tmp" - file = open(tmp_filename, 'w'); + file = tempfile.NamedTemporaryFile(mode='w', + prefix="b10-config.db.", + dir=self.data_path, + delete=False) + filename = file.name pp = pprint.PrettyPrinter(indent=4) s = pp.pformat(self.data) file.write(s) file.write("\n") file.close() if output_file_name: - os.rename(tmp_filename, output_file_name) + os.rename(filename, output_file_name) else: - os.rename(tmp_filename, self.db_filename) + os.rename(filename, self.db_filename) except IOError as ioe: # TODO: log this (level critical) print("[b10-cfgmgr] Unable to write config file; configuration not stored: " + str(ioe)) + # TODO: debug option to keep file? except OSError as ose: # TODO: log this (level critical) print("[b10-cfgmgr] Unable to write config file; configuration not stored: " + str(ose)) + try: + if filename and os.path.exists(filename): + os.remove(filename) + except OSError: + # Ok if we really can't delete it anymore, leave it + pass def __eq__(self, other): """Returns True if the data contained is equal. data_path and @@ -148,11 +161,23 @@ class ConfigManager: if module_name in self.module_specs: del self.module_specs[module_name] - def get_module_spec(self, module_name): + def get_module_spec(self, module_name = None): """Returns the full ModuleSpec for the module with the given - module_name""" - if module_name in self.module_specs: - return self.module_specs[module_name] + module_name. If no module name is given, a dict will + be returned with 'name': module_spec values. If the + module name is given, but does not exist, an empty dict + is returned""" + if module_name: + if module_name in self.module_specs: + return self.module_specs[module_name] + else: + # TODO: log error? + return {} + else: + result = {} + for module in self.module_specs: + result[module] = self.module_specs[module].get_full_spec() + return result def get_config_spec(self, name = None): """Returns a dict containing 'module_name': config_spec for @@ -201,95 +226,112 @@ class ConfigManager: if type(cmd) == dict: if 'module_name' in cmd and cmd['module_name'] != '': module_name = cmd['module_name'] - answer = isc.config.ccsession.create_answer(0, self.get_config_spec(module_name)) + answer = ccsession.create_answer(0, self.get_module_spec(module_name)) else: - answer = isc.config.ccsession.create_answer(1, "Bad module_name in get_module_spec command") + answer = ccsession.create_answer(1, "Bad module_name in get_module_spec command") else: - answer = isc.config.ccsession.create_answer(1, "Bad get_module_spec command, argument not a dict") + answer = ccsession.create_answer(1, "Bad get_module_spec command, argument not a dict") else: - answer = isc.config.ccsession.create_answer(0, self.get_config_spec()) + answer = ccsession.create_answer(0, self.get_module_spec()) return answer + def _handle_get_config_dict(self, cmd): + """Private function that handles the 'get_config' command + where the command has been checked to be a dict""" + if 'module_name' in cmd and cmd['module_name'] != '': + module_name = cmd['module_name'] + try: + return ccsession.create_answer(0, data.find(self.config.data, module_name)) + except data.DataNotFoundError as dnfe: + # no data is ok, that means we have nothing that + # deviates from default values + return ccsession.create_answer(0, { 'version': self.config.CONFIG_VERSION }) + else: + return ccsession.create_answer(1, "Bad module_name in get_config command") + def _handle_get_config(self, cmd): """Private function that handles the 'get_config' command""" - answer = {} if cmd != None: if type(cmd) == dict: - if 'module_name' in cmd and cmd['module_name'] != '': - module_name = cmd['module_name'] - try: - answer = isc.config.ccsession.create_answer(0, data.find(self.config.data, module_name)) - except data.DataNotFoundError as dnfe: - # no data is ok, that means we have nothing that - # deviates from default values - answer = isc.config.ccsession.create_answer(0, { 'version': self.config.CONFIG_VERSION }) - else: - answer = isc.config.ccsession.create_answer(1, "Bad module_name in get_config command") + return self._handle_get_config_dict(cmd) else: - answer = isc.config.ccsession.create_answer(1, "Bad get_config command, argument not a dict") + return ccsession.create_answer(1, "Bad get_config command, argument not a dict") + else: + return ccsession.create_answer(0, self.config.data) + + def _handle_set_config_module(self, cmd): + # the answer comes (or does not come) from the relevant module + # so we need a variable to see if we got it + answer = None + # todo: use api (and check the data against the definition?) + old_data = copy.deepcopy(self.config.data) + module_name = cmd[0] + conf_part = data.find_no_exc(self.config.data, module_name) + if conf_part: + data.merge(conf_part, cmd[1]) + update_cmd = ccsession.create_command(ccsession.COMMAND_CONFIG_UPDATE, + conf_part) + seq = self.cc.group_sendmsg(update_cmd, module_name) + answer, env = self.cc.group_recvmsg(False, seq) else: - answer = isc.config.ccsession.create_answer(0, self.config.data) + conf_part = data.set(self.config.data, module_name, {}) + data.merge(conf_part[module_name], cmd[1]) + # send out changed info + update_cmd = ccsession.create_command(ccsession.COMMAND_CONFIG_UPDATE, + conf_part[module_name]) + seq = self.cc.group_sendmsg(update_cmd, module_name) + # replace 'our' answer with that of the module + answer, env = self.cc.group_recvmsg(False, seq) + if answer: + rcode, val = ccsession.parse_answer(answer) + if rcode == 0: + self.write_config() + else: + self.config.data = old_data return answer + def _handle_set_config_all(self, cmd): + old_data = copy.deepcopy(self.config.data) + data.merge(self.config.data, cmd[0]) + # send out changed info + got_error = False + err_list = [] + for module in self.config.data: + if module != "version" and \ + (module not in old_data or self.config.data[module] != old_data[module]): + update_cmd = ccsession.create_command(ccsession.COMMAND_CONFIG_UPDATE, + self.config.data[module]) + seq = self.cc.group_sendmsg(update_cmd, module) + answer, env = self.cc.group_recvmsg(False, seq) + if answer == None: + got_error = True + err_list.append("No answer message from " + module) + else: + rcode, val = ccsession.parse_answer(answer) + if rcode != 0: + got_error = True + err_list.append(val) + if not got_error: + self.write_config() + return ccsession.create_answer(0) + else: + # TODO rollback changes that did get through, should we re-send update? + self.config.data = old_data + return ccsession.create_answer(1, " ".join(err_list)) + def _handle_set_config(self, cmd): """Private function that handles the 'set_config' command""" answer = None if cmd == None: - return isc.config.ccsession.create_answer(1, "Wrong number of arguments") + return ccsession.create_answer(1, "Wrong number of arguments") if len(cmd) == 2: - # todo: use api (and check the data against the definition?) - old_data = copy.deepcopy(self.config.data) - module_name = cmd[0] - conf_part = data.find_no_exc(self.config.data, module_name) - if conf_part: - data.merge(conf_part, cmd[1]) - update_cmd = isc.config.ccsession.create_command(isc.config.ccsession.COMMAND_CONFIG_UPDATE, conf_part) - seq = self.cc.group_sendmsg(update_cmd, module_name) - answer, env = self.cc.group_recvmsg(False, seq) - else: - conf_part = data.set(self.config.data, module_name, {}) - data.merge(conf_part[module_name], cmd[1]) - # send out changed info - update_cmd = isc.config.ccsession.create_command(isc.config.ccsession.COMMAND_CONFIG_UPDATE, conf_part[module_name]) - seq = self.cc.group_sendmsg(update_cmd, module_name) - # replace 'our' answer with that of the module - answer, env = self.cc.group_recvmsg(False, seq) - if answer: - rcode, val = isc.config.ccsession.parse_answer(answer) - if rcode == 0: - self.write_config() - else: - self.config.data = old_data + answer = self._handle_set_config_module(cmd) elif len(cmd) == 1: - old_data = copy.deepcopy(self.config.data) - data.merge(self.config.data, cmd[0]) - # send out changed info - got_error = False - err_list = [] - for module in self.config.data: - if module != "version" and (module not in old_data or self.config.data[module] != old_data[module]): - update_cmd = isc.config.ccsession.create_command(isc.config.ccsession.COMMAND_CONFIG_UPDATE, self.config.data[module]) - seq = self.cc.group_sendmsg(update_cmd, module) - answer, env = self.cc.group_recvmsg(False, seq) - if answer == None: - got_error = True - err_list.append("No answer message from " + module) - else: - rcode, val = isc.config.ccsession.parse_answer(answer) - if rcode != 0: - got_error = True - err_list.append(val) - if not got_error: - self.write_config() - answer = isc.config.ccsession.create_answer(0) - else: - # TODO rollback changes that did get through, should we re-send update? - self.config.data = old_data - answer = isc.config.ccsession.create_answer(1, " ".join(err_list)) + answer = self._handle_set_config_all(cmd) else: - answer = isc.config.ccsession.create_answer(1, "Wrong number of arguments") + answer = ccsession.create_answer(1, "Wrong number of arguments") if not answer: - answer = isc.config.ccsession.create_answer(1, "No answer message from " + cmd[0]) + answer = ccsession.create_answer(1, "No answer message from " + cmd[0]) return answer @@ -303,42 +345,38 @@ class ConfigManager: # We should make one general 'spec update for module' that # passes both specification and commands at once - spec_update = isc.config.ccsession.create_command(isc.config.ccsession.COMMAND_SPECIFICATION_UPDATE, - [ spec.get_module_name(), spec.get_config_spec() ]) + spec_update = ccsession.create_command(ccsession.COMMAND_MODULE_SPECIFICATION_UPDATE, + [ spec.get_module_name(), spec.get_full_spec() ]) self.cc.group_sendmsg(spec_update, "Cmd-Ctrld") - cmds_update = isc.config.ccsession.create_command(isc.config.ccsession.COMMAND_COMMANDS_UPDATE, - [ spec.get_module_name(), spec.get_commands_spec() ]) - self.cc.group_sendmsg(cmds_update, "Cmd-Ctrld") - answer = isc.config.ccsession.create_answer(0) - return answer + return ccsession.create_answer(0) def handle_msg(self, msg): """Handle a command from the cc channel to the configuration manager""" answer = {} - cmd, arg = isc.config.ccsession.parse_command(msg) + cmd, arg = ccsession.parse_command(msg) if cmd: - if cmd == isc.config.ccsession.COMMAND_GET_COMMANDS_SPEC: - answer = isc.config.ccsession.create_answer(0, self.get_commands_spec()) - elif cmd == isc.config.ccsession.COMMAND_GET_MODULE_SPEC: + if cmd == ccsession.COMMAND_GET_COMMANDS_SPEC: + answer = ccsession.create_answer(0, self.get_commands_spec()) + elif cmd == ccsession.COMMAND_GET_MODULE_SPEC: answer = self._handle_get_module_spec(arg) - elif cmd == isc.config.ccsession.COMMAND_GET_CONFIG: + elif cmd == ccsession.COMMAND_GET_CONFIG: answer = self._handle_get_config(arg) - elif cmd == isc.config.ccsession.COMMAND_SET_CONFIG: + elif cmd == ccsession.COMMAND_SET_CONFIG: answer = self._handle_set_config(arg) - elif cmd == isc.config.ccsession.COMMAND_SHUTDOWN: + elif cmd == ccsession.COMMAND_SHUTDOWN: # TODO: logging #print("[b10-cfgmgr] Received shutdown command") self.running = False - answer = isc.config.ccsession.create_answer(0) - elif cmd == isc.config.ccsession.COMMAND_MODULE_SPEC: + answer = ccsession.create_answer(0) + elif cmd == ccsession.COMMAND_MODULE_SPEC: try: answer = self._handle_module_spec(isc.config.ModuleSpec(arg)) except isc.config.ModuleSpecError as dde: - answer = isc.config.ccsession.create_answer(1, "Error in data definition: " + str(dde)) + answer = ccsession.create_answer(1, "Error in data definition: " + str(dde)) else: - answer = isc.config.ccsession.create_answer(1, "Unknown command: " + str(cmd)) + answer = ccsession.create_answer(1, "Unknown command: " + str(cmd)) else: - answer = isc.config.ccsession.create_answer(1, "Unknown message format: " + str(msg)) + answer = ccsession.create_answer(1, "Unknown message format: " + str(msg)) return answer def run(self): diff --git a/src/lib/python/isc/config/config_data.py b/src/lib/python/isc/config/config_data.py index f6ed2cd9c6..f18980158c 100644 --- a/src/lib/python/isc/config/config_data.py +++ b/src/lib/python/isc/config/config_data.py @@ -53,6 +53,51 @@ def check_type(spec_part, value): # todo: check types of map contents too raise isc.cc.data.DataTypeError(str(value) + " is not a map") +def convert_type(spec_part, value): + """Convert the give value(type is string) according specification + part relevant for the value. Raises an isc.cc.data.DataTypeError + exception if conversion failed. + """ + if type(spec_part) == dict and 'item_type' in spec_part: + data_type = spec_part['item_type'] + else: + raise isc.cc.data.DataTypeError(str("Incorrect specification part for type convering")) + + try: + if data_type == "integer": + return int(value) + elif data_type == "real": + return float(value) + elif data_type == "boolean": + return str.lower(str(value)) != 'false' + elif data_type == "string": + return str(value) + elif data_type == "list": + ret = [] + if type(value) == list: + for item in value: + ret.append(convert_type(spec_part['list_item_spec'], item)) + elif type(value) == str: + value = value.split(',') + for item in value: + sub_value = item.split() + for sub_item in sub_value: + ret.append(convert_type(spec_part['list_item_spec'], sub_item)) + + if ret == []: + raise isc.cc.data.DataTypeError(str(value) + " is not a list") + + return ret + elif data_type == "map": + return dict(value) + # todo: check types of map contents too + else: + return value + except ValueError as err: + raise isc.cc.data.DataTypeError(str(err)) + except TypeError as err: + raise isc.cc.data.DataTypeError(str(err)) + def find_spec_part(element, identifier): """find the data definition for the given identifier returns either a map with 'item_name' etc, or a list of those""" @@ -382,7 +427,6 @@ class MultiConfigData: else: entry['default'] = False result.append(entry) - #print(spec) return result def set_value(self, identifier, value): diff --git a/src/lib/python/isc/config/module_spec.py b/src/lib/python/isc/config/module_spec.py index a512cb5e9b..8eaec1cb80 100644 --- a/src/lib/python/isc/config/module_spec.py +++ b/src/lib/python/isc/config/module_spec.py @@ -80,15 +80,47 @@ class ModuleSpec: return _validate_spec_list(data_def, full, data, errors) else: # no spec, always bad - errors.append("No config_data specification") + if errors != None: + errors.append("No config_data specification") return False + def validate_command(self, cmd_name, cmd_params, errors = None): + '''Check whether the given piece of command conforms to this + command definition. If so, it reutrns True. If not, it will + return False. If errors is given, and is an array, a string + describing the error will be appended to it. The current version + stops as soon as there is one error. + cmd_name is command name to be validated, cmd_params includes + command's parameters needs to be validated. cmd_params must + be a map, with the format like: + {param1_name: param1_value, param2_name: param2_value} + ''' + cmd_spec = self.get_commands_spec() + if not cmd_spec: + return False + + for cmd in cmd_spec: + if cmd['command_name'] != cmd_name: + continue + return _validate_spec_list(cmd['command_args'], True, cmd_params, errors) + + return False def get_module_name(self): """Returns a string containing the name of the module as - specified by the specification given at __init__""" + specified by the specification given at __init__()""" return self._module_spec['module_name'] + def get_module_description(self): + """Returns a string containing the description of the module as + specified by the specification given at __init__(). + Returns an empty string if there is no description. + """ + if 'module_description' in self._module_spec: + return self._module_spec['module_description'] + else: + return "" + def get_full_spec(self): """Returns a dict representation of the full module specification""" return self._module_spec @@ -123,6 +155,9 @@ def _check(module_spec): raise ModuleSpecError("data specification not a dict") if "module_name" not in module_spec: raise ModuleSpecError("no module_name in module_spec") + if "module_description" in module_spec and \ + type(module_spec["module_description"]) != str: + raise ModuleSpecError("module_description is not a string") if "config_data" in module_spec: _check_config_spec(module_spec["config_data"]) if "commands" in module_spec: diff --git a/src/lib/python/isc/config/tests/Makefile.am b/src/lib/python/isc/config/tests/Makefile.am index dc43747a9d..18d378b15b 100644 --- a/src/lib/python/isc/config/tests/Makefile.am +++ b/src/lib/python/isc/config/tests/Makefile.am @@ -9,7 +9,7 @@ PYCOVERAGE = $(PYTHON) check-local: for pytest in $(PYTESTS) ; do \ echo Running test: $$pytest ; \ - env PYTHONPATH=$(abs_top_srcdir)/src/lib/python \ + env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python \ CONFIG_TESTDATA_PATH=$(abs_top_srcdir)/src/lib/config/testdata \ $(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \ done diff --git a/src/lib/python/isc/config/tests/ccsession_test.py b/src/lib/python/isc/config/tests/ccsession_test.py index 97deaaf8f2..3b45c845df 100644 --- a/src/lib/python/isc/config/tests/ccsession_test.py +++ b/src/lib/python/isc/config/tests/ccsession_test.py @@ -13,6 +13,8 @@ # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# $Id$ + # # Tests for the ConfigData and MultiConfigData classes # @@ -375,7 +377,7 @@ class fakeUIConn(): if name in self.get_answers: return self.get_answers[name] else: - return None + return {} def send_POST(self, name, arg = None): if name in self.post_answers: @@ -396,23 +398,20 @@ class TestUIModuleCCSession(unittest.TestCase): def create_uccs2(self, fake_conn): module_spec = isc.config.module_spec_from_file(self.spec_file("spec2.spec")) - fake_conn.set_get_answer('/config_spec', { module_spec.get_module_name(): module_spec.get_config_spec()}) - fake_conn.set_get_answer('/commands', { module_spec.get_module_name(): module_spec.get_commands_spec()}) + fake_conn.set_get_answer('/module_spec', { module_spec.get_module_name(): module_spec.get_full_spec()}) fake_conn.set_get_answer('/config_data', { 'version': 1 }) return UIModuleCCSession(fake_conn) def test_init(self): fake_conn = fakeUIConn() - fake_conn.set_get_answer('/config_spec', {}) - fake_conn.set_get_answer('/commands', {}) + fake_conn.set_get_answer('/module_spec', {}) fake_conn.set_get_answer('/config_data', { 'version': 1 }) uccs = UIModuleCCSession(fake_conn) self.assertEqual({}, uccs._specifications) self.assertEqual({ 'version': 1}, uccs._current_config) module_spec = isc.config.module_spec_from_file(self.spec_file("spec2.spec")) - fake_conn.set_get_answer('/config_spec', { module_spec.get_module_name(): module_spec.get_config_spec()}) - fake_conn.set_get_answer('/commands', { module_spec.get_module_name(): module_spec.get_commands_spec()}) + fake_conn.set_get_answer('/module_spec', { module_spec.get_module_name(): module_spec.get_full_spec()}) fake_conn.set_get_answer('/config_data', { 'version': 1 }) uccs = UIModuleCCSession(fake_conn) self.assertEqual(module_spec._module_spec, uccs._specifications['Spec2']._module_spec) diff --git a/src/lib/python/isc/config/tests/cfgmgr_test.py b/src/lib/python/isc/config/tests/cfgmgr_test.py index be24d09d52..76ce0735ec 100644 --- a/src/lib/python/isc/config/tests/cfgmgr_test.py +++ b/src/lib/python/isc/config/tests/cfgmgr_test.py @@ -13,6 +13,8 @@ # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# $Id$ + # # Tests for the configuration manager module # @@ -262,7 +264,7 @@ class TestConfigManager(unittest.TestCase): {'result': [0]}) self._handle_msg_helper({ "command": [ "module_spec", { 'foo': 1 } ] }, {'result': [1, 'Error in data definition: no module_name in module_spec']}) - self._handle_msg_helper({ "command": [ "get_module_spec" ] }, { 'result': [ 0, { self.spec.get_module_name(): self.spec.get_config_spec() } ]}) + self._handle_msg_helper({ "command": [ "get_module_spec" ] }, { 'result': [ 0, { self.spec.get_module_name(): self.spec.get_full_spec() } ]}) self._handle_msg_helper({ "command": [ "get_commands_spec" ] }, { 'result': [ 0, { self.spec.get_module_name(): self.spec.get_commands_spec() } ]}) # re-add this once we have new way to propagate spec changes (1 instead of the current 2 messages) #self.assertEqual(len(self.fake_session.message_queue), 2) diff --git a/src/lib/python/isc/config/tests/config_data_test.py b/src/lib/python/isc/config/tests/config_data_test.py index 20fa19dd72..5213635aeb 100644 --- a/src/lib/python/isc/config/tests/config_data_test.py +++ b/src/lib/python/isc/config/tests/config_data_test.py @@ -13,6 +13,8 @@ # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# $Id$ + # # Tests for the ConfigData and MultiConfigData classes # @@ -93,6 +95,71 @@ class TestConfigData(unittest.TestCase): self.assertRaises(isc.cc.data.DataTypeError, check_type, config_spec, 1) + def test_convert_type(self): + config_spec = isc.config.module_spec_from_file(self.data_path + os.sep + "spec22.spec").get_config_spec() + spec_part = find_spec_part(config_spec, "value1") + self.assertEqual(1, convert_type(spec_part, '1')) + self.assertEqual(2, convert_type(spec_part, 2.1)) + self.assertEqual(2, convert_type(spec_part, '2')) + self.assertEqual(3, convert_type(spec_part, '3')) + self.assertEqual(1, convert_type(spec_part, True)) + + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, "a") + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, [ 1, 2 ]) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, { "a": 1 }) + + spec_part = find_spec_part(config_spec, "value2") + self.assertEqual(1.1, convert_type(spec_part, '1.1')) + self.assertEqual(123.0, convert_type(spec_part, '123')) + self.assertEqual(1.0, convert_type(spec_part, True)) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, "a") + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, [ 1, 2 ]) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, { "a": 1 }) + + spec_part = find_spec_part(config_spec, "value3") + self.assertEqual(True, convert_type(spec_part, 'True')) + self.assertEqual(False, convert_type(spec_part, 'False')) + self.assertEqual(True, convert_type(spec_part, 1)) + self.assertEqual(True, convert_type(spec_part, 1.1)) + self.assertEqual(True, convert_type(spec_part, 'a')) + self.assertEqual(True, convert_type(spec_part, [1, 2])) + self.assertEqual(True, convert_type(spec_part, {'a' : 1})) + + spec_part = find_spec_part(config_spec, "value4") + self.assertEqual('asdf', convert_type(spec_part, "asdf")) + self.assertEqual('1', convert_type(spec_part, 1)) + self.assertEqual('1.1', convert_type(spec_part, 1.1)) + self.assertEqual('True', convert_type(spec_part, True)) + + spec_part = find_spec_part(config_spec, "value5") + self.assertEqual([1, 2], convert_type(spec_part, '1, 2')) + self.assertEqual([1, 2, 3], convert_type(spec_part, '1 2 3')) + self.assertEqual([1, 2, 3,4], convert_type(spec_part, '1 2 3, 4')) + self.assertEqual([1], convert_type(spec_part, [1,])) + self.assertEqual([1,2], convert_type(spec_part, [1,2])) + self.assertEqual([1,2], convert_type(spec_part, ['1', '2'])) + + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, 1.1) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, True) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, "a") + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, [ "a", "b" ]) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, [ "1", "b" ]) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, { "a": 1 }) + + spec_part = find_spec_part(config_spec, "value7") + self.assertEqual(['1', '2'], convert_type(spec_part, '1, 2')) + self.assertEqual(['1', '2', '3'], convert_type(spec_part, '1 2 3')) + self.assertEqual(['1', '2', '3','4'], convert_type(spec_part, '1 2 3, 4')) + self.assertEqual([1], convert_type(spec_part, [1,])) + self.assertEqual([1,2], convert_type(spec_part, [1,2])) + self.assertEqual(['1','2'], convert_type(spec_part, ['1', '2'])) + + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, 1.1) + self.assertRaises(isc.cc.data.DataTypeError, convert_type, spec_part, True) + self.assertEqual(['a'], convert_type(spec_part, "a")) + self.assertEqual(['a', 'b'], convert_type(spec_part, ["a", "b" ])) + self.assertEqual([1, 'b'], convert_type(spec_part, [1, "b" ])) + def test_find_spec_part(self): config_spec = self.cd.get_module_spec().get_config_spec() spec_part = find_spec_part(config_spec, "item1") diff --git a/src/lib/python/isc/config/tests/module_spec_test.py b/src/lib/python/isc/config/tests/module_spec_test.py index 7653d8f535..5ed2d59cbb 100644 --- a/src/lib/python/isc/config/tests/module_spec_test.py +++ b/src/lib/python/isc/config/tests/module_spec_test.py @@ -13,6 +13,8 @@ # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# $Id$ + # # Tests for the module_spec module # @@ -73,6 +75,7 @@ class TestModuleSpec(unittest.TestCase): self.assertRaises(ModuleSpecError, self.read_spec_file, "spec19.spec") self.assertRaises(ModuleSpecError, self.read_spec_file, "spec20.spec") self.assertRaises(ModuleSpecError, self.read_spec_file, "spec21.spec") + self.assertRaises(ModuleSpecError, self.read_spec_file, "spec26.spec") def validate_data(self, specfile_name, datafile_name): dd = self.read_spec_file(specfile_name); @@ -91,11 +94,33 @@ class TestModuleSpec(unittest.TestCase): self.assertEqual(True, self.validate_data("spec22.spec", "data22_7.data")) self.assertEqual(False, self.validate_data("spec22.spec", "data22_8.data")) + def validate_command_params(self, specfile_name, datafile_name, cmd_name): + dd = self.read_spec_file(specfile_name); + data_file = open(self.spec_file(datafile_name)) + data_str = data_file.read() + params = isc.cc.data.parse_value_str(data_str) + return dd.validate_command(cmd_name, params) + + def test_command_validation(self): + self.assertEqual(True, self.validate_command_params("spec27.spec", "data22_1.data", 'cmd1')) + self.assertEqual(False, self.validate_command_params("spec27.spec", "data22_2.data",'cmd1')) + self.assertEqual(False, self.validate_command_params("spec27.spec", "data22_3.data", 'cmd1')) + self.assertEqual(False, self.validate_command_params("spec27.spec", "data22_4.data", 'cmd1')) + self.assertEqual(False, self.validate_command_params("spec27.spec", "data22_5.data", 'cmd1')) + self.assertEqual(True, self.validate_command_params("spec27.spec", "data22_6.data", 'cmd1')) + self.assertEqual(True, self.validate_command_params("spec27.spec", "data22_7.data", 'cmd1')) + self.assertEqual(False, self.validate_command_params("spec27.spec", "data22_8.data", 'cmd1')) + self.assertEqual(False, self.validate_command_params("spec27.spec", "data22_8.data", 'cmd2')) + def test_init(self): self.assertRaises(ModuleSpecError, ModuleSpec, 1) module_spec = isc.config.module_spec_from_file(self.spec_file("spec1.spec"), False) self.spec1(module_spec) + module_spec = isc.config.module_spec_from_file(self.spec_file("spec25.spec"), True) + self.assertEqual("Spec25", module_spec.get_module_name()) + self.assertEqual("Just an empty module", module_spec.get_module_description()) + def test_str(self): module_spec = isc.config.module_spec_from_file(self.spec_file("spec1.spec"), False) self.assertEqual(module_spec.__str__(), "{'module_name': 'Spec1'}") diff --git a/src/lib/xfr/Makefile.am b/src/lib/xfr/Makefile.am index 75e7a4f577..77bef8f76e 100644 --- a/src/lib/xfr/Makefile.am +++ b/src/lib/xfr/Makefile.am @@ -1,12 +1,8 @@ -SUBDIRS = . - AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib AM_CPPFLAGS += -I$(top_srcdir)/src/lib/dns -I$(top_builddir)/src/lib/dns -AM_CPPFLAGS += -I$(top_srcdir)/ext -Wno-strict-aliasing -if GCC_WERROR_OK -AM_CPPFLAGS += -Werror -endif +AM_CXXFLAGS = $(B10_CXXFLAGS) -Wno-strict-aliasing +AM_CXXFLAGS += -Wno-unused-parameter # see src/lib/cc/Makefile.am CLEANFILES = *.gcno *.gcda @@ -15,28 +11,12 @@ libxfr_la_SOURCES = xfrout_client.h xfrout_client.cc libxfr_la_SOURCES += fd_share.h fd_share.cc pyexec_LTLIBRARIES = libxfr_python.la -libxfr_python_la_SOURCES = fdshare_python.cc +libxfr_python_la_SOURCES = fdshare_python.cc fd_share.cc fd_share.h libxfr_python_la_CPPFLAGS = $(AM_CPPFLAGS) $(PYTHON_INCLUDES) -libxfr_python_la_LDFLAGS = $(PYTHON_LDFLAGS) -# (still need boost for asio) -libxfr_python_la_LDFLAGS += $(BOOST_LDFLAGS) $(PYTHON_LDFLAGS) -libxfr_python_la_LDFLAGS += -module - -libxfr_python_la_LIBADD = $(top_builddir)/src/lib/xfr/libxfr.la -libxfr_python_la_LIBADD += $(top_builddir)/src/lib/exceptions/libexceptions.la -libxfr_python_la_LIBADD += $(BOOST_SYSTEM_LIB) $(PYTHON_LIB) -libxfr_python_la_LIBADD += $(PYTHON_LIB) +libxfr_python_la_CXXFLAGS = $(AM_CXXFLAGS) -#if HAVE_BOOST_PYTHON -#pyexec_LTLIBRARIES += bind10_xfr.la -#bind10_xfr_la_SOURCES = python_xfr.cc fd_share.cc fd_share.h -#bind10_xfr_la_CPPFLAGS = $(AM_CPPFLAGS) $(PYTHON_INCLUDES) -#if GCC_WERROR_OK -# XXX: Boost.Python triggers strict aliasing violation, so if we use -Werror -# we need to suppress the warnings. -#bind10_xfr_la_CPPFLAGS += -fno-strict-aliasing -#endif #bind10_xfr_la_LDFLAGS = $(BOOST_LDFLAGS) $(PYTHON_LDFLAGS) + # Python prefers .so, while some OSes (specifically MacOS) use a different # suffix for dynamic objects. -module is necessary to work this around. #bind10_xfr_la_LDFLAGS += -module diff --git a/src/lib/xfr/xfrout_client.cc b/src/lib/xfr/xfrout_client.cc index 1d5afe405b..7194f5444e 100644 --- a/src/lib/xfr/xfrout_client.cc +++ b/src/lib/xfr/xfrout_client.cc @@ -17,46 +17,72 @@ #include <cstdlib> #include <cstring> #include <iostream> + +#include <asio.hpp> + #include "fd_share.h" #include "xfrout_client.h" -using boost::asio::local::stream_protocol; +using namespace std; +using asio::local::stream_protocol; namespace isc { namespace xfr { +struct XfroutClientImpl { + XfroutClientImpl(const string& file); + const std::string file_path_; + asio::io_service io_service_; + // The socket used to communicate with the xfrout server. + stream_protocol::socket socket_; +}; + +XfroutClientImpl::XfroutClientImpl(const string& file) : + file_path_(file), socket_(io_service_) +{} + +XfroutClient::XfroutClient(const string& file) : + impl_(new XfroutClientImpl(file)) +{} + +XfroutClient::~XfroutClient() +{ + delete impl_; +} + void XfroutClient::connect() { - socket_.connect(stream_protocol::endpoint(file_path_)); + impl_->socket_.connect(stream_protocol::endpoint(impl_->file_path_)); } void XfroutClient::disconnect() { - socket_.close(); + impl_->socket_.close(); } int XfroutClient::sendXfroutRequestInfo(const int tcp_sock, uint8_t* msg_data, const uint16_t msg_len) { - if (-1 == send_fd(socket_.native(), tcp_sock)) { + if (-1 == send_fd(impl_->socket_.native(), tcp_sock)) { isc_throw(XfroutError, "Fail to send socket descriptor to xfrout module"); } // XXX: this shouldn't be blocking send, even though it's unlikely to block. const uint8_t lenbuf[2] = { msg_len >> 8, msg_len & 0xff }; - if (send(socket_.native(), lenbuf, sizeof(lenbuf), 0) != sizeof(lenbuf)) { + if (send(impl_->socket_.native(), lenbuf, sizeof(lenbuf), 0) != + sizeof(lenbuf)) { isc_throw(XfroutError, "failed to send XFR request length to xfrout module"); } - if (send(socket_.native(), msg_data, msg_len, 0) != msg_len) { + if (send(impl_->socket_.native(), msg_data, msg_len, 0) != msg_len) { isc_throw(XfroutError, "failed to send XFR request data to xfrout module"); } int databuf = 0; - if (recv(socket_.native(), &databuf, sizeof(int), 0) != 0) { + if (recv(impl_->socket_.native(), &databuf, sizeof(int), 0) != 0) { isc_throw(XfroutError, "xfr query hasn't been processed properly by xfrout module"); } diff --git a/src/lib/xfr/xfrout_client.h b/src/lib/xfr/xfrout_client.h index 36ad7904f0..f148b8036f 100644 --- a/src/lib/xfr/xfrout_client.h +++ b/src/lib/xfr/xfrout_client.h @@ -17,36 +17,38 @@ #ifndef _XFROUT_CLIENT_H #define _XFROUT_CLIENT_H +#include <stdint.h> + #include <string> -#include <boost/asio.hpp> #include <exceptions/exceptions.h> namespace isc { namespace xfr { +struct XfroutClientImpl; + class XfroutError: public Exception { public: XfroutError(const char *file, size_t line, const char *what): isc::Exception(file, line, what) {} }; -using boost::asio::local::stream_protocol; class XfroutClient { public: - XfroutClient(const std::string& file): - socket_(io_service_), file_path_(file) {} - + XfroutClient(const std::string& file); + ~XfroutClient(); +private: + // make this class non copyable + XfroutClient(const XfroutClient& source); + XfroutClient& operator=(const XfroutClient& source); +public: void connect(); void disconnect(); int sendXfroutRequestInfo(int tcp_sock, uint8_t* msg_data, uint16_t msg_len); - private: - boost::asio::io_service io_service_; - // The socket used to communicate with the xfrout server. - stream_protocol::socket socket_; - const std::string file_path_; + XfroutClientImpl* impl_; }; } // End for namespace xfr |