summaryrefslogtreecommitdiffstats
path: root/test/support
diff options
context:
space:
mode:
authorJordan Borean <jborean93@gmail.com>2023-01-31 01:10:30 +0100
committerGitHub <noreply@github.com>2023-01-31 01:10:30 +0100
commitd16ec2455d8ecc98439b0bbadcc1f50409ed1dfa (patch)
tree40642cd01ec47279a5172b1f176bf0b834221f25 /test/support
parentRemove unused and unreachable unit test code (#79854) (diff)
downloadansible-d16ec2455d8ecc98439b0bbadcc1f50409ed1dfa.tar.xz
ansible-d16ec2455d8ecc98439b0bbadcc1f50409ed1dfa.zip
Add tests to cover win_reboot incidental paths (#79856)
* Add tests to cover win_reboot incidental paths * Fix sanity issues
Diffstat (limited to 'test/support')
-rw-r--r--test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_reboot.py101
-rw-r--r--test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_quote.py114
-rw-r--r--test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_reboot.py620
3 files changed, 835 insertions, 0 deletions
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_reboot.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_reboot.py
new file mode 100644
index 0000000000..f1fad4d800
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_reboot.py
@@ -0,0 +1,101 @@
+# Copyright: (c) 2018, Matt Davis <mdavis@ansible.com>
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+from ansible.errors import AnsibleError
+from ansible.module_utils.common.text.converters import to_native
+from ansible.module_utils.common.validation import check_type_str, check_type_float
+from ansible.plugins.action import ActionBase
+from ansible.utils.display import Display
+
+from ansible_collections.ansible.windows.plugins.plugin_utils._reboot import reboot_host
+
+display = Display()
+
+
+def _positive_float(val):
+ float_val = check_type_float(val)
+ if float_val < 0:
+ return 0
+
+ else:
+ return float_val
+
+
+class ActionModule(ActionBase):
+ TRANSFERS_FILES = False
+ _VALID_ARGS = frozenset((
+ 'boot_time_command',
+ 'connect_timeout',
+ 'connect_timeout_sec',
+ 'msg',
+ 'post_reboot_delay',
+ 'post_reboot_delay_sec',
+ 'pre_reboot_delay',
+ 'pre_reboot_delay_sec',
+ 'reboot_timeout',
+ 'reboot_timeout_sec',
+ 'shutdown_timeout',
+ 'shutdown_timeout_sec',
+ 'test_command',
+ ))
+
+ def run(self, tmp=None, task_vars=None):
+ self._supports_check_mode = True
+ self._supports_async = True
+
+ if self._play_context.check_mode:
+ return {'changed': True, 'elapsed': 0, 'rebooted': True}
+
+ if task_vars is None:
+ task_vars = {}
+
+ super(ActionModule, self).run(tmp, task_vars)
+
+ parameters = {}
+ for names, check_func in [
+ (['boot_time_command'], check_type_str),
+ (['connect_timeout', 'connect_timeout_sec'], _positive_float),
+ (['msg'], check_type_str),
+ (['post_reboot_delay', 'post_reboot_delay_sec'], _positive_float),
+ (['pre_reboot_delay', 'pre_reboot_delay_sec'], _positive_float),
+ (['reboot_timeout', 'reboot_timeout_sec'], _positive_float),
+ (['test_command'], check_type_str),
+ ]:
+ for name in names:
+ value = self._task.args.get(name, None)
+ if value:
+ break
+ else:
+ value = None
+
+ # Defaults are applied in reboot_action so skip adding to kwargs if the input wasn't set (None)
+ if value is not None:
+ try:
+ value = check_func(value)
+ except TypeError as e:
+ raise AnsibleError("Invalid value given for '%s': %s." % (names[0], to_native(e)))
+
+ # Setting a lower value and kill PowerShell when sending the shutdown command. Just use the defaults
+ # if this is the case.
+ if names[0] == 'pre_reboot_delay' and value < 2:
+ continue
+
+ parameters[names[0]] = value
+
+ result = reboot_host(self._task.action, self._connection, **parameters)
+
+ # Not needed for testing and collection_name kwargs causes sanity error
+ # Historical behaviour had ignore_errors=True being able to ignore unreachable hosts and not just task errors.
+ # This snippet will allow that to continue but state that it will be removed in a future version and to use
+ # ignore_unreachable to ignore unreachable hosts.
+ # if result['unreachable'] and self._task.ignore_errors and not self._task.ignore_unreachable:
+ # dep_msg = "Host was unreachable but is being skipped because ignore_errors=True is set. In the future " \
+ # "only ignore_unreachable will be able to ignore an unreachable host for %s" % self._task.action
+ # display.deprecated(dep_msg, date="2023-05-01", collection_name="ansible.windows")
+ # result['unreachable'] = False
+ # result['failed'] = True
+
+ return result
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_quote.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_quote.py
new file mode 100644
index 0000000000..718a09905e
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_quote.py
@@ -0,0 +1,114 @@
+# Copyright (c) 2021 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+"""Quoting helpers for Windows
+
+This contains code to help with quoting values for use in the variable Windows
+shell. Right now it should only be used in ansible.windows as the interface is
+not final and could be subject to change.
+"""
+
+# FOR INTERNAL COLLECTION USE ONLY
+# The interfaces in this file are meant for use within the ansible.windows collection
+# and may not remain stable to outside uses. Changes may be made in ANY release, even a bugfix release.
+# See also: https://github.com/ansible/community/issues/539#issuecomment-780839686
+# Please open an issue if you have questions about this.
+
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+import re
+
+from ansible.module_utils.six import text_type
+
+
+_UNSAFE_C = re.compile(u'[\\s\t"]')
+_UNSAFE_CMD = re.compile(u'[\\s\\(\\)\\^\\|%!"<>&]')
+
+# PowerShell has 5 characters it uses as a single quote, we need to double up on all of them.
+# https://github.com/PowerShell/PowerShell/blob/b7cb335f03fe2992d0cbd61699de9d9aafa1d7c1/src/System.Management.Automation/engine/parser/CharTraits.cs#L265-L272
+# https://github.com/PowerShell/PowerShell/blob/b7cb335f03fe2992d0cbd61699de9d9aafa1d7c1/src/System.Management.Automation/engine/parser/CharTraits.cs#L18-L21
+_UNSAFE_PWSH = re.compile(u"(['\u2018\u2019\u201a\u201b])")
+
+
+def quote_c(s): # type: (text_type) -> text_type
+ """Quotes a value for the raw Win32 process command line.
+
+ Quotes a value to be safely used by anything that calls the Win32
+ CreateProcess API.
+
+ Args:
+ s: The string to quote.
+
+ Returns:
+ (text_type): The quoted string value.
+ """
+ # https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
+ if not s:
+ return u'""'
+
+ if not _UNSAFE_C.search(s):
+ return s
+
+ # Replace any double quotes in an argument with '\"'.
+ s = s.replace('"', '\\"')
+
+ # We need to double up on any '\' chars that preceded a double quote (now '\"').
+ s = re.sub(r'(\\+)\\"', r'\1\1\"', s)
+
+ # Double up '\' at the end of the argument so it doesn't escape out end quote.
+ s = re.sub(r'(\\+)$', r'\1\1', s)
+
+ # Finally wrap the entire argument in double quotes now we've escaped the double quotes within.
+ return u'"{0}"'.format(s)
+
+
+def quote_cmd(s): # type: (text_type) -> text_type
+ """Quotes a value for cmd.
+
+ Quotes a value to be safely used by a command prompt call.
+
+ Args:
+ s: The string to quote.
+
+ Returns:
+ (text_type): The quoted string value.
+ """
+ # https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way#a-better-method-of-quoting
+ if not s:
+ return u'""'
+
+ if not _UNSAFE_CMD.search(s):
+ return s
+
+ # Escape the metachars as we are quoting the string to stop cmd from interpreting that metachar. For example
+ # 'file &whoami.exe' would result in 'whoami.exe' being executed and then that output being used as the argument
+ # instead of the literal string.
+ # https://stackoverflow.com/questions/3411771/multiple-character-replace-with-python
+ for c in u'^()%!"<>&|': # '^' must be the first char that we scan and replace
+ if c in s:
+ # I can't find any docs that explicitly say this but to escape ", it needs to be prefixed with \^.
+ s = s.replace(c, (u"\\^" if c == u'"' else u"^") + c)
+
+ return u'^"{0}^"'.format(s)
+
+
+def quote_pwsh(s): # type: (text_type) -> text_type
+ """Quotes a value for PowerShell.
+
+ Quotes a value to be safely used by a PowerShell expression. The input
+ string because something that is safely wrapped in single quotes.
+
+ Args:
+ s: The string to quote.
+
+ Returns:
+ (text_type): The quoted string value.
+ """
+ # https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_quoting_rules?view=powershell-5.1
+ if not s:
+ return u"''"
+
+ # We should always quote values in PowerShell as it has conflicting rules where strings can and can't be quoted.
+ # This means we quote the entire arg with single quotes and just double up on the single quote equivalent chars.
+ return u"'{0}'".format(_UNSAFE_PWSH.sub(u'\\1\\1', s))
diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_reboot.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_reboot.py
new file mode 100644
index 0000000000..2399ee48ca
--- /dev/null
+++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/plugin_utils/_reboot.py
@@ -0,0 +1,620 @@
+# Copyright: (c) 2021, Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+"""Reboot action for Windows hosts
+
+This contains the code to reboot a Windows host for use by other action plugins
+in this collection. Right now it should only be used in this collection as the
+interface is not final and count be subject to change.
+"""
+
+# FOR INTERNAL COLLECTION USE ONLY
+# The interfaces in this file are meant for use within the ansible.windows collection
+# and may not remain stable to outside uses. Changes may be made in ANY release, even a bugfix release.
+# See also: https://github.com/ansible/community/issues/539#issuecomment-780839686
+# Please open an issue if you have questions about this.
+
+import datetime
+import json
+import random
+import time
+import traceback
+import uuid
+import typing as t
+
+from ansible.errors import AnsibleConnectionFailure, AnsibleError
+from ansible.module_utils.common.text.converters import to_text
+from ansible.plugins.connection import ConnectionBase
+from ansible.utils.display import Display
+
+from ansible_collections.ansible.windows.plugins.plugin_utils._quote import quote_pwsh
+
+
+# This is not ideal but the psrp connection plugin doesn't catch all these exceptions as an AnsibleConnectionFailure.
+# Until we can guarantee we are using a version of psrp that handles all this we try to handle those issues.
+try:
+ from requests.exceptions import (
+ RequestException,
+ )
+except ImportError:
+ RequestException = AnsibleConnectionFailure
+
+
+_LOGON_UI_KEY = (
+ r"HKLM:\SOFTWARE\Microsoft\Windows NT\CurrentVersion\Winlogon\AutoLogonChecked"
+)
+
+_DEFAULT_BOOT_TIME_COMMAND = (
+ "(Get-CimInstance -ClassName Win32_OperatingSystem -Property LastBootUpTime)"
+ ".LastBootUpTime.ToFileTime()"
+)
+
+T = t.TypeVar("T")
+
+display = Display()
+
+
+class _ReturnResultException(Exception):
+ """Used to sneak results back to the return dict from an exception"""
+
+ def __init__(self, msg, **result):
+ super().__init__(msg)
+ self.result = result
+
+
+class _TestCommandFailure(Exception):
+ """Differentiates between a connection failure and just a command assertion failure during the reboot loop"""
+
+
+def reboot_host(
+ task_action: str,
+ connection: ConnectionBase,
+ boot_time_command: str = _DEFAULT_BOOT_TIME_COMMAND,
+ connect_timeout: int = 5,
+ msg: str = "Reboot initiated by Ansible",
+ post_reboot_delay: int = 0,
+ pre_reboot_delay: int = 2,
+ reboot_timeout: int = 600,
+ test_command: t.Optional[str] = None,
+) -> t.Dict[str, t.Any]:
+ """Reboot a Windows Host.
+
+ Used by action plugins in ansible.windows to reboot a Windows host. It
+ takes in the connection plugin so it can run the commands on the targeted
+ host and monitor the reboot process. The return dict will have the
+ following keys set:
+
+ changed: Whether a change occurred (reboot was done)
+ elapsed: Seconds elapsed between the reboot and it coming back online
+ failed: Whether a failure occurred
+ unreachable: Whether it failed to connect to the host on the first cmd
+ rebooted: Whether the host was rebooted
+
+ When failed=True there may be more keys to give some information around
+ the failure like msg, exception. There are other keys that might be
+ returned as well but they are dependent on the failure that occurred.
+
+ Verbosity levels used:
+ 2: Message when each reboot step is completed
+ 4: Connection plugin operations and their results
+ 5: Raw commands run and the results of those commands
+ Debug: Everything, very verbose
+
+ Args:
+ task_action: The name of the action plugin that is running for logging.
+ connection: The connection plugin to run the reboot commands on.
+ boot_time_command: The command to run when getting the boot timeout.
+ connect_timeout: Override the connection timeout of the connection
+ plugin when polling the rebooted host.
+ msg: The message to display to interactive users when rebooting the
+ host.
+ post_reboot_delay: Seconds to wait after sending the reboot command
+ before checking to see if it has returned.
+ pre_reboot_delay: Seconds to wait when sending the reboot command.
+ reboot_timeout: Seconds to wait while polling for the host to come
+ back online.
+ test_command: Command to run when the host is back online and
+ determines the machine is ready for management. When not defined
+ the default command should wait until the reboot is complete and
+ all pre-login configuration has completed.
+
+ Returns:
+ (Dict[str, Any]): The return result as a dictionary. Use the 'failed'
+ key to determine if there was a failure or not.
+ """
+ result: t.Dict[str, t.Any] = {
+ "changed": False,
+ "elapsed": 0,
+ "failed": False,
+ "unreachable": False,
+ "rebooted": False,
+ }
+ host_context = {"do_close_on_reset": True}
+
+ # Get current boot time. A lot of tasks that require a reboot leave the WSMan stack in a bad place. Will try to
+ # get the initial boot time 3 times before giving up.
+ try:
+ previous_boot_time = _do_until_success_or_retry_limit(
+ task_action,
+ connection,
+ host_context,
+ "pre-reboot boot time check",
+ 3,
+ _get_system_boot_time,
+ task_action,
+ connection,
+ boot_time_command,
+ )
+
+ except Exception as e:
+ # Report a the failure based on the last exception received.
+ if isinstance(e, _ReturnResultException):
+ result.update(e.result)
+
+ if isinstance(e, AnsibleConnectionFailure):
+ result["unreachable"] = True
+ else:
+ result["failed"] = True
+
+ result["msg"] = str(e)
+ result["exception"] = traceback.format_exc()
+ return result
+
+ # Get the original connection_timeout option var so it can be reset after
+ original_connection_timeout: t.Optional[float] = None
+ try:
+ original_connection_timeout = connection.get_option("connection_timeout")
+ display.vvvv(
+ f"{task_action}: saving original connection_timeout of {original_connection_timeout}"
+ )
+ except KeyError:
+ display.vvvv(
+ f"{task_action}: connection_timeout connection option has not been set"
+ )
+
+ # Initiate reboot
+ # This command may be wrapped in other shells or command making it hard to detect what shutdown.exe actually
+ # returned. We use this hackery to return a json that contains the stdout/stderr/rc as a structured object for our
+ # code to parse and detect if something went wrong.
+ reboot_command = """$ErrorActionPreference = 'Continue'
+
+if ($%s) {
+ Remove-Item -LiteralPath '%s' -Force -ErrorAction SilentlyContinue
+}
+
+$stdout = $null
+$stderr = . { shutdown.exe /r /t %s /c %s | Set-Variable stdout } 2>&1 | ForEach-Object ToString
+
+ConvertTo-Json -Compress -InputObject @{
+ stdout = (@($stdout) -join "`n")
+ stderr = (@($stderr) -join "`n")
+ rc = $LASTEXITCODE
+}
+""" % (
+ str(not test_command),
+ _LOGON_UI_KEY,
+ int(pre_reboot_delay),
+ quote_pwsh(msg),
+ )
+
+ expected_test_result = (
+ None # We cannot have an expected result if the command is user defined
+ )
+ if not test_command:
+ # It turns out that LogonUI will create this registry key if it does not exist when it's about to show the
+ # logon prompt. Normally this is a volatile key but if someone has explicitly created it that might no longer
+ # be the case. We ensure it is not present on a reboot so we can wait until LogonUI creates it to determine
+ # the host is actually online and ready, e.g. no configurations/updates still to be applied.
+ # We echo a known successful statement to catch issues with powershell failing to start but the rc mysteriously
+ # being 0 causing it to consider a successful reboot too early (seen on ssh connections).
+ expected_test_result = f"success-{uuid.uuid4()}"
+ test_command = f"Get-Item -LiteralPath '{_LOGON_UI_KEY}' -ErrorAction Stop; '{expected_test_result}'"
+
+ start = None
+ try:
+ _perform_reboot(task_action, connection, reboot_command)
+
+ start = datetime.datetime.utcnow()
+ result["changed"] = True
+ result["rebooted"] = True
+
+ if post_reboot_delay != 0:
+ display.vv(
+ f"{task_action}: waiting an additional {post_reboot_delay} seconds"
+ )
+ time.sleep(post_reboot_delay)
+
+ # Keep on trying to run the last boot time check until it is successful or the timeout is raised
+ display.vv(f"{task_action} validating reboot")
+ _do_until_success_or_timeout(
+ task_action,
+ connection,
+ host_context,
+ "last boot time check",
+ reboot_timeout,
+ _check_boot_time,
+ task_action,
+ connection,
+ host_context,
+ previous_boot_time,
+ boot_time_command,
+ connect_timeout,
+ )
+
+ # Reset the connection plugin connection timeout back to the original
+ if original_connection_timeout is not None:
+ _set_connection_timeout(
+ task_action,
+ connection,
+ host_context,
+ original_connection_timeout,
+ )
+
+ # Run test command until ti is successful or a timeout occurs
+ display.vv(f"{task_action} running post reboot test command")
+ _do_until_success_or_timeout(
+ task_action,
+ connection,
+ host_context,
+ "post-reboot test command",
+ reboot_timeout,
+ _run_test_command,
+ task_action,
+ connection,
+ test_command,
+ expected=expected_test_result,
+ )
+
+ display.vv(f"{task_action}: system successfully rebooted")
+
+ except Exception as e:
+ if isinstance(e, _ReturnResultException):
+ result.update(e.result)
+
+ result["failed"] = True
+ result["msg"] = str(e)
+ result["exception"] = traceback.format_exc()
+
+ if start:
+ elapsed = datetime.datetime.utcnow() - start
+ result["elapsed"] = elapsed.seconds
+
+ return result
+
+
+def _check_boot_time(
+ task_action: str,
+ connection: ConnectionBase,
+ host_context: t.Dict[str, t.Any],
+ previous_boot_time: int,
+ boot_time_command: str,
+ timeout: int,
+):
+ """Checks the system boot time has been changed or not"""
+ display.vvvv("%s: attempting to get system boot time" % task_action)
+
+ # override connection timeout from defaults to custom value
+ if timeout:
+ _set_connection_timeout(task_action, connection, host_context, timeout)
+
+ # try and get boot time
+ current_boot_time = _get_system_boot_time(
+ task_action, connection, boot_time_command
+ )
+ if current_boot_time == previous_boot_time:
+ raise _TestCommandFailure("boot time has not changed")
+
+
+def _do_until_success_or_retry_limit(
+ task_action: str,
+ connection: ConnectionBase,
+ host_context: t.Dict[str, t.Any],
+ action_desc: str,
+ retries: int,
+ func: t.Callable[..., T],
+ *args: t.Any,
+ **kwargs: t.Any,
+) -> t.Optional[T]:
+ """Runs the function multiple times ignoring errors until the retry limit is hit"""
+
+ def wait_condition(idx):
+ return idx < retries
+
+ return _do_until_success_or_condition(
+ task_action,
+ connection,
+ host_context,
+ action_desc,
+ wait_condition,
+ func,
+ *args,
+ **kwargs,
+ )
+
+
+def _do_until_success_or_timeout(
+ task_action: str,
+ connection: ConnectionBase,
+ host_context: t.Dict[str, t.Any],
+ action_desc: str,
+ timeout: float,
+ func: t.Callable[..., T],
+ *args: t.Any,
+ **kwargs: t.Any,
+) -> t.Optional[T]:
+ """Runs the function multiple times ignoring errors until a timeout occurs"""
+ max_end_time = datetime.datetime.utcnow() + datetime.timedelta(seconds=timeout)
+
+ def wait_condition(idx):
+ return datetime.datetime.utcnow() < max_end_time
+
+ try:
+ return _do_until_success_or_condition(
+ task_action,
+ connection,
+ host_context,
+ action_desc,
+ wait_condition,
+ func,
+ *args,
+ **kwargs,
+ )
+ except Exception:
+ raise Exception(
+ "Timed out waiting for %s (timeout=%s)" % (action_desc, timeout)
+ )
+
+
+def _do_until_success_or_condition(
+ task_action: str,
+ connection: ConnectionBase,
+ host_context: t.Dict[str, t.Any],
+ action_desc: str,
+ condition: t.Callable[[int], bool],
+ func: t.Callable[..., T],
+ *args: t.Any,
+ **kwargs: t.Any,
+) -> t.Optional[T]:
+ """Runs the function multiple times ignoring errors until the condition is false"""
+ fail_count = 0
+ max_fail_sleep = 12
+ reset_required = False
+ last_error = None
+
+ while fail_count == 0 or condition(fail_count):
+ try:
+ if reset_required:
+ # Keep on trying the reset until it succeeds.
+ _reset_connection(task_action, connection, host_context)
+ reset_required = False
+
+ else:
+ res = func(*args, **kwargs)
+ display.vvvvv("%s: %s success" % (task_action, action_desc))
+
+ return res
+
+ except Exception as e:
+ last_error = e
+
+ if not isinstance(e, _TestCommandFailure):
+ # The error may be due to a connection problem, just reset the connection just in case
+ reset_required = True
+
+ # Use exponential backoff with a max timeout, plus a little bit of randomness
+ random_int = random.randint(0, 1000) / 1000
+ fail_sleep = 2**fail_count + random_int
+ if fail_sleep > max_fail_sleep:
+ fail_sleep = max_fail_sleep + random_int
+
+ try:
+ error = str(e).splitlines()[-1]
+ except IndexError:
+ error = str(e)
+
+ display.vvvvv(
+ "{action}: {desc} fail {e_type} '{err}', retrying in {sleep:.4} seconds...\n{tcb}".format(
+ action=task_action,
+ desc=action_desc,
+ e_type=type(e).__name__,
+ err=error,
+ sleep=fail_sleep,
+ tcb=traceback.format_exc(),
+ )
+ )
+
+ fail_count += 1
+ time.sleep(fail_sleep)
+
+ if last_error:
+ raise last_error
+
+ return None
+
+
+def _execute_command(
+ task_action: str,
+ connection: ConnectionBase,
+ command: str,
+) -> t.Tuple[int, str, str]:
+ """Runs a command on the Windows host and returned the result"""
+ display.vvvvv(f"{task_action}: running command: {command}")
+
+ # Need to wrap the command in our PowerShell encoded wrapper. This is done to align the command input to a
+ # common shell and to allow the psrp connection plugin to report the correct exit code without manually setting
+ # $LASTEXITCODE for just that plugin.
+ command = connection._shell._encode_script(command)
+
+ try:
+ rc, stdout, stderr = connection.exec_command(
+ command, in_data=None, sudoable=False
+ )
+ except RequestException as e:
+ # The psrp connection plugin should be doing this but until we can guarantee it does we just convert it here
+ # to ensure AnsibleConnectionFailure refers to actual connection errors.
+ raise AnsibleConnectionFailure(f"Failed to connect to the host: {e}")
+
+ rc = rc or 0
+ stdout = to_text(stdout, errors="surrogate_or_strict").strip()
+ stderr = to_text(stderr, errors="surrogate_or_strict").strip()
+
+ display.vvvvv(
+ f"{task_action}: command result - rc: {rc}, stdout: {stdout}, stderr: {stderr}"
+ )
+
+ return rc, stdout, stderr
+
+
+def _get_system_boot_time(
+ task_action: str,
+ connection: ConnectionBase,
+ boot_time_command: str,
+) -> str:
+ """Gets a unique identifier to represent the boot time of the Windows host"""
+ display.vvvv(f"{task_action}: getting boot time")
+ rc, stdout, stderr = _execute_command(task_action, connection, boot_time_command)
+
+ if rc != 0:
+ msg = f"{task_action}: failed to get host boot time info"
+ raise _ReturnResultException(msg, rc=rc, stdout=stdout, stderr=stderr)
+
+ display.vvvv(f"{task_action}: last boot time: {stdout}")
+ return stdout
+
+
+def _perform_reboot(
+ task_action: str,
+ connection: ConnectionBase,
+ reboot_command: str,
+ handle_abort: bool = True,
+) -> None:
+ """Runs the reboot command"""
+ display.vv(f"{task_action}: rebooting server...")
+
+ stdout = stderr = None
+ try:
+ rc, stdout, stderr = _execute_command(task_action, connection, reboot_command)
+
+ except AnsibleConnectionFailure as e:
+ # If the connection is closed too quickly due to the system being shutdown, carry on
+ display.vvvv(f"{task_action}: AnsibleConnectionFailure caught and handled: {e}")
+ rc = 0
+
+ if stdout:
+ try:
+ reboot_result = json.loads(stdout)
+ except getattr(json.decoder, "JSONDecodeError", ValueError):
+ # While the reboot command should output json it may have failed for some other reason. We continue
+ # reporting with that output instead
+ pass
+ else:
+ stdout = reboot_result.get("stdout", stdout)
+ stderr = reboot_result.get("stderr", stderr)
+ rc = int(reboot_result.get("rc", rc))
+
+ # Test for "A system shutdown has already been scheduled. (1190)" and handle it gracefully
+ if handle_abort and (rc == 1190 or (rc != 0 and stderr and "(1190)" in stderr)):
+ display.warning("A scheduled reboot was pre-empted by Ansible.")
+
+ # Try to abort (this may fail if it was already aborted)
+ rc, stdout, stderr = _execute_command(
+ task_action, connection, "shutdown.exe /a"
+ )
+ display.vvvv(
+ f"{task_action}: result from trying to abort existing shutdown - rc: {rc}, stdout: {stdout}, stderr: {stderr}"
+ )
+
+ return _perform_reboot(
+ task_action, connection, reboot_command, handle_abort=False
+ )
+
+ if rc != 0:
+ msg = f"{task_action}: Reboot command failed"
+ raise _ReturnResultException(msg, rc=rc, stdout=stdout, stderr=stderr)
+
+
+def _reset_connection(
+ task_action: str,
+ connection: ConnectionBase,
+ host_context: t.Dict[str, t.Any],
+ ignore_errors: bool = False,
+) -> None:
+ """Resets the connection handling any errors"""
+
+ def _wrap_conn_err(func, *args, **kwargs):
+ try:
+ func(*args, **kwargs)
+
+ except (AnsibleError, RequestException) as e:
+ if ignore_errors:
+ return False
+
+ raise AnsibleError(e)
+
+ return True
+
+ # While reset() should probably better handle this some connection plugins don't clear the existing connection on
+ # reset() leaving resources still in use on the target (WSMan shells). Instead we try to manually close the
+ # connection then call reset. If it fails once we want to skip closing to avoid a perpetual loop and just hope
+ # reset() brings us back into a good state. If it's successful we still want to try it again.
+ if host_context["do_close_on_reset"]:
+ display.vvvv(f"{task_action}: closing connection plugin")
+ try:
+ success = _wrap_conn_err(connection.close)
+
+ except Exception:
+ host_context["do_close_on_reset"] = False
+ raise
+
+ host_context["do_close_on_reset"] = success
+
+ # For some connection plugins (ssh) reset actually does something more than close so we also class that
+ display.vvvv(f"{task_action}: resetting connection plugin")
+ try:
+ _wrap_conn_err(connection.reset)
+
+ except AttributeError:
+ # Not all connection plugins have reset so we just ignore those, close should have done our job.
+ pass
+
+
+def _run_test_command(
+ task_action: str,
+ connection: ConnectionBase,
+ command: str,
+ expected: t.Optional[str] = None,
+) -> None:
+ """Runs the user specified test command until the host is able to run it properly"""
+ display.vvvv(f"{task_action}: attempting post-reboot test command")
+
+ rc, stdout, stderr = _execute_command(task_action, connection, command)
+
+ if rc != 0:
+ msg = f"{task_action}: Test command failed - rc: {rc}, stdout: {stdout}, stderr: {stderr}"
+ raise _TestCommandFailure(msg)
+
+ if expected and expected not in stdout:
+ msg = f"{task_action}: Test command failed - '{expected}' was not in stdout: {stdout}"
+ raise _TestCommandFailure(msg)
+
+
+def _set_connection_timeout(
+ task_action: str,
+ connection: ConnectionBase,
+ host_context: t.Dict[str, t.Any],
+ timeout: float,
+) -> None:
+ """Sets the connection plugin connection_timeout option and resets the connection"""
+ try:
+ current_connection_timeout = connection.get_option("connection_timeout")
+ except KeyError:
+ # Not all connection plugins implement this, just ignore the setting if it doesn't work
+ return
+
+ if timeout == current_connection_timeout:
+ return
+
+ display.vvvv(f"{task_action}: setting connect_timeout {timeout}")
+ connection.set_option("connection_timeout", timeout)
+
+ _reset_connection(task_action, connection, host_context, ignore_errors=True)