ATH-01-014: save authenticator validation state in flow context

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

bugfixes

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-06-16 15:16:27 +02:00
parent ce77d82b24
commit f15cac39c8
No known key found for this signature in database
6 changed files with 80 additions and 63 deletions

View File

@ -82,8 +82,9 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
def retrieve_file(self) -> str: def retrieve_file(self) -> str:
"""Get blueprint from path""" """Get blueprint from path"""
try: try:
full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path)).resolve() base = Path(CONFIG.y("blueprints_dir"))
if not str(full_path).startswith(CONFIG.y("blueprints_dir")): full_path = base.joinpath(Path(self.path)).resolve()
if not str(full_path).startswith(str(base.resolve())):
raise BlueprintRetrievalFailed("Invalid blueprint path") raise BlueprintRetrievalFailed("Invalid blueprint path")
with full_path.open("r", encoding="utf-8") as _file: with full_path.open("r", encoding="utf-8") as _file:
return _file.read() return _file.read()

View File

@ -204,12 +204,12 @@ class ChallengeStageView(StageView):
for field, errors in response.errors.items(): for field, errors in response.errors.items():
for error in errors: for error in errors:
full_errors.setdefault(field, []) full_errors.setdefault(field, [])
full_errors[field].append( field_error = {
{
"string": str(error), "string": str(error),
"code": error.code,
} }
) if hasattr(error, "code"):
field_error["code"] = error.code
full_errors[field].append(field_error)
challenge_response.initial_data["response_errors"] = full_errors challenge_response.initial_data["response_errors"] = full_errors
if not challenge_response.is_valid(): if not challenge_response.is_valid():
self.logger.error( self.logger.error(

View File

@ -132,9 +132,9 @@ class TestPolicyProcess(TestCase):
) )
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
http_request = self.factory.get(reverse("authentik_core:impersonate-end")) http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
http_request.user = self.user http_request.user = self.user
http_request.resolver_match = resolve(reverse("authentik_core:impersonate-end")) http_request.resolver_match = resolve(reverse("authentik_api:user-impersonate-end"))
request = PolicyRequest(self.user) request = PolicyRequest(self.user)
request.set_http_request(http_request) request.set_http_request(http_request)

View File

