summaryrefslogtreecommitdiffstats
path: root/test/units/module_utils/basic/test_selinux.py
blob: 1851616b3e6a2fcd7e4ac65fa8242ba40b5b8439 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# -*- coding: utf-8 -*-
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2016 Toshio Kuratomi <tkuratomi@ansible.com>
# (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from __future__ import annotations

import errno
import json
import pytest

from unittest.mock import mock_open, patch

from ansible.module_utils import basic
from ansible.module_utils.common.text.converters import to_bytes
from ansible.module_utils.six.moves import builtins


@pytest.fixture
def no_args_module_exec():
    with patch.object(basic, '_ANSIBLE_ARGS', b'{"ANSIBLE_MODULE_ARGS": {}}'):
        yield  # we're patching the global module object, so nothing to yield


def no_args_module(selinux_enabled=None, selinux_mls_enabled=None):
    am = basic.AnsibleModule(argument_spec={})
    # just dirty-patch the wrappers on the object instance since it's short-lived
    if isinstance(selinux_enabled, bool):
        patch.object(am, 'selinux_enabled', return_value=selinux_enabled).start()
    if isinstance(selinux_mls_enabled, bool):
        patch.object(am, 'selinux_mls_enabled', return_value=selinux_mls_enabled).start()
    return am


# test AnsibleModule selinux wrapper methods
@pytest.mark.usefixtures('no_args_module_exec')
class TestSELinuxMU:
    def test_selinux_enabled(self):
        # test selinux unavailable
        # selinux unavailable, should return false
        with patch.object(basic, 'HAVE_SELINUX', False):
            assert no_args_module().selinux_enabled() is False

            # test selinux present/not-enabled
            disabled_mod = no_args_module()
            with patch.object(basic, 'selinux', create=True) as selinux:
                selinux.is_selinux_enabled.return_value = 0
                assert disabled_mod.selinux_enabled() is False

        # ensure value is cached (same answer after unpatching)
        assert disabled_mod.selinux_enabled() is False

        # and present / enabled
        with patch.object(basic, 'HAVE_SELINUX', True):
            enabled_mod = no_args_module()
            with patch.object(basic, 'selinux', create=True) as selinux:
                selinux.is_selinux_enabled.return_value = 1
                assert enabled_mod.selinux_enabled() is True
        # ensure value is cached (same answer after unpatching)
        assert enabled_mod.selinux_enabled() is True

    def test_selinux_mls_enabled(self):
        # selinux unavailable, should return false
        with patch.object(basic, 'HAVE_SELINUX', False):
            assert no_args_module().selinux_mls_enabled() is False
            # selinux disabled, should return false
            with patch.object(basic, 'selinux', create=True) as selinux:
                selinux.is_selinux_mls_enabled.return_value = 0
                assert no_args_module(selinux_enabled=False).selinux_mls_enabled() is False

        with patch.object(basic, 'HAVE_SELINUX', True):
            # selinux enabled, should pass through the value of is_selinux_mls_enabled
            with patch.object(basic, 'selinux', create=True) as selinux:
                selinux.is_selinux_mls_enabled.return_value = 1
                assert no_args_module(selinux_enabled=True).selinux_mls_enabled() is True

    def test_selinux_initial_context(self):
        # selinux missing/disabled/enabled sans MLS is 3-element None
        assert no_args_module(selinux_enabled=False, selinux_mls_enabled=False).selinux_initial_context() == [None, None, None]
        assert no_args_module(selinux_enabled=True, selinux_mls_enabled=False).selinux_initial_context() == [None, None, None]
        # selinux enabled with MLS is 4-element None
        assert no_args_module(selinux_enabled=True, selinux_mls_enabled=True).selinux_initial_context() == [None, None, None, None]

    def test_selinux_default_context(self):
        # selinux unavailable
        with patch.object(basic, 'HAVE_SELINUX', False):
            assert no_args_module().selinux_default_context(path='/foo/bar') == [None, None, None]

        am = no_args_module(selinux_enabled=True, selinux_mls_enabled=True)
        with patch.object(basic, 'selinux', create=True) as selinux:
            # matchpathcon success
            selinux.matchpathcon.return_value = [0, 'unconfined_u:object_r:default_t:s0']
            assert am.selinux_default_context(path='/foo/bar') == ['unconfined_u', 'object_r', 'default_t', 's0']

        with patch.object(basic, 'selinux', create=True) as selinux:
            # matchpathcon fail (return initial context value)
            selinux.matchpathcon.return_value = [-1, '']
            assert am.selinux_default_context(path='/foo/bar') == [None, None, None, None]

        with patch.object(basic, 'selinux', create=True) as selinux:
            # matchpathcon OSError
            selinux.matchpathcon.side_effect = OSError
            assert am.selinux_default_context(path='/foo/bar') == [None, None, None, None]

    def test_selinux_context(self):
        # selinux unavailable
        with patch.object(basic, 'HAVE_SELINUX', False):
            assert no_args_module().selinux_context(path='/foo/bar') == [None, None, None]

        am = no_args_module(selinux_enabled=True, selinux_mls_enabled=True)
        # lgetfilecon_raw passthru
        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lgetfilecon_raw.return_value = [0, 'unconfined_u:object_r:default_t:s0']
            assert am.selinux_context(path='/foo/bar') == ['unconfined_u', 'object_r', 'default_t', 's0']

        # lgetfilecon_raw returned a failure
        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lgetfilecon_raw.return_value = [-1, '']
            assert am.selinux_context(path='/foo/bar') == [None, None, None, None]

        # lgetfilecon_raw OSError (should bomb the module)
        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lgetfilecon_raw.side_effect = OSError(errno.ENOENT, 'NotFound')
            with pytest.raises(SystemExit):
                am.selinux_context(path='/foo/bar')

        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lgetfilecon_raw.side_effect = OSError()
            with pytest.raises(SystemExit):
                am.selinux_context(path='/foo/bar')

    def test_is_special_selinux_path(self):
        args = to_bytes(json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos",
                                                             '_ansible_remote_tmp': "/tmp",
                                                             '_ansible_keep_remote_files': False})))

        with patch.object(basic, '_ANSIBLE_ARGS', args):
            am = basic.AnsibleModule(
                argument_spec=dict(),
            )

            def _mock_find_mount_point(path):
                if path.startswith('/some/path'):
                    return '/some/path'
                elif path.startswith('/weird/random/fstype'):
                    return '/weird/random/fstype'
                return '/'

            am.find_mount_point = _mock_find_mount_point
            am.selinux_context = lambda path: ['foo_u', 'foo_r', 'foo_t', 's0']

            m = mock_open()
            m.side_effect = OSError

            with patch.object(builtins, 'open', m, create=True):
                assert am.is_special_selinux_path('/some/path/that/should/be/nfs') == (False, None)

            mount_data = [
                '/dev/disk1 / ext4 rw,seclabel,relatime,data=ordered 0 0\n',
                '10.1.1.1:/path/to/nfs /some/path nfs ro 0 0\n',
                'whatever /weird/random/fstype foos rw 0 0\n',
            ]

            # mock_open has a broken readlines() implementation apparently...
            # this should work by default but doesn't, so we fix it
            m = mock_open(read_data=''.join(mount_data))
            m.return_value.readlines.return_value = mount_data

            with patch.object(builtins, 'open', m, create=True):
                assert am.is_special_selinux_path('/some/random/path') == (False, None)
                assert am.is_special_selinux_path('/some/path/that/should/be/nfs') == (True, ['foo_u', 'foo_r', 'foo_t', 's0'])
                assert am.is_special_selinux_path('/weird/random/fstype/path') == (True, ['foo_u', 'foo_r', 'foo_t', 's0'])

    def test_set_context_if_different(self):
        am = no_args_module(selinux_enabled=False)
        assert am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True) is True
        assert am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False) is False

        am = no_args_module(selinux_enabled=True, selinux_mls_enabled=True)
        am.selinux_context = lambda path: ['bar_u', 'bar_r', None, None]
        am.is_special_selinux_path = lambda path: (False, None)

        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lsetfilecon.return_value = 0
            assert am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False) is True
            selinux.lsetfilecon.assert_called_with('/path/to/file', 'foo_u:foo_r:foo_t:s0')
            selinux.lsetfilecon.reset_mock()
            am.check_mode = True
            assert am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False) is True
            assert not selinux.lsetfilecon.called
            am.check_mode = False

        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lsetfilecon.return_value = 1
            with pytest.raises(SystemExit):
                am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True)

        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lsetfilecon.side_effect = OSError
            with pytest.raises(SystemExit):
                am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True)

        am.is_special_selinux_path = lambda path: (True, ['sp_u', 'sp_r', 'sp_t', 's0'])

        with patch.object(basic, 'selinux', create=True) as selinux:
            selinux.lsetfilecon.return_value = 0
            assert am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False) is True
            selinux.lsetfilecon.assert_called_with('/path/to/file', 'sp_u:sp_r:sp_t:s0')