diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index c35485253..4e34e1f3b 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -12,7 +12,7 @@ from django.utils.timezone import datetime, now from django.views import View from django.views.decorators.csrf import csrf_exempt from guardian.shortcuts import get_anonymous_user -from jwt import PyJWK, PyJWTError, decode +from jwt import PyJWK, PyJWT, PyJWTError, decode from sentry_sdk.hub import Hub from structlog.stdlib import get_logger @@ -306,7 +306,24 @@ class TokenParams: source: Optional[OAuthSource] = None parsed_key: Optional[PyJWK] = None - for source in self.provider.jwks_sources.all(): + + # Fully decode the JWT without verifying the signature, so we can get access to + # the header. + # Get the Key ID from the header, and use that to optimise our source query to only find + # sources that have a JWK for that Key ID + # The Key ID doesn't have a fixed format, but must match between an issued JWT + # and whatever is returned by the JWKS endpoint + try: + decode_unvalidated = PyJWT().decode_complete( + assertion, options={"verify_signature": False} + ) + except (PyJWTError, ValueError, TypeError, AttributeError) as exc: + LOGGER.warning("failed to parse jwt for kid lookup", exc=exc) + raise TokenError("invalid_grant") + expected_kid = decode_unvalidated["header"]["kid"] + for source in self.provider.jwks_sources.filter( + oidc_jwks__keys__contains=[{"kid": expected_kid}] + ): LOGGER.debug("verifying jwt with source", source=source.slug) keys = source.oidc_jwks.get("keys", []) for key in keys: