diff --git a/authentik/core/api/sources.py b/authentik/core/api/sources.py index a9c481a07..e1da3a2fb 100644 --- a/authentik/core/api/sources.py +++ b/authentik/core/api/sources.py @@ -45,6 +45,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer): "verbose_name", "verbose_name_plural", "policy_engine_mode", + "user_matching_mode", ] diff --git a/authentik/core/migrations/0020_source_user_matching_mode.py b/authentik/core/migrations/0020_source_user_matching_mode.py new file mode 100644 index 000000000..68f6f9524 --- /dev/null +++ b/authentik/core/migrations/0020_source_user_matching_mode.py @@ -0,0 +1,40 @@ +# Generated by Django 3.2 on 2021-05-03 17:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_core", "0019_source_managed"), + ] + + operations = [ + migrations.AddField( + model_name="source", + name="user_matching_mode", + field=models.TextField( + choices=[ + ("identifier", "Use the source-specific identifier"), + ( + "email_link", + "Link to a user with identical email address. Can have security implications when a source doesn't validate email addresses.", + ), + ( + "email_deny", + "Use the user's email address, but deny enrollment when the email address already exists.", + ), + ( + "username_link", + "Link to a user with identical username address. Can have security implications when a username is used with another source.", + ), + ( + "username_deny", + "Use the user's username, but deny enrollment when the username already exists.", + ), + ], + default="identifier", + help_text="How the source determines if an existing user should be authenticated or a new user enrolled.", + ), + ), + ] diff --git a/authentik/core/models.py b/authentik/core/models.py index 501b556c7..b1f8fc2af 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -240,6 +240,30 @@ class Application(PolicyBindingModel): verbose_name_plural = _("Applications") +class SourceUserMatchingModes(models.TextChoices): + """Different modes a source can handle new/returning users""" + + IDENTIFIER = "identifier", _("Use the source-specific identifier") + EMAIL_LINK = "email_link", _( + ( + "Link to a user with identical email address. Can have security implications " + "when a source doesn't validate email addresses." + ) + ) + EMAIL_DENY = "email_deny", _( + "Use the user's email address, but deny enrollment when the email address already exists." + ) + USERNAME_LINK = "username_link", _( + ( + "Link to a user with identical username address. Can have security implications " + "when a username is used with another source." + ) + ) + USERNAME_DENY = "username_deny", _( + "Use the user's username, but deny enrollment when the username already exists." + ) + + class Source(ManagedModel, SerializerModel, PolicyBindingModel): """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" @@ -272,6 +296,17 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): related_name="source_enrollment", ) + user_matching_mode = models.TextField( + choices=SourceUserMatchingModes.choices, + default=SourceUserMatchingModes.IDENTIFIER, + help_text=_( + ( + "How the source determines if an existing user should be authenticated or " + "a new user enrolled." + ) + ), + ) + objects = InheritanceManager() @property @@ -301,6 +336,8 @@ class UserSourceConnection(CreatedUpdatedModel): user = models.ForeignKey(User, on_delete=models.CASCADE) source = models.ForeignKey(Source, on_delete=models.CASCADE) + objects = InheritanceManager() + class Meta: unique_together = (("user", "source"),) diff --git a/authentik/core/sources/__init__.py b/authentik/core/sources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py new file mode 100644 index 000000000..c6f5b9795 --- /dev/null +++ b/authentik/core/sources/flow_manager.py @@ -0,0 +1,261 @@ +"""Source decision helper""" +from enum import Enum +from typing import Any, Optional, Type + +from django.contrib import messages +from django.db.models.query_utils import Q +from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest +from django.shortcuts import redirect +from django.urls import reverse +from django.utils.translation import gettext as _ +from structlog.stdlib import get_logger + +from authentik.core.models import ( + Source, + SourceUserMatchingModes, + User, + UserSourceConnection, +) +from authentik.core.sources.stage import ( + PLAN_CONTEXT_SOURCES_CONNECTION, + PostUserEnrollmentStage, +) +from authentik.events.models import Event, EventAction +from authentik.flows.models import Flow, Stage, in_memory_stage +from authentik.flows.planner import ( + PLAN_CONTEXT_PENDING_USER, + PLAN_CONTEXT_REDIRECT, + PLAN_CONTEXT_SOURCE, + PLAN_CONTEXT_SSO, + FlowPlanner, +) +from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN +from authentik.lib.utils.urls import redirect_with_qs +from authentik.policies.utils import delete_none_keys +from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND +from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT + + +class Action(Enum): + """Actions that can be decided based on the request + and source settings""" + + LINK = "link" + AUTH = "auth" + ENROLL = "enroll" + DENY = "deny" + + +class SourceFlowManager: + """Help sources decide what they should do after authorization. Based on source settings and + previous connections, authenticate the user, enroll a new user, link to an existing user + or deny the request.""" + + source: Source + request: HttpRequest + + identifier: str + + connection_type: Type[UserSourceConnection] = UserSourceConnection + + def __init__( + self, + source: Source, + request: HttpRequest, + identifier: str, + enroll_info: dict[str, Any], + ) -> None: + self.source = source + self.request = request + self.identifier = identifier + self.enroll_info = enroll_info + self._logger = get_logger().bind(source=source, identifier=identifier) + + # pylint: disable=too-many-return-statements + def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: + """decide which action should be taken""" + new_connection = self.connection_type( + source=self.source, identifier=self.identifier + ) + # When request is authenticated, always link + if self.request.user.is_authenticated: + new_connection.user = self.request.user + new_connection = self.update_connection(new_connection, **kwargs) + new_connection.save() + return Action.LINK, new_connection + + existing_connections = self.connection_type.objects.filter( + source=self.source, identifier=self.identifier + ) + if existing_connections.exists(): + connection = existing_connections.first() + return Action.AUTH, self.update_connection(connection, **kwargs) + # No connection exists, but we match on identifier, so enroll + if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER: + # We don't save the connection here cause it doesn't have a user assigned yet + return Action.ENROLL, self.update_connection(new_connection, **kwargs) + + # Check for existing users with matching attributes + query = Q() + # Either query existing user based on email or username + if self.source.user_matching_mode in [ + SourceUserMatchingModes.EMAIL_LINK, + SourceUserMatchingModes.EMAIL_DENY, + ]: + if not self.enroll_info.get("email", None): + self._logger.warning("Refusing to use none email", source=self.source) + return Action.DENY, None + query = Q(email__exact=self.enroll_info.get("email", None)) + if self.source.user_matching_mode in [ + SourceUserMatchingModes.USERNAME_LINK, + SourceUserMatchingModes.USERNAME_DENY, + ]: + if not self.enroll_info.get("username", None): + self._logger.warning( + "Refusing to use none username", source=self.source + ) + return Action.DENY, None + query = Q(username__exact=self.enroll_info.get("username", None)) + matching_users = User.objects.filter(query) + # No matching users, always enroll + if not matching_users.exists(): + return Action.ENROLL, self.update_connection(new_connection, **kwargs) + + user = matching_users.first() + if self.source.user_matching_mode in [ + SourceUserMatchingModes.EMAIL_LINK, + SourceUserMatchingModes.USERNAME_LINK, + ]: + new_connection.user = user + new_connection = self.update_connection(new_connection, **kwargs) + new_connection.save() + return Action.LINK, new_connection + if self.source.user_matching_mode in [ + SourceUserMatchingModes.EMAIL_DENY, + SourceUserMatchingModes.USERNAME_DENY, + ]: + return Action.DENY, None + return Action.DENY, None + + def update_connection( + self, connection: UserSourceConnection, **kwargs + ) -> UserSourceConnection: + """Optionally make changes to the connection after it is looked up/created.""" + return connection + + def get_flow(self, **kwargs) -> HttpResponse: + """Get the flow response based on user_matching_mode""" + action, connection = self.get_action() + if action == Action.LINK: + self._logger.debug("Linking existing user") + return self.handle_existing_user_link() + if not connection: + return redirect("/") + if action == Action.AUTH: + self._logger.debug("Handling auth user") + return self.handle_auth_user(connection) + if action == Action.ENROLL: + self._logger.debug("Handling enrollment of new user") + return self.handle_enroll(connection) + return redirect("/") + + # pylint: disable=unused-argument + def get_stages_to_append(self, flow: Flow) -> list[Stage]: + """Hook to override stages which are appended to the flow""" + if flow.slug == self.source.enrollment_flow.slug: + return [ + in_memory_stage(PostUserEnrollmentStage), + ] + return [] + + def _handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse: + """Prepare Authentication Plan, redirect user FlowExecutor""" + # Ensure redirect is carried through when user was trying to + # authorize application + final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get( + NEXT_ARG_NAME, "authentik_core:if-admin" + ) + kwargs.update( + { + # Since we authenticate the user by their token, they have no backend set + PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend", + PLAN_CONTEXT_SSO: True, + PLAN_CONTEXT_SOURCE: self.source, + PLAN_CONTEXT_REDIRECT: final_redirect, + } + ) + if not flow: + return HttpResponseBadRequest() + # We run the Flow planner here so we can pass the Pending user in the context + planner = FlowPlanner(flow) + plan = planner.plan(self.request, kwargs) + for stage in self.get_stages_to_append(flow): + plan.append(stage) + self.request.session[SESSION_KEY_PLAN] = plan + return redirect_with_qs( + "authentik_core:if-flow", + self.request.GET, + flow_slug=flow.slug, + ) + + # pylint: disable=unused-argument + def handle_auth_user( + self, + connection: UserSourceConnection, + ) -> HttpResponse: + """Login user and redirect.""" + messages.success( + self.request, + _( + "Successfully authenticated with %(source)s!" + % {"source": self.source.name} + ), + ) + flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} + return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs) + + def handle_existing_user_link( + self, + ) -> HttpResponse: + """Handler when the user was already authenticated and linked an external source + to their account.""" + Event.new( + EventAction.SOURCE_LINKED, + message="Linked Source", + source=self.source, + ).from_http(self.request) + messages.success( + self.request, + _("Successfully linked %(source)s!" % {"source": self.source.name}), + ) + return redirect( + reverse( + "authentik_core:if-admin", + ) + + f"#/user;page-{self.source.slug}" + ) + + def handle_enroll( + self, + connection: UserSourceConnection, + ) -> HttpResponse: + """User was not authenticated and previous request was not authenticated.""" + messages.success( + self.request, + _( + "Successfully authenticated with %(source)s!" + % {"source": self.source.name} + ), + ) + + # We run the Flow planner here so we can pass the Pending user in the context + if not self.source.enrollment_flow: + self._logger.warning("source has no enrollment flow") + return HttpResponseBadRequest() + return self._handle_login_flow( + self.source.enrollment_flow, + **{ + PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info), + PLAN_CONTEXT_SOURCES_CONNECTION: connection, + }, + ) diff --git a/authentik/sources/oauth/views/flows.py b/authentik/core/sources/stage.py similarity index 53% rename from authentik/sources/oauth/views/flows.py rename to authentik/core/sources/stage.py index 1dc239aed..bcc37adf0 100644 --- a/authentik/sources/oauth/views/flows.py +++ b/authentik/core/sources/stage.py @@ -1,32 +1,30 @@ -"""OAuth Stages""" +"""Source flow manager stages""" from django.http import HttpRequest, HttpResponse -from authentik.core.models import User +from authentik.core.models import User, UserSourceConnection from authentik.events.models import Event, EventAction from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import StageView -from authentik.sources.oauth.models import UserOAuthSourceConnection -PLAN_CONTEXT_SOURCES_OAUTH_ACCESS = "sources_oauth_access" +PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection" class PostUserEnrollmentStage(StageView): - """Dynamically injected stage which saves the OAuth Connection after + """Dynamically injected stage which saves the Connection after the user has been enrolled.""" # pylint: disable=unused-argument def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: """Stage used after the user has been enrolled""" - access: UserOAuthSourceConnection = self.executor.plan.context[ - PLAN_CONTEXT_SOURCES_OAUTH_ACCESS + connection: UserSourceConnection = self.executor.plan.context[ + PLAN_CONTEXT_SOURCES_CONNECTION ] user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] - access.user = user - access.save() - UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) + connection.user = user + connection.save() Event.new( EventAction.SOURCE_LINKED, - message="Linked OAuth Source", - source=access.source, + message="Linked Source", + source=connection.source, ).from_http(self.request) return self.executor.stage_ok() diff --git a/authentik/policies/event_matcher/migrations/0013_alter_eventmatcherpolicy_app.py b/authentik/policies/event_matcher/migrations/0013_alter_eventmatcherpolicy_app.py new file mode 100644 index 000000000..46cc8443a --- /dev/null +++ b/authentik/policies/event_matcher/migrations/0013_alter_eventmatcherpolicy_app.py @@ -0,0 +1,84 @@ +# Generated by Django 3.2 on 2021-05-02 17:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_policies_event_matcher", "0012_auto_20210323_1339"), + ] + + operations = [ + migrations.AlterField( + model_name="eventmatcherpolicy", + name="app", + field=models.TextField( + blank=True, + choices=[ + ("authentik.admin", "authentik Admin"), + ("authentik.api", "authentik API"), + ("authentik.events", "authentik Events"), + ("authentik.crypto", "authentik Crypto"), + ("authentik.flows", "authentik Flows"), + ("authentik.outposts", "authentik Outpost"), + ("authentik.lib", "authentik lib"), + ("authentik.policies", "authentik Policies"), + ("authentik.policies.dummy", "authentik Policies.Dummy"), + ( + "authentik.policies.event_matcher", + "authentik Policies.Event Matcher", + ), + ("authentik.policies.expiry", "authentik Policies.Expiry"), + ("authentik.policies.expression", "authentik Policies.Expression"), + ("authentik.policies.hibp", "authentik Policies.HaveIBeenPwned"), + ("authentik.policies.password", "authentik Policies.Password"), + ("authentik.policies.reputation", "authentik Policies.Reputation"), + ("authentik.providers.proxy", "authentik Providers.Proxy"), + ("authentik.providers.oauth2", "authentik Providers.OAuth2"), + ("authentik.providers.saml", "authentik Providers.SAML"), + ("authentik.recovery", "authentik Recovery"), + ("authentik.sources.ldap", "authentik Sources.LDAP"), + ("authentik.sources.oauth", "authentik Sources.OAuth"), + ("authentik.sources.plex", "authentik Sources.Plex"), + ("authentik.sources.saml", "authentik Sources.SAML"), + ( + "authentik.stages.authenticator_static", + "authentik Stages.Authenticator.Static", + ), + ( + "authentik.stages.authenticator_totp", + "authentik Stages.Authenticator.TOTP", + ), + ( + "authentik.stages.authenticator_validate", + "authentik Stages.Authenticator.Validate", + ), + ( + "authentik.stages.authenticator_webauthn", + "authentik Stages.Authenticator.WebAuthn", + ), + ("authentik.stages.captcha", "authentik Stages.Captcha"), + ("authentik.stages.consent", "authentik Stages.Consent"), + ("authentik.stages.deny", "authentik Stages.Deny"), + ("authentik.stages.dummy", "authentik Stages.Dummy"), + ("authentik.stages.email", "authentik Stages.Email"), + ( + "authentik.stages.identification", + "authentik Stages.Identification", + ), + ("authentik.stages.invitation", "authentik Stages.User Invitation"), + ("authentik.stages.password", "authentik Stages.Password"), + ("authentik.stages.prompt", "authentik Stages.Prompt"), + ("authentik.stages.user_delete", "authentik Stages.User Delete"), + ("authentik.stages.user_login", "authentik Stages.User Login"), + ("authentik.stages.user_logout", "authentik Stages.User Logout"), + ("authentik.stages.user_write", "authentik Stages.User Write"), + ("authentik.core", "authentik Core"), + ("authentik.managed", "authentik Managed"), + ], + default="", + help_text="Match events created by selected application. When left empty, all applications are matched.", + ), + ), + ] diff --git a/authentik/sources/oauth/auth.py b/authentik/sources/oauth/auth.py deleted file mode 100644 index 62836f574..000000000 --- a/authentik/sources/oauth/auth.py +++ /dev/null @@ -1,23 +0,0 @@ -"""authentik oauth_client Authorization backend""" -from typing import Optional - -from django.contrib.auth.backends import ModelBackend -from django.http import HttpRequest - -from authentik.core.models import User -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection - - -class AuthorizedServiceBackend(ModelBackend): - "Authentication backend for users registered with remote OAuth provider." - - def authenticate( - self, request: HttpRequest, source: OAuthSource, identifier: str - ) -> Optional[User]: - "Fetch user for a given source by id." - access = UserOAuthSourceConnection.objects.filter( - source=source, identifier=identifier - ).select_related("user") - if not access.exists(): - return None - return access.first().user diff --git a/authentik/sources/oauth/tests/test_type_discord.py b/authentik/sources/oauth/tests/test_type_discord.py index 6b337d1bd..86340afed 100644 --- a/authentik/sources/oauth/tests/test_type_discord.py +++ b/authentik/sources/oauth/tests/test_type_discord.py @@ -1,7 +1,7 @@ """Discord Type tests""" from django.test import TestCase -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection +from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.discord import DiscordOAuth2Callback # https://discord.com/developers/docs/resources/user#user-object @@ -33,9 +33,7 @@ class TestTypeDiscord(TestCase): def test_enroll_context(self): """Test discord Enrollment context""" - ak_context = DiscordOAuth2Callback().get_user_enroll_context( - self.source, UserOAuthSourceConnection(), DISCORD_USER - ) + ak_context = DiscordOAuth2Callback().get_user_enroll_context(DISCORD_USER) self.assertEqual(ak_context["username"], DISCORD_USER["username"]) self.assertEqual(ak_context["email"], DISCORD_USER["email"]) self.assertEqual(ak_context["name"], DISCORD_USER["username"]) diff --git a/authentik/sources/oauth/tests/test_type_github.py b/authentik/sources/oauth/tests/test_type_github.py index 3acce60fc..50a699b9c 100644 --- a/authentik/sources/oauth/tests/test_type_github.py +++ b/authentik/sources/oauth/tests/test_type_github.py @@ -1,7 +1,7 @@ """GitHub Type tests""" from django.test import TestCase -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection +from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.github import GitHubOAuth2Callback # https://developer.github.com/v3/users/#get-the-authenticated-user @@ -63,9 +63,7 @@ class TestTypeGitHub(TestCase): def test_enroll_context(self): """Test GitHub Enrollment context""" - ak_context = GitHubOAuth2Callback().get_user_enroll_context( - self.source, UserOAuthSourceConnection(), GITHUB_USER - ) + ak_context = GitHubOAuth2Callback().get_user_enroll_context(GITHUB_USER) self.assertEqual(ak_context["username"], GITHUB_USER["login"]) self.assertEqual(ak_context["email"], GITHUB_USER["email"]) self.assertEqual(ak_context["name"], GITHUB_USER["name"]) diff --git a/authentik/sources/oauth/tests/test_type_google.py b/authentik/sources/oauth/tests/test_type_google.py index 6f43812ad..6f79f8a2d 100644 --- a/authentik/sources/oauth/tests/test_type_google.py +++ b/authentik/sources/oauth/tests/test_type_google.py @@ -1,7 +1,7 @@ """google Type tests""" from django.test import TestCase -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection +from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.google import GoogleOAuth2Callback # https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en @@ -32,9 +32,7 @@ class TestTypeGoogle(TestCase): def test_enroll_context(self): """Test Google Enrollment context""" - ak_context = GoogleOAuth2Callback().get_user_enroll_context( - self.source, UserOAuthSourceConnection(), GOOGLE_USER - ) + ak_context = GoogleOAuth2Callback().get_user_enroll_context(GOOGLE_USER) self.assertEqual(ak_context["username"], GOOGLE_USER["email"]) self.assertEqual(ak_context["email"], GOOGLE_USER["email"]) self.assertEqual(ak_context["name"], GOOGLE_USER["name"]) diff --git a/authentik/sources/oauth/tests/test_type_twitter.py b/authentik/sources/oauth/tests/test_type_twitter.py index b0918fa62..84fdd0f80 100644 --- a/authentik/sources/oauth/tests/test_type_twitter.py +++ b/authentik/sources/oauth/tests/test_type_twitter.py @@ -1,7 +1,7 @@ """Twitter Type tests""" from django.test import Client, TestCase -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection +from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.twitter import TwitterOAuthCallback # https://developer.twitter.com/en/docs/twitter-api/v1/accounts-and-users/manage-account-settings/ \ @@ -104,9 +104,7 @@ class TestTypeGitHub(TestCase): def test_enroll_context(self): """Test Twitter Enrollment context""" - ak_context = TwitterOAuthCallback().get_user_enroll_context( - self.source, UserOAuthSourceConnection(), TWITTER_USER - ) + ak_context = TwitterOAuthCallback().get_user_enroll_context(TWITTER_USER) self.assertEqual(ak_context["username"], TWITTER_USER["screen_name"]) self.assertEqual(ak_context["email"], TWITTER_USER.get("email", None)) self.assertEqual(ak_context["name"], TWITTER_USER["name"]) diff --git a/authentik/sources/oauth/types/azure_ad.py b/authentik/sources/oauth/types/azure_ad.py index fbd81f08f..7d5dc02fb 100644 --- a/authentik/sources/oauth/types/azure_ad.py +++ b/authentik/sources/oauth/types/azure_ad.py @@ -2,7 +2,6 @@ from typing import Any, Optional from uuid import UUID -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback @@ -10,7 +9,7 @@ from authentik.sources.oauth.views.callback import OAuthCallback class AzureADOAuthCallback(OAuthCallback): """AzureAD OAuth2 Callback""" - def get_user_id(self, source: OAuthSource, info: dict[str, Any]) -> Optional[str]: + def get_user_id(self, info: dict[str, Any]) -> Optional[str]: try: return str(UUID(info.get("objectId")).int) except TypeError: @@ -18,8 +17,6 @@ class AzureADOAuthCallback(OAuthCallback): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: mail = info.get("mail", None) or info.get("otherMails", [None])[0] diff --git a/authentik/sources/oauth/types/discord.py b/authentik/sources/oauth/types/discord.py index 00bac79fd..a97cca546 100644 --- a/authentik/sources/oauth/types/discord.py +++ b/authentik/sources/oauth/types/discord.py @@ -1,7 +1,6 @@ """Discord OAuth Views""" from typing import Any -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -21,8 +20,6 @@ class DiscordOAuth2Callback(OAuthCallback): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: return { diff --git a/authentik/sources/oauth/types/facebook.py b/authentik/sources/oauth/types/facebook.py index ab27d6b6f..8efe16102 100644 --- a/authentik/sources/oauth/types/facebook.py +++ b/authentik/sources/oauth/types/facebook.py @@ -4,7 +4,6 @@ from typing import Any, Optional from facebook import GraphAPI from authentik.sources.oauth.clients.oauth2 import OAuth2Client -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -34,8 +33,6 @@ class FacebookOAuth2Callback(OAuthCallback): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: return { diff --git a/authentik/sources/oauth/types/github.py b/authentik/sources/oauth/types/github.py index c830d4919..791e98912 100644 --- a/authentik/sources/oauth/types/github.py +++ b/authentik/sources/oauth/types/github.py @@ -1,7 +1,6 @@ """GitHub OAuth Views""" from typing import Any -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback @@ -11,8 +10,6 @@ class GitHubOAuth2Callback(OAuthCallback): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: return { diff --git a/authentik/sources/oauth/types/google.py b/authentik/sources/oauth/types/google.py index e69004254..ee6bdf63f 100644 --- a/authentik/sources/oauth/types/google.py +++ b/authentik/sources/oauth/types/google.py @@ -1,7 +1,6 @@ """Google OAuth Views""" from typing import Any -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -21,8 +20,6 @@ class GoogleOAuth2Callback(OAuthCallback): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: return { diff --git a/authentik/sources/oauth/types/oidc.py b/authentik/sources/oauth/types/oidc.py index e2acf4b63..01fae8dcd 100644 --- a/authentik/sources/oauth/types/oidc.py +++ b/authentik/sources/oauth/types/oidc.py @@ -1,7 +1,7 @@ """OpenID Connect OAuth Views""" from typing import Any -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection +from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -19,13 +19,11 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect): class OpenIDConnectOAuth2Callback(OAuthCallback): """OpenIDConnect OAuth2 Callback""" - def get_user_id(self, source: OAuthSource, info: dict[str, str]) -> str: + def get_user_id(self, info: dict[str, str]) -> str: return info.get("sub", "") def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: return { diff --git a/authentik/sources/oauth/types/reddit.py b/authentik/sources/oauth/types/reddit.py index 74c777e6d..53757b38e 100644 --- a/authentik/sources/oauth/types/reddit.py +++ b/authentik/sources/oauth/types/reddit.py @@ -4,7 +4,6 @@ from typing import Any from requests.auth import HTTPBasicAuth from authentik.sources.oauth.clients.oauth2 import OAuth2Client -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -36,8 +35,6 @@ class RedditOAuth2Callback(OAuthCallback): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: return { diff --git a/authentik/sources/oauth/types/twitter.py b/authentik/sources/oauth/types/twitter.py index df1ed1a9f..b4df3d607 100644 --- a/authentik/sources/oauth/types/twitter.py +++ b/authentik/sources/oauth/types/twitter.py @@ -1,7 +1,6 @@ """Twitter OAuth Views""" from typing import Any -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback @@ -11,8 +10,6 @@ class TwitterOAuthCallback(OAuthCallback): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: return { diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py index 5ca72c85e..036652f26 100644 --- a/authentik/sources/oauth/views/callback.py +++ b/authentik/sources/oauth/views/callback.py @@ -4,35 +4,14 @@ from typing import Any, Optional from django.conf import settings from django.contrib import messages from django.http import Http404, HttpRequest, HttpResponse -from django.http.response import HttpResponseBadRequest from django.shortcuts import redirect -from django.urls import reverse from django.utils.translation import gettext as _ from django.views.generic import View from structlog.stdlib import get_logger -from authentik.core.models import User -from authentik.events.models import Event, EventAction -from authentik.flows.models import Flow, in_memory_stage -from authentik.flows.planner import ( - PLAN_CONTEXT_PENDING_USER, - PLAN_CONTEXT_REDIRECT, - PLAN_CONTEXT_SOURCE, - PLAN_CONTEXT_SSO, - FlowPlanner, -) -from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN -from authentik.lib.utils.urls import redirect_with_qs -from authentik.policies.utils import delete_none_keys -from authentik.sources.oauth.auth import AuthorizedServiceBackend +from authentik.core.sources.flow_manager import SourceFlowManager from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.views.base import OAuthClientMixin -from authentik.sources.oauth.views.flows import ( - PLAN_CONTEXT_SOURCES_OAUTH_ACCESS, - PostUserEnrollmentStage, -) -from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND -from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT LOGGER = get_logger() @@ -40,8 +19,7 @@ LOGGER = get_logger() class OAuthCallback(OAuthClientMixin, View): "Base OAuth callback view." - source_id = None - source = None + source: OAuthSource # pylint: disable=too-many-return-statements def get(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: @@ -60,47 +38,27 @@ class OAuthCallback(OAuthClientMixin, View): # Fetch access token token = client.get_access_token() if token is None: - return self.handle_login_failure(self.source, "Could not retrieve token.") + return self.handle_login_failure("Could not retrieve token.") if "error" in token: - return self.handle_login_failure(self.source, token["error"]) + return self.handle_login_failure(token["error"]) # Fetch profile info - info = client.get_profile_info(token) - if info is None: - return self.handle_login_failure(self.source, "Could not retrieve profile.") - identifier = self.get_user_id(self.source, info) + raw_info = client.get_profile_info(token) + if raw_info is None: + return self.handle_login_failure("Could not retrieve profile.") + identifier = self.get_user_id(raw_info) if identifier is None: - return self.handle_login_failure(self.source, "Could not determine id.") + return self.handle_login_failure("Could not determine id.") # Get or create access record - defaults = { - "access_token": token.get("access_token"), - } - existing = UserOAuthSourceConnection.objects.filter( - source=self.source, identifier=identifier + enroll_info = self.get_user_enroll_context(raw_info) + sfm = OAuthSourceFlowManager( + source=self.source, + request=self.request, + identifier=identifier, + enroll_info=enroll_info, ) - - if existing.exists(): - connection = existing.first() - connection.access_token = token.get("access_token") - UserOAuthSourceConnection.objects.filter(pk=connection.pk).update( - **defaults - ) - else: - connection = UserOAuthSourceConnection( - source=self.source, - identifier=identifier, - access_token=token.get("access_token"), - ) - user = AuthorizedServiceBackend().authenticate( - source=self.source, identifier=identifier, request=request + return sfm.get_flow( + token=token, ) - if user is None: - if self.request.user.is_authenticated: - LOGGER.debug("Linking existing user", source=self.source) - return self.handle_existing_user_link(self.source, connection, info) - LOGGER.debug("Handling enrollment of new user", source=self.source) - return self.handle_enroll(self.source, connection, info) - LOGGER.debug("Handling existing user", source=self.source) - return self.handle_existing_user(self.source, user, connection, info) # pylint: disable=unused-argument def get_callback_url(self, source: OAuthSource) -> str: @@ -114,132 +72,34 @@ class OAuthCallback(OAuthClientMixin, View): def get_user_enroll_context( self, - source: OAuthSource, - access: UserOAuthSourceConnection, info: dict[str, Any], ) -> dict[str, Any]: """Create a dict of User data""" raise NotImplementedError() # pylint: disable=unused-argument - def get_user_id( - self, source: UserOAuthSourceConnection, info: dict[str, Any] - ) -> Optional[str]: + def get_user_id(self, info: dict[str, Any]) -> Optional[str]: """Return unique identifier from the profile info.""" if "id" in info: return info["id"] return None - def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse: + def handle_login_failure(self, reason: str) -> HttpResponse: "Message user and redirect on error." LOGGER.warning("Authentication Failure", reason=reason) messages.error(self.request, _("Authentication Failed.")) - return redirect(self.get_error_redirect(source, reason)) + return redirect(self.get_error_redirect(self.source, reason)) - def handle_login_flow( - self, flow: Flow, *stages_to_append, **kwargs - ) -> HttpResponse: - """Prepare Authentication Plan, redirect user FlowExecutor""" - # Ensure redirect is carried through when user was trying to - # authorize application - final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get( - NEXT_ARG_NAME, "authentik_core:if-admin" - ) - kwargs.update( - { - # Since we authenticate the user by their token, they have no backend set - PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend", - PLAN_CONTEXT_SSO: True, - PLAN_CONTEXT_SOURCE: self.source, - PLAN_CONTEXT_REDIRECT: final_redirect, - } - ) - if not flow: - return HttpResponseBadRequest() - # We run the Flow planner here so we can pass the Pending user in the context - planner = FlowPlanner(flow) - plan = planner.plan(self.request, kwargs) - for stage in stages_to_append: - plan.append(stage) - self.request.session[SESSION_KEY_PLAN] = plan - return redirect_with_qs( - "authentik_core:if-flow", - self.request.GET, - flow_slug=flow.slug, - ) - # pylint: disable=unused-argument - def handle_existing_user( - self, - source: OAuthSource, - user: User, - access: UserOAuthSourceConnection, - info: dict[str, Any], - ) -> HttpResponse: - "Login user and redirect." - messages.success( - self.request, - _( - "Successfully authenticated with %(source)s!" - % {"source": self.source.name} - ), - ) - flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user} - return self.handle_login_flow(source.authentication_flow, **flow_kwargs) +class OAuthSourceFlowManager(SourceFlowManager): + """Flow manager for oauth sources""" - def handle_existing_user_link( - self, - source: OAuthSource, - access: UserOAuthSourceConnection, - info: dict[str, Any], - ) -> HttpResponse: - """Handler when the user was already authenticated and linked an external source - to their account.""" - # there's already a user logged in, just link them up - user = self.request.user - access.user = user - access.save() - UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) - Event.new( - EventAction.SOURCE_LINKED, message="Linked OAuth Source", source=source - ).from_http(self.request) - messages.success( - self.request, - _("Successfully linked %(source)s!" % {"source": self.source.name}), - ) - return redirect( - reverse( - "authentik_core:if-admin", - ) - + f"#/user;page-{self.source.slug}" - ) + connection_type = UserOAuthSourceConnection - def handle_enroll( - self, - source: OAuthSource, - access: UserOAuthSourceConnection, - info: dict[str, Any], - ) -> HttpResponse: - """User was not authenticated and previous request was not authenticated.""" - messages.success( - self.request, - _( - "Successfully authenticated with %(source)s!" - % {"source": self.source.name} - ), - ) - - # We run the Flow planner here so we can pass the Pending user in the context - if not source.enrollment_flow: - LOGGER.warning("source has no enrollment flow", source=source) - return HttpResponseBadRequest() - return self.handle_login_flow( - source.enrollment_flow, - in_memory_stage(PostUserEnrollmentStage), - **{ - PLAN_CONTEXT_PROMPT: delete_none_keys( - self.get_user_enroll_context(source, access, info) - ), - PLAN_CONTEXT_SOURCES_OAUTH_ACCESS: access, - }, - ) + def update_connection( + self, connection: UserOAuthSourceConnection, token: dict[str, Any] + ) -> UserOAuthSourceConnection: + """Set the access_token on the connection""" + connection.access_token = token.get("access_token") + connection.save() + return connection diff --git a/authentik/sources/plex/api.py b/authentik/sources/plex/api.py index a9501132c..2a923d383 100644 --- a/authentik/sources/plex/api.py +++ b/authentik/sources/plex/api.py @@ -1,26 +1,22 @@ """Plex Source Serializer""" -from urllib.parse import urlencode - from django.http import Http404 from django.shortcuts import get_object_or_404 from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema -from requests import RequestException, get from rest_framework.decorators import action from rest_framework.fields import CharField from rest_framework.permissions import AllowAny from rest_framework.request import Request from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet -from structlog.stdlib import get_logger from authentik.api.decorators import permission_required from authentik.core.api.sources import SourceSerializer from authentik.core.api.utils import PassiveSerializer -from authentik.flows.challenge import ChallengeTypes, RedirectChallenge +from authentik.flows.challenge import RedirectChallenge +from authentik.flows.views import to_stage_response from authentik.sources.plex.models import PlexSource - -LOGGER = get_logger() +from authentik.sources.plex.plex import PlexAuth class PlexSourceSerializer(SourceSerializer): @@ -72,29 +68,8 @@ class PlexSourceViewSet(ModelViewSet): plex_token = request.data.get("plex_token", None) if not plex_token: raise Http404 - qs = {"X-Plex-Token": plex_token, "X-Plex-Client-Identifier": source.client_id} - try: - response = get( - f"https://plex.tv/api/v2/resources?{urlencode(qs)}", - headers={"Accept": "application/json"}, - ) - response.raise_for_status() - except RequestException as exc: - LOGGER.warning("Unable to fetch user resources", exc=exc) + auth_api = PlexAuth(source, plex_token) + if not auth_api.check_server_overlap(): raise Http404 - else: - resources: list[dict] = response.json() - for resource in resources: - if resource["provides"] != "server": - continue - if resource["clientIdentifier"] in source.allowed_servers: - LOGGER.info( - "Plex allowed access from server", name=resource["name"] - ) - request.session["foo"] = "bar" - break - return Response( - RedirectChallenge( - {"type": ChallengeTypes.REDIRECT.value, "to": ""} - ).data - ) + response = auth_api.get_user_url(request) + return to_stage_response(request, response) diff --git a/authentik/sources/plex/migrations/0002_plexsourceconnection.py b/authentik/sources/plex/migrations/0002_plexsourceconnection.py new file mode 100644 index 000000000..3f139a0ff --- /dev/null +++ b/authentik/sources/plex/migrations/0002_plexsourceconnection.py @@ -0,0 +1,38 @@ +# Generated by Django 3.2 on 2021-05-03 17:06 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_core", "0020_source_user_matching_mode"), + ("authentik_sources_plex", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="PlexSourceConnection", + fields=[ + ( + "usersourceconnection_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="authentik_core.usersourceconnection", + ), + ), + ("plex_token", models.TextField()), + ("identifier", models.TextField()), + ], + options={ + "verbose_name": "User Plex Source Connection", + "verbose_name_plural": "User Plex Source Connections", + }, + bases=("authentik_core.usersourceconnection",), + ), + ] diff --git a/authentik/sources/plex/models.py b/authentik/sources/plex/models.py index da66c1902..4ad8d0901 100644 --- a/authentik/sources/plex/models.py +++ b/authentik/sources/plex/models.py @@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework.fields import CharField from rest_framework.serializers import BaseSerializer -from authentik.core.models import Source +from authentik.core.models import Source, UserSourceConnection from authentik.core.types import UILoginButton from authentik.flows.challenge import Challenge, ChallengeTypes @@ -53,3 +53,15 @@ class PlexSource(Source): verbose_name = _("Plex Source") verbose_name_plural = _("Plex Sources") + + +class PlexSourceConnection(UserSourceConnection): + """Connect user and plex source""" + + plex_token = models.TextField() + identifier = models.TextField() + + class Meta: + + verbose_name = _("User Plex Source Connection") + verbose_name_plural = _("User Plex Source Connections") diff --git a/authentik/sources/plex/plex.py b/authentik/sources/plex/plex.py index 1f6b394a8..c51ef83bf 100644 --- a/authentik/sources/plex/plex.py +++ b/authentik/sources/plex/plex.py @@ -1,136 +1,113 @@ -"""Plex OAuth Views""" -from typing import Any, Optional +"""Plex Views""" from urllib.parse import urlencode -from django.http.response import Http404 -from requests import post -from requests.api import get +import requests +from django.http.request import HttpRequest +from django.http.response import Http404, HttpResponse from requests.exceptions import RequestException from structlog.stdlib import get_logger from authentik import __version__ -from authentik.sources.oauth.clients.oauth2 import OAuth2Client -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection -from authentik.sources.oauth.types.manager import MANAGER, SourceType -from authentik.sources.oauth.views.callback import OAuthCallback -from authentik.sources.oauth.views.redirect import OAuthRedirect +from authentik.core.sources.flow_manager import SourceFlowManager +from authentik.sources.plex.models import PlexSource, PlexSourceConnection LOGGER = get_logger() SESSION_ID_KEY = "PLEX_ID" SESSION_CODE_KEY = "PLEX_CODE" -DEFAULT_PAYLOAD = { - "X-Plex-Product": "authentik", - "X-Plex-Version": __version__, - "X-Plex-Device-Vendor": "BeryJu.org", -} -class PlexRedirect(OAuthRedirect): - """Plex Auth redirect, get a pin then redirect to a URL to claim it""" +class PlexAuth: + """Plex authentication utilities""" - headers = {} + _source: PlexSource + _token: str - def get_pin(self, **data) -> dict: - """Get plex pin that the user will claim - https://forums.plex.tv/t/authenticating-with-plex/609370""" - return post( - "https://plex.tv/api/v2/pins.json?strong=true", - data=data, - headers=self.headers, - ).json() - - def get_redirect_url(self, **kwargs) -> str: - slug = kwargs.get("source_slug", "") - self.headers = {"Origin": self.request.build_absolute_uri("/")} - try: - source: OAuthSource = OAuthSource.objects.get(slug=slug) - except OAuthSource.DoesNotExist: - raise Http404(f"Unknown OAuth source '{slug}'.") - else: - payload = DEFAULT_PAYLOAD.copy() - payload["X-Plex-Client-Identifier"] = source.consumer_key - # Get a pin first - pin = self.get_pin(**payload) - LOGGER.debug("Got pin", **pin) - self.request.session[SESSION_ID_KEY] = pin["id"] - self.request.session[SESSION_CODE_KEY] = pin["code"] - qs = { - "clientID": source.consumer_key, - "code": pin["code"], - "forwardUrl": self.request.build_absolute_uri( - self.get_callback_url(source) - ), - } - return f"https://app.plex.tv/auth#!?{urlencode(qs)}" - - -class PlexOAuthClient(OAuth2Client): - """Retrive the plex token after authentication, then ask the plex API about user info""" - - def check_application_state(self) -> bool: - return SESSION_ID_KEY in self.request.session - - def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: - payload = dict(DEFAULT_PAYLOAD) - payload["X-Plex-Client-Identifier"] = self.source.consumer_key - payload["Accept"] = "application/json" - response = get( - f"https://plex.tv/api/v2/pins/{self.request.session[SESSION_ID_KEY]}", - headers=payload, + def __init__(self, source: PlexSource, token: str): + self._source = source + self._token = token + self._session = requests.Session() + self._session.headers.update( + {"Accept": "application/json", "Content-Type": "application/json"} ) - response.raise_for_status() - token = response.json()["authToken"] - return {"plex_token": token} + self._session.headers.update(self.headers) - def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: - "Fetch user profile information." - qs = {"X-Plex-Token": token["plex_token"]} - print(token) - try: - response = self.do_request( - "get", f"https://plex.tv/users/account.json?{urlencode(qs)}" - ) - response.raise_for_status() - except RequestException as exc: - LOGGER.warning("Unable to fetch user profile", exc=exc) - return None - else: - info = response.json() - return info.get("user", {}) - - -class PlexOAuth2Callback(OAuthCallback): - """Plex OAuth2 Callback""" - - client_class = PlexOAuthClient - - def get_user_id( - self, source: UserOAuthSourceConnection, info: dict[str, Any] - ) -> Optional[str]: - return info.get("uuid") - - def get_user_enroll_context( - self, - source: OAuthSource, - access: UserOAuthSourceConnection, - info: dict[str, Any], - ) -> dict[str, Any]: + @property + def headers(self) -> dict[str, str]: + """Get common headers""" return { - "username": info.get("username"), - "email": info.get("email"), - "name": info.get("title"), + "X-Plex-Product": "authentik", + "X-Plex-Version": __version__, + "X-Plex-Device-Vendor": "BeryJu.org", } + def get_resources(self) -> list[dict]: + """Get all resources the plex-token has access to""" + qs = { + "X-Plex-Token": self._token, + "X-Plex-Client-Identifier": self._source.client_id, + } + response = self._session.get( + f"https://plex.tv/api/v2/resources?{urlencode(qs)}", + ) + response.raise_for_status() + return response.json() -@MANAGER.type() -class PlexType(SourceType): - """Plex Type definition""" + def get_user_info(self) -> tuple[dict, int]: + """Get user info of the plex token""" + qs = { + "X-Plex-Token": self._token, + "X-Plex-Client-Identifier": self._source.client_id, + } + response = self._session.get( + f"https://plex.tv/api/v2/user?{urlencode(qs)}", + ) + response.raise_for_status() + raw_user_info = response.json() + return { + "username": raw_user_info.get("username"), + "email": raw_user_info.get("email"), + "name": raw_user_info.get("title"), + }, raw_user_info.get("id") - redirect_view = PlexRedirect - callback_view = PlexOAuth2Callback - name = "Plex" - slug = "plex" + def check_server_overlap(self) -> bool: + """Check if the plex-token has any server overlap with our configured servers""" + try: + resources = self.get_resources() + except RequestException as exc: + LOGGER.warning("Unable to fetch user resources", exc=exc) + raise Http404 + else: + for resource in resources: + if resource["provides"] != "server": + continue + if resource["clientIdentifier"] in self._source.allowed_servers: + LOGGER.info( + "Plex allowed access from server", name=resource["name"] + ) + return True + return False - authorization_url = "" - access_token_url = "" # nosec - profile_url = "" + def get_user_url(self, request: HttpRequest) -> HttpResponse: + """Get a URL to a flow executor for either enrollment or authentication""" + user_info, identifier = self.get_user_info() + sfm = PlexSourceFlowManager( + source=self._source, + request=request, + identifier=str(identifier), + enroll_info=user_info, + ) + return sfm.get_flow(plex_token=self._token) + + +class PlexSourceFlowManager(SourceFlowManager): + """Flow manager for plex sources""" + + connection_type = PlexSourceConnection + + def update_connection( + self, connection: PlexSourceConnection, plex_token: str + ) -> PlexSourceConnection: + """Set the access_token on the connection""" + connection.plex_token = plex_token + connection.save() + return connection diff --git a/swagger.yaml b/swagger.yaml index e0bf34a22..e6db70d15 100755 --- a/swagger.yaml +++ b/swagger.yaml @@ -17289,6 +17289,17 @@ definitions: enum: - all - any + user_matching_mode: + title: User matching mode + description: How the source determines if an existing user should be authenticated + or a new user enrolled. + type: string + enum: + - identifier + - email_link + - email_deny + - username_link + - username_deny UserSetting: required: - object_uid @@ -17369,6 +17380,17 @@ definitions: enum: - all - any + user_matching_mode: + title: User matching mode + description: How the source determines if an existing user should be authenticated + or a new user enrolled. + type: string + enum: + - identifier + - email_link + - email_deny + - username_link + - username_deny server_uri: title: Server URI type: string @@ -17549,6 +17571,17 @@ definitions: enum: - all - any + user_matching_mode: + title: User matching mode + description: How the source determines if an existing user should be authenticated + or a new user enrolled. + type: string + enum: + - identifier + - email_link + - email_deny + - username_link + - username_deny provider_type: title: Provider type type: string @@ -17678,6 +17711,17 @@ definitions: enum: - all - any + user_matching_mode: + title: User matching mode + description: How the source determines if an existing user should be authenticated + or a new user enrolled. + type: string + enum: + - identifier + - email_link + - email_deny + - username_link + - username_deny client_id: title: Client id type: string @@ -17792,6 +17836,17 @@ definitions: enum: - all - any + user_matching_mode: + title: User matching mode + description: How the source determines if an existing user should be authenticated + or a new user enrolled. + type: string + enum: + - identifier + - email_link + - email_deny + - username_link + - username_deny pre_authentication_flow: title: Pre authentication flow description: Flow used before authentication. @@ -18537,6 +18592,17 @@ definitions: enabled: title: Enabled type: boolean + user_matching_mode: + title: User matching mode + description: How the source determines if an existing user should + be authenticated or a new user enrolled. + type: string + enum: + - identifier + - email_link + - email_deny + - username_link + - username_deny authentication_flow: title: Authentication flow description: Flow to use when authenticating existing users.