summaryrefslogtreecommitdiffstats
path: root/awx/sso/middleware.py
blob: 1944ff4d0f2a0cc1b9a6808c10b776f6f116bca7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved.

# Python
import urllib

# Six
import six

# Django
from django.utils.functional import LazyObject
from django.shortcuts import redirect

# Python Social Auth
from social_core.exceptions import SocialAuthBaseException
from social_core.utils import social_logger
from social_django.middleware import SocialAuthExceptionMiddleware


class SocialAuthMiddleware(SocialAuthExceptionMiddleware):

    def process_view(self, request, callback, callback_args, callback_kwargs):
        if request.path.startswith('/sso/login/'):
            request.session['social_auth_last_backend'] = callback_kwargs['backend']

    def process_request(self, request):
        token_key = request.COOKIES.get('token', '')
        token_key = urllib.quote(urllib.unquote(token_key).strip('"'))

        if not hasattr(request, 'successful_authenticator'):
            request.successful_authenticator = None

        if not request.path.startswith('/sso/') and 'migrations_notran' not in request.path:
            if request.user and request.user.is_authenticated():
                # The rest of the code base rely hevily on type/inheritance checks,
                # LazyObject sent from Django auth middleware can be buggy if not
                # converted back to its original object.
                if isinstance(request.user, LazyObject) and request.user._wrapped:
                    request.user = request.user._wrapped
                request.session.pop('social_auth_error', None)
                request.session.pop('social_auth_last_backend', None)

    def process_exception(self, request, exception):
        strategy = getattr(request, 'social_strategy', None)
        if strategy is None or self.raise_exception(request, exception):
            return

        if isinstance(exception, SocialAuthBaseException) or request.path.startswith('/sso/'):
            backend = getattr(request, 'backend', None)
            backend_name = getattr(backend, 'name', 'unknown-backend')

            message = self.get_message(request, exception)
            if request.session.get('social_auth_last_backend') != backend_name:
                backend_name = request.session.get('social_auth_last_backend')
                message = request.GET.get('error_description', message)

            full_backend_name = backend_name
            try:
                idp_name = strategy.request_data()['RelayState']
                full_backend_name = '%s:%s' % (backend_name, idp_name)
            except KeyError:
                pass

            social_logger.error(message)

            url = self.get_redirect_uri(request, exception)
            request.session['social_auth_error'] = (full_backend_name, message)
            return redirect(url)

    def get_message(self, request, exception):
        msg = six.text_type(exception)
        if msg and msg[-1] not in '.?!':
            msg = msg + '.'
        return msg

    def get_redirect_uri(self, request, exception):
        strategy = getattr(request, 'social_strategy', None)
        return strategy.session_get('next', '') or strategy.setting('LOGIN_ERROR_URL')