@ -36,9 +36,9 @@ from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_ME
COOKIE_NAME_MFA = "authentik_mfa" COOKIE_NAME_MFA = "authentik_mfa"
SESSION_KEY_STAGES = "authentik/stages/authenticator_validate/stages" PLAN_CONTEXT_STAGES = "goauthentik.io/stages/authenticator_validate/stages"
SESSION_KEY_SELECTED_STAGE = "authentik/stages/authenticator_validate/selected_stage" PLAN_CONTEXT_SELECTED_STAGE = "goauthentik.io/stages/authenticator_validate/selected_stage"
SESSION_KEY_DEVICE_CHALLENGES = "authentik/stages/authenticator_validate/device_challenges" PLAN_CONTEXT_DEVICE_CHALLENGES = "goauthentik.io/stages/authenticator_validate/device_challenges"
class SelectableStageSerializer(PassiveSerializer): class SelectableStageSerializer(PassiveSerializer):
@ -72,8 +72,8 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-authenticator-validate") component = CharField(default="ak-stage-authenticator-validate")
def _challenge_allowed(self, classes: list): def _challenge_allowed(self, classes: list):
device_challenges: list[dict] = self.stage.request.session.get( device_challenges: list[dict] = self.stage.executor.plan.context.get(
SESSION_KEY_DEVICE_CHALLENGES, [] PLAN_CONTEXT_DEVICE_CHALLENGES, []
) )
if not any(x["device_class"] in classes for x in device_challenges): if not any(x["device_class"] in classes for x in device_challenges):
raise ValidationError("No compatible device class allowed") raise ValidationError("No compatible device class allowed")
@ -103,7 +103,9 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
"""Check which challenge the user has selected. Actual logic only used for SMS stage.""" """Check which challenge the user has selected. Actual logic only used for SMS stage."""
# First check if the challenge is valid # First check if the challenge is valid
allowed = False allowed = False
for device_challenge in self.stage.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []): for device_challenge in self.stage.executor.plan.context.get(
PLAN_CONTEXT_DEVICE_CHALLENGES, []
):
if device_challenge.get("device_class", "") == challenge.get( if device_challenge.get("device_class", "") == challenge.get(
"device_class", "" "device_class", ""
) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""): ) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""):
@ -121,11 +123,11 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
def validate_selected_stage(self, stage_pk: str) -> str: def validate_selected_stage(self, stage_pk: str) -> str:
"""Check that the selected stage is valid""" """Check that the selected stage is valid"""
stages = self.stage.request.session.get(SESSION_KEY_STAGES, []) stages = self.stage.executor.plan.context.get(PLAN_CONTEXT_STAGES, [])
if not any(str(stage.pk) == stage_pk for stage in stages): if not any(str(stage.pk) == stage_pk for stage in stages):
raise ValidationError("Selected stage is invalid") raise ValidationError("Selected stage is invalid")
self.stage.logger.debug("Setting selected stage to ", stage=stage_pk) self.stage.logger.debug("Setting selected stage to ", stage=stage_pk)
self.stage.request.session[SESSION_KEY_SELECTED_STAGE] = stage_pk self.stage.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = stage_pk
return stage_pk return stage_pk
def validate(self, attrs: dict): def validate(self, attrs: dict):
@ -230,7 +232,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
else: else:
self.logger.debug("No pending user, continuing") self.logger.debug("No pending user, continuing")
return self.executor.stage_ok() return self.executor.stage_ok()
self.request.session[SESSION_KEY_DEVICE_CHALLENGES] = challenges self.executor.plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = challenges
# No allowed devices # No allowed devices
if len(challenges) < 1: if len(challenges) < 1:
@ -263,23 +265,23 @@ class AuthenticatorValidateStageView(ChallengeStageView):
if stage.configuration_stages.count() == 1: if stage.configuration_stages.count() == 1:
next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk) next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk)
self.logger.debug("Single stage configured, auto-selecting", stage=next_stage) self.logger.debug("Single stage configured, auto-selecting", stage=next_stage)
self.request.session[SESSION_KEY_SELECTED_STAGE] = next_stage self.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = next_stage
# Because that normal execution only happens on post, we directly inject it here and # Because that normal execution only happens on post, we directly inject it here and
# return it # return it
self.executor.plan.insert_stage(next_stage) self.executor.plan.insert_stage(next_stage)
return self.executor.stage_ok() return self.executor.stage_ok()
stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses() stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses()
self.request.session[SESSION_KEY_STAGES] = stages self.executor.plan.context[PLAN_CONTEXT_STAGES] = stages
return super().get(self.request, *args, **kwargs) return super().get(self.request, *args, **kwargs)
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
res = super().post(request, *args, **kwargs) res = super().post(request, *args, **kwargs)
if ( if (
SESSION_KEY_SELECTED_STAGE in self.request.session PLAN_CONTEXT_SELECTED_STAGE in self.executor.plan.context
and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
): ):
self.logger.debug("Got selected stage in session, running that") self.logger.debug("Got selected stage in context, running that")
stage_pk = self.request.session.get(SESSION_KEY_SELECTED_STAGE) stage_pk = self.executor.plan.context(PLAN_CONTEXT_SELECTED_STAGE)
# Because the foreign key to stage.configuration_stage points to # Because the foreign key to stage.configuration_stage points to
# a base stage class, we need to do another lookup # a base stage class, we need to do another lookup
stage = Stage.objects.get_subclass(pk=stage_pk) stage = Stage.objects.get_subclass(pk=stage_pk)
@ -290,8 +292,8 @@ class AuthenticatorValidateStageView(ChallengeStageView):
return res return res
def get_challenge(self) -> AuthenticatorValidationChallenge: def get_challenge(self) -> AuthenticatorValidationChallenge:
challenges = self.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []) challenges = self.executor.plan.context.get(PLAN_CONTEXT_DEVICE_CHALLENGES, [])
stages = self.request.session.get(SESSION_KEY_STAGES, []) stages = self.executor.plan.context.get(PLAN_CONTEXT_STAGES, [])
stage_challenges = [] stage_challenges = []
for stage in stages: for stage in stages:
serializer = SelectableStageSerializer( serializer = SelectableStageSerializer(
@ -306,6 +308,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
stage_challenges.append(serializer.data) stage_challenges.append(serializer.data)
return AuthenticatorValidationChallenge( return AuthenticatorValidationChallenge(
data={ data={
"component": "ak-stage-authenticator-validate",
"type": ChallengeTypes.NATIVE.value, "type": ChallengeTypes.NATIVE.value,
"device_challenges": challenges, "device_challenges": challenges,
"configuration_stages": stage_challenges, "configuration_stages": stage_challenges,
@ -385,8 +388,3 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device": webauthn_device, "device": webauthn_device,
} }
return self.set_valid_mfa_cookie(response.device) return self.set_valid_mfa_cookie(response.device)
def cleanup(self):
self.request.session.pop(SESSION_KEY_STAGES, None)
self.request.session.pop(SESSION_KEY_SELECTED_STAGE, None)
self.request.session.pop(SESSION_KEY_DEVICE_CHALLENGES, None)

View File

@ -1,26 +1,19 @@
"""Test validator stage""" """Test validator stage"""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from django.contrib.sessions.middleware import SessionMiddleware
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.urls.base import reverse from django.urls.base import reverse
from rest_framework.exceptions import ValidationError
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import FlowDesignation, FlowStageBinding, NotConfiguredAction from authentik.flows.models import FlowDesignation, FlowStageBinding, NotConfiguredAction
from authentik.flows.planner import FlowPlan from authentik.flows.planner import FlowPlan
from authentik.flows.stage import StageView
from authentik.flows.tests import FlowTestCase from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id, generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.lib.tests.utils import dummy_get_response
from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice
from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_validate.stage import ( from authentik.stages.authenticator_validate.stage import PLAN_CONTEXT_DEVICE_CHALLENGES
SESSION_KEY_DEVICE_CHALLENGES,
AuthenticatorValidationChallengeResponse,
)
from authentik.stages.identification.models import IdentificationStage, UserFields from authentik.stages.identification.models import IdentificationStage, UserFields
@ -86,12 +79,17 @@ class AuthenticatorValidateStageTests(FlowTestCase):
def test_validate_selected_challenge(self): def test_validate_selected_challenge(self):
"""Test validate_selected_challenge""" """Test validate_selected_challenge"""
# Prepare request with session flow = create_test_flow()
request = self.request_factory.get("/") stage = AuthenticatorValidateStage.objects.create(
name=generate_id(),
not_configured_action=NotConfiguredAction.CONFIGURE,
device_classes=[DeviceClasses.STATIC, DeviceClasses.TOTP],
)
middleware = SessionMiddleware(dummy_get_response) session = self.client.session
middleware.process_request(request) plan = FlowPlan(flow_pk=flow.pk.hex)
request.session[SESSION_KEY_DEVICE_CHALLENGES] = [ plan.append_stage(stage)
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{ {
"device_class": "static", "device_class": "static",
"device_uid": "1", "device_uid": "1",
@ -101,23 +99,43 @@ class AuthenticatorValidateStageTests(FlowTestCase):
"device_uid": "2", "device_uid": "2",
}, },
] ]
request.session.save() session[SESSION_KEY_PLAN] = plan
session.save()
res = AuthenticatorValidationChallengeResponse() response = self.client.post(
res.stage = StageView(FlowExecutorView()) reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
res.stage.request = request data={
with self.assertRaises(ValidationError): "selected_challenge": {
res.validate_selected_challenge(
{
"device_class": "baz", "device_class": "baz",
"device_uid": "quox", "device_uid": "quox",
"challenge": {},
} }
},
) )
res.validate_selected_challenge( self.assertStageResponse(
{ response,
flow,
response_errors={
"selected_challenge": [{"string": "invalid challenge selected", "code": "invalid"}]
},
component="ak-stage-authenticator-validate",
)
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
data={
"selected_challenge": {
"device_class": "static", "device_class": "static",
"device_uid": "1", "device_uid": "1",
} "challenge": {},
},
},
)
self.assertStageResponse(
response,
flow,
response_errors={"non_field_errors": [{"string": "Empty response", "code": "invalid"}]},
component="ak-stage-authenticator-validate",
) )
@patch( @patch(

View File

@ -22,7 +22,7 @@ from authentik.stages.authenticator_validate.challenge import (
) )
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_validate.stage import ( from authentik.stages.authenticator_validate.stage import (
SESSION_KEY_DEVICE_CHALLENGES, PLAN_CONTEXT_DEVICE_CHALLENGES,
AuthenticatorValidateStageView, AuthenticatorValidateStageView,
) )
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
@ -211,14 +211,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
plan.append_stage(stage) plan.append_stage(stage)
plan.append_stage(UserLoginStage(name=generate_id())) plan.append_stage(UserLoginStage(name=generate_id()))
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
session[SESSION_KEY_PLAN] = plan plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
session[SESSION_KEY_DEVICE_CHALLENGES] = [
{ {
"device_class": device.__class__.__name__.lower().replace("device", ""), "device_class": device.__class__.__name__.lower().replace("device", ""),
"device_uid": device.pk, "device_uid": device.pk,
"challenge": {}, "challenge": {},
} }
] ]
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
) )
@ -283,14 +283,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
plan = FlowPlan(flow_pk=flow.pk.hex) plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage) plan.append_stage(stage)
plan.append_stage(UserLoginStage(name=generate_id())) plan.append_stage(UserLoginStage(name=generate_id()))
session[SESSION_KEY_PLAN] = plan plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
session[SESSION_KEY_DEVICE_CHALLENGES] = [
{ {
"device_class": device.__class__.__name__.lower().replace("device", ""), "device_class": device.__class__.__name__.lower().replace("device", ""),
"device_uid": device.pk, "device_uid": device.pk,
"challenge": {}, "challenge": {},
} }
] ]
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
) )