diff options
author | Sam Doran <sdoran@redhat.com> | 2021-03-19 20:09:18 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-19 20:09:18 +0100 |
commit | abacf6a108b038571a0c3daeae63da0897c8fcb6 (patch) | |
tree | c9da1813642dde72ff13f89ac03e4fee0e043f39 | |
parent | find - set proper default based on use_regex (#73961) (diff) | |
download | ansible-abacf6a108b038571a0c3daeae63da0897c8fcb6.tar.xz ansible-abacf6a108b038571a0c3daeae63da0897c8fcb6.zip |
Use ArgumentSpecValidator in AnsibleModule (#73703)
* Begin using ArgumentSpecValidator in AnsibleModule
* Add check parameters to ArgumentSpecValidator
Add additional parameters for specifying required and mutually exclusive parameters.
Add code to the .validate() method that runs these additional checks.
* Make errors related to unsupported parameters match existing behavior
Update the punctuation in the message slightly to make it more readable.
Add a property to ArgumentSpecValidator to hold valid parameter names.
* Set default values after performining checks
* FIx sanity test failure
* Use correct parameters when checking sub options
* Use a dict when iterating over check functions
Referencing by key names makes things a bit more readable IMO.
* Fix bug in comparison for sub options evaluation
* Add options_context to check functions
This allows the parent parameter to be added the the error message if a validation
error occurs in a sub option.
* Fix bug in apply_defaults behavior of sub spec validation
* Accept options_conext in get_unsupported_parameters()
If options_context is supplied, a tuple of parent key names of unsupported parameter will be
created. This allows the full "path" to the unsupported parameter to be reported.
* Build path to the unsupported parameter for error messages.
* Remove unused import
* Update recursive finder test
* Skip if running in check mode
This was done in the _check_arguments() method. That was moved to a function that has no
way of calling fail_json(), so it must be done outside of validation.
This is a silght change in behavior, but I believe the correct one.
Previously, only unsupported parameters would cause a failure. All other checks would not be executed
if the modlue did not support check mode. This would hide validation failures in check mode.
* The great purge
Remove all methods related to argument spec validation from AnsibleModule
* Keep _name and kind in the caller and out of the validator
This seems a bit awkward since this means the caller could end up with {name} and {kind} in
the error message if they don't run the messages through the .format() method
with name and kind parameters.
* Double moustaches work
I wasn't sure if they get stripped or not. Looks like they do. Neat trick.
* Add changelog
* Update unsupported parameter test
The error message changed to include name and kind.
* Remove unused import
* Add better documentation for ArgumentSpecValidator class
* Fix example
* Few more docs fixes
* Mark required and mutually exclusive attributes as private
* Mark validate functions as private
* Reorganize functions in validation.py
* Remove unused imports in basic.py related to argument spec validation
* Create errors is module_utils
We have errors in lib/ansible/errors/ but those cannot be used by modules.
* Update recursive finder test
* Move errors to file rather than __init__.py
* Change ArgumentSpecValidator.validate() interface
Raise AnsibleValidationErrorMultiple on validation error which contains all AnsibleValidationError
exceptions for validation failures.
Return the validated parameters if validation is successful rather than True/False.
Update docs and tests.
* Get attribute in loop so that the attribute name can also be used as a parameter
* Shorten line
* Update calling code in AnsibleModule for new validator interface
* Update calling code in validate_argument_spec based in new validation interface
* Base custom exception class off of Exception
* Call the __init__ method of the base Exception class to populate args
* Ensure no_log values are always updated
* Make custom exceptions more hierarchical
This redefines AnsibleError from lib/ansible/errors with a different signature since that cannot
be used by modules. This may be a bad idea. Maybe lib/ansible/errors should be moved to
module_utils, or AnsibleError defined in this commit should use the same signature as the original.
* Just go back to basing off Exception
* Return ValidationResult object on successful validation
Create a ValidationResult class.
Return a ValidationResult from ArgumentSpecValidator.validate() when validation is successful.
Update class and method docs.
Update unit tests based on interface change.
* Make it easier to get error objects from AnsibleValidationResultMultiple
This makes the interface cleaner when getting individual error objects contained in a single
AnsibleValidationResultMultiple instance.
* Define custom exception for each type of validation failure
These errors indicate where a validation error occured. Currently they are empty but could
contain specific data for each exception type in the future.
* Update tests based on (yet another) interface change
* Mark several more functions as private
These are all doing rather "internal" things. The ArgumentSpecValidator class is the preferred
public interface.
* Move warnings and deprecations to result object
Rather than calling deprecate() and warn() directly, store them on the result object so the
caller can decide what to do with them.
* Use subclass for module arg spec validation
The subclass uses global warning and deprecations feature
* Fix up docs
* Remove legal_inputs munging from _handle_aliases()
This is done in AnsibleModule by the _set_internal_properties() method. It only makes sense
to do that for an AnsibleModule instance (it should update the parameters before performing
validation) and shouldn't be done by the validator.
Create a private function just for getting legal inputs since that is done in a couple of places.
It may make sense store that on the ValidationResult object.
* Increase test coverage
* Remove unnecessary conditional
ci_complete
* Mark warnings and deprecations as private in the ValidationResult
They can be made public once we come up with a way to make them more generally useful,
probably by creating cusom objects to store the data in more structure way.
* Mark valid_parameter_names as private and populate it during initialization
* Use a global for storing the list of additonal checks to perform
This list is used by the main validate method as well as the sub spec validation.
19 files changed, 1034 insertions, 1118 deletions
diff --git a/changelogs/fragments/use-validator-in-ansiblemodule.yml b/changelogs/fragments/use-validator-in-ansiblemodule.yml new file mode 100644 index 0000000000..b5e31fb965 --- /dev/null +++ b/changelogs/fragments/use-validator-in-ansiblemodule.yml @@ -0,0 +1,5 @@ +major_changes: + - >- + AnsibleModule - use ``ArgumentSpecValidator`` class for validating argument spec and remove + private methods related to argument spec validation. Any modules using private methods + should now use the ``ArgumentSpecValidator`` class or the appropriate validation function. diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index 015c08ff7d..cd2a055334 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -90,6 +90,8 @@ from ansible.module_utils.common.text.converters import ( container_to_text as json_dict_bytes_to_unicode, ) +from ansible.module_utils.common.arg_spec import ModuleArgumentSpecValidator + from ansible.module_utils.common.text.formatters import ( lenient_lowercase, bytes_to_human, @@ -155,25 +157,15 @@ from ansible.module_utils.common.sys_info import ( ) from ansible.module_utils.pycompat24 import get_exception, literal_eval from ansible.module_utils.common.parameters import ( - _remove_values_conditions, - _sanitize_keys_conditions, - sanitize_keys, env_fallback, - get_unsupported_parameters, - get_type_validator, - handle_aliases, - list_deprecations, - list_no_log_values, remove_values, - set_defaults, - set_fallbacks, - validate_argument_types, - AnsibleFallbackNotFound, + sanitize_keys, DEFAULT_TYPE_VALIDATORS, PASS_VARS, PASS_BOOLS, ) +from ansible.module_utils.errors import AnsibleFallbackNotFound, AnsibleValidationErrorMultiple, UnsupportedError from ansible.module_utils.six import ( PY2, PY3, @@ -187,24 +179,6 @@ from ansible.module_utils.six import ( from ansible.module_utils.six.moves import map, reduce, shlex_quote from ansible.module_utils.common.validation import ( check_missing_parameters, - check_mutually_exclusive, - check_required_arguments, - check_required_by, - check_required_if, - check_required_one_of, - check_required_together, - count_terms, - check_type_bool, - check_type_bits, - check_type_bytes, - check_type_float, - check_type_int, - check_type_jsonarg, - check_type_list, - check_type_dict, - check_type_path, - check_type_raw, - check_type_str, safe_eval, ) from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses @@ -507,48 +481,43 @@ class AnsibleModule(object): # Save parameter values that should never be logged self.no_log_values = set() - self._load_params() - self._set_fallbacks() - - # append to legal_inputs and then possibly check against them - try: - self.aliases = self._handle_aliases() - except (ValueError, TypeError) as e: - # Use exceptions here because it isn't safe to call fail_json until no_log is processed - print('\n{"failed": true, "msg": "Module alias error: %s"}' % to_native(e)) - sys.exit(1) - - self._handle_no_log_values() - # check the locale as set by the current environment, and reset to # a known valid (LANG=C) if it's an invalid/unavailable locale self._check_locale() + self._load_params() self._set_internal_properties() - self._check_arguments() - # check exclusive early - if not bypass_checks: - self._check_mutually_exclusive(mutually_exclusive) + self.validator = ModuleArgumentSpecValidator(self.argument_spec, + self.mutually_exclusive, + self.required_together, + self.required_one_of, + self.required_if, + self.required_by, + ) - self._set_defaults(pre=True) + self.validation_result = self.validator.validate(self.params) + self.params.update(self.validation_result.validated_parameters) + self.no_log_values.update(self.validation_result._no_log_values) - # This is for backwards compatibility only. - self._CHECK_ARGUMENT_TYPES_DISPATCHER = DEFAULT_TYPE_VALIDATORS + try: + error = self.validation_result.errors[0] + except IndexError: + error = None - if not bypass_checks: - self._check_required_arguments() - self._check_argument_types() - self._check_argument_values() - self._check_required_together(required_together) - self._check_required_one_of(required_one_of) - self._check_required_if(required_if) - self._check_required_by(required_by) + # Fail for validation errors, even in check mode + if error: + msg = self.validation_result.errors.msg + if isinstance(error, UnsupportedError): + msg = "Unsupported parameters for ({name}) {kind}: {msg}".format(name=self._name, kind='module', msg=msg) - self._set_defaults(pre=False) + self.fail_json(msg=msg) + + if self.check_mode and not self.supports_check_mode: + self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name) - # deal with options sub-spec - self._handle_options() + # This is for backwards compatibility only. + self._CHECK_ARGUMENT_TYPES_DISPATCHER = DEFAULT_TYPE_VALIDATORS if not self.no_log: self._log_invocation() @@ -1274,42 +1243,6 @@ class AnsibleModule(object): self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" % to_native(e), exception=traceback.format_exc()) - def _handle_aliases(self, spec=None, param=None, option_prefix=''): - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - - # this uses exceptions as it happens before we can safely call fail_json - alias_warnings = [] - alias_deprecations = [] - alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings, alias_deprecations) - for option, alias in alias_warnings: - warn('Both option %s and its alias %s are set.' % (option_prefix + option, option_prefix + alias)) - - for deprecation in alias_deprecations: - deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], - version=deprecation.get('version'), date=deprecation.get('date'), - collection_name=deprecation.get('collection_name')) - - return alias_results - - def _handle_no_log_values(self, spec=None, param=None): - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - - try: - self.no_log_values.update(list_no_log_values(spec, param)) - except TypeError as te: - self.fail_json(msg="Failure when processing no_log parameters. Module invocation will be hidden. " - "%s" % to_native(te), invocation={'module_args': 'HIDDEN DUE TO FAILURE'}) - - for message in list_deprecations(spec, param): - deprecate(message['msg'], version=message.get('version'), date=message.get('date'), - collection_name=message.get('collection_name')) - def _set_internal_properties(self, argument_spec=None, module_parameters=None): if argument_spec is None: argument_spec = self.argument_spec @@ -1333,344 +1266,9 @@ class AnsibleModule(object): if not hasattr(self, PASS_VARS[k][0]): setattr(self, PASS_VARS[k][0], PASS_VARS[k][1]) - def _check_arguments(self, spec=None, param=None, legal_inputs=None): - unsupported_parameters = set() - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - if legal_inputs is None: - legal_inputs = self._legal_inputs - - unsupported_parameters = get_unsupported_parameters(spec, param, legal_inputs) - - if unsupported_parameters: - msg = "Unsupported parameters for (%s) module: %s" % (self._name, ', '.join(sorted(list(unsupported_parameters)))) - if self._options_context: - msg += " found in %s." % " -> ".join(self._options_context) - supported_parameters = list() - for key in sorted(spec.keys()): - if 'aliases' in spec[key] and spec[key]['aliases']: - supported_parameters.append("%s (%s)" % (key, ', '.join(sorted(spec[key]['aliases'])))) - else: - supported_parameters.append(key) - msg += " Supported parameters include: %s" % (', '.join(supported_parameters)) - self.fail_json(msg=msg) - - if self.check_mode and not self.supports_check_mode: - self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name) - - def _count_terms(self, check, param=None): - if param is None: - param = self.params - return count_terms(check, param) - - def _check_mutually_exclusive(self, spec, param=None): - if param is None: - param = self.params - - try: - check_mutually_exclusive(spec, param) - except TypeError as e: - msg = to_native(e) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - - def _check_required_one_of(self, spec, param=None): - if spec is None: - return - - if param is None: - param = self.params - - try: - check_required_one_of(spec, param) - except TypeError as e: - msg = to_native(e) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - - def _check_required_together(self, spec, param=None): - if spec is None: - return - if param is None: - param = self.params - - try: - check_required_together(spec, param) - except TypeError as e: - msg = to_native(e) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - - def _check_required_by(self, spec, param=None): - if spec is None: - return - if param is None: - param = self.params - - try: - check_required_by(spec, param) - except TypeError as e: - self.fail_json(msg=to_native(e)) - - def _check_required_arguments(self, spec=None, param=None): - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - - try: - check_required_arguments(spec, param) - except TypeError as e: - msg = to_native(e) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - - def _check_required_if(self, spec, param=None): - ''' ensure that parameters which conditionally required are present ''' - if spec is None: - return - if param is None: - param = self.params - - try: - check_required_if(spec, param) - except TypeError as e: - msg = to_native(e) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - - def _check_argument_values(self, spec=None, param=None): - ''' ensure all arguments have the requested values, and there are no stray arguments ''' - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - for (k, v) in spec.items(): - choices = v.get('choices', None) - if choices is None: - continue - if isinstance(choices, SEQUENCETYPE) and not isinstance(choices, (binary_type, text_type)): - if k in param: - # Allow one or more when type='list' param with choices - if isinstance(param[k], list): - diff_list = ", ".join([item for item in param[k] if item not in choices]) - if diff_list: - choices_str = ", ".join([to_native(c) for c in choices]) - msg = "value of %s must be one or more of: %s. Got no match for: %s" % (k, choices_str, diff_list) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - elif param[k] not in choices: - # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking - # the value. If we can't figure this out, module author is responsible. - lowered_choices = None - if param[k] == 'False': - lowered_choices = lenient_lowercase(choices) - overlap = BOOLEANS_FALSE.intersection(choices) - if len(overlap) == 1: - # Extract from a set - (param[k],) = overlap - - if param[k] == 'True': - if lowered_choices is None: - lowered_choices = lenient_lowercase(choices) - overlap = BOOLEANS_TRUE.intersection(choices) - if len(overlap) == 1: - (param[k],) = overlap - - if param[k] not in choices: - choices_str = ", ".join([to_native(c) for c in choices]) - msg = "value of %s must be one of: %s, got: %s" % (k, choices_str, param[k]) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - else: - msg = "internal error: choices for argument %s are not iterable: %s" % (k, choices) - if self._options_context: - msg += " found in %s" % " -> ".join(self._options_context) - self.fail_json(msg=msg) - def safe_eval(self, value, locals=None, include_exceptions=False): return safe_eval(value, locals, include_exceptions) - def _check_type_str(self, value, param=None, prefix=''): - opts = { - 'error': False, - 'warn': False, - 'ignore': True - } - - # Ignore, warn, or error when converting to a string. - allow_conversion = opts.get(self._string_conversion_action, True) - try: - return check_type_str(value, allow_conversion) - except TypeError: - common_msg = 'quote the entire value to ensure it does not change.' - from_msg = '{0!r}'.format(value) - to_msg = '{0!r}'.format(to_text(value)) - - if param is not None: - if prefix: - param = '{0}{1}'.format(prefix, param) - - from_msg = '{0}: {1!r}'.format(param, value) - to_msg = '{0}: {1!r}'.format(param, to_text(value)) - - if self._string_conversion_action == 'error': - msg = common_msg.capitalize() - raise TypeError(to_native(msg)) - elif self._string_conversion_action == 'warn': - msg = ('The value "{0}" (type {1.__class__.__name__}) was converted to "{2}" (type string). ' - 'If this does not look like what you expect, {3}').format(from_msg, value, to_msg, common_msg) - self.warn(to_native(msg)) - return to_native(value, errors='surrogate_or_strict') - - def _check_type_list(self, value): - return check_type_list(value) - - def _check_type_dict(self, value): - return check_type_dict(value) - - def _check_type_bool(self, value): - return check_type_bool(value) - - def _check_type_int(self, value): - return check_type_int(value) - - def _check_type_float(self, value): - return check_type_float(value) - - def _check_type_path(self, value): - return check_type_path(value) - - def _check_type_jsonarg(self, value): - return check_type_jsonarg(value) - - def _check_type_raw(self, value): - return check_type_raw(value) - - def _check_type_bytes(self, value): - return check_type_bytes(value) - - def _check_type_bits(self, value): - return check_type_bits(value) - - def _handle_options(self, argument_spec=None, params=None, prefix=''): - ''' deal with options to create sub spec ''' - if argument_spec is None: - argument_spec = self.argument_spec - if params is None: - params = self.params - - for (k, v) in argument_spec.items(): - wanted = v.get('type', None) - if wanted == 'dict' or (wanted == 'list' and v.get('elements', '') == 'dict'): - spec = v.get('options', None) - if v.get('apply_defaults', False): - if spec is not None: - if params.get(k) is None: - params[k] = {} - else: - continue - elif spec is None or k not in params or params[k] is None: - continue - - self._options_context.append(k) - - if isinstance(params[k], dict): - elements = [params[k]] - else: - elements = params[k] - - for idx, param in enumerate(elements): - if not isinstance(param, dict): - self.fail_json(msg="value of %s must be of type dict or list of dict" % k) - - new_prefix = prefix + k - if wanted == 'list': - new_prefix += '[%d]' % idx - new_prefix += '.' - - self._set_fallbacks(spec, param) - options_aliases = self._handle_aliases(spec, param, option_prefix=new_prefix) - - options_legal_inputs = list(spec.keys()) + list(options_aliases.keys()) - - self._check_arguments(spec, param, options_legal_inputs) - - # check exclusive early - if not self.bypass_checks: - self._check_mutually_exclusive(v.get('mutually_exclusive', None), param) - - self._set_defaults(pre=True, spec=spec, param=param) - - if not self.bypass_checks: - self._check_required_arguments(spec, param) - self._check_argument_types(spec, param, new_prefix) - self._check_argument_values(spec, param) - - self._check_required_together(v.get('required_together', None), param) - self._check_required_one_of(v.get('required_one_of', None), param) - self._check_required_if(v.get('required_if', None), param) - self._check_required_by(v.get('required_by', None), param) - - self._set_defaults(pre=False, spec=spec, param=param) - - # handle multi level options (sub argspec) - self._handle_options(spec, param, new_prefix) - self._options_context.pop() - - def _get_wanted_type(self, wanted, k): - # Use the private method for 'str' type to handle the string conversion warning. - if wanted == 'str': - type_checker, wanted = self._check_type_str, 'str' - else: - type_checker, wanted = get_type_validator(wanted) - if type_checker is None: - self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) - - return type_checker, wanted - - def _check_argument_types(self, spec=None, param=None, prefix=''): - ''' ensure all arguments have the requested type ''' - - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - - errors = [] - validate_argument_types(spec, param, errors=errors) - - if errors: - self.fail_json(msg=errors[0]) - - def _set_defaults(self, pre=True, spec=None, param=None): - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - - # The interface for set_defaults is different than _set_defaults() - # The third parameter controls whether or not defaults are actually set. - set_default = not pre - self.no_log_values.update(set_defaults(spec, param, set_default)) - - def _set_fallbacks(self, spec=None, param=None): - if spec is None: - spec = self.argument_spec - if param is None: - param = self.params - - self.no_log_values.update(set_fallbacks(spec, param)) - def _load_params(self): ''' read the input and set the params attribute. diff --git a/lib/ansible/module_utils/common/arg_spec.py b/lib/ansible/module_utils/common/arg_spec.py index 54bf80a587..c4d4a247ed 100644 --- a/lib/ansible/module_utils/common/arg_spec.py +++ b/lib/ansible/module_utils/common/arg_spec.py @@ -5,71 +5,146 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type - from copy import deepcopy -from ansible.module_utils.common._collections_compat import ( - Sequence, -) - from ansible.module_utils.common.parameters import ( - get_unsupported_parameters, - handle_aliases, - list_no_log_values, - remove_values, - set_defaults, + _ADDITIONAL_CHECKS, + _get_legal_inputs, + _get_unsupported_parameters, + _handle_aliases, + _list_no_log_values, + _set_defaults, + _validate_argument_types, + _validate_argument_values, + _validate_sub_spec, set_fallbacks, - validate_argument_types, - validate_argument_values, - validate_sub_spec, ) from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.warnings import deprecate, warn + from ansible.module_utils.common.validation import ( + check_mutually_exclusive, check_required_arguments, + check_required_by, + check_required_if, + check_required_one_of, + check_required_together, +) + +from ansible.module_utils.errors import ( + AliasError, + AnsibleValidationErrorMultiple, + MutuallyExclusiveError, + NoLogError, + RequiredByError, + RequiredDefaultError, + RequiredError, + RequiredIfError, + RequiredOneOfError, + RequiredTogetherError, + UnsupportedError, ) -from ansible.module_utils.six import string_types +class ValidationResult: + """Result of argument spec validation. -class ArgumentSpecValidator(): - """Argument spec validation class""" + :param parameters: Terms to be validated and coerced to the correct type. + :type parameters: dict - def __init__(self, argument_spec, parameters): - self._error_messages = [] + """ + + def __init__(self, parameters): self._no_log_values = set() - self.argument_spec = argument_spec - # Make a copy of the original parameters to avoid changing them - self._validated_parameters = deepcopy(parameters) self._unsupported_parameters = set() - - @property - def error_messages(self): - return self._error_messages + self._validated_parameters = deepcopy(parameters) + self._deprecations = [] + self._warnings = [] + self.errors = AnsibleValidationErrorMultiple() @property def validated_parameters(self): return self._validated_parameters - def _add_error(self, error): - if isinstance(error, string_types): - self._error_messages.append(error) - elif isinstance(error, Sequence): - self._error_messages.extend(error) - else: - raise ValueError('Error messages must be a string or sequence not a %s' % type(error)) + @property + def unsupported_parameters(self): + return self._unsupported_parameters + + @property + def error_messages(self): + return self.errors.messages + + +class ArgumentSpecValidator: + """Argument spec validation class + + Creates a validator based on the ``argument_spec`` that can be used to + validate a number of parameters using the ``validate()`` method. + + :param argument_spec: Specification of valid parameters and their type. May + include nested argument specs. + :type argument_spec: dict + + :param mutually_exclusive: List or list of lists of terms that should not + be provided together. + :type mutually_exclusive: list, optional + + :param required_together: List of lists of terms that are required together. + :type required_together: list, optional - def _sanitize_error_messages(self): - self._error_messages = remove_values(self._error_messages, self._no_log_values) + :param required_one_of: List of lists of terms, one of which in each list + is required. + :type required_one_of: list, optional - def validate(self, *args, **kwargs): - """Validate module parameters against argument spec. + :param required_if: List of lists of ``[parameter, value, [parameters]]`` where + one of [parameters] is required if ``parameter`` == ``value``. + :type required_if: list, optional + + :param required_by: Dictionary of parameter names that contain a list of + parameters required by each key in the dictionary. + :type required_by: dict, optional + """ + + def __init__(self, argument_spec, + mutually_exclusive=None, + required_together=None, + required_one_of=None, + required_if=None, + required_by=None, + ): + + self._mutually_exclusive = mutually_exclusive + self._required_together = required_together + self._required_one_of = required_one_of + self._required_if = required_if + self._required_by = required_by + self._valid_parameter_names = set() + self.argument_spec = argument_spec + + for key in sorted(self.argument_spec.keys()): + aliases = self.argument_spec[key].get('aliases') + if aliases: + self._valid_parameter_names.update(["{key} ({aliases})".format(key=key, aliases=", ".join(sorted(aliases)))]) + else: + self._valid_parameter_names.update([key]) + + def validate(self, parameters, *args, **kwargs): + """Validate module parameters against argument spec. Returns a + ValidationResult object. + + Error messages in the ValidationResult may contain no_log values and should be + sanitized before logging or displaying. :Example: - validator = ArgumentSpecValidator(argument_spec, parameters) - passeded = validator.validate() + validator = ArgumentSpecValidator(argument_spec) + result = validator.validate(parameters) + + if result.error_messages: + sys.exit("Validation failed: {0}".format(", ".join(result.error_messages)) + + valid_params = result.validated_parameters :param argument_spec: Specification of parameters, type, and valid values :type argument_spec: dict @@ -77,58 +152,104 @@ class ArgumentSpecValidator(): :param parameters: Parameters provided to the role :type parameters: dict - :returns: True if no errors were encountered, False if any errors were encountered. - :rtype: bool + :return: Object containing validated parameters. + :rtype: ValidationResult """ - self._no_log_values.update(set_fallbacks(self.argument_spec, self._validated_parameters)) + result = ValidationResult(parameters) + + result._no_log_values.update(set_fallbacks(self.argument_spec, result._validated_parameters)) alias_warnings = [] alias_deprecations = [] try: - alias_results, legal_inputs = handle_aliases(self.argument_spec, self._validated_parameters, alias_warnings, alias_deprecations) + aliases = _handle_aliases(self.argument_spec, result._validated_parameters, alias_warnings, alias_deprecations) except (TypeError, ValueError) as e: - alias_results = {} - legal_inputs = None - self._add_error(to_native(e)) + aliases = {} + result.errors.append(AliasError(to_native(e))) + + legal_inputs = _get_legal_inputs(self.argument_spec, result._validated_parameters, aliases) for option, alias in alias_warnings: - warn('Both option %s and its alias %s are set.' % (option, alias)) + result._warnings.append({'option': option, 'alias': alias}) for deprecation in alias_deprecations: - deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], - version=deprecation.get('version'), date=deprecation.get('date'), - collection_name=deprecation.get('collection_name')) + result._deprecations.append({ + 'name': deprecation['name'], + 'version': deprecation.get('version'), + 'date': deprecation.get('date'), + 'collection_name': deprecation.get('collection_name'), + }) - self._no_log_values.update(list_no_log_values(self.argument_spec, self._validated_parameters)) + try: + result._no_log_values.update(_list_no_log_values(self.argument_spec, result._validated_parameters)) + except TypeError as te: + result.errors.append(NoLogError(to_native(te))) + + try: + result._unsupported_parameters.update(_get_unsupported_parameters(self.argument_spec, result._validated_parameters, legal_inputs)) + except TypeError as te: + result.errors.append(RequiredDefaultError(to_native(te))) + except ValueError as ve: + result.errors.append(AliasError(to_native(ve))) - if legal_inputs is None: - legal_inputs = list(alias_results.keys()) + list(self.argument_spec.keys()) - self._unsupported_parameters.update(get_unsupported_parameters(self.argument_spec, self._validated_parameters, legal_inputs)) + try: + check_mutually_exclusive(self._mutually_exclusive, result._validated_parameters) + except TypeError as te: + result.errors.append(MutuallyExclusiveError(to_native(te))) - self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters, False)) + result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters, False)) try: - check_required_arguments(self.argument_spec, self._validated_parameters) + check_required_arguments(self.argument_spec, result._validated_parameters) except TypeError as e: - self._add_error(to_native(e)) + result.errors.append(RequiredError(to_native(e))) + + _validate_argument_types(self.argument_spec, result._validated_parameters, errors=result.errors) + _validate_argument_values(self.argument_spec, result._validated_parameters, errors=result.errors) + + for check in _ADDITIONAL_CHECKS: + try: + check['func'](getattr(self, "_{attr}".format(attr=check['attr'])), result._validated_parameters) + except TypeError as te: + result.errors.append(check['err'](to_native(te))) + + result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters)) + + _validate_sub_spec(self.argument_spec, result._validated_parameters, + errors=result.errors, + no_log_values=result._no_log_values, + unsupported_parameters=result._unsupported_parameters) + + if result._unsupported_parameters: + flattened_names = [] + for item in result._unsupported_parameters: + if isinstance(item, tuple): + flattened_names.append(".".join(item)) + else: + flattened_names.append(item) + + unsupported_string = ", ".join(sorted(list(flattened_names))) + supported_string = ", ".join(self._valid_parameter_names) + result.errors.append( + UnsupportedError("{0}. Supported parameters include: {1}.".format(unsupported_string, supported_string))) + + return result - validate_argument_types(self.argument_spec, self._validated_parameters, errors=self._error_messages) - validate_argument_values(self.argument_spec, self._validated_parameters, errors=self._error_messages) - self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters)) +class ModuleArgumentSpecValidator(ArgumentSpecValidator): + def __init__(self, *args, **kwargs): + super(ModuleArgumentSpecValidator, self).__init__(*args, **kwargs) - validate_sub_spec(self.argument_spec, self._validated_parameters, - errors=self._error_messages, - no_log_values=self._no_log_values, - unsupported_parameters=self._unsupported_parameters) + def validate(self, parameters): + result = super(ModuleArgumentSpecValidator, self).validate(parameters) - if self._unsupported_parameters: - self._add_error('Unsupported parameters: %s' % ', '.join(sorted(list(self._unsupported_parameters)))) + for d in result._deprecations: + deprecate("Alias '{name}' is deprecated. See the module docs for more information".format(name=d['name']), + version=d.get('version'), date=d.get('date'), + collection_name=d.get('collection_name')) - self._sanitize_error_messages() + for w in result._warnings: + warn('Both option {option} and its alias {alias} are set.'.format(option=w['option'], alias=w['alias'])) - if self.error_messages: - return False - else: - return True + return result diff --git a/lib/ansible/module_utils/common/parameters.py b/lib/ansible/module_utils/common/parameters.py index 4fa5dab84c..e297573410 100644 --- a/lib/ansible/module_utils/common/parameters.py +++ b/lib/ansible/module_utils/common/parameters.py @@ -15,6 +15,22 @@ from ansible.module_utils.common.collections import is_iterable from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.formatters import lenient_lowercase from ansible.module_utils.common.warnings import warn +from ansible.module_utils.errors import ( + AliasError, + AnsibleFallbackNotFound, + AnsibleValidationErrorMultiple, + ArgumentTypeError, + ArgumentValueError, + ElementError, + MutuallyExclusiveError, + NoLogError, + RequiredByError, + RequiredError, + RequiredIfError, + RequiredOneOfError, + RequiredTogetherError, + SubParameterTypeError, +) from ansible.module_utils.parsing.convert_bool import BOOLEANS_FALSE, BOOLEANS_TRUE from ansible.module_utils.common._collections_compat import ( @@ -59,6 +75,13 @@ from ansible.module_utils.common.validation import ( # Python2 & 3 way to get NoneType NoneType = type(None) +_ADDITIONAL_CHECKS = ( + {'func': check_required_together, 'attr': 'required_together', 'err': RequiredTogetherError}, + {'func': check_required_one_of, 'attr': 'required_one_of', 'err': RequiredOneOfError}, + {'func': check_required_if, 'attr': 'required_if', 'err': RequiredIfError}, + {'func': check_required_by, 'attr': 'required_by', 'err': RequiredByError}, +) + # if adding boolean attribute, also add to PASS_BOOL # some of this dupes defaults from controller config PASS_VARS = { @@ -97,8 +120,221 @@ DEFAULT_TYPE_VALIDATORS = { } -class AnsibleFallbackNotFound(Exception): - pass +def _get_type_validator(wanted): + """Returns the callable used to validate a wanted type and the type name. + + :arg wanted: String or callable. If a string, get the corresponding + validation function from DEFAULT_TYPE_VALIDATORS. If callable, + get the name of the custom callable and return that for the type_checker. + + :returns: Tuple of callable function or None, and a string that is the name + of the wanted type. + """ + + # Use one our our builtin validators. + if not callable(wanted): + if wanted is None: + # Default type for parameters + wanted = 'str' + + type_checker = DEFAULT_TYPE_VALIDATORS.get(wanted) + + # Use the custom callable for validation. + else: + type_checker = wanted + wanted = getattr(wanted, '__name__', to_native(type(wanted))) + + return type_checker, wanted + + +def _get_legal_inputs(argument_spec, parameters, aliases=None): + if aliases is None: + aliases = _handle_aliases(argument_spec, parameters) + + return list(aliases.keys()) + list(argument_spec.keys()) + + +def _get_unsupported_parameters(argument_spec, parameters, legal_inputs=None, options_context=None): + """Check keys in parameters against those provided in legal_inputs + to ensure they contain legal values. If legal_inputs are not supplied, + they will be generated using the argument_spec. + + :arg argument_spec: Dictionary of parameters, their type, and valid values. + :arg parameters: Dictionary of parameters. + :arg legal_inputs: List of valid key names property names. Overrides values + in argument_spec. + :arg options_context: List of parent keys for tracking the context of where + a parameter is defined. + + :returns: Set of unsupported parameters. Empty set if no unsupported parameters + are found. + """ + + if legal_inputs is None: + legal_inputs = _get_legal_inputs(argument_spec, parameters) + + unsupported_parameters = set() + for k in parameters.keys(): + if k not in legal_inputs: + context = k + if options_context: + context = tuple(options_context + [k]) + + unsupported_parameters.add(context) + + return unsupported_parameters + + +def _handle_aliases(argument_spec, parameters, alias_warnings=None, alias_deprecations=None): + """Process aliases from an argument_spec including warnings and deprecations. + + Modify ``parameters`` by adding a new key for each alias with the supplied + value from ``parameters``. + + If a list is provided to the alias_warnings parameter, it will be filled with tuples + (option, alias) in every case where both an option and its alias are specified. + + If a list is provided to alias_deprecations, it will be populated with dictionaries, + each containing deprecation information for each alias found in argument_spec. + + :param argument_spec: Dictionary of parameters, their type, and valid values. + :type argument_spec: dict + + :param parameters: Dictionary of parameters. + :type parameters: dict + + :param alias_warnings: + :type alias_warnings: list + + :param alias_deprecations: + :type alias_deprecations: list + """ + + aliases_results = {} # alias:canon + + for (k, v) in argument_spec.items(): + aliases = v.get('aliases', None) + default = v.get('default', None) + required = v.get('required', False) + + if alias_deprecations is not None: + for alias in argument_spec[k].get('deprecated_aliases', []): + if alias.get('name') in parameters: + alias_deprecations.append(alias) + + if default is not None and required: + # not alias specific but this is a good place to check this + raise ValueError("internal error: required and default are mutually exclusive for %s" % k) + + if aliases is None: + continue + + if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)): + raise TypeError('internal error: aliases must be a list or tuple') + + for alias in aliases: + aliases_results[alias] = k + if alias in parameters: + if k in parameters and alias_warnings is not None: + alias_warnings.append((k, alias)) + parameters[k] = parameters[alias] + + return aliases_results + + +def _list_deprecations(argument_spec, parameters, prefix=''): + """Return a list of deprecations + + :arg argument_spec: An argument spec dictionary + :arg parameters: Dictionary of parameters + + :returns: List of dictionaries containing a message and version in which + the deprecated parameter will be removed, or an empty list:: + + [{'msg': "Param 'deptest' is deprecated. See the module docs for more information", 'version': '2.9'}] + """ + + deprecations = [] + for arg_name, arg_opts in argument_spec.items(): + if arg_name in parameters: + if prefix: + sub_prefix = '%s["%s"]' % (prefix, arg_name) + else: + sub_prefix = arg_name + if arg_opts.get('removed_at_date') is not None: + deprecations.append({ + 'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix, + 'date': arg_opts.get('removed_at_date'), + 'collection_name': arg_opts.get('removed_from_collection'), + }) + elif arg_opts.get('removed_in_version') is not None: + deprecations.append({ + 'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix, + 'version': arg_opts.get('removed_in_version'), + 'collection_name': arg_opts.get('removed_from_collection'), + }) + # Check sub-argument spec + sub_argument_spec = arg_opts.get('options') + if sub_argument_spec is not None: + sub_arguments = parameters[arg_name] + if isinstance(sub_arguments, Mapping): + sub_arguments = [sub_arguments] + if isinstance(sub_arguments, list): + for sub_params in sub_arguments: + if isinstance(sub_params, Mapping): + deprecations.extend(_list_deprecations(sub_argument_spec, sub_params, prefix=sub_prefix)) + + return deprecations + + +def _list_no_log_values(argument_spec, params): + """Return set of no log values + + :arg argument_spec: An argument spec dictionary + :arg params: Dictionary of all parameters + + :returns: Set of strings that should be hidden from output:: + + {'secret_dict_value', 'secret_list_item_one', 'secret_list_item_two', 'secret_string'} + """ + + no_log_values = set() + for arg_name, arg_opts in argument_spec.items(): + if arg_opts.get('no_log', False): + # Find the value for the no_log'd param + no_log_object = params.get(arg_name, None) + + if no_log_object: + try: + no_log_values.update(_return_datastructure_name(no_log_object)) + except TypeError as e: + raise TypeError('Failed to convert "%s": %s' % (arg_name, to_native(e))) + + # Get no_log values from suboptions + sub_argument_spec = arg_opts.get('options') + if sub_argument_spec is not None: + wanted_type = arg_opts.get('type') + sub_parameters = params.get(arg_name) + + if sub_parameters is not None: + if wanted_type == 'dict' or (wanted_type == 'list' and arg_opts.get('elements', '') == 'dict'): + # Sub parameters can be a dict or list of dicts. Ensure parameters are always a list. + if not isinstance(sub_parameters, list): + sub_parameters = [sub_parameters] + + for sub_param in sub_parameters: + # Validate dict fields in case they came in as strings + + if isinstance(sub_param, string_types): + sub_param = check_type_dict(sub_param) + + if not isinstance(sub_param, Mapping): + raise TypeError("Value '{1}' in the sub parameter field '{0}' must by a {2}, " + "not '{1.__class__.__name__}'".format(arg_name, sub_param, wanted_type)) + + no_log_values.update(_list_no_log_values(sub_argument_spec, sub_param)) + + return no_log_values def _return_datastructure_name(obj): @@ -217,79 +453,7 @@ def _remove_values_conditions(value, no_log_strings, deferred_removals): return value -def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals): - """ Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """ - if isinstance(value, (text_type, binary_type)): - return value - - if isinstance(value, Sequence): - if isinstance(value, MutableSequence): - new_value = type(value)() - else: - new_value = [] # Need a mutable value - deferred_removals.append((value, new_value)) - return new_value - - if isinstance(value, Set): - if isinstance(value, MutableSet): - new_value = type(value)() - else: - new_value = set() # Need a mutable value - deferred_removals.append((value, new_value)) - return new_value - - if isinstance(value, Mapping): - if isinstance(value, MutableMapping): - new_value = type(value)() - else: - new_value = {} # Need a mutable value - deferred_removals.append((value, new_value)) - return new_value - - if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): - return value - - if isinstance(value, (datetime.datetime, datetime.date)): - return value - - raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) - - -def env_fallback(*args, **kwargs): - """Load value from environment variable""" - - for arg in args: - if arg in os.environ: - return os.environ[arg] - raise AnsibleFallbackNotFound - - -def set_fallbacks(argument_spec, parameters): - no_log_values = set() - for param, value in argument_spec.items(): - fallback = value.get('fallback', (None,)) - fallback_strategy = fallback[0] - fallback_args = [] - fallback_kwargs = {} - if param not in parameters and fallback_strategy is not None: - for item in fallback[1:]: - if isinstance(item, dict): - fallback_kwargs = item - else: - fallback_args = item - try: - fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs) - except AnsibleFallbackNotFound: - continue - else: - if value.get('no_log', False) and fallback_value: - no_log_values.add(fallback_value) - parameters[param] = fallback_value - - return no_log_values - - -def set_defaults(argument_spec, parameters, set_default=True): +def _set_defaults(argument_spec, parameters, set_default=True): """Set default values for parameters when no value is supplied. Modifies parameters directly. @@ -326,284 +490,50 @@ def set_defaults(argument_spec, parameters, set_default=True): return no_log_values -def list_no_log_values(argument_spec, params): - """Return set of no log values - - :arg argument_spec: An argument spec dictionary from a module - :arg params: Dictionary of all parameters - - :returns: Set of strings that should be hidden from output:: - - {'secret_dict_value', 'secret_list_item_one', 'secret_list_item_two', 'secret_string'} - """ - - no_log_values = set() - for arg_name, arg_opts in argument_spec.items(): - if arg_opts.get('no_log', False): - # Find the value for the no_log'd param - no_log_object = params.get(arg_name, None) - - if no_log_object: - try: - no_log_values.update(_return_datastructure_name(no_log_object)) - except TypeError as e: - raise TypeError('Failed to convert "%s": %s' % (arg_name, to_native(e))) - - # Get no_log values from suboptions - sub_argument_spec = arg_opts.get('options') - if sub_argument_spec is not None: - wanted_type = arg_opts.get('type') - sub_parameters = params.get(arg_name) - - if sub_parameters is not None: - if wanted_type == 'dict' or (wanted_type == 'list' and arg_opts.get('elements', '') == 'dict'): - # Sub parameters can be a dict or list of dicts. Ensure parameters are always a list. - if not isinstance(sub_parameters, list): - sub_parameters = [sub_parameters] - - for sub_param in sub_parameters: - # Validate dict fields in case they came in as strings - - if isinstance(sub_param, string_types): - sub_param = check_type_dict(sub_param) - - if not isinstance(sub_param, Mapping): - raise TypeError("Value '{1}' in the sub parameter field '{0}' must by a {2}, " - "not '{1.__class__.__name__}'".format(arg_name, sub_param, wanted_type)) - - no_log_values.update(list_no_log_values(sub_argument_spec, sub_param)) - - return no_log_values - - -def list_deprecations(argument_spec, parameters, prefix=''): - """Return a list of deprecations - - :arg argument_spec: An argument spec dictionary from a module - :arg parameters: Dictionary of parameters - - :returns: List of dictionaries containing a message and version in which - the deprecated parameter will be removed, or an empty list:: - - [{'msg': "Param 'deptest' is deprecated. See the module docs for more information", 'version': '2.9'}] - """ - - deprecations = [] - for arg_name, arg_opts in argument_spec.items(): - if arg_name in parameters: - if prefix: - sub_prefix = '%s["%s"]' % (prefix, arg_name) - else: - sub_prefix = arg_name - if arg_opts.get('removed_at_date') is not None: - deprecations.append({ - 'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix, - 'date': arg_opts.get('removed_at_date'), - 'collection_name': arg_opts.get('removed_from_collection'), - }) - elif arg_opts.get('removed_in_version') is not None: - deprecations.append({ - 'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix, - 'version': arg_opts.get('removed_in_version'), - 'collection_name': arg_opts.get('removed_from_collection'), - }) - # Check sub-argument spec - sub_argument_spec = arg_opts.get('options') - if sub_argument_spec is not None: - sub_arguments = parameters[arg_name] - if isinstance(sub_arguments, Mapping): - sub_arguments = [sub_arguments] - if isinstance(sub_arguments, list): - for sub_params in sub_arguments: - if isinstance(sub_params, Mapping): - deprecations.extend(list_deprecations(sub_argument_spec, sub_params, prefix=sub_prefix)) - - return deprecations - - -def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()): - """ Sanitize the keys in a container object by removing no_log values from key names. - - This is a companion function to the `remove_values()` function. Similar to that function, - we make use of deferred_removals to avoid hitting maximum recursion depth in cases of - large data structures. - - :param obj: The container object to sanitize. Non-container objects are returned unmodified. - :param no_log_strings: A set of string values we do not want logged. - :param ignore_keys: A set of string values of keys to not sanitize. - - :returns: An object with sanitized keys. - """ - - deferred_removals = deque() - - no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] - new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals) - - while deferred_removals: - old_data, new_data = deferred_removals.popleft() +def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals): + """ Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """ + if isinstance(value, (text_type, binary_type)): + return value - if isinstance(new_data, Mapping): - for old_key, old_elem in old_data.items(): - if old_key in ignore_keys or old_key.startswith('_ansible'): - new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) - else: - # Sanitize the old key. We take advantage of the sanitizing code in - # _remove_values_conditions() rather than recreating it here. - new_key = _remove_values_conditions(old_key, no_log_strings, None) - new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) + if isinstance(value, Sequence): + if isinstance(value, MutableSequence): + new_value = type(value)() else: - for elem in old_data: - new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals) - if isinstance(new_data, MutableSequence): - new_data.append(new_elem) - elif isinstance(new_data, MutableSet): - new_data.add(new_elem) - else: - raise TypeError('Unknown container type encountered when removing private values from keys') - - return new_value - - -def remove_values(value, no_log_strings): - """ Remove strings in no_log_strings from value. If value is a container - type, then remove a lot more. - - Use of deferred_removals exists, rather than a pure recursive solution, - because of the potential to hit the maximum recursion depth when dealing with - large amounts of data (see issue #24560). - """ - - deferred_removals = deque() - - no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] - new_value = _remove_values_conditions(value, no_log_strings, deferred_removals) + new_value = [] # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value - while deferred_removals: - old_data, new_data = deferred_removals.popleft() - if isinstance(new_data, Mapping): - for old_key, old_elem in old_data.items(): - new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals) - new_data[old_key] = new_elem + if isinstance(value, Set): + if isinstance(value, MutableSet): + new_value = type(value)() else: - for elem in old_data: - new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals) - if isinstance(new_data, MutableSequence): - new_data.append(new_elem) - elif isinstance(new_data, MutableSet): - new_data.add(new_elem) - else: - raise TypeError('Unknown container type encountered when removing private values from output') - - return new_value - - -def handle_aliases(argument_spec, parameters, alias_warnings=None, alias_deprecations=None): - """Return a two item tuple. The first is a dictionary of aliases, the second is - a list of legal inputs. - - Modify supplied parameters by adding a new key for each alias. - - If a list is provided to the alias_warnings parameter, it will be filled with tuples - (option, alias) in every case where both an option and its alias are specified. - - If a list is provided to alias_deprecations, it will be populated with dictionaries, - each containing deprecation information for each alias found in argument_spec. - """ - - legal_inputs = ['_ansible_%s' % k for k in PASS_VARS] - aliases_results = {} # alias:canon - - for (k, v) in argument_spec.items(): - legal_inputs.append(k) - aliases = v.get('aliases', None) - default = v.get('default', None) - required = v.get('required', False) - - if alias_deprecations is not None: - for alias in argument_spec[k].get('deprecated_aliases', []): - if alias.get('name') in parameters: - alias_deprecations.append(alias) - - if default is not None and required: - # not alias specific but this is a good place to check this - raise ValueError("internal error: required and default are mutually exclusive for %s" % k) - - if aliases is None: - continue - - if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)): - raise TypeError('internal error: aliases must be a list or tuple') - - for alias in aliases: - legal_inputs.append(alias) - aliases_results[alias] = k - if alias in parameters: - if k in parameters and alias_warnings is not None: - alias_warnings.append((k, alias)) - parameters[k] = parameters[alias] - - return aliases_results, legal_inputs - - -def get_unsupported_parameters(argument_spec, parameters, legal_inputs=None): - """Check keys in parameters against those provided in legal_inputs - to ensure they contain legal values. If legal_inputs are not supplied, - they will be generated using the argument_spec. - - :arg argument_spec: Dictionary of parameters, their type, and valid values. - :arg parameters: Dictionary of parameters. - :arg legal_inputs: List of valid key names property names. Overrides values - in argument_spec. - - :returns: Set of unsupported parameters. Empty set if no unsupported parameters - are found. - """ - - if legal_inputs is None: - aliases, legal_inputs = handle_aliases(argument_spec, parameters) - - unsupported_parameters = set() - for k in parameters.keys(): - if k not in legal_inputs: - unsupported_parameters.add(k) - - return unsupported_parameters - - -def get_type_validator(wanted): - """Returns the callable used to validate a wanted type and the type name. - - :arg wanted: String or callable. If a string, get the corresponding - validation function from DEFAULT_TYPE_VALIDATORS. If callable, - get the name of the custom callable and return that for the type_checker. - - :returns: Tuple of callable function or None, and a string that is the name - of the wanted type. - """ + new_value = set() # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value - # Use one our our builtin validators. - if not callable(wanted): - if wanted is None: - # Default type for parameters - wanted = 'str' + if isinstance(value, Mapping): + if isinstance(value, MutableMapping): + new_value = type(value)() + else: + new_value = {} # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value - type_checker = DEFAULT_TYPE_VALIDATORS.get(wanted) + if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): + return value - # Use the custom callable for validation. - else: - type_checker = wanted - wanted = getattr(wanted, '__name__', to_native(type(wanted))) + if isinstance(value, (datetime.datetime, datetime.date)): + return value - return type_checker, wanted + raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) -def validate_elements(wanted_type, parameter, values, options_context=None, errors=None): +def _validate_elements(wanted_type, parameter, values, options_context=None, errors=None): if errors is None: - errors = [] + errors = AnsibleValidationErrorMultiple() - type_checker, wanted_element_type = get_type_validator(wanted_type) + type_checker, wanted_element_type = _get_type_validator(wanted_type) validated_parameters = [] # Get param name for strings so we can later display this value in a useful error message if needed # Only pass 'kwargs' to our checkers and ignore custom callable checkers @@ -622,11 +552,11 @@ def validate_elements(wanted_type, parameter, values, options_context=None, erro if options_context: msg += " found in '%s'" % " -> ".join(options_context) msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_element_type, to_native(e)) - errors.append(msg) + errors.append(ElementError(msg)) return validated_parameters -def validate_argument_types(argument_spec, parameters, prefix='', options_context=None, errors=None): +def _validate_argument_types(argument_spec, parameters, prefix='', options_context=None, errors=None): """Validate that parameter types match the type in the argument spec. Determine the appropriate type checker function and run each @@ -637,7 +567,7 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex :param argument_spec: Argument spec :type argument_spec: dict - :param parameters: Parameters passed to module + :param parameters: Parameters :type parameters: dict :param prefix: Name of the parent key that contains the spec. Used in the error message @@ -653,7 +583,7 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex """ if errors is None: - errors = [] + errors = AnsibleValidationErrorMultiple() for param, spec in argument_spec.items(): if param not in parameters: @@ -664,7 +594,7 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex continue wanted_type = spec.get('type') - type_checker, wanted_name = get_type_validator(wanted_type) + type_checker, wanted_name = _get_type_validator(wanted_type) # Get param name for strings so we can later display this value in a useful error message if needed # Only pass 'kwargs' to our checkers and ignore custom callable checkers kwargs = {} @@ -685,22 +615,22 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex if options_context: msg += " found in '%s'." % " -> ".join(options_context) msg += ", elements value check is supported only with 'list' type" - errors.append(msg) - parameters[param] = validate_elements(elements_wanted_type, param, elements, options_context, errors) + errors.append(ArgumentTypeError(msg)) + parameters[param] = _validate_elements(elements_wanted_type, param, elements, options_context, errors) except (TypeError, ValueError) as e: msg = "argument '%s' is of type %s" % (param, type(value)) if options_context: msg += " found in '%s'." % " -> ".join(options_context) msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e)) - errors.append(msg) + errors.append(ArgumentTypeError(msg)) -def validate_argument_values(argument_spec, parameters, options_context=None, errors=None): +def _validate_argument_values(argument_spec, parameters, options_context=None, errors=None): """Ensure all arguments have the requested values, and there are no stray arguments""" if errors is None: - errors = [] + errors = AnsibleValidationErrorMultiple() for param, spec in argument_spec.items(): choices = spec.get('choices') @@ -716,8 +646,8 @@ def validate_argument_values(argument_spec, parameters, options_context=None, er choices_str = ", ".join([to_native(c) for c in choices]) msg = "value of %s must be one or more of: %s. Got no match for: %s" % (param, choices_str, diff_list) if options_context: - msg += " found in %s" % " -> ".join(options_context) - errors.append(msg) + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + errors.append(ArgumentValueError(msg)) elif parameters[param] not in choices: # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking # the value. If we can't figure this out, module author is responsible. @@ -740,23 +670,23 @@ def validate_argument_values(argument_spec, parameters, options_context=None, er choices_str = ", ".join([to_native(c) for c in choices]) msg = "value of %s must be one of: %s, got: %s" % (param, choices_str, parameters[param]) if options_context: - msg += " found in %s" % " -> ".join(options_context) - errors.append(msg) + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + errors.append(ArgumentValueError(msg)) else: msg = "internal error: choices for argument %s are not iterable: %s" % (param, choices) if options_context: - msg += " found in %s" % " -> ".join(options_context) - errors.append(msg) + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + errors.append(ArgumentTypeError(msg)) -def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None, errors=None, no_log_values=None, unsupported_parameters=None): +def _validate_sub_spec(argument_spec, parameters, prefix='', options_context=None, errors=None, no_log_values=None, unsupported_parameters=None): """Validate sub argument spec. This function is recursive.""" if options_context is None: options_context = [] if errors is None: - errors = [] + errors = AnsibleValidationErrorMultiple() if no_log_values is None: no_log_values = set() @@ -766,11 +696,11 @@ def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None for param, value in argument_spec.items(): wanted = value.get('type') - if wanted == 'dict' or (wanted == 'list' and value.get('elements', '') == dict): + if wanted == 'dict' or (wanted == 'list' and value.get('elements', '') == 'dict'): sub_spec = value.get('options') if value.get('apply_defaults', False): if sub_spec is not None: - if parameters.get(value) is None: + if parameters.get(param) is None: parameters[param] = {} else: continue @@ -788,7 +718,7 @@ def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None for idx, sub_parameters in enumerate(elements): if not isinstance(sub_parameters, dict): - errors.append("value of '%s' must be of type dict or list of dicts" % param) + errors.append(SubParameterTypeError("value of '%s' must be of type dict or list of dicts" % param)) # Set prefix for warning messages new_prefix = prefix + param @@ -799,53 +729,159 @@ def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None no_log_values.update(set_fallbacks(sub_spec, sub_parameters)) alias_warnings = [] + alias_deprecations = [] try: - options_aliases, legal_inputs = handle_aliases(sub_spec, sub_parameters, alias_warnings) + options_aliases = _handle_aliases(sub_spec, sub_parameters, alias_warnings, alias_deprecations) except (TypeError, ValueError) as e: options_aliases = {} - legal_inputs = None - errors.append(to_native(e)) + errors.append(AliasError(to_native(e))) for option, alias in alias_warnings: warn('Both option %s and its alias %s are set.' % (option, alias)) - no_log_values.update(list_no_log_values(sub_spec, sub_parameters)) + try: + no_log_values.update(_list_no_log_values(sub_spec, sub_parameters)) + except TypeError as te: + errors.append(NoLogError(to_native(te))) - if legal_inputs is None: - legal_inputs = list(options_aliases.keys()) + list(sub_spec.keys()) - unsupported_parameters.update(get_unsupported_parameters(sub_spec, sub_parameters, legal_inputs)) + legal_inputs = _get_legal_inputs(sub_spec, sub_parameters, options_aliases) + unsupported_parameters.update(_get_unsupported_parameters(sub_spec, sub_parameters, legal_inputs, options_context)) try: - check_mutually_exclusive(value.get('mutually_exclusive'), sub_parameters) + check_mutually_exclusive(value.get('mutually_exclusive'), sub_parameters, options_context) except TypeError as e: - errors.append(to_native(e)) + errors.append(MutuallyExclusiveError(to_native(e))) - no_log_values.update(set_defaults(sub_spec, sub_parameters, False)) + no_log_values.update(_set_defaults(sub_spec, sub_parameters, False)) try: - check_required_arguments(sub_spec, sub_parameters) + check_required_arguments(sub_spec, sub_parameters, options_context) except TypeError as e: - errors.append(to_native(e)) + errors.append(RequiredError(to_native(e))) - validate_argument_types(sub_spec, sub_parameters, new_prefix, options_context, errors=errors) - validate_argument_values(sub_spec, sub_parameters, options_context, errors=errors) + _validate_argument_types(sub_spec, sub_parameters, new_prefix, options_context, errors=errors) + _validate_argument_values(sub_spec, sub_parameters, options_context, errors=errors) - checks = [ - (check_required_together, 'required_together'), - (check_required_one_of, 'required_one_of'), - (check_required_if, 'required_if'), - (check_required_by, 'required_by'), - ] - - for check in checks: + for check in _ADDITIONAL_CHECKS: try: - check[0](value.get(check[1]), parameters) + check['func'](value.get(check['attr']), sub_parameters, options_context) except TypeError as e: - errors.append(to_native(e)) + errors.append(check['err'](to_native(e))) - no_log_values.update(set_defaults(sub_spec, sub_parameters)) + no_log_values.update(_set_defaults(sub_spec, sub_parameters)) # Handle nested specs - validate_sub_spec(sub_spec, sub_parameters, new_prefix, options_context, errors, no_log_values, unsupported_parameters) + _validate_sub_spec(sub_spec, sub_parameters, new_prefix, options_context, errors, no_log_values, unsupported_parameters) options_context.pop() + + +def env_fallback(*args, **kwargs): + """Load value from environment variable""" + + for arg in args: + if arg in os.environ: + return os.environ[arg] + raise AnsibleFallbackNotFound + + +def set_fallbacks(argument_spec, parameters): + no_log_values = set() + for param, value in argument_spec.items(): + fallback = value.get('fallback', (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if param not in parameters and fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs) + except AnsibleFallbackNotFound: + continue + else: + if value.get('no_log', False) and fallback_value: + no_log_values.add(fallback_value) + parameters[param] = fallback_value + + return no_log_values + + +def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()): + """ Sanitize the keys in a container object by removing no_log values from key names. + + This is a companion function to the `remove_values()` function. Similar to that function, + we make use of deferred_removals to avoid hitting maximum recursion depth in cases of + large data structures. + + :param obj: The container object to sanitize. Non-container objects are returned unmodified. + :param no_log_strings: A set of string values we do not want logged. + :param ignore_keys: A set of string values of keys to not sanitize. + + :returns: An object with sanitized keys. + """ + + deferred_removals = deque() + + no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] + new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals) + + while deferred_removals: + old_data, new_data = deferred_removals.popleft() + + if isinstance(new_data, Mapping): + for old_key, old_elem in old_data.items(): + if old_key in ignore_keys or old_key.startswith('_ansible'): + new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) + else: + # Sanitize the old key. We take advantage of the sanitizing code in + # _remove_values_conditions() rather than recreating it here. + new_key = _remove_values_conditions(old_key, no_log_strings, None) + new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) + else: + for elem in old_data: + new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals) + if isinstance(new_data, MutableSequence): + new_data.append(new_elem) + elif isinstance(new_data, MutableSet): + new_data.add(new_elem) + else: + raise TypeError('Unknown container type encountered when removing private values from keys') + + return new_value + + +def remove_values(value, no_log_strings): + """ Remove strings in no_log_strings from value. If value is a container + type, then remove a lot more. + + Use of deferred_removals exists, rather than a pure recursive solution, + because of the potential to hit the maximum recursion depth when dealing with + large amounts of data (see issue #24560). + """ + + deferred_removals = deque() + + no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] + new_value = _remove_values_conditions(value, no_log_strings, deferred_removals) + + while deferred_removals: + old_data, new_data = deferred_removals.popleft() + if isinstance(new_data, Mapping): + for old_key, old_elem in old_data.items(): + new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals) + new_data[old_key] = new_elem + else: + for elem in old_data: + new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals) + if isinstance(new_data, MutableSequence): + new_data.append(new_elem) + elif isinstance(new_data, MutableSet): + new_data.add(new_elem) + else: + raise TypeError('Unknown container type encountered when removing private values from output') + + return new_value diff --git a/lib/ansible/module_utils/common/validation.py b/lib/ansible/module_utils/common/validation.py index df40905987..d8c74e0232 100644 --- a/lib/ansible/module_utils/common/validation.py +++ b/lib/ansible/module_utils/common/validation.py @@ -39,7 +39,35 @@ def count_terms(terms, parameters): return len(set(terms).intersection(parameters)) -def check_mutually_exclusive(terms, parameters): +def safe_eval(value, locals=None, include_exceptions=False): + # do not allow method calls to modules + if not isinstance(value, string_types): + # already templated to a datavaluestructure, perhaps? + if include_exceptions: + return (value, None) + return value + if re.search(r'\w\.\w+\(', value): + if include_exceptions: + return (value, None) + return value + # do not allow imports + if re.search(r'import \w+', value): + if include_exceptions: + return (value, None) + return value + try: + result = literal_eval(value) + if include_exceptions: + return (result, None) + else: + return result + except Exception as e: + if include_exceptions: + return (value, e) + return value + + +def check_mutually_exclusive(terms, parameters, options_context=None): """Check mutually exclusive terms against argument parameters Accepts a single list or list of lists that are groups of terms that should be @@ -63,12 +91,14 @@ def check_mutually_exclusive(terms, parameters): if results: full_list = ['|'.join(check) for check in results] msg = "parameters are mutually exclusive: %s" % ', '.join(full_list) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) raise TypeError(to_native(msg)) return results -def check_required_one_of(terms, parameters): +def check_required_one_of(terms, parameters, options_context=None): """Check each list of terms to ensure at least one exists in the given module parameters @@ -93,12 +123,14 @@ def check_required_one_of(terms, parameters): if results: for term in results: msg = "one of the following is required: %s" % ', '.join(term) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) raise TypeError(to_native(msg)) return results -def check_required_together(terms, parameters): +def check_required_together(terms, parameters, options_context=None): """Check each list of terms to ensure every parameter in each list exists in the given parameters @@ -125,12 +157,14 @@ def check_required_together(terms, parameters): if results: for term in results: msg = "parameters are required together: %s" % ', '.join(term) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) raise TypeError(to_native(msg)) return results -def check_required_by(requirements, parameters): +def check_required_by(requirements, parameters, options_context=None): """For each key in requirements, check the corresponding list to see if they exist in parameters @@ -161,12 +195,14 @@ def check_required_by(requirements, parameters): for key, missing in result.items(): if len(missing) > 0: msg = "missing parameter(s) required by '%s': %s" % (key, ', '.join(missing)) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) raise TypeError(to_native(msg)) return result -def check_required_arguments(argument_spec, parameters): +def check_required_arguments(argument_spec, parameters, options_context=None): """Check all paramaters in argument_spec and return a list of parameters that are required but not present in parameters @@ -190,12 +226,14 @@ def check_required_arguments(argument_spec, parameters): if missing: msg = "missing required arguments: %s" % ", ".join(sorted(missing)) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) raise TypeError(to_native(msg)) return missing -def check_required_if(requirements, parameters): +def check_required_if(requirements, parameters, options_context=None): """Check parameters that are conditionally required Raises TypeError if the check fails @@ -272,6 +310,8 @@ def check_required_if(requirements, parameters): for missing in results: msg = "%s is %s but %s of the following are missing: %s" % ( missing['parameter'], missing['value'], missing['requires'], ', '.join(missing['missing'])) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) raise TypeError(to_native(msg)) return results @@ -304,34 +344,6 @@ def check_missing_parameters(parameters, required_parameters=None): return missing_params -def safe_eval(value, locals=None, include_exceptions=False): - # do not allow method calls to modules - if not isinstance(value, string_types): - # already templated to a datavaluestructure, perhaps? - if include_exceptions: - return (value, None) - return value - if re.search(r'\w\.\w+\(', value): - if include_exceptions: - return (value, None) - return value - # do not allow imports - if re.search(r'import \w+', value): - if include_exceptions: - return (value, None) - return value - try: - result = literal_eval(value) - if include_exceptions: - return (result, None) - else: - return result - except Exception as e: - if include_exceptions: - return (value, e) - return value - - # FIXME: The param and prefix parameters here are coming from AnsibleModule._check_type_string() # which is using those for the warning messaged based on string conversion warning settings. # Not sure how to deal with that here since we don't have config state to query. diff --git a/lib/ansible/module_utils/errors.py b/lib/ansible/module_utils/errors.py new file mode 100644 index 0000000000..953c78dce2 --- /dev/null +++ b/lib/ansible/module_utils/errors.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +class AnsibleFallbackNotFound(Exception): + """Fallback validator was not found""" + + +class AnsibleValidationError(Exception): + """Single argument spec validation error""" + + def __init__(self, message): + super(AnsibleValidationError, self).__init__(message) + self.error_message = message + + @property + def msg(self): + return self.args[0] + + +class AnsibleValidationErrorMultiple(AnsibleValidationError): + """Multiple argument spec validation errors""" + + def __init__(self, errors=None): + self.errors = errors[:] if errors else [] + + def __getitem__(self, key): + return self.errors[key] + + def __setitem__(self, key, value): + self.errors[key] = value + + def __delitem__(self, key): + del self.errors[key] + + @property + def msg(self): + return self.errors[0].args[0] + + @property + def messages(self): + return [err.msg for err in self.errors] + + def append(self, error): + self.errors.append(error) + + def extend(self, errors): + self.errors.extend(errors) + + +class AliasError(AnsibleValidationError): + """Error handling aliases""" + + +class ArgumentTypeError(AnsibleValidationError): + """Error with parameter type""" + + +class ArgumentValueError(AnsibleValidationError): + """Error with parameter value""" + + +class ElementError(AnsibleValidationError): + """Error when validating elements""" + + +class MutuallyExclusiveError(AnsibleValidationError): + """Mutually exclusive parameters were supplied""" + + +class NoLogError(AnsibleValidationError): + """Error converting no_log values""" + + +class RequiredByError(AnsibleValidationError): + """Error with parameters that are required by other parameters""" + + +class RequiredDefaultError(AnsibleValidationError): + """A required parameter was assigned a default value""" + + +class RequiredError(AnsibleValidationError): + """Missing a required parameter""" + + +class RequiredIfError(AnsibleValidationError): + """Error with conditionally required parameters""" + + +class RequiredOneOfError(AnsibleValidationError): + """Error with parameters where at least one is required""" + + +class RequiredTogetherError(AnsibleValidationError): + """Error with parameters that are required together""" + + +class SubParameterTypeError(AnsibleValidationError): + """Incorrect type for subparameter""" + + +class UnsupportedError(AnsibleValidationError): + """Unsupported parameters were supplied""" diff --git a/lib/ansible/plugins/action/validate_argument_spec.py b/lib/ansible/plugins/action/validate_argument_spec.py index 9923de9fcb..4162005202 100644 --- a/lib/ansible/plugins/action/validate_argument_spec.py +++ b/lib/ansible/plugins/action/validate_argument_spec.py @@ -8,6 +8,7 @@ from ansible.errors import AnsibleError from ansible.plugins.action import ActionBase from ansible.module_utils.six import iteritems, string_types from ansible.module_utils.common.arg_spec import ArgumentSpecValidator +from ansible.module_utils.errors import AnsibleValidationErrorMultiple class ActionModule(ActionBase): @@ -82,13 +83,14 @@ class ActionModule(ActionBase): args_from_vars = self.get_args_from_task_vars(argument_spec_data, task_vars) provided_arguments.update(args_from_vars) - validator = ArgumentSpecValidator(argument_spec_data, provided_arguments) + validator = ArgumentSpecValidator(argument_spec_data) + validation_result = validator.validate(provided_arguments) - if not validator.validate(): + if validation_result.error_messages: result['failed'] = True - result['msg'] = 'Validation of arguments failed:\n%s' % '\n'.join(validator.error_messages) + result['msg'] = 'Validation of arguments failed:\n%s' % '\n'.join(validation_result.error_messages) result['argument_spec_data'] = argument_spec_data - result['argument_errors'] = validator.error_messages + result['argument_errors'] = validation_result.error_messages return result result['changed'] = False diff --git a/test/units/executor/module_common/test_recursive_finder.py b/test/units/executor/module_common/test_recursive_finder.py index 97a2e43fb0..074dfb2f19 100644 --- a/test/units/executor/module_common/test_recursive_finder.py +++ b/test/units/executor/module_common/test_recursive_finder.py @@ -29,7 +29,6 @@ from io import BytesIO import ansible.errors from ansible.executor.module_common import recursive_finder -from ansible.module_utils.six import PY2 # These are the modules that are brought in by module_utils/basic.py This may need to be updated @@ -58,12 +57,14 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/__init__.py', 'ansible/module_utils/common/text/formatters.py', 'ansible/module_utils/common/validation.py', 'ansible/module_utils/common/_utils.py', + 'ansible/module_utils/common/arg_spec.py', 'ansible/module_utils/compat/__init__.py', 'ansible/module_utils/compat/_selectors2.py', 'ansible/module_utils/compat/selectors.py', 'ansible/module_utils/compat/selinux.py', 'ansible/module_utils/distro/__init__.py', 'ansible/module_utils/distro/_distro.py', + 'ansible/module_utils/errors.py', 'ansible/module_utils/parsing/__init__.py', 'ansible/module_utils/parsing/convert_bool.py', 'ansible/module_utils/pycompat24.py', diff --git a/test/units/module_utils/basic/test_argument_spec.py b/test/units/module_utils/basic/test_argument_spec.py index 36a699d66c..1b3f703521 100644 --- a/test/units/module_utils/basic/test_argument_spec.py +++ b/test/units/module_utils/basic/test_argument_spec.py @@ -84,9 +84,9 @@ INVALID_SPECS = ( ({'arg': {'type': 'list', 'elements': MOCK_VALIDATOR_FAIL}}, {'arg': [1, "bad"]}, "bad conversion"), # unknown parameter ({'arg': {'type': 'int'}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'}, - 'Unsupported parameters for (ansible_unittest) module: other Supported parameters include: arg'), + 'Unsupported parameters for (ansible_unittest) module: other. Supported parameters include: arg.'), ({'arg': {'type': 'int', 'aliases': ['argument']}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'}, - 'Unsupported parameters for (ansible_unittest) module: other Supported parameters include: arg (argument)'), + 'Unsupported parameters for (ansible_unittest) module: other. Supported parameters include: arg (argument).'), # parameter is required ({'arg': {'required': True}}, {}, 'missing required arguments: arg'), ) @@ -496,7 +496,7 @@ class TestComplexOptions: # Missing required option ({'foobar': [{}]}, 'missing required arguments: foo found in foobar'), # Invalid option - ({'foobar': [{"foo": "hello", "bam": "good", "invalid": "bad"}]}, 'module: invalid found in foobar. Supported parameters include'), + ({'foobar': [{"foo": "hello", "bam": "good", "invalid": "bad"}]}, 'module: foobar.invalid. Supported parameters include'), # Mutually exclusive options found ({'foobar': [{"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}]}, 'parameters are mutually exclusive: bam|bam1 found in foobar'), @@ -520,7 +520,7 @@ class TestComplexOptions: ({'foobar': {}}, 'missing required arguments: foo found in foobar'), # Invalid option ({'foobar': {"foo": "hello", "bam": "good", "invalid": "bad"}}, - 'module: invalid found in foobar. Supported parameters include'), + 'module: foobar.invalid. Supported parameters include'), # Mutually exclusive options found ({'foobar': {"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}}, 'parameters are mutually exclusive: bam|bam1 found in foobar'), diff --git a/test/units/module_utils/common/arg_spec/test__add_error.py b/test/units/module_utils/common/arg_spec/test__add_error.py deleted file mode 100644 index bf98580398..0000000000 --- a/test/units/module_utils/common/arg_spec/test__add_error.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2021 Ansible Project -# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) - -from __future__ import absolute_import, division, print_function -__metaclass__ = type - -import pytest - -from ansible.module_utils.common.arg_spec import ArgumentSpecValidator - - -def test_add_sequence(): - v = ArgumentSpecValidator({}, {}) - errors = [ - 'one error', - 'another error', - ] - v._add_error(errors) - - assert v.error_messages == errors - - -def test_invalid_error_message(): - v = ArgumentSpecValidator({}, {}) - - with pytest.raises(ValueError, match="Error messages must be a string or sequence not a"): - v._add_error(None) diff --git a/test/units/module_utils/common/arg_spec/test_aliases.py b/test/units/module_utils/common/arg_spec/test_aliases.py index 4dcdf938a7..f4c96c74e5 100644 --- a/test/units/module_utils/common/arg_spec/test_aliases.py +++ b/test/units/module_utils/common/arg_spec/test_aliases.py @@ -7,10 +7,11 @@ __metaclass__ = type import pytest -from ansible.module_utils.common.arg_spec import ArgumentSpecValidator +from ansible.module_utils.errors import AnsibleValidationError, AnsibleValidationErrorMultiple +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult from ansible.module_utils.common.warnings import get_deprecation_messages, get_warning_messages -# id, argument spec, parameters, expected parameters, expected pass/fail, error, deprecation, warning +# id, argument spec, parameters, expected parameters, deprecation, warning ALIAS_TEST_CASES = [ ( "alias", @@ -20,29 +21,6 @@ ALIAS_TEST_CASES = [ 'dir': '/tmp', 'path': '/tmp', }, - True, - "", - "", - "", - ), - ( - "alias-invalid", - {'path': {'aliases': 'bad'}}, - {}, - {'path': None}, - False, - "internal error: aliases must be a list or tuple", - "", - "", - ), - ( - # This isn't related to aliases, but it exists in the alias handling code - "default-and-required", - {'name': {'default': 'ray', 'required': True}}, - {}, - {'name': 'ray'}, - False, - "internal error: required and default are mutually exclusive for name", "", "", ), @@ -58,10 +36,8 @@ ALIAS_TEST_CASES = [ 'directory': '/tmp', 'path': '/tmp', }, - True, - "", "", - "Both option path and its alias directory are set", + {'alias': 'directory', 'option': 'path'}, ), ( "deprecated-alias", @@ -81,39 +57,66 @@ ALIAS_TEST_CASES = [ 'path': '/tmp', 'not_yo_path': '/tmp', }, - True, - "", - "Alias 'not_yo_path' is deprecated.", + {'version': '1.7', 'date': None, 'collection_name': None, 'name': 'not_yo_path'}, "", ) ] +# id, argument spec, parameters, expected parameters, error +ALIAS_TEST_CASES_INVALID = [ + ( + "alias-invalid", + {'path': {'aliases': 'bad'}}, + {}, + {'path': None}, + "internal error: aliases must be a list or tuple", + ), + ( + # This isn't related to aliases, but it exists in the alias handling code + "default-and-required", + {'name': {'default': 'ray', 'required': True}}, + {}, + {'name': 'ray'}, + "internal error: required and default are mutually exclusive for name", + ), +] + + @pytest.mark.parametrize( - ('arg_spec', 'parameters', 'expected', 'passfail', 'error', 'deprecation', 'warning'), - ((i[1], i[2], i[3], i[4], i[5], i[6], i[7]) for i in ALIAS_TEST_CASES), + ('arg_spec', 'parameters', 'expected', 'deprecation', 'warning'), + ((i[1:]) for i in ALIAS_TEST_CASES), ids=[i[0] for i in ALIAS_TEST_CASES] ) -def test_aliases(arg_spec, parameters, expected, passfail, error, deprecation, warning): - v = ArgumentSpecValidator(arg_spec, parameters) - passed = v.validate() +def test_aliases(arg_spec, parameters, expected, deprecation, warning): + v = ArgumentSpecValidator(arg_spec) + result = v.validate(parameters) - assert passed is passfail - assert v.validated_parameters == expected + assert isinstance(result, ValidationResult) + assert result.validated_parameters == expected + assert result.error_messages == [] - if not error: - assert v.error_messages == [] + if deprecation: + assert deprecation == result._deprecations[0] else: - assert error in v.error_messages[0] + assert result._deprecations == [] - deprecations = get_deprecation_messages() - if not deprecations: - assert deprecations == () + if warning: + assert warning == result._warnings[0] else: - assert deprecation in get_deprecation_messages()[0]['msg'] + assert result._warnings == [] - warnings = get_warning_messages() - if not warning: - assert warnings == () - else: - assert warning in warnings[0] + +@pytest.mark.parametrize( + ('arg_spec', 'parameters', 'expected', 'error'), + ((i[1:]) for i in ALIAS_TEST_CASES_INVALID), + ids=[i[0] for i in ALIAS_TEST_CASES_INVALID] +) +def test_aliases_invalid(arg_spec, parameters, expected, error): + v = ArgumentSpecValidator(arg_spec) + result = v.validate(parameters) + + assert isinstance(result, ValidationResult) + assert error in result.error_messages + assert isinstance(result.errors.errors[0], AnsibleValidationError) + assert isinstance(result.errors, AnsibleValidationErrorMultiple) diff --git a/test/units/module_utils/common/arg_spec/test_module_validate.py b/test/units/module_utils/common/arg_spec/test_module_validate.py new file mode 100644 index 0000000000..14e6e1e7c7 --- /dev/null +++ b/test/units/module_utils/common/arg_spec/test_module_validate.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import ansible.module_utils.common.warnings as warnings + +from ansible.module_utils.common.arg_spec import ModuleArgumentSpecValidator, ValidationResult + + +def test_module_validate(): + arg_spec = {'name': {}} + parameters = {'name': 'larry'} + expected = {'name': 'larry'} + + v = ModuleArgumentSpecValidator(arg_spec) + result = v.validate(parameters) + + assert isinstance(result, ValidationResult) + assert result.error_messages == [] + assert result._deprecations == [] + assert result._warnings == [] + assert result.validated_parameters == expected + + +def test_module_alias_deprecations_warnings(): + arg_spec = { + 'path': { + 'aliases': ['source', 'src', 'flamethrower'], + 'deprecated_aliases': [{'name': 'flamethrower', 'date': '2020-03-04'}], + }, + } + parameters = {'flamethrower': '/tmp', 'source': '/tmp'} + expected = { + 'path': '/tmp', + 'flamethrower': '/tmp', + 'source': '/tmp', + } + + v = ModuleArgumentSpecValidator(arg_spec) + result = v.validate(parameters) + + assert result.validated_parameters == expected + assert result._deprecations == [ + { + 'collection_name': None, + 'date': '2020-03-04', + 'name': 'flamethrower', + 'version': None, + } + ] + assert "Alias 'flamethrower' is deprecated" in warnings._global_deprecations[0]['msg'] + assert result._warnings == [{'alias': 'flamethrower', 'option': 'path'}] + assert "Both option path and its alias flamethrower are set" in warnings._global_warnings[0] diff --git a/test/units/module_utils/common/arg_spec/test_sub_spec.py b/test/units/module_utils/common/arg_spec/test_sub_spec.py index eaa775fdf5..a6e75754eb 100644 --- a/test/units/module_utils/common/arg_spec/test_sub_spec.py +++ b/test/units/module_utils/common/arg_spec/test_sub_spec.py @@ -5,7 +5,7 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type -from ansible.module_utils.common.arg_spec import ArgumentSpecValidator +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult def test_sub_spec(): @@ -39,12 +39,12 @@ def test_sub_spec(): } } - v = ArgumentSpecValidator(arg_spec, parameters) - passed = v.validate() + v = ArgumentSpecValidator(arg_spec) + result = v.validate(parameters) - assert passed is True - assert v.error_messages == [] - assert v.validated_parameters == expected + assert isinstance(result, ValidationResult) + assert result.validated_parameters == expected + assert result.error_messages == [] def test_nested_sub_spec(): @@ -98,9 +98,9 @@ def test_nested_sub_spec(): } } - v = ArgumentSpecValidator(arg_spec, parameters) - passed = v.validate() + v = ArgumentSpecValidator(arg_spec) + result = v.validate(parameters) - assert passed is True - assert v.error_messages == [] - assert v.validated_parameters == expected + assert isinstance(result, ValidationResult) + assert result.validated_parameters == expected + assert result.error_messages == [] diff --git a/test/units/module_utils/common/arg_spec/test_validate_invalid.py b/test/units/module_utils/common/arg_spec/test_validate_invalid.py index 99ab62b18d..5384ee225a 100644 --- a/test/units/module_utils/common/arg_spec/test_validate_invalid.py +++ b/test/units/module_utils/common/arg_spec/test_validate_invalid.py @@ -7,17 +7,19 @@ __metaclass__ = type import pytest -from ansible.module_utils.common.arg_spec import ArgumentSpecValidator +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult +from ansible.module_utils.errors import AnsibleValidationErrorMultiple from ansible.module_utils.six import PY2 -# Each item is id, argument_spec, parameters, expected, error test string +# Each item is id, argument_spec, parameters, expected, unsupported parameters, error test string INVALID_SPECS = [ ( 'invalid-list', {'packages': {'type': 'list'}}, {'packages': {'key': 'value'}}, {'packages': {'key': 'value'}}, + set(), "unable to convert to list: <class 'dict'> cannot be converted to a list", ), ( @@ -25,6 +27,7 @@ INVALID_SPECS = [ {'users': {'type': 'dict'}}, {'users': ['one', 'two']}, {'users': ['one', 'two']}, + set(), "unable to convert to dict: <class 'list'> cannot be converted to a dict", ), ( @@ -32,6 +35,7 @@ INVALID_SPECS = [ {'bool': {'type': 'bool'}}, {'bool': {'k': 'v'}}, {'bool': {'k': 'v'}}, + set(), "unable to convert to bool: <class 'dict'> cannot be converted to a bool", ), ( @@ -39,6 +43,7 @@ INVALID_SPECS = [ {'float': {'type': 'float'}}, {'float': 'hello'}, {'float': 'hello'}, + set(), "unable to convert to float: <class 'str'> cannot be converted to a float", ), ( @@ -46,6 +51,7 @@ INVALID_SPECS = [ {'bytes': {'type': 'bytes'}}, {'bytes': 'one'}, {'bytes': 'one'}, + set(), "unable to convert to bytes: <class 'str'> cannot be converted to a Byte value", ), ( @@ -53,6 +59,7 @@ INVALID_SPECS = [ {'bits': {'type': 'bits'}}, {'bits': 'one'}, {'bits': 'one'}, + set(), "unable to convert to bits: <class 'str'> cannot be converted to a Bit value", ), ( @@ -60,6 +67,7 @@ INVALID_SPECS = [ {'some_json': {'type': 'jsonarg'}}, {'some_json': set()}, {'some_json': set()}, + set(), "unable to convert to jsonarg: <class 'set'> cannot be converted to a json string", ), ( @@ -74,13 +82,15 @@ INVALID_SPECS = [ 'badparam': '', 'another': '', }, - "Unsupported parameters: another, badparam", + set(('another', 'badparam')), + "another, badparam. Supported parameters include: name.", ), ( 'invalid-elements', {'numbers': {'type': 'list', 'elements': 'int'}}, {'numbers': [55, 33, 34, {'key': 'value'}]}, {'numbers': [55, 33, 34]}, + set(), "Elements value for option 'numbers' is of type <class 'dict'> and we were unable to convert to int: <class 'dict'> cannot be converted to an int" ), ( @@ -88,23 +98,29 @@ INVALID_SPECS = [ {'req': {'required': True}}, {}, {'req': None}, + set(), "missing required arguments: req" ) ] @pytest.mark.parametrize( - ('arg_spec', 'parameters', 'expected', 'error'), - ((i[1], i[2], i[3], i[4]) for i in INVALID_SPECS), + ('arg_spec', 'parameters', 'expected', 'unsupported', 'error'), + (i[1:] for i in INVALID_SPECS), ids=[i[0] for i in INVALID_SPECS] ) -def test_invalid_spec(arg_spec, parameters, expected, error): - v = ArgumentSpecValidator(arg_spec, parameters) - passed = v.validate() +def test_invalid_spec(arg_spec, parameters, expected, unsupported, error): + v = ArgumentSpecValidator(arg_spec) + result = v.validate(parameters) + + with pytest.raises(AnsibleValidationErrorMultiple) as exc_info: + raise result.errors if PY2: error = error.replace('class', 'type') - assert error in v.error_messages[0] - assert v.validated_parameters == expected - assert passed is False + assert isinstance(result, ValidationResult) + assert error in exc_info.value.msg + assert error in result.error_messages[0] + assert result.unsupported_parameters == unsupported + assert result.validated_parameters == expected diff --git a/test/units/module_utils/common/arg_spec/test_validate_valid.py b/test/units/module_utils/common/arg_spec/test_validate_valid.py index 0b139aff5d..b35b856f3a 100644 --- a/test/units/module_utils/common/arg_spec/test_validate_valid.py +++ b/test/units/module_utils/common/arg_spec/test_validate_valid.py @@ -7,45 +7,53 @@ __metaclass__ = type import pytest -from ansible.module_utils.common.arg_spec import ArgumentSpecValidator +import ansible.module_utils.common.warnings as warnings -# Each item is id, argument_spec, parameters, expected +from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult + +# Each item is id, argument_spec, parameters, expected, valid parameter names VALID_SPECS = [ ( 'str-no-type-specified', {'name': {}}, {'name': 'rey'}, {'name': 'rey'}, + set(('name',)), ), ( 'str', {'name': {'type': 'str'}}, {'name': 'rey'}, {'name': 'rey'}, + set(('name',)), ), ( 'str-convert', {'name': {'type': 'str'}}, {'name': 5}, {'name': '5'}, + set(('name',)), ), ( 'list', {'packages': {'type': 'list'}}, {'packages': ['vim', 'python']}, {'packages': ['vim', 'python']}, + set(('packages',)), ), ( 'list-comma-string', {'packages': {'type': 'list'}}, {'packages': 'vim,python'}, {'packages': ['vim', 'python']}, + set(('packages',)), ), ( 'list-comma-string-space', {'packages': {'type': 'list'}}, {'packages': 'vim, python'}, {'packages': ['vim', ' python']}, + set(('packages',)), ), ( 'dict', @@ -64,6 +72,7 @@ VALID_SPECS = [ 'last': 'skywalker', } }, + set(('user',)), ), ( 'dict-k=v', @@ -76,6 +85,7 @@ VALID_SPECS = [ 'last': 'skywalker', } }, + set(('user',)), ), ( 'dict-k=v-spaces', @@ -88,6 +98,7 @@ VALID_SPECS = [ 'last': 'skywalker', } }, + set(('user',)), ), ( 'bool', @@ -103,6 +114,7 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'bool-ints', @@ -118,6 +130,7 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'bool-true-false', @@ -133,6 +146,7 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'bool-yes-no', @@ -148,6 +162,7 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'bool-y-n', @@ -163,6 +178,7 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'bool-on-off', @@ -178,6 +194,7 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'bool-1-0', @@ -193,6 +210,7 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'bool-float', @@ -208,89 +226,112 @@ VALID_SPECS = [ 'enabled': True, 'disabled': False, }, + set(('enabled', 'disabled')), ), ( 'float', {'digit': {'type': 'float'}}, {'digit': 3.14159}, {'digit': 3.14159}, + set(('digit',)), ), ( 'float-str', {'digit': {'type': 'float'}}, {'digit': '3.14159'}, {'digit': 3.14159}, + set(('digit',)), ), ( 'path', {'path': {'type': 'path'}}, {'path': '~/bin'}, {'path': '/home/ansible/bin'}, + set(('path',)), ), ( 'raw', {'raw': {'type': 'raw'}}, {'raw': 0x644}, {'raw': 0x644}, + set(('raw',)), ), ( 'bytes', {'bytes': {'type': 'bytes'}}, {'bytes': '2K'}, {'bytes': 2048}, + set(('bytes',)), ), ( 'bits', {'bits': {'type': 'bits'}}, {'bits': '1Mb'}, {'bits': 1048576}, + set(('bits',)), ), ( 'jsonarg', {'some_json': {'type': 'jsonarg'}}, {'some_json': '{"users": {"bob": {"role": "accountant"}}}'}, {'some_json': '{"users": {"bob": {"role": "accountant"}}}'}, + set(('some_json',)), ), ( 'jsonarg-list', {'some_json': {'type': 'jsonarg'}}, {'some_json': ['one', 'two']}, {'some_json': '["one", "two"]'}, + set(('some_json',)), ), ( 'jsonarg-dict', {'some_json': {'type': 'jsonarg'}}, {'some_json': {"users": {"bob": {"role": "accountant"}}}}, {'some_json': '{"users": {"bob": {"role": "accountant"}}}'}, + set(('some_json',)), ), ( 'defaults', {'param': {'default': 'DEFAULT'}}, {}, {'param': 'DEFAULT'}, + set(('param',)), ), ( 'elements', {'numbers': {'type': 'list', 'elements': 'int'}}, {'numbers': [55, 33, 34, '22']}, {'numbers': [55, 33, 34, 22]}, + set(('numbers',)), ), + ( + 'aliases', + {'src': {'aliases': ['path', 'source']}}, + {'src': '/tmp'}, + {'src': '/tmp'}, + set(('src (path, source)',)), + ) ] @pytest.mark.parametrize( - ('arg_spec', 'parameters', 'expected'), - ((i[1], i[2], i[3]) for i in VALID_SPECS), + ('arg_spec', 'parameters', 'expected', 'valid_params'), + (i[1:] for i in VALID_SPECS), ids=[i[0] for i in VALID_SPECS] ) -def test_valid_spec(arg_spec, parameters, expected, mocker): - +def test_valid_spec(arg_spec, parameters, expected, valid_params, mocker): mocker.patch('ansible.module_utils.common.validation.os.path.expanduser', return_value='/home/ansible/bin') mocker.patch('ansible.module_utils.common.validation.os.path.expandvars', return_value='/home/ansible/bin') - v = ArgumentSpecValidator(arg_spec, parameters) - passed = v.validate() + v = ArgumentSpecValidator(arg_spec) + result = v.validate(parameters) + + assert isinstance(result, ValidationResult) + assert result.validated_parameters == expected + assert result.unsupported_parameters == set() + assert result.error_messages == [] + assert v._valid_parameter_names == valid_params - assert v.validated_parameters == expected - assert v.error_messages == [] - assert passed is True + # Again to check caching + assert v._valid_parameter_names == valid_params diff --git a/test/units/module_utils/common/parameters/test_check_arguments.py b/test/units/module_utils/common/parameters/test_check_arguments.py index 48bbfe7d71..5311217930 100644 --- a/test/units/module_utils/common/parameters/test_check_arguments.py +++ b/test/units/module_utils/common/parameters/test_check_arguments.py @@ -8,7 +8,7 @@ __metaclass__ = type import pytest -from ansible.module_utils.common.parameters import get_unsupported_parameters +from ansible.module_utils.common.parameters import _get_unsupported_parameters @pytest.fixture @@ -19,32 +19,6 @@ def argument_spec(): } -def mock_handle_aliases(*args): - aliases = {} - legal_inputs = [ - '_ansible_check_mode', - '_ansible_debug', - '_ansible_diff', - '_ansible_keep_remote_files', - '_ansible_module_name', - '_ansible_no_log', - '_ansible_remote_tmp', - '_ansible_selinux_special_fs', - '_ansible_shell_executable', - '_ansible_socket', - '_ansible_string_conversion_action', - '_ansible_syslog_facility', - '_ansible_tmpdir', - '_ansible_verbosity', - '_ansible_version', - 'state', - 'status', - 'enabled', - ] - - return aliases, legal_inputs - - @pytest.mark.parametrize( ('module_parameters', 'legal_inputs', 'expected'), ( @@ -59,7 +33,6 @@ def mock_handle_aliases(*args): ) ) def test_check_arguments(argument_spec, module_parameters, legal_inputs, expected, mocker): - mocker.patch('ansible.module_utils.common.parameters.handle_aliases', side_effect=mock_handle_aliases) - result = get_unsupported_parameters(argument_spec, module_parameters, legal_inputs) + result = _get_unsupported_parameters(argument_spec, module_parameters, legal_inputs) assert result == expected diff --git a/test/units/module_utils/common/parameters/test_handle_aliases.py b/test/units/module_utils/common/parameters/test_handle_aliases.py index bc88437fba..e20a88824b 100644 --- a/test/units/module_utils/common/parameters/test_handle_aliases.py +++ b/test/units/module_utils/common/parameters/test_handle_aliases.py @@ -8,27 +8,9 @@ __metaclass__ = type import pytest -from ansible.module_utils.common.parameters import handle_aliases +from ansible.module_utils.common.parameters import _handle_aliases from ansible.module_utils._text import to_native -DEFAULT_LEGAL_INPUTS = [ - '_ansible_check_mode', - '_ansible_debug', - '_ansible_diff', - '_ansible_keep_remote_files', - '_ansible_module_name', - '_ansible_no_log', - '_ansible_remote_tmp', - '_ansible_selinux_special_fs', - '_ansible_shell_executable', - '_ansible_socket', - '_ansible_string_conversion_action', - '_ansible_syslog_facility', - '_ansible_tmpdir', - '_ansible_verbosity', - '_ansible_version', -] - def test_handle_aliases_no_aliases(): argument_spec = { @@ -40,14 +22,9 @@ def test_handle_aliases_no_aliases(): 'path': 'bar' } - expected = ( - {}, - DEFAULT_LEGAL_INPUTS + ['name'], - ) - expected[1].sort() + expected = {} + result = _handle_aliases(argument_spec, params) - result = handle_aliases(argument_spec, params) - result[1].sort() assert expected == result @@ -63,14 +40,9 @@ def test_handle_aliases_basic(): 'nick': 'foo', } - expected = ( - {'surname': 'name', 'nick': 'name'}, - DEFAULT_LEGAL_INPUTS + ['name', 'surname', 'nick'], - ) - expected[1].sort() + expected = {'surname': 'name', 'nick': 'name'} + result = _handle_aliases(argument_spec, params) - result = handle_aliases(argument_spec, params) - result[1].sort() assert expected == result @@ -84,7 +56,7 @@ def test_handle_aliases_value_error(): } with pytest.raises(ValueError) as ve: - handle_aliases(argument_spec, params) + _handle_aliases(argument_spec, params) assert 'internal error: aliases must be a list or tuple' == to_native(ve.error) @@ -98,5 +70,5 @@ def test_handle_aliases_type_error(): } with pytest.raises(TypeError) as te: - handle_aliases(argument_spec, params) + _handle_aliases(argument_spec, params) assert 'internal error: required and default are mutually exclusive' in to_native(te.error) diff --git a/test/units/module_utils/common/parameters/test_list_deprecations.py b/test/units/module_utils/common/parameters/test_list_deprecations.py index 0a17187c04..6f0bb71a6f 100644 --- a/test/units/module_utils/common/parameters/test_list_deprecations.py +++ b/test/units/module_utils/common/parameters/test_list_deprecations.py @@ -7,7 +7,7 @@ __metaclass__ = type import pytest -from ansible.module_utils.common.parameters import list_deprecations +from ansible.module_utils.common.parameters import _list_deprecations @pytest.fixture @@ -33,7 +33,7 @@ def test_list_deprecations(): 'foo': {'old': 'value'}, 'bar': [{'old': 'value'}, {}], } - result = list_deprecations(argument_spec, params) + result = _list_deprecations(argument_spec, params) assert len(result) == 3 result.sort(key=lambda entry: entry['msg']) assert result[0]['msg'] == """Param 'bar["old"]' is deprecated. See the module docs for more information""" diff --git a/test/units/module_utils/common/parameters/test_list_no_log_values.py b/test/units/module_utils/common/parameters/test_list_no_log_values.py index 1b74055593..ac0e7353fd 100644 --- a/test/units/module_utils/common/parameters/test_list_no_log_values.py +++ b/test/units/module_utils/common/parameters/test_list_no_log_values.py @@ -7,7 +7,7 @@ __metaclass__ = type import pytest -from ansible.module_utils.common.parameters import list_no_log_values +from ansible.module_utils.common.parameters import _list_no_log_values @pytest.fixture @@ -55,12 +55,12 @@ def test_list_no_log_values_no_secrets(module_parameters): 'value': {'type': 'int'}, } expected = set() - assert expected == list_no_log_values(argument_spec, module_parameters) + assert expected == _list_no_log_values(argument_spec, module_parameters) def test_list_no_log_values(argument_spec, module_parameters): expected = set(('under', 'makeshift')) - assert expected == list_no_log_values(argument_spec(), module_parameters()) + assert expected == _list_no_log_values(argument_spec(), module_parameters()) @pytest.mark.parametrize('extra_params', [ @@ -81,7 +81,7 @@ def test_list_no_log_values_invalid_suboptions(argument_spec, module_parameters, with pytest.raises(TypeError, match=r"(Value '.*?' in the sub parameter field '.*?' must by a dict, not '.*?')" r"|(dictionary requested, could not parse JSON or key=value)"): - list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) + _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) def test_list_no_log_values_suboptions(argument_spec, module_parameters): @@ -103,7 +103,7 @@ def test_list_no_log_values_suboptions(argument_spec, module_parameters): } expected = set(('under', 'makeshift', 'bagel')) - assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) + assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) def test_list_no_log_values_sub_suboptions(argument_spec, module_parameters): @@ -136,7 +136,7 @@ def test_list_no_log_values_sub_suboptions(argument_spec, module_parameters): } expected = set(('under', 'makeshift', 'saucy', 'corporate')) - assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) + assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) def test_list_no_log_values_suboptions_list(argument_spec, module_parameters): @@ -164,7 +164,7 @@ def test_list_no_log_values_suboptions_list(argument_spec, module_parameters): } expected = set(('under', 'makeshift', 'playroom', 'luxury')) - assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) + assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) def test_list_no_log_values_sub_suboptions_list(argument_spec, module_parameters): @@ -204,7 +204,7 @@ def test_list_no_log_values_sub_suboptions_list(argument_spec, module_parameters } expected = set(('under', 'makeshift', 'playroom', 'luxury', 'basis', 'gave', 'composure', 'thumping')) - assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) + assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) @pytest.mark.parametrize('extra_params, expected', ( @@ -225,4 +225,4 @@ def test_string_suboptions_as_string(argument_spec, module_parameters, extra_par result = set(('under', 'makeshift')) result.update(expected) - assert result == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) + assert result == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) |