|
1 | 1 | import logging
|
2 | 2 | from enum import StrEnum
|
3 |
| -from urllib.parse import urlencode, urljoin, urlparse, urlunparse |
| 3 | +from urllib.parse import parse_qs, urlencode, urljoin, urlparse, urlunparse |
4 | 4 |
|
5 | 5 | from allauth.socialaccount.adapter import get_adapter
|
6 | 6 | from allauth.socialaccount.models import SocialApp
|
7 | 7 | from django.conf import settings
|
8 | 8 | from django.contrib import messages
|
| 9 | +from django.http import HttpRequest, HttpResponse |
9 | 10 | from django.shortcuts import redirect
|
10 | 11 | from django.urls import reverse
|
11 |
| -from django.views.generic import TemplateView |
| 12 | +from django.utils.translation import gettext_lazy as _ |
| 13 | +from django.views.generic import TemplateView, View |
12 | 14 | from pydantic import ValidationError
|
13 | 15 |
|
14 | 16 | from pretix.base.models import User
|
|
18 | 20 | from pretix.helpers.urls import build_absolute_uri
|
19 | 21 |
|
20 | 22 | from .schemas.login_providers import LoginProviders
|
| 23 | +from .schemas.oauth2_params import OAuth2Params |
21 | 24 |
|
22 | 25 | logger = logging.getLogger(__name__)
|
23 | 26 | adapter = get_adapter()
|
24 | 27 |
|
25 | 28 |
|
26 |
| -def oauth_login(request, provider): |
27 |
| - gs = GlobalSettingsObject() |
28 |
| - client_id = gs.settings.get('login_providers', as_type=dict).get(provider, {}).get('client_id') |
29 |
| - provider = adapter.get_provider(request, provider, client_id=client_id) |
30 |
| - |
31 |
| - base_url = provider.get_login_url(request) |
32 |
| - query_params = { |
33 |
| - "next": build_absolute_uri("plugins:socialauth:social.oauth.return") |
34 |
| - } |
35 |
| - parsed_url = urlparse(base_url) |
36 |
| - updated_url = parsed_url._replace(query=urlencode(query_params)) |
37 |
| - return redirect(urlunparse(updated_url)) |
| 29 | +class OAuthLoginView(View): |
| 30 | + def get(self, request: HttpRequest, provider: str) -> HttpResponse: |
| 31 | + self.set_oauth2_params(request) |
38 | 32 |
|
| 33 | + gs = GlobalSettingsObject() |
| 34 | + client_id = ( |
| 35 | + gs.settings.get("login_providers", as_type=dict) |
| 36 | + .get(provider, {}) |
| 37 | + .get("client_id") |
| 38 | + ) |
| 39 | + provider_instance = adapter.get_provider(request, provider, client_id=client_id) |
| 40 | + |
| 41 | + base_url = provider_instance.get_login_url(request) |
| 42 | + query_params = { |
| 43 | + "next": build_absolute_uri("plugins:socialauth:social.oauth.return") |
| 44 | + } |
| 45 | + parsed_url = urlparse(base_url) |
| 46 | + updated_url = parsed_url._replace(query=urlencode(query_params)) |
| 47 | + return redirect(urlunparse(updated_url)) |
| 48 | + |
| 49 | + @staticmethod |
| 50 | + def set_oauth2_params(request: HttpRequest) -> None: |
| 51 | + """ |
| 52 | + Handle Login with SSO button from other components |
| 53 | + This function will set 'oauth2_params' in session for oauth2_callback |
| 54 | + """ |
| 55 | + next_url = request.GET.get("next", "") |
| 56 | + if not next_url: |
| 57 | + return |
| 58 | + |
| 59 | + parsed = urlparse(next_url) |
| 60 | + |
| 61 | + # Only allow relative URLs |
| 62 | + if parsed.netloc or parsed.scheme: |
| 63 | + return |
| 64 | + |
| 65 | + params = parse_qs(parsed.query) |
| 66 | + sanitized_params = { |
| 67 | + k: v[0] |
| 68 | + for k, v in params.items() |
| 69 | + if k in OAuth2Params.model_fields.keys() |
| 70 | + } |
| 71 | + |
| 72 | + try: |
| 73 | + oauth2_params = OAuth2Params.model_validate(sanitized_params) |
| 74 | + request.session["oauth2_params"] = oauth2_params.model_dump() |
| 75 | + except ValidationError as e: |
| 76 | + logger.warning("Ignore invalid OAuth2 parameters: %s.", e) |
| 77 | + |
| 78 | + |
| 79 | +class OAuthReturnView(View): |
| 80 | + def get(self, request: HttpRequest) -> HttpResponse: |
| 81 | + try: |
| 82 | + user = self.get_or_create_user(request) |
| 83 | + response = process_login_and_set_cookie(request, user, False) |
| 84 | + oauth2_params = request.session.pop("oauth2_params", {}) |
| 85 | + if oauth2_params: |
| 86 | + try: |
| 87 | + oauth2_params = OAuth2Params.model_validate(oauth2_params) |
| 88 | + query_string = urlencode(oauth2_params.model_dump()) |
| 89 | + auth_url = reverse("control:oauth2_provider.authorize") |
| 90 | + return redirect(f"{auth_url}?{query_string}") |
| 91 | + except ValidationError as e: |
| 92 | + logger.warning("Ignore invalid OAuth2 parameters: %s.", e) |
| 93 | + |
| 94 | + return response |
| 95 | + except AttributeError as e: |
| 96 | + messages.error( |
| 97 | + request, _("Error while authorizing: no email address available.") |
| 98 | + ) |
| 99 | + logger.error("Error while authorizing: %s", e) |
| 100 | + return redirect("control:auth.login") |
39 | 101 |
|
40 |
| -def oauth_return(request): |
41 |
| - try: |
42 |
| - user, _ = User.objects.get_or_create( |
| 102 | + @staticmethod |
| 103 | + def get_or_create_user(request: HttpRequest) -> User: |
| 104 | + """ |
| 105 | + Get or create a user from social auth information. |
| 106 | + """ |
| 107 | + return User.objects.get_or_create( |
43 | 108 | email=request.user.email,
|
44 | 109 | defaults={
|
45 |
| - 'locale': getattr(request, 'LANGUAGE_CODE', settings.LANGUAGE_CODE), |
46 |
| - 'timezone': getattr(request, 'timezone', settings.TIME_ZONE), |
47 |
| - 'auth_backend': 'native', |
48 |
| - 'password': '', |
| 110 | + "locale": getattr(request, "LANGUAGE_CODE", settings.LANGUAGE_CODE), |
| 111 | + "timezone": getattr(request, "timezone", settings.TIME_ZONE), |
| 112 | + "auth_backend": "native", |
| 113 | + "password": "", |
49 | 114 | },
|
50 |
| - ) |
51 |
| - return process_login_and_set_cookie(request, user, False) |
52 |
| - except AttributeError: |
53 |
| - messages.error( |
54 |
| - request, _('Error while authorizing: no email address available.') |
55 |
| - ) |
56 |
| - logger.error('Error while authorizing: user has no email address.') |
57 |
| - return redirect('control:auth.login') |
| 115 | + )[0] |
58 | 116 |
|
59 | 117 |
|
60 | 118 | class SocialLoginView(AdministratorPermissionRequiredMixin, TemplateView):
|
|
0 commit comments