sources/oauth: fix handling of sources with spaces in their name

This commit is contained in:
Jens Langhammer 2020-05-19 21:53:36 +02:00
parent 4d45dc31a9
commit f58ee7fb52
4 changed files with 24 additions and 15 deletions

View File

@ -21,5 +21,6 @@ class PassbookSourceOAuthConfig(AppConfig):
for source_type in settings.PASSBOOK_SOURCES_OAUTH_TYPES: for source_type in settings.PASSBOOK_SOURCES_OAUTH_TYPES:
try: try:
import_module(source_type) import_module(source_type)
LOGGER.debug("Loaded OAuth Source Type", type=source_type)
except ImportError as exc: except ImportError as exc:
LOGGER.debug(exc) LOGGER.debug(exc)

View File

@ -1,9 +1,11 @@
"""Source type manager""" """Source type manager"""
from enum import Enum from enum import Enum
from typing import Callable, Dict, List
from django.utils.text import slugify from django.utils.text import slugify
from structlog import get_logger from structlog import get_logger
from passbook.sources.oauth.models import OAuthSource
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
LOGGER = get_logger() LOGGER = get_logger()
@ -19,18 +21,20 @@ class RequestKind(Enum):
class SourceTypeManager: class SourceTypeManager:
"""Manager to hold all Source types.""" """Manager to hold all Source types."""
__source_types = {} __source_types: Dict[RequestKind, Dict[str, Callable]] = {}
__names = [] __names: List[str] = []
def source(self, kind, name): def source(self, kind: RequestKind, name: str):
"""Class decorator to register classes inline.""" """Class decorator to register classes inline."""
def inner_wrapper(cls): def inner_wrapper(cls):
if kind not in self.__source_types: if kind.value not in self.__source_types:
self.__source_types[kind] = {} self.__source_types[kind.value] = {}
self.__source_types[kind][name.lower()] = cls self.__source_types[kind.value][slugify(name)] = cls
self.__names.append(name) self.__names.append(name)
LOGGER.debug("Registered source", source_class=cls.__name__, kind=kind) LOGGER.debug(
"Registered source", source_class=cls.__name__, kind=kind.value
)
return cls return cls
return inner_wrapper return inner_wrapper
@ -39,15 +43,16 @@ class SourceTypeManager:
"""Get list of tuples of all registered names""" """Get list of tuples of all registered names"""
return [(slugify(x), x) for x in set(self.__names)] return [(slugify(x), x) for x in set(self.__names)]
def find(self, source, kind): def find(self, source: OAuthSource, kind: RequestKind) -> Callable:
"""Find fitting Source Type""" """Find fitting Source Type"""
if kind in self.__source_types: if kind.value in self.__source_types:
if source.provider_type in self.__source_types[kind]: if source.provider_type in self.__source_types[kind.value]:
return self.__source_types[kind][source.provider_type] return self.__source_types[kind.value][source.provider_type]
LOGGER.warning("no matching type found, using default")
# Return defaults # Return defaults
if kind == RequestKind.callback: if kind.value == RequestKind.callback:
return OAuthCallback return OAuthCallback
if kind == RequestKind.redirect: if kind.value == RequestKind.redirect:
return OAuthRedirect return OAuthRedirect
raise KeyError raise KeyError

View File

@ -1,11 +1,10 @@
"""OAuth Client User Creation Utils""" """OAuth Client User Creation Utils"""
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from passbook.core.models import User from passbook.core.models import User
def user_get_or_create(**kwargs): def user_get_or_create(**kwargs: str) -> User:
"""Create user or return existing user""" """Create user or return existing user"""
try: try:
new_user = User.objects.create_user(**kwargs) new_user = User.objects.create_user(**kwargs)

View File

@ -2,10 +2,13 @@
from django.http import Http404 from django.http import Http404
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.views import View from django.views import View
from structlog import get_logger
from passbook.sources.oauth.models import OAuthSource from passbook.sources.oauth.models import OAuthSource
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
LOGGER = get_logger()
class DispatcherView(View): class DispatcherView(View):
"""Dispatch OAuth Redirect/Callback views to their proper class based on URL parameters""" """Dispatch OAuth Redirect/Callback views to their proper class based on URL parameters"""
@ -19,4 +22,5 @@ class DispatcherView(View):
raise Http404 raise Http404
source = get_object_or_404(OAuthSource, slug=slug) source = get_object_or_404(OAuthSource, slug=slug)
view = MANAGER.find(source, kind=RequestKind(self.kind)) view = MANAGER.find(source, kind=RequestKind(self.kind))
LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind)
return view.as_view()(*args, **kwargs) return view.as_view()(*args, **kwargs)