diff --git a/e2e/test_provider_proxy.py b/e2e/test_provider_proxy.py index a6ff18535..411ef73a8 100644 --- a/e2e/test_provider_proxy.py +++ b/e2e/test_provider_proxy.py @@ -4,14 +4,14 @@ from time import sleep from typing import Any, Dict, Optional from unittest.case import skipUnless +from channels.testing import ChannelsLiveServerTestCase from docker.client import DockerClient, from_env from docker.models.containers import Container from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys -from channels.testing import ChannelsLiveServerTestCase -from passbook import __version__ from e2e.utils import USER, SeleniumTestCase +from passbook import __version__ from passbook.core.models import Application from passbook.flows.models import Flow from passbook.outposts.models import Outpost, OutpostDeploymentType, OutpostType @@ -124,6 +124,7 @@ class TestProviderProxyConnect(ChannelsLiveServerTestCase): return container def test_proxy_connectivity(self): + """Test proxy connectivity over websocket""" SeleniumTestCase().apply_default_data() proxy: ProxyProvider = ProxyProvider.objects.create( name="proxy_provider", diff --git a/passbook/providers/oauth2/constants.py b/passbook/providers/oauth2/constants.py index 1bd3379a3..6bfc7c81a 100644 --- a/passbook/providers/oauth2/constants.py +++ b/passbook/providers/oauth2/constants.py @@ -7,7 +7,6 @@ PROMPT_CONSNET = "consent" SCOPE_OPENID = "openid" SCOPE_OPENID_PROFILE = "profile" SCOPE_OPENID_EMAIL = "email" -SCOPE_OPENID_INTROSPECTION = "token_introspection" # Read/write full user (including email) SCOPE_GITHUB_USER = "user" diff --git a/passbook/providers/oauth2/models.py b/passbook/providers/oauth2/models.py index 73b22cf9f..441006a11 100644 --- a/passbook/providers/oauth2/models.py +++ b/passbook/providers/oauth2/models.py @@ -202,11 +202,6 @@ class OAuth2Provider(Provider): ), ) - @property - def scope_names(self) -> List[str]: - """Return list of assigned scopes seperated with a space""" - return [pm.scope_name for pm in self.property_mappings.all()] - def create_refresh_token( self, user: User, scope: List[str], id_token: Optional["IDToken"] = None ) -> "RefreshToken": diff --git a/passbook/providers/oauth2/utils.py b/passbook/providers/oauth2/utils.py index 909725a27..adcae91cc 100644 --- a/passbook/providers/oauth2/utils.py +++ b/passbook/providers/oauth2/utils.py @@ -2,7 +2,7 @@ import re from base64 import b64decode from binascii import Error -from typing import List, Tuple +from typing import List, Optional, Tuple from django.http import HttpRequest, HttpResponse, JsonResponse from django.utils.cache import patch_vary_headers @@ -50,7 +50,7 @@ def cors_allow_any(request, response): return response -def extract_access_token(request: HttpRequest) -> str: +def extract_access_token(request: HttpRequest) -> Optional[str]: """ Get the access token using Authorization Request Header Field method. Or try getting via GET. @@ -66,7 +66,7 @@ def extract_access_token(request: HttpRequest) -> str: return request.POST.get("access_token") if "access_token" in request.GET: return request.GET.get("access_token") - return "" + return None def extract_client_auth(request: HttpRequest) -> Tuple[str, str]: @@ -103,9 +103,12 @@ def protected_resource_view(scopes: List[str]): def wrapper(view): def view_wrapper(request, *args, **kwargs): - access_token = extract_access_token(request) - try: + access_token = extract_access_token(request) + if not access_token: + LOGGER.debug("No token passed") + raise BearerTokenError("invalid_token") + try: kwargs["token"] = RefreshToken.objects.get( access_token=access_token diff --git a/passbook/providers/oauth2/views/introspection.py b/passbook/providers/oauth2/views/introspection.py index b72964159..7673d4adf 100644 --- a/passbook/providers/oauth2/views/introspection.py +++ b/passbook/providers/oauth2/views/introspection.py @@ -1,15 +1,17 @@ """passbook OAuth2 Token Introspection Views""" -from dataclasses import InitVar, dataclass -from typing import Optional +from dataclasses import dataclass, field from django.http import HttpRequest, HttpResponse from django.views import View from structlog import get_logger -from passbook.providers.oauth2.constants import SCOPE_OPENID_INTROSPECTION from passbook.providers.oauth2.errors import TokenIntrospectionError from passbook.providers.oauth2.models import IDToken, OAuth2Provider, RefreshToken -from passbook.providers.oauth2.utils import TokenResponse, extract_client_auth +from passbook.providers.oauth2.utils import ( + TokenResponse, + extract_access_token, + extract_client_auth, +) LOGGER = get_logger() @@ -18,39 +20,17 @@ LOGGER = get_logger() class TokenIntrospectionParams: """Parameters for Token Introspection""" - client_id: str - client_secret: str + token: RefreshToken - raw_token: InitVar[str] + provider: OAuth2Provider = field(init=False) + id_token: IDToken = field(init=False) - token: Optional[RefreshToken] = None - - provider: Optional[OAuth2Provider] = None - id_token: Optional[IDToken] = None - - def __post_init__(self, raw_token: str): - try: - self.token = RefreshToken.objects.get(access_token=raw_token) - except RefreshToken.DoesNotExist: - LOGGER.debug("Token does not exist", token=raw_token) - raise TokenIntrospectionError() + def __post_init__(self): if self.token.is_expired: - LOGGER.debug("Token is not valid", token=raw_token) - raise TokenIntrospectionError() - try: - self.provider = OAuth2Provider.objects.get( - client_id=self.client_id, client_secret=self.client_secret, - ) - except OAuth2Provider.DoesNotExist: - LOGGER.debug("provider for ID not found", client_id=self.client_id) - raise TokenIntrospectionError() - if SCOPE_OPENID_INTROSPECTION not in self.provider.scope_names: - LOGGER.debug( - "OAuth2Provider does not have introspection scope", - client_id=self.client_id, - ) + LOGGER.debug("Token is not valid") raise TokenIntrospectionError() + self.provider = self.token.provider self.id_token = self.token.id_token if not self.token.id_token: @@ -59,31 +39,61 @@ class TokenIntrospectionParams: ) raise TokenIntrospectionError() - audience = self.token.id_token.aud - if not audience: - LOGGER.debug( - "No audience found for token", token=self.token, - ) + def authenticate_basic(self, request: HttpRequest) -> bool: + """Attempt to authenticate via Basic auth of client_id:client_secret""" + client_id, client_secret = extract_client_auth(request) + if client_id == client_secret == "": + return False + if ( + client_id != self.provider.client_id + or client_secret != self.provider.client_secret + ): + LOGGER.debug("(basic) Provider for basic auth does not exist") raise TokenIntrospectionError() + return True - if audience not in self.provider.scope_names: - LOGGER.debug( - "provider does not audience scope", - client_id=self.client_id, - audience=audience, - ) + def authenticate_bearer(self, request: HttpRequest) -> bool: + """Attempt to authenticate via token sent as bearer header""" + body_token = extract_access_token(request) + if not body_token: + return False + tokens = RefreshToken.objects.filter(access_token=body_token).select_related( + "provider" + ) + if not tokens.exists(): + LOGGER.debug("(bearer) Token does not exist") raise TokenIntrospectionError() + if tokens.first().provider != self.provider: + LOGGER.debug("(bearer) Token providers don't match") + raise TokenIntrospectionError() + return True @staticmethod def from_request(request: HttpRequest) -> "TokenIntrospectionParams": """Extract required Parameters from HTTP Request""" - # Introspection only supports POST requests - client_id, client_secret = extract_client_auth(request) - return TokenIntrospectionParams( - raw_token=request.POST.get("token"), - client_id=client_id, - client_secret=client_secret, - ) + raw_token = request.POST.get("token") + token_type_hint = request.POST.get("token_type_hint", "access_token") + token_filter = {token_type_hint: raw_token} + + if token_type_hint not in ["access_token", "refresh_token"]: + LOGGER.debug("token_type_hint has invalid value", value=token_type_hint) + raise TokenIntrospectionError() + + try: + token: RefreshToken = RefreshToken.objects.select_related("provider").get( + **token_filter + ) + except RefreshToken.DoesNotExist: + LOGGER.debug("Token does not exist", token=raw_token) + raise TokenIntrospectionError() + + params = TokenIntrospectionParams(token=token) + if not any( + [params.authenticate_basic(request), params.authenticate_bearer(request)] + ): + LOGGER.debug("Not authenticated") + raise TokenIntrospectionError() + return params class TokenIntrospectionView(View): @@ -101,12 +111,12 @@ class TokenIntrospectionView(View): self.params = TokenIntrospectionParams.from_request(request) response_dic = {} - if self.id_token: - token_dict = self.id_token.to_dict() + if self.params.id_token: + token_dict = self.params.id_token.to_dict() for k in ("aud", "sub", "exp", "iat", "iss"): response_dic[k] = token_dict[k] response_dic["active"] = True - response_dic["client_id"] = self.token.provider.client_id + response_dic["client_id"] = self.params.token.provider.client_id return TokenResponse(response_dic) except TokenIntrospectionError: