diff --git a/e2e/test_sources_oauth.py b/e2e/test_sources_oauth.py index c0339f83d..e78bd5b3e 100644 --- a/e2e/test_sources_oauth.py +++ b/e2e/test_sources_oauth.py @@ -16,6 +16,7 @@ from passbook.flows.models import Flow from passbook.sources.oauth.models import OAuthSource TOKEN_URL = "http://127.0.0.1:5556/dex/token" +CONFIG_PATH = "/tmp/dex.yml" class TestSourceOAuth(SeleniumTestCase): @@ -60,8 +61,7 @@ class TestSourceOAuth(SeleniumTestCase): "storage": {"config": {"file": "/tmp/dex.db"}, "type": "sqlite3"}, "web": {"http": "0.0.0.0:5556"}, } - config_file = "./e2e/dex/config-dev.yaml" - with open(config_file, "w+") as _file: + with open(CONFIG_PATH, "w+") as _file: safe_dump(config, _file) def setup_client(self) -> Container: @@ -80,7 +80,7 @@ class TestSourceOAuth(SeleniumTestCase): start_period=1 * 100 * 1000000, ), volumes={ - abspath("./e2e/dex/config-dev.yaml"): { + abspath(CONFIG_PATH): { "bind": "/config.yml", "mode": "ro", } diff --git a/passbook/flows/models.py b/passbook/flows/models.py index a48dc5880..aa1523722 100644 --- a/passbook/flows/models.py +++ b/passbook/flows/models.py @@ -1,5 +1,5 @@ """Flow models""" -from typing import Callable, Optional +from typing import TYPE_CHECKING, Optional, Type from uuid import uuid4 from django.db import models @@ -12,6 +12,9 @@ from passbook.core.types import UIUserSettings from passbook.lib.utils.reflection import class_to_path from passbook.policies.models import PolicyBindingModel +if TYPE_CHECKING: + from passbook.flows.stage import StageView + LOGGER = get_logger() @@ -57,9 +60,9 @@ class Stage(models.Model): return f"Stage {self.name}" -def in_memory_stage(_type: Callable) -> Stage: +def in_memory_stage(view: Type["StageView"]) -> Stage: """Creates an in-memory stage instance, based on a `_type` as view.""" - class_path = class_to_path(_type) + class_path = class_to_path(view) stage = Stage() stage.type = class_path return stage diff --git a/passbook/sources/oauth/views/callback.py b/passbook/sources/oauth/views/callback.py index a38781958..73cce9d26 100644 --- a/passbook/sources/oauth/views/callback.py +++ b/passbook/sources/oauth/views/callback.py @@ -12,7 +12,7 @@ from structlog import get_logger from passbook.audit.models import Event, EventAction from passbook.core.models import User -from passbook.flows.models import Flow +from passbook.flows.models import Flow, in_memory_stage from passbook.flows.planner import ( PLAN_CONTEXT_PENDING_USER, PLAN_CONTEXT_SSO, @@ -24,6 +24,10 @@ from passbook.policies.utils import delete_none_keys from passbook.sources.oauth.auth import AuthorizedServiceBackend from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.views.base import OAuthClientMixin +from passbook.sources.oauth.views.flows import ( + PLAN_CONTEXT_SOURCES_OAUTH_ACCESS, + PostUserEnrollmentStage, +) from passbook.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND from passbook.stages.prompt.stage import PLAN_CONTEXT_PROMPT @@ -36,16 +40,17 @@ class OAuthCallback(OAuthClientMixin, View): source_id = None source = None + # pylint: disable=too-many-return-statements def get(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: """View Get handler""" slug = kwargs.get("source_slug", "") try: self.source = OAuthSource.objects.get(slug=slug) except OAuthSource.DoesNotExist: - raise Http404("Unknown OAuth source '%s'." % slug) + raise Http404(f"Unknown OAuth source '{slug}'.") else: if not self.source.enabled: - raise Http404("source %s is not enabled." % slug) + raise Http404(f"Source {slug} is not enabled.") client = self.get_client(self.source) callback = self.get_callback_url(self.source) # Fetch access token @@ -89,8 +94,11 @@ class OAuthCallback(OAuthClientMixin, View): source=self.source, identifier=identifier, request=request ) if user is None: - LOGGER.debug("Handling new connection", source=self.source) - return self.handle_new_connection(self.source, connection, info) + if self.request.user.is_authenticated: + LOGGER.debug("Linking existing user", source=self.source) + return self.handle_existing_user_link(self.source, connection, info) + LOGGER.debug("Handling enrollment of new user", source=self.source) + return self.handle_enroll(self.source, connection, info) LOGGER.debug("Handling existing user", source=self.source) return self.handle_existing_user(self.source, user, connection, info) @@ -122,6 +130,12 @@ class OAuthCallback(OAuthClientMixin, View): return info["id"] return None + def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse: + "Message user and redirect on error." + LOGGER.warning("Authentication Failure", reason=reason) + messages.error(self.request, _("Authentication Failed.")) + return redirect(self.get_error_redirect(source, reason)) + def handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse: """Prepare Authentication Plan, redirect user FlowExecutor""" kwargs.update( @@ -133,7 +147,7 @@ class OAuthCallback(OAuthClientMixin, View): ) # We run the Flow planner here so we can pass the Pending user in the context planner = FlowPlanner(flow) - plan = planner.plan(self.request, kwargs,) + plan = planner.plan(self.request, kwargs) self.request.session[SESSION_KEY_PLAN] = plan return redirect_with_qs( "passbook_flows:flow-executor-shell", self.request.GET, flow_slug=flow.slug, @@ -158,40 +172,40 @@ class OAuthCallback(OAuthClientMixin, View): flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user} return self.handle_login_flow(source.authentication_flow, **flow_kwargs) - def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse: - "Message user and redirect on error." - LOGGER.warning("Authentication Failure", reason=reason) - messages.error(self.request, _("Authentication Failed.")) - return redirect(self.get_error_redirect(source, reason)) - - def handle_new_connection( + def handle_existing_user_link( self, source: OAuthSource, access: UserOAuthSourceConnection, info: Dict[str, Any], ) -> HttpResponse: - """Check if a user exists for the connection and connect them, otherwise - prepare to enroll a new user.""" - if self.request.user.is_authenticated: - # there's already a user logged in, just link them up - user = self.request.user - access.user = user - access.save() - UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) - Event.new( - EventAction.CUSTOM, message="Linked OAuth Source", source=source - ).from_http(self.request) - messages.success( - self.request, - _("Successfully linked %(source)s!" % {"source": self.source.name}), + """Handler when the user was already authenticated and linked an external source + to their account.""" + # there's already a user logged in, just link them up + user = self.request.user + access.user = user + access.save() + UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) + Event.new( + EventAction.CUSTOM, message="Linked OAuth Source", source=source + ).from_http(self.request) + messages.success( + self.request, + _("Successfully linked %(source)s!" % {"source": self.source.name}), + ) + return redirect( + reverse( + "passbook_sources_oauth:oauth-client-user", + kwargs={"source_slug": self.source.slug}, ) - return redirect( - reverse( - "passbook_sources_oauth:oauth-client-user", - kwargs={"source_slug": self.source.slug}, - ) - ) - # User was not authenticated, new user will be created + ) + + def handle_enroll( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> HttpResponse: + """User was not authenticated and previous request was not authenticated.""" messages.success( self.request, _( @@ -199,11 +213,23 @@ class OAuthCallback(OAuthClientMixin, View): % {"source": self.source.name} ), ) - # Trim out all keys that have a value of None, - # so we use `"key" in ` checks in policies + # Because we inject a stage into the planned flow, we can't use `self.handle_login_flow` context = { + # Since we authenticate the user by their token, they have no backend set + PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend", + PLAN_CONTEXT_SSO: True, PLAN_CONTEXT_PROMPT: delete_none_keys( self.get_user_enroll_context(source, access, info) - ) + ), + PLAN_CONTEXT_SOURCES_OAUTH_ACCESS: access, } - return self.handle_login_flow(source.enrollment_flow, **context) + # We run the Flow planner here so we can pass the Pending user in the context + planner = FlowPlanner(source.enrollment_flow) + plan = planner.plan(self.request, context) + plan.append(in_memory_stage(PostUserEnrollmentStage)) + self.request.session[SESSION_KEY_PLAN] = plan + return redirect_with_qs( + "passbook_flows:flow-executor-shell", + self.request.GET, + flow_slug=source.enrollment_flow.slug, + ) diff --git a/passbook/sources/oauth/views/flows.py b/passbook/sources/oauth/views/flows.py new file mode 100644 index 000000000..1cc89eab1 --- /dev/null +++ b/passbook/sources/oauth/views/flows.py @@ -0,0 +1,28 @@ +"""OAuth Stages""" +from django.http import HttpRequest, HttpResponse + +from passbook.audit.models import Event, EventAction +from passbook.core.models import User +from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER +from passbook.flows.stage import StageView +from passbook.sources.oauth.models import UserOAuthSourceConnection + +PLAN_CONTEXT_SOURCES_OAUTH_ACCESS = "sources_oauth_access" + + +class PostUserEnrollmentStage(StageView): + """Dynamically injected stage which saves the OAuth Connection after + the user has been enrolled.""" + + def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + access: UserOAuthSourceConnection = self.executor.plan.context[ + PLAN_CONTEXT_SOURCES_OAUTH_ACCESS + ] + user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] + access.user = user + access.save() + UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) + Event.new( + EventAction.CUSTOM, message="Linked OAuth Source", source=access.source + ).from_http(self.request) + return self.executor.stage_ok()