diff options
Diffstat (limited to 'bin/ansible-connection')
-rwxr-xr-x | bin/ansible-connection | 101 |
1 files changed, 78 insertions, 23 deletions
diff --git a/bin/ansible-connection b/bin/ansible-connection index 062a12127f..2165dbc69f 100755 --- a/bin/ansible-connection +++ b/bin/ansible-connection @@ -37,6 +37,8 @@ import struct import sys import time import traceback +import syslog +import datetime from io import BytesIO @@ -47,6 +49,7 @@ from ansible.playbook.play_context import PlayContext from ansible.plugins import connection_loader from ansible.utils.path import unfrackpath, makedirs_safe + def do_fork(): ''' Does the required double fork for a daemon process. Based on @@ -97,25 +100,46 @@ def recv_data(s): data += d return data + class Server(): + def __init__(self, path, play_context): self.path = path self.play_context = play_context - # FIXME: the connection loader here is created brand new, - # so it will not see any custom paths loaded (ie. via - # roles), so we will need to serialize the connection - # loader and send it over as we do the PlayContext - # in the main() method below. - self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin) - self.conn._connect() + self._start_time = datetime.datetime.now() + + try: + # FIXME: the connection loader here is created brand new, + # so it will not see any custom paths loaded (ie. via + # roles), so we will need to serialize the connection + # loader and send it over as we do the PlayContext + # in the main() method below. + self.log('loading connection plugin %s' % str(play_context.connection)) + self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin) + self.conn._connect() + + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.socket.bind(path) + self.socket.listen(1) - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.socket.bind(path) - self.socket.listen(1) + except Exception as exc: + self.log(exc) + return signal.signal(signal.SIGALRM, self.alarm_handler) + def log(self, msg): + syslog_msg = '[%s] %s' % (self.play_context.remote_addr, msg) + facility = getattr(syslog, C.DEFAULT_SYSLOG_FACILITY, syslog.LOG_USER) + syslog.openlog('ansible-connection', 0, facility) + syslog.syslog(syslog.LOG_INFO, syslog_msg) + + def dispatch(self, obj, name, *args, **kwargs): + meth = getattr(obj, name, None) + if meth: + return meth(*args, **kwargs) + def alarm_handler(self, signum, frame): ''' Alarm handler @@ -124,6 +148,9 @@ class Server(): # areas of code to check, so they can terminate # earlier than the socket going back to the accept # call and failing there. + # + # hooks the connection plugin to handle any cleanup + self.dispatch(self.conn, 'alarm_handler', signum, frame) self.socket.close() def run(self): @@ -150,6 +177,8 @@ class Server(): if not data: break + signal.alarm(C.DEFAULT_TIMEOUT) + rc = 255 try: if data.startswith(b'EXEC: '): @@ -166,6 +195,18 @@ class Server(): rc = 0 except: pass + elif data.startswith(b'CONTEXT: '): + pc_data = data.split(b'CONTEXT: ')[1] + + src = StringIO(pc_data) + pc_data = cPickle.load(src) + src.close() + + pc = PlayContext() + pc.deserialize(pc_data) + + self.dispatch(self.conn, 'update_play_context', pc) + continue else: stdout = '' stderr = 'Invalid action specified' @@ -173,19 +214,25 @@ class Server(): stdout = '' stderr = traceback.format_exc() + signal.alarm(0) + send_data(s, to_bytes(str(rc))) send_data(s, to_bytes(stdout)) send_data(s, to_bytes(stderr)) s.close() except Exception as e: # FIXME: proper logging and error handling here - print("run exception: %s" % e) + self.log('runtime exception: %s' % e) print(traceback.format_exc()) finally: # when done, close the connection properly and cleanup # the socket file so it can be recreated + end_time = datetime.datetime.now() + delta = end_time - self._start_time + self.log('shutting down connection, connection was active for %s secs' % delta) try: self.conn.close() + self.socket.close() except Exception as e: pass os.remove(self.path) @@ -205,7 +252,7 @@ def main(): cur_line = sys.stdin.readline() src = BytesIO(to_bytes(init_data)) pc_data = cPickle.load(src) - src.close() + #src.close() pc = PlayContext() pc.deserialize(pc_data) @@ -236,11 +283,11 @@ def main(): if not os.path.exists(sf_path): pid = do_fork() if pid == 0: - server = Server(sf_path, pc) - fcntl.lockf(lock_fd, fcntl.LOCK_UN) - os.close(lock_fd) - server.run() - sys.exit(0) + server = Server(sf_path, pc) + fcntl.lockf(lock_fd, fcntl.LOCK_UN) + os.close(lock_fd) + server.run() + sys.exit(0) fcntl.lockf(lock_fd, fcntl.LOCK_UN) os.close(lock_fd) @@ -262,24 +309,32 @@ def main(): break except socket.error: # FIXME: better error handling/logging/message here - # FIXME: make # of retries configurable? - time.sleep(0.1) + time.sleep(C.PERSISTENT_CONNECT_INTERVAL) attempts += 1 - if attempts > 10: - sys.stderr.write("failed to connect to the host, connection timeout\n") + if attempts > C.PERSISTENT_CONNECT_RETRIES: + sys.stderr.write("failed to connect to the host, connection timeout") sys.exit(255) + # + # send the play_context back into the connection so the connection + # can handle any privilege escalation or deescalation activities + # + pc_data = 'CONTEXT: %s' % src.getvalue() + send_data(sf, to_bytes(pc_data)) + src.close() + send_data(sf, to_bytes(data.strip())) + rc = int(recv_data(sf), 10) stdout = recv_data(sf) stderr = recv_data(sf) + sys.stdout.write(to_native(stdout)) sys.stderr.write(to_native(stderr)) - #sys.stdout.flush() - #sys.stderr.flush() sf.close() break + sys.exit(rc) if __name__ == '__main__': |