diff options
-rw-r--r-- | awx/main/access.py | 100 | ||||
-rw-r--r-- | awx/main/models/workflow.py | 19 | ||||
-rw-r--r-- | awx/main/tests/conftest.py | 2 | ||||
-rw-r--r-- | awx/main/tests/functional/api/test_rbac_displays.py | 10 |
4 files changed, 40 insertions, 91 deletions
diff --git a/awx/main/access.py b/awx/main/access.py index 6ce7600026..7b7ede48e4 100644 --- a/awx/main/access.py +++ b/awx/main/access.py @@ -31,7 +31,7 @@ __all__ = ['get_user_queryset', 'check_user_access', 'check_user_access_with_err logger = logging.getLogger('awx.main.access') access_registry = { - # <model_class>: [<access_class>, ...], + # <model_class>: <access_class>, # ... } @@ -41,8 +41,7 @@ class StateConflict(ValidationError): def register_access(model_class, access_class): - access_classes = access_registry.setdefault(model_class, []) - access_classes.append(access_class) + access_registry[model_class] = access_class @property @@ -66,19 +65,9 @@ def get_user_queryset(user, model_class): Return a queryset for the given model_class containing only the instances that should be visible to the given user. ''' - querysets = [] - for access_class in access_registry.get(model_class, []): - access_instance = access_class(user) - querysets.append(access_instance.get_queryset()) - if not querysets: - return model_class.objects.none() - elif len(querysets) == 1: - return querysets[0] - else: - queryset = model_class.objects.all() - for qs in querysets: - queryset = queryset.filter(pk__in=qs.values_list('pk', flat=True)) - return queryset + access_class = access_registry[model_class] + access_instance = access_class(user) + return access_instance.get_queryset() def check_user_access(user, model_class, action, *args, **kwargs): @@ -86,33 +75,26 @@ def check_user_access(user, model_class, action, *args, **kwargs): Return True if user can perform action against model_class with the provided parameters. ''' - for access_class in access_registry.get(model_class, []): - access_instance = access_class(user) - access_method = getattr(access_instance, 'can_%s' % action, None) - if not access_method: - logger.debug('%s.%s not found', access_instance.__class__.__name__, - 'can_%s' % action) - continue - result = access_method(*args, **kwargs) - logger.debug('%s.%s %r returned %r', access_instance.__class__.__name__, - getattr(access_method, '__name__', 'unknown'), args, result) - if result: - return result - return False + access_class = access_registry[model_class] + access_instance = access_class(user) + access_method = getattr(access_instance, 'can_%s' % action) + result = access_method(*args, **kwargs) + logger.debug('%s.%s %r returned %r', access_instance.__class__.__name__, + getattr(access_method, '__name__', 'unknown'), args, result) + return result def check_user_access_with_errors(user, model_class, action, *args, **kwargs): ''' Return T/F permission and summary of problems with the action. ''' - for access_class in access_registry.get(model_class, []): - access_instance = access_class(user, save_messages=True) - access_method = getattr(access_instance, 'can_%s' % action, None) - result = access_method(*args, **kwargs) - logger.debug('%s.%s %r returned %r', access_instance.__class__.__name__, - access_method.__name__, args, result) - return (result, access_instance.messages) - return (False, '') + access_class = access_registry[model_class] + access_instance = access_class(user, save_messages=True) + access_method = getattr(access_instance, 'can_%s' % action, None) + result = access_method(*args, **kwargs) + logger.debug('%s.%s %r returned %r', access_instance.__class__.__name__, + access_method.__name__, args, result) + return (result, access_instance.messages) def get_user_capabilities(user, instance, **kwargs): @@ -123,9 +105,8 @@ def get_user_capabilities(user, instance, **kwargs): convenient for the user interface to consume and hide or show various actions in the interface. ''' - for access_class in access_registry.get(type(instance), []): - return access_class(user).get_user_capabilities(instance, **kwargs) - return None + access_class = access_registry[instance.__class__] + return access_class(user).get_user_capabilities(instance, **kwargs) def check_superuser(func): @@ -2008,7 +1989,7 @@ class UnifiedJobTemplateAccess(BaseAccess): return qs.all() def can_start(self, obj, validate_license=True): - access_class = access_registry.get(obj.__class__, [])[0] + access_class = access_registry[obj.__class__] access_instance = access_class(self.user) return access_instance.can_start(obj, validate_license=validate_license) @@ -2376,38 +2357,5 @@ class RoleAccess(BaseAccess): return False -register_access(User, UserAccess) -register_access(Organization, OrganizationAccess) -register_access(Inventory, InventoryAccess) -register_access(Host, HostAccess) -register_access(Group, GroupAccess) -register_access(InventorySource, InventorySourceAccess) -register_access(InventoryUpdate, InventoryUpdateAccess) -register_access(Credential, CredentialAccess) -register_access(CredentialType, CredentialTypeAccess) -register_access(Team, TeamAccess) -register_access(Project, ProjectAccess) -register_access(ProjectUpdate, ProjectUpdateAccess) -register_access(JobTemplate, JobTemplateAccess) -register_access(Job, JobAccess) -register_access(JobHostSummary, JobHostSummaryAccess) -register_access(JobEvent, JobEventAccess) -register_access(SystemJobTemplate, SystemJobTemplateAccess) -register_access(SystemJob, SystemJobAccess) -register_access(AdHocCommand, AdHocCommandAccess) -register_access(AdHocCommandEvent, AdHocCommandEventAccess) -register_access(Schedule, ScheduleAccess) -register_access(UnifiedJobTemplate, UnifiedJobTemplateAccess) -register_access(UnifiedJob, UnifiedJobAccess) -register_access(ActivityStream, ActivityStreamAccess) -register_access(CustomInventoryScript, CustomInventoryScriptAccess) -register_access(Role, RoleAccess) -register_access(NotificationTemplate, NotificationTemplateAccess) -register_access(Notification, NotificationAccess) -register_access(Label, LabelAccess) -register_access(WorkflowJobTemplateNode, WorkflowJobTemplateNodeAccess) -register_access(WorkflowJobNode, WorkflowJobNodeAccess) -register_access(WorkflowJobTemplate, WorkflowJobTemplateAccess) -register_access(WorkflowJob, WorkflowJobAccess) -register_access(Instance, InstanceAccess) -register_access(InstanceGroup, InstanceGroupAccess) +for cls in BaseAccess.__subclasses__(): + access_registry[cls.model] = cls diff --git a/awx/main/models/workflow.py b/awx/main/models/workflow.py index b5a3afba1d..52d58717a0 100644 --- a/awx/main/models/workflow.py +++ b/awx/main/models/workflow.py @@ -187,15 +187,16 @@ class WorkflowJobTemplateNode(WorkflowNodeBase): ''' create_kwargs = {} for field_name in self._get_workflow_job_field_names(): - if hasattr(self, field_name): - item = getattr(self, field_name) - if field_name in ['inventory', 'credential']: - if not user.can_access(item.__class__, 'use', item): - continue - if field_name in ['unified_job_template']: - if not user.can_access(item.__class__, 'start', item, validate_license=False): - continue - create_kwargs[field_name] = item + item = getattr(self, field_name, None) + if item is None: + continue + if field_name in ['inventory', 'credential']: + if not user.can_access(item.__class__, 'use', item): + continue + if field_name in ['unified_job_template']: + if not user.can_access(item.__class__, 'start', item, validate_license=False): + continue + create_kwargs[field_name] = item create_kwargs['workflow_job_template'] = workflow_job_template return self.__class__.objects.create(**create_kwargs) diff --git a/awx/main/tests/conftest.py b/awx/main/tests/conftest.py index 35f77c1f1d..6118bf83bd 100644 --- a/awx/main/tests/conftest.py +++ b/awx/main/tests/conftest.py @@ -24,7 +24,7 @@ def mock_access(): mock_instance = mock.MagicMock(__name__='foobar') MockAccess = mock.MagicMock(return_value=mock_instance) the_patch = mock.patch.dict('awx.main.access.access_registry', - {TowerClass: [MockAccess]}, clear=False) + {TowerClass: MockAccess}, clear=False) the_patch.__enter__() yield mock_instance finally: diff --git a/awx/main/tests/functional/api/test_rbac_displays.py b/awx/main/tests/functional/api/test_rbac_displays.py index 99723c545f..ca83d6df4a 100644 --- a/awx/main/tests/functional/api/test_rbac_displays.py +++ b/awx/main/tests/functional/api/test_rbac_displays.py @@ -188,7 +188,7 @@ class TestAccessListCapabilities: self, inventory, rando, get, mocker, mock_access_method): inventory.admin_role.members.add(rando) - with mocker.patch.object(access_registry[Role][0], 'can_unattach', mock_access_method): + with mocker.patch.object(access_registry[Role], 'can_unattach', mock_access_method): response = get(reverse('api:inventory_access_list', kwargs={'pk': inventory.id}), rando) mock_access_method.assert_called_once_with(inventory.admin_role, rando, 'members', **self.extra_kwargs) @@ -198,7 +198,7 @@ class TestAccessListCapabilities: def test_access_list_indirect_access_capability( self, inventory, organization, org_admin, get, mocker, mock_access_method): - with mocker.patch.object(access_registry[Role][0], 'can_unattach', mock_access_method): + with mocker.patch.object(access_registry[Role], 'can_unattach', mock_access_method): response = get(reverse('api:inventory_access_list', kwargs={'pk': inventory.id}), org_admin) mock_access_method.assert_called_once_with(organization.admin_role, org_admin, 'members', **self.extra_kwargs) @@ -210,7 +210,7 @@ class TestAccessListCapabilities: self, inventory, team, team_member, get, mocker, mock_access_method): team.member_role.children.add(inventory.admin_role) - with mocker.patch.object(access_registry[Role][0], 'can_unattach', mock_access_method): + with mocker.patch.object(access_registry[Role], 'can_unattach', mock_access_method): response = get(reverse('api:inventory_access_list', kwargs={'pk': inventory.id}), team_member) mock_access_method.assert_called_once_with(inventory.admin_role, team.member_role, 'parents', **self.extra_kwargs) @@ -229,7 +229,7 @@ class TestAccessListCapabilities: def test_team_roles_unattach(mocker, team, team_member, inventory, mock_access_method, get): team.member_role.children.add(inventory.admin_role) - with mocker.patch.object(access_registry[Role][0], 'can_unattach', mock_access_method): + with mocker.patch.object(access_registry[Role], 'can_unattach', mock_access_method): response = get(reverse('api:team_roles_list', kwargs={'pk': team.id}), team_member) # Did we assess whether team_member can remove team's permission to the inventory? @@ -244,7 +244,7 @@ def test_user_roles_unattach(mocker, organization, alice, bob, mock_access_metho organization.member_role.members.add(alice) organization.member_role.members.add(bob) - with mocker.patch.object(access_registry[Role][0], 'can_unattach', mock_access_method): + with mocker.patch.object(access_registry[Role], 'can_unattach', mock_access_method): response = get(reverse('api:user_roles_list', kwargs={'pk': alice.id}), bob) # Did we assess whether bob can remove alice's permission to the inventory? |