diff --git a/authentik/stages/authenticator_static/stage.py b/authentik/stages/authenticator_static/stage.py index 841bb35ad..35464f96d 100644 --- a/authentik/stages/authenticator_static/stage.py +++ b/authentik/stages/authenticator_static/stage.py @@ -4,9 +4,13 @@ from django_otp.plugins.otp_static.models import StaticDevice, StaticToken from rest_framework.fields import CharField, ListField from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge +from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView from authentik.stages.authenticator_static.models import AuthenticatorStaticStage +SESSION_STATIC_DEVICE = "static_device" +SESSION_STATIC_TOKENS = "static_device_tokens" + class AuthenticatorStaticChallenge(WithUserInfoChallenge): """Static authenticator challenge""" @@ -27,8 +31,7 @@ class AuthenticatorStaticStageView(ChallengeStageView): response_class = AuthenticatorStaticChallengeResponse def get_challenge(self, *args, **kwargs) -> AuthenticatorStaticChallenge: - user = self.get_pending_user() - tokens: list[StaticToken] = StaticToken.objects.filter(device__user=user) + tokens: list[StaticToken] = self.request.session[SESSION_STATIC_TOKENS] return AuthenticatorStaticChallenge( data={ "type": ChallengeTypes.NATIVE.value, @@ -44,25 +47,22 @@ class AuthenticatorStaticStageView(ChallengeStageView): stage: AuthenticatorStaticStage = self.executor.current_stage - devices = StaticDevice.objects.filter(user=user) - # Currently, this stage only supports one device per user. If the user already - # has a device, just skip to the next stage - if devices.exists(): - if not any(x.confirmed for x in devices): - return super().get(request, *args, **kwargs) - return self.executor.stage_ok() - - device = StaticDevice.objects.create(user=user, confirmed=False, name="Static Token") - for _ in range(0, stage.token_count): - StaticToken.objects.create(device=device, token=StaticToken.random_token()) + if SESSION_STATIC_DEVICE not in self.request.session: + device = StaticDevice(user=user, confirmed=False, name="Static Token") + tokens = [] + for _ in range(0, stage.token_count): + tokens.append(StaticToken(device=device, token=StaticToken.random_token())) + self.request.session[SESSION_STATIC_DEVICE] = device + self.request.session[SESSION_STATIC_TOKENS] = tokens return super().get(request, *args, **kwargs) def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: """Verify OTP Token""" - user = self.get_pending_user() - device: StaticDevice = StaticDevice.objects.filter(user=user).first() - if not device: - return self.executor.stage_invalid() + device: StaticDevice = self.request.session[SESSION_STATIC_DEVICE] device.confirmed = True device.save() + for token in self.request.session[SESSION_STATIC_TOKENS]: + token.save() + del self.request.session[SESSION_STATIC_DEVICE] + del self.request.session[SESSION_STATIC_TOKENS] return self.executor.stage_ok() diff --git a/authentik/stages/authenticator_totp/stage.py b/authentik/stages/authenticator_totp/stage.py index 11fcb2fac..856f02793 100644 --- a/authentik/stages/authenticator_totp/stage.py +++ b/authentik/stages/authenticator_totp/stage.py @@ -13,10 +13,13 @@ from authentik.flows.challenge import ( ChallengeTypes, WithUserInfoChallenge, ) +from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView from authentik.stages.authenticator_totp.models import AuthenticatorTOTPStage from authentik.stages.authenticator_totp.settings import OTP_TOTP_ISSUER +SESSION_TOTP_DEVICE = "totp_device" + class AuthenticatorTOTPChallenge(WithUserInfoChallenge): """TOTP Setup challenge""" @@ -49,8 +52,7 @@ class AuthenticatorTOTPStageView(ChallengeStageView): response_class = AuthenticatorTOTPChallengeResponse def get_challenge(self, *args, **kwargs) -> Challenge: - user = self.get_pending_user() - device: TOTPDevice = TOTPDevice.objects.filter(user=user).first() + device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE] return AuthenticatorTOTPChallenge( data={ "type": ChallengeTypes.NATIVE.value, @@ -62,8 +64,7 @@ class AuthenticatorTOTPStageView(ChallengeStageView): def get_response_instance(self, data: QueryDict) -> ChallengeResponse: response = super().get_response_instance(data) - user = self.get_pending_user() - response.device = TOTPDevice.objects.filter(user=user).first() + response.device = self.request.session.get(SESSION_TOTP_DEVICE) return response def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: @@ -74,17 +75,18 @@ class AuthenticatorTOTPStageView(ChallengeStageView): stage: AuthenticatorTOTPStage = self.executor.current_stage - TOTPDevice.objects.create( - user=user, confirmed=False, digits=stage.digits, name="TOTP Authenticator" - ) + if SESSION_TOTP_DEVICE not in self.request.session: + device = TOTPDevice( + user=user, confirmed=False, digits=stage.digits, name="TOTP Authenticator" + ) + + self.request.session[SESSION_TOTP_DEVICE] = device return super().get(request, *args, **kwargs) def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: """TOTP Token is validated by challenge""" - user = self.get_pending_user() - device: TOTPDevice = TOTPDevice.objects.filter(user=user).first() - if not device: - return self.executor.stage_invalid() + device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE] device.confirmed = True device.save() + del self.request.session[SESSION_TOTP_DEVICE] return self.executor.stage_ok()