flows: remove need for post() wrapper by using dispatch (#6765)
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
7cbce1bb3d
commit
e373bae189
|
@ -48,7 +48,7 @@ class Action(Enum):
|
||||||
class MessageStage(StageView):
|
class MessageStage(StageView):
|
||||||
"""Show a pre-configured message after the flow is done"""
|
"""Show a pre-configured message after the flow is done"""
|
||||||
|
|
||||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
"""Show a pre-configured message after the flow is done"""
|
"""Show a pre-configured message after the flow is done"""
|
||||||
message = getattr(self.executor.current_stage, "message", "")
|
message = getattr(self.executor.current_stage, "message", "")
|
||||||
level = getattr(self.executor.current_stage, "level", messages.SUCCESS)
|
level = getattr(self.executor.current_stage, "level", messages.SUCCESS)
|
||||||
|
@ -59,10 +59,6 @@ class MessageStage(StageView):
|
||||||
)
|
)
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
|
||||||
"""Wrapper for post requests"""
|
|
||||||
return self.get(request)
|
|
||||||
|
|
||||||
|
|
||||||
class SourceFlowManager:
|
class SourceFlowManager:
|
||||||
"""Help sources decide what they should do after authorization. Based on source settings and
|
"""Help sources decide what they should do after authorization. Based on source settings and
|
||||||
|
|
|
@ -13,7 +13,7 @@ class PostUserEnrollmentStage(StageView):
|
||||||
"""Dynamically injected stage which saves the Connection after
|
"""Dynamically injected stage which saves the Connection after
|
||||||
the user has been enrolled."""
|
the user has been enrolled."""
|
||||||
|
|
||||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Stage used after the user has been enrolled"""
|
"""Stage used after the user has been enrolled"""
|
||||||
connection: UserSourceConnection = self.executor.plan.context[
|
connection: UserSourceConnection = self.executor.plan.context[
|
||||||
PLAN_CONTEXT_SOURCES_CONNECTION
|
PLAN_CONTEXT_SOURCES_CONNECTION
|
||||||
|
@ -27,7 +27,3 @@ class PostUserEnrollmentStage(StageView):
|
||||||
source=connection.source,
|
source=connection.source,
|
||||||
).from_http(self.request)
|
).from_http(self.request)
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
|
||||||
"""Wrapper for post requests"""
|
|
||||||
return self.get(request)
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ def view_tester_factory(view_class: type[StageView]) -> Callable:
|
||||||
|
|
||||||
def tester(self: TestViews):
|
def tester(self: TestViews):
|
||||||
model_class = view_class(self.exec)
|
model_class = view_class(self.exec)
|
||||||
|
if not hasattr(model_class, "dispatch"):
|
||||||
self.assertIsNotNone(model_class.post)
|
self.assertIsNotNone(model_class.post)
|
||||||
self.assertIsNotNone(model_class.get)
|
self.assertIsNotNone(model_class.get)
|
||||||
|
|
||||||
|
|
|
@ -295,7 +295,7 @@ class FlowExecutorView(APIView):
|
||||||
span.set_data("Method", "GET")
|
span.set_data("Method", "GET")
|
||||||
span.set_data("authentik Stage", self.current_stage_view)
|
span.set_data("authentik Stage", self.current_stage_view)
|
||||||
span.set_data("authentik Flow", self.flow.slug)
|
span.set_data("authentik Flow", self.flow.slug)
|
||||||
stage_response = self.current_stage_view.get(request, *args, **kwargs)
|
stage_response = self.current_stage_view.dispatch(request)
|
||||||
return to_stage_response(request, stage_response)
|
return to_stage_response(request, stage_response)
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
return self.handle_exception(exc)
|
return self.handle_exception(exc)
|
||||||
|
@ -339,7 +339,7 @@ class FlowExecutorView(APIView):
|
||||||
span.set_data("Method", "POST")
|
span.set_data("Method", "POST")
|
||||||
span.set_data("authentik Stage", self.current_stage_view)
|
span.set_data("authentik Stage", self.current_stage_view)
|
||||||
span.set_data("authentik Flow", self.flow.slug)
|
span.set_data("authentik Flow", self.flow.slug)
|
||||||
stage_response = self.current_stage_view.post(request, *args, **kwargs)
|
stage_response = self.current_stage_view.dispatch(request)
|
||||||
return to_stage_response(request, stage_response)
|
return to_stage_response(request, stage_response)
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
return self.handle_exception(exc)
|
return self.handle_exception(exc)
|
||||||
|
|
|
@ -7,10 +7,6 @@ from authentik.flows.stage import StageView
|
||||||
class DenyStageView(StageView):
|
class DenyStageView(StageView):
|
||||||
"""Cancels the current flow"""
|
"""Cancels the current flow"""
|
||||||
|
|
||||||
def get(self, request: HttpRequest) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Cancels the current flow"""
|
"""Cancels the current flow"""
|
||||||
return self.executor.stage_invalid()
|
return self.executor.stage_invalid()
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
|
||||||
"""Wrapper for post requests"""
|
|
||||||
return self.get(request)
|
|
||||||
|
|
|
@ -21,10 +21,6 @@ INVITATION = "invitation"
|
||||||
class InvitationStageView(StageView):
|
class InvitationStageView(StageView):
|
||||||
"""Finalise Authentication flow by logging the user in"""
|
"""Finalise Authentication flow by logging the user in"""
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
|
||||||
"""Wrapper for post requests"""
|
|
||||||
return self.get(request)
|
|
||||||
|
|
||||||
def get_token(self) -> Optional[str]:
|
def get_token(self) -> Optional[str]:
|
||||||
"""Get token from saved get-arguments or prompt_data"""
|
"""Get token from saved get-arguments or prompt_data"""
|
||||||
# Check for ?token= and ?itoken=
|
# Check for ?token= and ?itoken=
|
||||||
|
@ -55,7 +51,7 @@ class InvitationStageView(StageView):
|
||||||
return None
|
return None
|
||||||
return invite
|
return invite
|
||||||
|
|
||||||
def get(self, request: HttpRequest) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Apply data to the current flow based on a URL"""
|
"""Apply data to the current flow based on a URL"""
|
||||||
stage: InvitationStage = self.executor.current_stage
|
stage: InvitationStage = self.executor.current_stage
|
||||||
|
|
||||||
|
|
|
@ -11,11 +11,7 @@ from authentik.flows.stage import StageView
|
||||||
class UserDeleteStageView(StageView):
|
class UserDeleteStageView(StageView):
|
||||||
"""Finalise unenrollment flow by deleting the user object."""
|
"""Finalise unenrollment flow by deleting the user object."""
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Wrapper for post requests"""
|
|
||||||
return self.get(request)
|
|
||||||
|
|
||||||
def get(self, request: HttpRequest) -> HttpResponse:
|
|
||||||
"""Delete currently pending user"""
|
"""Delete currently pending user"""
|
||||||
user = self.get_pending_user()
|
user = self.get_pending_user()
|
||||||
if not user.is_authenticated:
|
if not user.is_authenticated:
|
||||||
|
|
|
@ -41,17 +41,11 @@ class UserLoginStageView(ChallengeStageView):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Wrapper for post requests"""
|
"""Check for remember_me, and do login"""
|
||||||
stage: UserLoginStage = self.executor.current_stage
|
stage: UserLoginStage = self.executor.current_stage
|
||||||
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
|
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
|
||||||
return super().post(request, *args, **kwargs)
|
return super().dispatch(request)
|
||||||
return self.do_login(request)
|
|
||||||
|
|
||||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
|
||||||
stage: UserLoginStage = self.executor.current_stage
|
|
||||||
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
|
|
||||||
return super().get(request, *args, **kwargs)
|
|
||||||
return self.do_login(request)
|
return self.do_login(request)
|
||||||
|
|
||||||
def challenge_valid(self, response: UserLoginChallengeResponse) -> HttpResponse:
|
def challenge_valid(self, response: UserLoginChallengeResponse) -> HttpResponse:
|
||||||
|
|
|
@ -8,7 +8,7 @@ from authentik.flows.stage import StageView
|
||||||
class UserLogoutStageView(StageView):
|
class UserLogoutStageView(StageView):
|
||||||
"""Finalise Authentication flow by logging the user in"""
|
"""Finalise Authentication flow by logging the user in"""
|
||||||
|
|
||||||
def get(self, request: HttpRequest) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Remove the user from the current session"""
|
"""Remove the user from the current session"""
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Logged out",
|
"Logged out",
|
||||||
|
@ -17,7 +17,3 @@ class UserLogoutStageView(StageView):
|
||||||
)
|
)
|
||||||
logout(self.request)
|
logout(self.request)
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
|
||||||
"""Wrapper for post requests"""
|
|
||||||
return self.get(request)
|
|
||||||
|
|
|
@ -51,10 +51,6 @@ class UserWriteStageView(StageView):
|
||||||
attrs = attrs.get(comp)
|
attrs = attrs.get(comp)
|
||||||
attrs[parts[-1]] = value
|
attrs[parts[-1]] = value
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
|
||||||
"""Wrapper for post requests"""
|
|
||||||
return self.get(request)
|
|
||||||
|
|
||||||
def ensure_user(self) -> tuple[Optional[User], bool]:
|
def ensure_user(self) -> tuple[Optional[User], bool]:
|
||||||
"""Ensure a user exists"""
|
"""Ensure a user exists"""
|
||||||
user_created = False
|
user_created = False
|
||||||
|
@ -127,7 +123,7 @@ class UserWriteStageView(StageView):
|
||||||
if connection.source.name not in user.attributes[USER_ATTRIBUTE_SOURCES]:
|
if connection.source.name not in user.attributes[USER_ATTRIBUTE_SOURCES]:
|
||||||
user.attributes[USER_ATTRIBUTE_SOURCES].append(connection.source.name)
|
user.attributes[USER_ATTRIBUTE_SOURCES].append(connection.source.name)
|
||||||
|
|
||||||
def get(self, request: HttpRequest) -> HttpResponse:
|
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Save data in the current flow to the currently pending user. If no user is pending,
|
"""Save data in the current flow to the currently pending user. If no user is pending,
|
||||||
a new user is created."""
|
a new user is created."""
|
||||||
if PLAN_CONTEXT_PROMPT not in self.executor.plan.context:
|
if PLAN_CONTEXT_PROMPT not in self.executor.plan.context:
|
||||||
|
|
Reference in New Issue