sources: rewrite onboarding
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
e56c3fc54c
commit
35faf269db
|
@ -45,6 +45,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
|
|||
"verbose_name",
|
||||
"verbose_name_plural",
|
||||
"policy_engine_mode",
|
||||
"user_matching_mode",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Generated by Django 3.2 on 2021-05-03 17:06
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0019_source_managed"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="source",
|
||||
name="user_matching_mode",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("identifier", "Use the source-specific identifier"),
|
||||
(
|
||||
"email_link",
|
||||
"Link to a user with identical email address. Can have security implications when a source doesn't validate email addresses.",
|
||||
),
|
||||
(
|
||||
"email_deny",
|
||||
"Use the user's email address, but deny enrollment when the email address already exists.",
|
||||
),
|
||||
(
|
||||
"username_link",
|
||||
"Link to a user with identical username address. Can have security implications when a username is used with another source.",
|
||||
),
|
||||
(
|
||||
"username_deny",
|
||||
"Use the user's username, but deny enrollment when the username already exists.",
|
||||
),
|
||||
],
|
||||
default="identifier",
|
||||
help_text="How the source determines if an existing user should be authenticated or a new user enrolled.",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -240,6 +240,30 @@ class Application(PolicyBindingModel):
|
|||
verbose_name_plural = _("Applications")
|
||||
|
||||
|
||||
class SourceUserMatchingModes(models.TextChoices):
|
||||
"""Different modes a source can handle new/returning users"""
|
||||
|
||||
IDENTIFIER = "identifier", _("Use the source-specific identifier")
|
||||
EMAIL_LINK = "email_link", _(
|
||||
(
|
||||
"Link to a user with identical email address. Can have security implications "
|
||||
"when a source doesn't validate email addresses."
|
||||
)
|
||||
)
|
||||
EMAIL_DENY = "email_deny", _(
|
||||
"Use the user's email address, but deny enrollment when the email address already exists."
|
||||
)
|
||||
USERNAME_LINK = "username_link", _(
|
||||
(
|
||||
"Link to a user with identical username address. Can have security implications "
|
||||
"when a username is used with another source."
|
||||
)
|
||||
)
|
||||
USERNAME_DENY = "username_deny", _(
|
||||
"Use the user's username, but deny enrollment when the username already exists."
|
||||
)
|
||||
|
||||
|
||||
class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
|
||||
|
||||
|
@ -272,6 +296,17 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
|||
related_name="source_enrollment",
|
||||
)
|
||||
|
||||
user_matching_mode = models.TextField(
|
||||
choices=SourceUserMatchingModes.choices,
|
||||
default=SourceUserMatchingModes.IDENTIFIER,
|
||||
help_text=_(
|
||||
(
|
||||
"How the source determines if an existing user should be authenticated or "
|
||||
"a new user enrolled."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
objects = InheritanceManager()
|
||||
|
||||
@property
|
||||
|
@ -301,6 +336,8 @@ class UserSourceConnection(CreatedUpdatedModel):
|
|||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
source = models.ForeignKey(Source, on_delete=models.CASCADE)
|
||||
|
||||
objects = InheritanceManager()
|
||||
|
||||
class Meta:
|
||||
|
||||
unique_together = (("user", "source"),)
|
||||
|
|
|
@ -0,0 +1,261 @@
|
|||
"""Source decision helper"""
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from django.contrib import messages
|
||||
from django.db.models.query_utils import Q
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest
|
||||
from django.shortcuts import redirect
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import (
|
||||
Source,
|
||||
SourceUserMatchingModes,
|
||||
User,
|
||||
UserSourceConnection,
|
||||
)
|
||||
from authentik.core.sources.stage import (
|
||||
PLAN_CONTEXT_SOURCES_CONNECTION,
|
||||
PostUserEnrollmentStage,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import Flow, Stage, in_memory_stage
|
||||
from authentik.flows.planner import (
|
||||
PLAN_CONTEXT_PENDING_USER,
|
||||
PLAN_CONTEXT_REDIRECT,
|
||||
PLAN_CONTEXT_SOURCE,
|
||||
PLAN_CONTEXT_SSO,
|
||||
FlowPlanner,
|
||||
)
|
||||
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN
|
||||
from authentik.lib.utils.urls import redirect_with_qs
|
||||
from authentik.policies.utils import delete_none_keys
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
"""Actions that can be decided based on the request
|
||||
and source settings"""
|
||||
|
||||
LINK = "link"
|
||||
AUTH = "auth"
|
||||
ENROLL = "enroll"
|
||||
DENY = "deny"
|
||||
|
||||
|
||||
class SourceFlowManager:
|
||||
"""Help sources decide what they should do after authorization. Based on source settings and
|
||||
previous connections, authenticate the user, enroll a new user, link to an existing user
|
||||
or deny the request."""
|
||||
|
||||
source: Source
|
||||
request: HttpRequest
|
||||
|
||||
identifier: str
|
||||
|
||||
connection_type: Type[UserSourceConnection] = UserSourceConnection
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: Source,
|
||||
request: HttpRequest,
|
||||
identifier: str,
|
||||
enroll_info: dict[str, Any],
|
||||
) -> None:
|
||||
self.source = source
|
||||
self.request = request
|
||||
self.identifier = identifier
|
||||
self.enroll_info = enroll_info
|
||||
self._logger = get_logger().bind(source=source, identifier=identifier)
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
|
||||
"""decide which action should be taken"""
|
||||
new_connection = self.connection_type(
|
||||
source=self.source, identifier=self.identifier
|
||||
)
|
||||
# When request is authenticated, always link
|
||||
if self.request.user.is_authenticated:
|
||||
new_connection.user = self.request.user
|
||||
new_connection = self.update_connection(new_connection, **kwargs)
|
||||
new_connection.save()
|
||||
return Action.LINK, new_connection
|
||||
|
||||
existing_connections = self.connection_type.objects.filter(
|
||||
source=self.source, identifier=self.identifier
|
||||
)
|
||||
if existing_connections.exists():
|
||||
connection = existing_connections.first()
|
||||
return Action.AUTH, self.update_connection(connection, **kwargs)
|
||||
# No connection exists, but we match on identifier, so enroll
|
||||
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
|
||||
# We don't save the connection here cause it doesn't have a user assigned yet
|
||||
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
|
||||
|
||||
# Check for existing users with matching attributes
|
||||
query = Q()
|
||||
# Either query existing user based on email or username
|
||||
if self.source.user_matching_mode in [
|
||||
SourceUserMatchingModes.EMAIL_LINK,
|
||||
SourceUserMatchingModes.EMAIL_DENY,
|
||||
]:
|
||||
if not self.enroll_info.get("email", None):
|
||||
self._logger.warning("Refusing to use none email", source=self.source)
|
||||
return Action.DENY, None
|
||||
query = Q(email__exact=self.enroll_info.get("email", None))
|
||||
if self.source.user_matching_mode in [
|
||||
SourceUserMatchingModes.USERNAME_LINK,
|
||||
SourceUserMatchingModes.USERNAME_DENY,
|
||||
]:
|
||||
if not self.enroll_info.get("username", None):
|
||||
self._logger.warning(
|
||||
"Refusing to use none username", source=self.source
|
||||
)
|
||||
return Action.DENY, None
|
||||
query = Q(username__exact=self.enroll_info.get("username", None))
|
||||
matching_users = User.objects.filter(query)
|
||||
# No matching users, always enroll
|
||||
if not matching_users.exists():
|
||||
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
|
||||
|
||||
user = matching_users.first()
|
||||
if self.source.user_matching_mode in [
|
||||
SourceUserMatchingModes.EMAIL_LINK,
|
||||
SourceUserMatchingModes.USERNAME_LINK,
|
||||
]:
|
||||
new_connection.user = user
|
||||
new_connection = self.update_connection(new_connection, **kwargs)
|
||||
new_connection.save()
|
||||
return Action.LINK, new_connection
|
||||
if self.source.user_matching_mode in [
|
||||
SourceUserMatchingModes.EMAIL_DENY,
|
||||
SourceUserMatchingModes.USERNAME_DENY,
|
||||
]:
|
||||
return Action.DENY, None
|
||||
return Action.DENY, None
|
||||
|
||||
def update_connection(
|
||||
self, connection: UserSourceConnection, **kwargs
|
||||
) -> UserSourceConnection:
|
||||
"""Optionally make changes to the connection after it is looked up/created."""
|
||||
return connection
|
||||
|
||||
def get_flow(self, **kwargs) -> HttpResponse:
|
||||
"""Get the flow response based on user_matching_mode"""
|
||||
action, connection = self.get_action()
|
||||
if action == Action.LINK:
|
||||
self._logger.debug("Linking existing user")
|
||||
return self.handle_existing_user_link()
|
||||
if not connection:
|
||||
return redirect("/")
|
||||
if action == Action.AUTH:
|
||||
self._logger.debug("Handling auth user")
|
||||
return self.handle_auth_user(connection)
|
||||
if action == Action.ENROLL:
|
||||
self._logger.debug("Handling enrollment of new user")
|
||||
return self.handle_enroll(connection)
|
||||
return redirect("/")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_stages_to_append(self, flow: Flow) -> list[Stage]:
|
||||
"""Hook to override stages which are appended to the flow"""
|
||||
if flow.slug == self.source.enrollment_flow.slug:
|
||||
return [
|
||||
in_memory_stage(PostUserEnrollmentStage),
|
||||
]
|
||||
return []
|
||||
|
||||
def _handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse:
|
||||
"""Prepare Authentication Plan, redirect user FlowExecutor"""
|
||||
# Ensure redirect is carried through when user was trying to
|
||||
# authorize application
|
||||
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
|
||||
NEXT_ARG_NAME, "authentik_core:if-admin"
|
||||
)
|
||||
kwargs.update(
|
||||
{
|
||||
# Since we authenticate the user by their token, they have no backend set
|
||||
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
|
||||
PLAN_CONTEXT_SSO: True,
|
||||
PLAN_CONTEXT_SOURCE: self.source,
|
||||
PLAN_CONTEXT_REDIRECT: final_redirect,
|
||||
}
|
||||
)
|
||||
if not flow:
|
||||
return HttpResponseBadRequest()
|
||||
# We run the Flow planner here so we can pass the Pending user in the context
|
||||
planner = FlowPlanner(flow)
|
||||
plan = planner.plan(self.request, kwargs)
|
||||
for stage in self.get_stages_to_append(flow):
|
||||
plan.append(stage)
|
||||
self.request.session[SESSION_KEY_PLAN] = plan
|
||||
return redirect_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
self.request.GET,
|
||||
flow_slug=flow.slug,
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def handle_auth_user(
|
||||
self,
|
||||
connection: UserSourceConnection,
|
||||
) -> HttpResponse:
|
||||
"""Login user and redirect."""
|
||||
messages.success(
|
||||
self.request,
|
||||
_(
|
||||
"Successfully authenticated with %(source)s!"
|
||||
% {"source": self.source.name}
|
||||
),
|
||||
)
|
||||
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
|
||||
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs)
|
||||
|
||||
def handle_existing_user_link(
|
||||
self,
|
||||
) -> HttpResponse:
|
||||
"""Handler when the user was already authenticated and linked an external source
|
||||
to their account."""
|
||||
Event.new(
|
||||
EventAction.SOURCE_LINKED,
|
||||
message="Linked Source",
|
||||
source=self.source,
|
||||
).from_http(self.request)
|
||||
messages.success(
|
||||
self.request,
|
||||
_("Successfully linked %(source)s!" % {"source": self.source.name}),
|
||||
)
|
||||
return redirect(
|
||||
reverse(
|
||||
"authentik_core:if-admin",
|
||||
)
|
||||
+ f"#/user;page-{self.source.slug}"
|
||||
)
|
||||
|
||||
def handle_enroll(
|
||||
self,
|
||||
connection: UserSourceConnection,
|
||||
) -> HttpResponse:
|
||||
"""User was not authenticated and previous request was not authenticated."""
|
||||
messages.success(
|
||||
self.request,
|
||||
_(
|
||||
"Successfully authenticated with %(source)s!"
|
||||
% {"source": self.source.name}
|
||||
),
|
||||
)
|
||||
|
||||
# We run the Flow planner here so we can pass the Pending user in the context
|
||||
if not self.source.enrollment_flow:
|
||||
self._logger.warning("source has no enrollment flow")
|
||||
return HttpResponseBadRequest()
|
||||
return self._handle_login_flow(
|
||||
self.source.enrollment_flow,
|
||||
**{
|
||||
PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info),
|
||||
PLAN_CONTEXT_SOURCES_CONNECTION: connection,
|
||||
},
|
||||
)
|
|
@ -1,32 +1,30 @@
|
|||
"""OAuth Stages"""
|
||||
"""Source flow manager stages"""
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.core.models import User, UserSourceConnection
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.stage import StageView
|
||||
from authentik.sources.oauth.models import UserOAuthSourceConnection
|
||||
|
||||
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS = "sources_oauth_access"
|
||||
PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection"
|
||||
|
||||
|
||||
class PostUserEnrollmentStage(StageView):
|
||||
"""Dynamically injected stage which saves the OAuth Connection after
|
||||
"""Dynamically injected stage which saves the Connection after
|
||||
the user has been enrolled."""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
"""Stage used after the user has been enrolled"""
|
||||
access: UserOAuthSourceConnection = self.executor.plan.context[
|
||||
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS
|
||||
connection: UserSourceConnection = self.executor.plan.context[
|
||||
PLAN_CONTEXT_SOURCES_CONNECTION
|
||||
]
|
||||
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
||||
access.user = user
|
||||
access.save()
|
||||
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
|
||||
connection.user = user
|
||||
connection.save()
|
||||
Event.new(
|
||||
EventAction.SOURCE_LINKED,
|
||||
message="Linked OAuth Source",
|
||||
source=access.source,
|
||||
message="Linked Source",
|
||||
source=connection.source,
|
||||
).from_http(self.request)
|
||||
return self.executor.stage_ok()
|
|
@ -0,0 +1,84 @@
|
|||
# Generated by Django 3.2 on 2021-05-02 17:06
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_policies_event_matcher", "0012_auto_20210323_1339"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="eventmatcherpolicy",
|
||||
name="app",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
choices=[
|
||||
("authentik.admin", "authentik Admin"),
|
||||
("authentik.api", "authentik API"),
|
||||
("authentik.events", "authentik Events"),
|
||||
("authentik.crypto", "authentik Crypto"),
|
||||
("authentik.flows", "authentik Flows"),
|
||||
("authentik.outposts", "authentik Outpost"),
|
||||
("authentik.lib", "authentik lib"),
|
||||
("authentik.policies", "authentik Policies"),
|
||||
("authentik.policies.dummy", "authentik Policies.Dummy"),
|
||||
(
|
||||
"authentik.policies.event_matcher",
|
||||
"authentik Policies.Event Matcher",
|
||||
),
|
||||
("authentik.policies.expiry", "authentik Policies.Expiry"),
|
||||
("authentik.policies.expression", "authentik Policies.Expression"),
|
||||
("authentik.policies.hibp", "authentik Policies.HaveIBeenPwned"),
|
||||
("authentik.policies.password", "authentik Policies.Password"),
|
||||
("authentik.policies.reputation", "authentik Policies.Reputation"),
|
||||
("authentik.providers.proxy", "authentik Providers.Proxy"),
|
||||
("authentik.providers.oauth2", "authentik Providers.OAuth2"),
|
||||
("authentik.providers.saml", "authentik Providers.SAML"),
|
||||
("authentik.recovery", "authentik Recovery"),
|
||||
("authentik.sources.ldap", "authentik Sources.LDAP"),
|
||||
("authentik.sources.oauth", "authentik Sources.OAuth"),
|
||||
("authentik.sources.plex", "authentik Sources.Plex"),
|
||||
("authentik.sources.saml", "authentik Sources.SAML"),
|
||||
(
|
||||
"authentik.stages.authenticator_static",
|
||||
"authentik Stages.Authenticator.Static",
|
||||
),
|
||||
(
|
||||
"authentik.stages.authenticator_totp",
|
||||
"authentik Stages.Authenticator.TOTP",
|
||||
),
|
||||
(
|
||||
"authentik.stages.authenticator_validate",
|
||||
"authentik Stages.Authenticator.Validate",
|
||||
),
|
||||
(
|
||||
"authentik.stages.authenticator_webauthn",
|
||||
"authentik Stages.Authenticator.WebAuthn",
|
||||
),
|
||||
("authentik.stages.captcha", "authentik Stages.Captcha"),
|
||||
("authentik.stages.consent", "authentik Stages.Consent"),
|
||||
("authentik.stages.deny", "authentik Stages.Deny"),
|
||||
("authentik.stages.dummy", "authentik Stages.Dummy"),
|
||||
("authentik.stages.email", "authentik Stages.Email"),
|
||||
(
|
||||
"authentik.stages.identification",
|
||||
"authentik Stages.Identification",
|
||||
),
|
||||
("authentik.stages.invitation", "authentik Stages.User Invitation"),
|
||||
("authentik.stages.password", "authentik Stages.Password"),
|
||||
("authentik.stages.prompt", "authentik Stages.Prompt"),
|
||||
("authentik.stages.user_delete", "authentik Stages.User Delete"),
|
||||
("authentik.stages.user_login", "authentik Stages.User Login"),
|
||||
("authentik.stages.user_logout", "authentik Stages.User Logout"),
|
||||
("authentik.stages.user_write", "authentik Stages.User Write"),
|
||||
("authentik.core", "authentik Core"),
|
||||
("authentik.managed", "authentik Managed"),
|
||||
],
|
||||
default="",
|
||||
help_text="Match events created by selected application. When left empty, all applications are matched.",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -1,23 +0,0 @@
|
|||
"""authentik oauth_client Authorization backend"""
|
||||
from typing import Optional
|
||||
|
||||
from django.contrib.auth.backends import ModelBackend
|
||||
from django.http import HttpRequest
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
|
||||
|
||||
class AuthorizedServiceBackend(ModelBackend):
|
||||
"Authentication backend for users registered with remote OAuth provider."
|
||||
|
||||
def authenticate(
|
||||
self, request: HttpRequest, source: OAuthSource, identifier: str
|
||||
) -> Optional[User]:
|
||||
"Fetch user for a given source by id."
|
||||
access = UserOAuthSourceConnection.objects.filter(
|
||||
source=source, identifier=identifier
|
||||
).select_related("user")
|
||||
if not access.exists():
|
||||
return None
|
||||
return access.first().user
|
|
@ -1,7 +1,7 @@
|
|||
"""Discord Type tests"""
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.discord import DiscordOAuth2Callback
|
||||
|
||||
# https://discord.com/developers/docs/resources/user#user-object
|
||||
|
@ -33,9 +33,7 @@ class TestTypeDiscord(TestCase):
|
|||
|
||||
def test_enroll_context(self):
|
||||
"""Test discord Enrollment context"""
|
||||
ak_context = DiscordOAuth2Callback().get_user_enroll_context(
|
||||
self.source, UserOAuthSourceConnection(), DISCORD_USER
|
||||
)
|
||||
ak_context = DiscordOAuth2Callback().get_user_enroll_context(DISCORD_USER)
|
||||
self.assertEqual(ak_context["username"], DISCORD_USER["username"])
|
||||
self.assertEqual(ak_context["email"], DISCORD_USER["email"])
|
||||
self.assertEqual(ak_context["name"], DISCORD_USER["username"])
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""GitHub Type tests"""
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.github import GitHubOAuth2Callback
|
||||
|
||||
# https://developer.github.com/v3/users/#get-the-authenticated-user
|
||||
|
@ -63,9 +63,7 @@ class TestTypeGitHub(TestCase):
|
|||
|
||||
def test_enroll_context(self):
|
||||
"""Test GitHub Enrollment context"""
|
||||
ak_context = GitHubOAuth2Callback().get_user_enroll_context(
|
||||
self.source, UserOAuthSourceConnection(), GITHUB_USER
|
||||
)
|
||||
ak_context = GitHubOAuth2Callback().get_user_enroll_context(GITHUB_USER)
|
||||
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
|
||||
self.assertEqual(ak_context["email"], GITHUB_USER["email"])
|
||||
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""google Type tests"""
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.google import GoogleOAuth2Callback
|
||||
|
||||
# https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en
|
||||
|
@ -32,9 +32,7 @@ class TestTypeGoogle(TestCase):
|
|||
|
||||
def test_enroll_context(self):
|
||||
"""Test Google Enrollment context"""
|
||||
ak_context = GoogleOAuth2Callback().get_user_enroll_context(
|
||||
self.source, UserOAuthSourceConnection(), GOOGLE_USER
|
||||
)
|
||||
ak_context = GoogleOAuth2Callback().get_user_enroll_context(GOOGLE_USER)
|
||||
self.assertEqual(ak_context["username"], GOOGLE_USER["email"])
|
||||
self.assertEqual(ak_context["email"], GOOGLE_USER["email"])
|
||||
self.assertEqual(ak_context["name"], GOOGLE_USER["name"])
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Twitter Type tests"""
|
||||
from django.test import Client, TestCase
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.twitter import TwitterOAuthCallback
|
||||
|
||||
# https://developer.twitter.com/en/docs/twitter-api/v1/accounts-and-users/manage-account-settings/ \
|
||||
|
@ -104,9 +104,7 @@ class TestTypeGitHub(TestCase):
|
|||
|
||||
def test_enroll_context(self):
|
||||
"""Test Twitter Enrollment context"""
|
||||
ak_context = TwitterOAuthCallback().get_user_enroll_context(
|
||||
self.source, UserOAuthSourceConnection(), TWITTER_USER
|
||||
)
|
||||
ak_context = TwitterOAuthCallback().get_user_enroll_context(TWITTER_USER)
|
||||
self.assertEqual(ak_context["username"], TWITTER_USER["screen_name"])
|
||||
self.assertEqual(ak_context["email"], TWITTER_USER.get("email", None))
|
||||
self.assertEqual(ak_context["name"], TWITTER_USER["name"])
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
|
||||
|
@ -10,7 +9,7 @@ from authentik.sources.oauth.views.callback import OAuthCallback
|
|||
class AzureADOAuthCallback(OAuthCallback):
|
||||
"""AzureAD OAuth2 Callback"""
|
||||
|
||||
def get_user_id(self, source: OAuthSource, info: dict[str, Any]) -> Optional[str]:
|
||||
def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
|
||||
try:
|
||||
return str(UUID(info.get("objectId")).int)
|
||||
except TypeError:
|
||||
|
@ -18,8 +17,6 @@ class AzureADOAuthCallback(OAuthCallback):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Discord OAuth Views"""
|
||||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
@ -21,8 +20,6 @@ class DiscordOAuth2Callback(OAuthCallback):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Any, Optional
|
|||
from facebook import GraphAPI
|
||||
|
||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
@ -34,8 +33,6 @@ class FacebookOAuth2Callback(OAuthCallback):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""GitHub OAuth Views"""
|
||||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
|
||||
|
@ -11,8 +10,6 @@ class GitHubOAuth2Callback(OAuthCallback):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Google OAuth Views"""
|
||||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
@ -21,8 +20,6 @@ class GoogleOAuth2Callback(OAuthCallback):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""OpenID Connect OAuth Views"""
|
||||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
@ -19,13 +19,11 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
|||
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
||||
"""OpenIDConnect OAuth2 Callback"""
|
||||
|
||||
def get_user_id(self, source: OAuthSource, info: dict[str, str]) -> str:
|
||||
def get_user_id(self, info: dict[str, str]) -> str:
|
||||
return info.get("sub", "")
|
||||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Any
|
|||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
@ -36,8 +35,6 @@ class RedditOAuth2Callback(OAuthCallback):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Twitter OAuth Views"""
|
||||
from typing import Any
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
|
||||
|
@ -11,8 +10,6 @@ class TwitterOAuthCallback(OAuthCallback):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -4,35 +4,14 @@ from typing import Any, Optional
|
|||
from django.conf import settings
|
||||
from django.contrib import messages
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.http.response import HttpResponseBadRequest
|
||||
from django.shortcuts import redirect
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
from django.views.generic import View
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import Flow, in_memory_stage
|
||||
from authentik.flows.planner import (
|
||||
PLAN_CONTEXT_PENDING_USER,
|
||||
PLAN_CONTEXT_REDIRECT,
|
||||
PLAN_CONTEXT_SOURCE,
|
||||
PLAN_CONTEXT_SSO,
|
||||
FlowPlanner,
|
||||
)
|
||||
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN
|
||||
from authentik.lib.utils.urls import redirect_with_qs
|
||||
from authentik.policies.utils import delete_none_keys
|
||||
from authentik.sources.oauth.auth import AuthorizedServiceBackend
|
||||
from authentik.core.sources.flow_manager import SourceFlowManager
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.views.base import OAuthClientMixin
|
||||
from authentik.sources.oauth.views.flows import (
|
||||
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS,
|
||||
PostUserEnrollmentStage,
|
||||
)
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
|
||||
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
@ -40,8 +19,7 @@ LOGGER = get_logger()
|
|||
class OAuthCallback(OAuthClientMixin, View):
|
||||
"Base OAuth callback view."
|
||||
|
||||
source_id = None
|
||||
source = None
|
||||
source: OAuthSource
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def get(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
|
||||
|
@ -60,47 +38,27 @@ class OAuthCallback(OAuthClientMixin, View):
|
|||
# Fetch access token
|
||||
token = client.get_access_token()
|
||||
if token is None:
|
||||
return self.handle_login_failure(self.source, "Could not retrieve token.")
|
||||
return self.handle_login_failure("Could not retrieve token.")
|
||||
if "error" in token:
|
||||
return self.handle_login_failure(self.source, token["error"])
|
||||
return self.handle_login_failure(token["error"])
|
||||
# Fetch profile info
|
||||
info = client.get_profile_info(token)
|
||||
if info is None:
|
||||
return self.handle_login_failure(self.source, "Could not retrieve profile.")
|
||||
identifier = self.get_user_id(self.source, info)
|
||||
raw_info = client.get_profile_info(token)
|
||||
if raw_info is None:
|
||||
return self.handle_login_failure("Could not retrieve profile.")
|
||||
identifier = self.get_user_id(raw_info)
|
||||
if identifier is None:
|
||||
return self.handle_login_failure(self.source, "Could not determine id.")
|
||||
return self.handle_login_failure("Could not determine id.")
|
||||
# Get or create access record
|
||||
defaults = {
|
||||
"access_token": token.get("access_token"),
|
||||
}
|
||||
existing = UserOAuthSourceConnection.objects.filter(
|
||||
source=self.source, identifier=identifier
|
||||
)
|
||||
|
||||
if existing.exists():
|
||||
connection = existing.first()
|
||||
connection.access_token = token.get("access_token")
|
||||
UserOAuthSourceConnection.objects.filter(pk=connection.pk).update(
|
||||
**defaults
|
||||
)
|
||||
else:
|
||||
connection = UserOAuthSourceConnection(
|
||||
enroll_info = self.get_user_enroll_context(raw_info)
|
||||
sfm = OAuthSourceFlowManager(
|
||||
source=self.source,
|
||||
request=self.request,
|
||||
identifier=identifier,
|
||||
access_token=token.get("access_token"),
|
||||
enroll_info=enroll_info,
|
||||
)
|
||||
user = AuthorizedServiceBackend().authenticate(
|
||||
source=self.source, identifier=identifier, request=request
|
||||
return sfm.get_flow(
|
||||
token=token,
|
||||
)
|
||||
if user is None:
|
||||
if self.request.user.is_authenticated:
|
||||
LOGGER.debug("Linking existing user", source=self.source)
|
||||
return self.handle_existing_user_link(self.source, connection, info)
|
||||
LOGGER.debug("Handling enrollment of new user", source=self.source)
|
||||
return self.handle_enroll(self.source, connection, info)
|
||||
LOGGER.debug("Handling existing user", source=self.source)
|
||||
return self.handle_existing_user(self.source, user, connection, info)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_callback_url(self, source: OAuthSource) -> str:
|
||||
|
@ -114,132 +72,34 @@ class OAuthCallback(OAuthClientMixin, View):
|
|||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create a dict of User data"""
|
||||
raise NotImplementedError()
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_user_id(
|
||||
self, source: UserOAuthSourceConnection, info: dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
|
||||
"""Return unique identifier from the profile info."""
|
||||
if "id" in info:
|
||||
return info["id"]
|
||||
return None
|
||||
|
||||
def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse:
|
||||
def handle_login_failure(self, reason: str) -> HttpResponse:
|
||||
"Message user and redirect on error."
|
||||
LOGGER.warning("Authentication Failure", reason=reason)
|
||||
messages.error(self.request, _("Authentication Failed."))
|
||||
return redirect(self.get_error_redirect(source, reason))
|
||||
return redirect(self.get_error_redirect(self.source, reason))
|
||||
|
||||
def handle_login_flow(
|
||||
self, flow: Flow, *stages_to_append, **kwargs
|
||||
) -> HttpResponse:
|
||||
"""Prepare Authentication Plan, redirect user FlowExecutor"""
|
||||
# Ensure redirect is carried through when user was trying to
|
||||
# authorize application
|
||||
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
|
||||
NEXT_ARG_NAME, "authentik_core:if-admin"
|
||||
)
|
||||
kwargs.update(
|
||||
{
|
||||
# Since we authenticate the user by their token, they have no backend set
|
||||
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
|
||||
PLAN_CONTEXT_SSO: True,
|
||||
PLAN_CONTEXT_SOURCE: self.source,
|
||||
PLAN_CONTEXT_REDIRECT: final_redirect,
|
||||
}
|
||||
)
|
||||
if not flow:
|
||||
return HttpResponseBadRequest()
|
||||
# We run the Flow planner here so we can pass the Pending user in the context
|
||||
planner = FlowPlanner(flow)
|
||||
plan = planner.plan(self.request, kwargs)
|
||||
for stage in stages_to_append:
|
||||
plan.append(stage)
|
||||
self.request.session[SESSION_KEY_PLAN] = plan
|
||||
return redirect_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
self.request.GET,
|
||||
flow_slug=flow.slug,
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def handle_existing_user(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
user: User,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> HttpResponse:
|
||||
"Login user and redirect."
|
||||
messages.success(
|
||||
self.request,
|
||||
_(
|
||||
"Successfully authenticated with %(source)s!"
|
||||
% {"source": self.source.name}
|
||||
),
|
||||
)
|
||||
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user}
|
||||
return self.handle_login_flow(source.authentication_flow, **flow_kwargs)
|
||||
class OAuthSourceFlowManager(SourceFlowManager):
|
||||
"""Flow manager for oauth sources"""
|
||||
|
||||
def handle_existing_user_link(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> HttpResponse:
|
||||
"""Handler when the user was already authenticated and linked an external source
|
||||
to their account."""
|
||||
# there's already a user logged in, just link them up
|
||||
user = self.request.user
|
||||
access.user = user
|
||||
access.save()
|
||||
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
|
||||
Event.new(
|
||||
EventAction.SOURCE_LINKED, message="Linked OAuth Source", source=source
|
||||
).from_http(self.request)
|
||||
messages.success(
|
||||
self.request,
|
||||
_("Successfully linked %(source)s!" % {"source": self.source.name}),
|
||||
)
|
||||
return redirect(
|
||||
reverse(
|
||||
"authentik_core:if-admin",
|
||||
)
|
||||
+ f"#/user;page-{self.source.slug}"
|
||||
)
|
||||
connection_type = UserOAuthSourceConnection
|
||||
|
||||
def handle_enroll(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> HttpResponse:
|
||||
"""User was not authenticated and previous request was not authenticated."""
|
||||
messages.success(
|
||||
self.request,
|
||||
_(
|
||||
"Successfully authenticated with %(source)s!"
|
||||
% {"source": self.source.name}
|
||||
),
|
||||
)
|
||||
|
||||
# We run the Flow planner here so we can pass the Pending user in the context
|
||||
if not source.enrollment_flow:
|
||||
LOGGER.warning("source has no enrollment flow", source=source)
|
||||
return HttpResponseBadRequest()
|
||||
return self.handle_login_flow(
|
||||
source.enrollment_flow,
|
||||
in_memory_stage(PostUserEnrollmentStage),
|
||||
**{
|
||||
PLAN_CONTEXT_PROMPT: delete_none_keys(
|
||||
self.get_user_enroll_context(source, access, info)
|
||||
),
|
||||
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS: access,
|
||||
},
|
||||
)
|
||||
def update_connection(
|
||||
self, connection: UserOAuthSourceConnection, token: dict[str, Any]
|
||||
) -> UserOAuthSourceConnection:
|
||||
"""Set the access_token on the connection"""
|
||||
connection.access_token = token.get("access_token")
|
||||
connection.save()
|
||||
return connection
|
||||
|
|
|
@ -1,26 +1,22 @@
|
|||
"""Plex Source Serializer"""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.http import Http404
|
||||
from django.shortcuts import get_object_or_404
|
||||
from drf_yasg import openapi
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from requests import RequestException, get
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.decorators import permission_required
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge
|
||||
from authentik.flows.challenge import RedirectChallenge
|
||||
from authentik.flows.views import to_stage_response
|
||||
from authentik.sources.plex.models import PlexSource
|
||||
|
||||
LOGGER = get_logger()
|
||||
from authentik.sources.plex.plex import PlexAuth
|
||||
|
||||
|
||||
class PlexSourceSerializer(SourceSerializer):
|
||||
|
@ -72,29 +68,8 @@ class PlexSourceViewSet(ModelViewSet):
|
|||
plex_token = request.data.get("plex_token", None)
|
||||
if not plex_token:
|
||||
raise Http404
|
||||
qs = {"X-Plex-Token": plex_token, "X-Plex-Client-Identifier": source.client_id}
|
||||
try:
|
||||
response = get(
|
||||
f"https://plex.tv/api/v2/resources?{urlencode(qs)}",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Unable to fetch user resources", exc=exc)
|
||||
auth_api = PlexAuth(source, plex_token)
|
||||
if not auth_api.check_server_overlap():
|
||||
raise Http404
|
||||
else:
|
||||
resources: list[dict] = response.json()
|
||||
for resource in resources:
|
||||
if resource["provides"] != "server":
|
||||
continue
|
||||
if resource["clientIdentifier"] in source.allowed_servers:
|
||||
LOGGER.info(
|
||||
"Plex allowed access from server", name=resource["name"]
|
||||
)
|
||||
request.session["foo"] = "bar"
|
||||
break
|
||||
return Response(
|
||||
RedirectChallenge(
|
||||
{"type": ChallengeTypes.REDIRECT.value, "to": ""}
|
||||
).data
|
||||
)
|
||||
response = auth_api.get_user_url(request)
|
||||
return to_stage_response(request, response)
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Generated by Django 3.2 on 2021-05-03 17:06
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0020_source_user_matching_mode"),
|
||||
("authentik_sources_plex", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="PlexSourceConnection",
|
||||
fields=[
|
||||
(
|
||||
"usersourceconnection_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.usersourceconnection",
|
||||
),
|
||||
),
|
||||
("plex_token", models.TextField()),
|
||||
("identifier", models.TextField()),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "User Plex Source Connection",
|
||||
"verbose_name_plural": "User Plex Source Connections",
|
||||
},
|
||||
bases=("authentik_core.usersourceconnection",),
|
||||
),
|
||||
]
|
|
@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework.fields import CharField
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
|
||||
from authentik.core.models import Source
|
||||
from authentik.core.models import Source, UserSourceConnection
|
||||
from authentik.core.types import UILoginButton
|
||||
from authentik.flows.challenge import Challenge, ChallengeTypes
|
||||
|
||||
|
@ -53,3 +53,15 @@ class PlexSource(Source):
|
|||
|
||||
verbose_name = _("Plex Source")
|
||||
verbose_name_plural = _("Plex Sources")
|
||||
|
||||
|
||||
class PlexSourceConnection(UserSourceConnection):
|
||||
"""Connect user and plex source"""
|
||||
|
||||
plex_token = models.TextField()
|
||||
identifier = models.TextField()
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = _("User Plex Source Connection")
|
||||
verbose_name_plural = _("User Plex Source Connections")
|
||||
|
|
|
@ -1,136 +1,113 @@
|
|||
"""Plex OAuth Views"""
|
||||
from typing import Any, Optional
|
||||
"""Plex Views"""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.http.response import Http404
|
||||
from requests import post
|
||||
from requests.api import get
|
||||
import requests
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import Http404, HttpResponse
|
||||
from requests.exceptions import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik import __version__
|
||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
from authentik.core.sources.flow_manager import SourceFlowManager
|
||||
from authentik.sources.plex.models import PlexSource, PlexSourceConnection
|
||||
|
||||
LOGGER = get_logger()
|
||||
SESSION_ID_KEY = "PLEX_ID"
|
||||
SESSION_CODE_KEY = "PLEX_CODE"
|
||||
DEFAULT_PAYLOAD = {
|
||||
|
||||
|
||||
class PlexAuth:
|
||||
"""Plex authentication utilities"""
|
||||
|
||||
_source: PlexSource
|
||||
_token: str
|
||||
|
||||
def __init__(self, source: PlexSource, token: str):
|
||||
self._source = source
|
||||
self._token = token
|
||||
self._session = requests.Session()
|
||||
self._session.headers.update(
|
||||
{"Accept": "application/json", "Content-Type": "application/json"}
|
||||
)
|
||||
self._session.headers.update(self.headers)
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""Get common headers"""
|
||||
return {
|
||||
"X-Plex-Product": "authentik",
|
||||
"X-Plex-Version": __version__,
|
||||
"X-Plex-Device-Vendor": "BeryJu.org",
|
||||
}
|
||||
|
||||
|
||||
class PlexRedirect(OAuthRedirect):
|
||||
"""Plex Auth redirect, get a pin then redirect to a URL to claim it"""
|
||||
|
||||
headers = {}
|
||||
|
||||
def get_pin(self, **data) -> dict:
|
||||
"""Get plex pin that the user will claim
|
||||
https://forums.plex.tv/t/authenticating-with-plex/609370"""
|
||||
return post(
|
||||
"https://plex.tv/api/v2/pins.json?strong=true",
|
||||
data=data,
|
||||
headers=self.headers,
|
||||
).json()
|
||||
|
||||
def get_redirect_url(self, **kwargs) -> str:
|
||||
slug = kwargs.get("source_slug", "")
|
||||
self.headers = {"Origin": self.request.build_absolute_uri("/")}
|
||||
try:
|
||||
source: OAuthSource = OAuthSource.objects.get(slug=slug)
|
||||
except OAuthSource.DoesNotExist:
|
||||
raise Http404(f"Unknown OAuth source '{slug}'.")
|
||||
else:
|
||||
payload = DEFAULT_PAYLOAD.copy()
|
||||
payload["X-Plex-Client-Identifier"] = source.consumer_key
|
||||
# Get a pin first
|
||||
pin = self.get_pin(**payload)
|
||||
LOGGER.debug("Got pin", **pin)
|
||||
self.request.session[SESSION_ID_KEY] = pin["id"]
|
||||
self.request.session[SESSION_CODE_KEY] = pin["code"]
|
||||
def get_resources(self) -> list[dict]:
|
||||
"""Get all resources the plex-token has access to"""
|
||||
qs = {
|
||||
"clientID": source.consumer_key,
|
||||
"code": pin["code"],
|
||||
"forwardUrl": self.request.build_absolute_uri(
|
||||
self.get_callback_url(source)
|
||||
),
|
||||
"X-Plex-Token": self._token,
|
||||
"X-Plex-Client-Identifier": self._source.client_id,
|
||||
}
|
||||
return f"https://app.plex.tv/auth#!?{urlencode(qs)}"
|
||||
|
||||
|
||||
class PlexOAuthClient(OAuth2Client):
|
||||
"""Retrive the plex token after authentication, then ask the plex API about user info"""
|
||||
|
||||
def check_application_state(self) -> bool:
|
||||
return SESSION_ID_KEY in self.request.session
|
||||
|
||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||
payload = dict(DEFAULT_PAYLOAD)
|
||||
payload["X-Plex-Client-Identifier"] = self.source.consumer_key
|
||||
payload["Accept"] = "application/json"
|
||||
response = get(
|
||||
f"https://plex.tv/api/v2/pins/{self.request.session[SESSION_ID_KEY]}",
|
||||
headers=payload,
|
||||
response = self._session.get(
|
||||
f"https://plex.tv/api/v2/resources?{urlencode(qs)}",
|
||||
)
|
||||
response.raise_for_status()
|
||||
token = response.json()["authToken"]
|
||||
return {"plex_token": token}
|
||||
return response.json()
|
||||
|
||||
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||
"Fetch user profile information."
|
||||
qs = {"X-Plex-Token": token["plex_token"]}
|
||||
print(token)
|
||||
try:
|
||||
response = self.do_request(
|
||||
"get", f"https://plex.tv/users/account.json?{urlencode(qs)}"
|
||||
def get_user_info(self) -> tuple[dict, int]:
|
||||
"""Get user info of the plex token"""
|
||||
qs = {
|
||||
"X-Plex-Token": self._token,
|
||||
"X-Plex-Client-Identifier": self._source.client_id,
|
||||
}
|
||||
response = self._session.get(
|
||||
f"https://plex.tv/api/v2/user?{urlencode(qs)}",
|
||||
)
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Unable to fetch user profile", exc=exc)
|
||||
return None
|
||||
else:
|
||||
info = response.json()
|
||||
return info.get("user", {})
|
||||
|
||||
|
||||
class PlexOAuth2Callback(OAuthCallback):
|
||||
"""Plex OAuth2 Callback"""
|
||||
|
||||
client_class = PlexOAuthClient
|
||||
|
||||
def get_user_id(
|
||||
self, source: UserOAuthSourceConnection, info: dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
return info.get("uuid")
|
||||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
source: OAuthSource,
|
||||
access: UserOAuthSourceConnection,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
raw_user_info = response.json()
|
||||
return {
|
||||
"username": info.get("username"),
|
||||
"email": info.get("email"),
|
||||
"name": info.get("title"),
|
||||
}
|
||||
"username": raw_user_info.get("username"),
|
||||
"email": raw_user_info.get("email"),
|
||||
"name": raw_user_info.get("title"),
|
||||
}, raw_user_info.get("id")
|
||||
|
||||
def check_server_overlap(self) -> bool:
|
||||
"""Check if the plex-token has any server overlap with our configured servers"""
|
||||
try:
|
||||
resources = self.get_resources()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Unable to fetch user resources", exc=exc)
|
||||
raise Http404
|
||||
else:
|
||||
for resource in resources:
|
||||
if resource["provides"] != "server":
|
||||
continue
|
||||
if resource["clientIdentifier"] in self._source.allowed_servers:
|
||||
LOGGER.info(
|
||||
"Plex allowed access from server", name=resource["name"]
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_user_url(self, request: HttpRequest) -> HttpResponse:
|
||||
"""Get a URL to a flow executor for either enrollment or authentication"""
|
||||
user_info, identifier = self.get_user_info()
|
||||
sfm = PlexSourceFlowManager(
|
||||
source=self._source,
|
||||
request=request,
|
||||
identifier=str(identifier),
|
||||
enroll_info=user_info,
|
||||
)
|
||||
return sfm.get_flow(plex_token=self._token)
|
||||
|
||||
|
||||
@MANAGER.type()
|
||||
class PlexType(SourceType):
|
||||
"""Plex Type definition"""
|
||||
class PlexSourceFlowManager(SourceFlowManager):
|
||||
"""Flow manager for plex sources"""
|
||||
|
||||
redirect_view = PlexRedirect
|
||||
callback_view = PlexOAuth2Callback
|
||||
name = "Plex"
|
||||
slug = "plex"
|
||||
connection_type = PlexSourceConnection
|
||||
|
||||
authorization_url = ""
|
||||
access_token_url = "" # nosec
|
||||
profile_url = ""
|
||||
def update_connection(
|
||||
self, connection: PlexSourceConnection, plex_token: str
|
||||
) -> PlexSourceConnection:
|
||||
"""Set the access_token on the connection"""
|
||||
connection.plex_token = plex_token
|
||||
connection.save()
|
||||
return connection
|
||||
|
|
66
swagger.yaml
66
swagger.yaml
|
@ -17289,6 +17289,17 @@ definitions:
|
|||
enum:
|
||||
- all
|
||||
- any
|
||||
user_matching_mode:
|
||||
title: User matching mode
|
||||
description: How the source determines if an existing user should be authenticated
|
||||
or a new user enrolled.
|
||||
type: string
|
||||
enum:
|
||||
- identifier
|
||||
- email_link
|
||||
- email_deny
|
||||
- username_link
|
||||
- username_deny
|
||||
UserSetting:
|
||||
required:
|
||||
- object_uid
|
||||
|
@ -17369,6 +17380,17 @@ definitions:
|
|||
enum:
|
||||
- all
|
||||
- any
|
||||
user_matching_mode:
|
||||
title: User matching mode
|
||||
description: How the source determines if an existing user should be authenticated
|
||||
or a new user enrolled.
|
||||
type: string
|
||||
enum:
|
||||
- identifier
|
||||
- email_link
|
||||
- email_deny
|
||||
- username_link
|
||||
- username_deny
|
||||
server_uri:
|
||||
title: Server URI
|
||||
type: string
|
||||
|
@ -17549,6 +17571,17 @@ definitions:
|
|||
enum:
|
||||
- all
|
||||
- any
|
||||
user_matching_mode:
|
||||
title: User matching mode
|
||||
description: How the source determines if an existing user should be authenticated
|
||||
or a new user enrolled.
|
||||
type: string
|
||||
enum:
|
||||
- identifier
|
||||
- email_link
|
||||
- email_deny
|
||||
- username_link
|
||||
- username_deny
|
||||
provider_type:
|
||||
title: Provider type
|
||||
type: string
|
||||
|
@ -17678,6 +17711,17 @@ definitions:
|
|||
enum:
|
||||
- all
|
||||
- any
|
||||
user_matching_mode:
|
||||
title: User matching mode
|
||||
description: How the source determines if an existing user should be authenticated
|
||||
or a new user enrolled.
|
||||
type: string
|
||||
enum:
|
||||
- identifier
|
||||
- email_link
|
||||
- email_deny
|
||||
- username_link
|
||||
- username_deny
|
||||
client_id:
|
||||
title: Client id
|
||||
type: string
|
||||
|
@ -17792,6 +17836,17 @@ definitions:
|
|||
enum:
|
||||
- all
|
||||
- any
|
||||
user_matching_mode:
|
||||
title: User matching mode
|
||||
description: How the source determines if an existing user should be authenticated
|
||||
or a new user enrolled.
|
||||
type: string
|
||||
enum:
|
||||
- identifier
|
||||
- email_link
|
||||
- email_deny
|
||||
- username_link
|
||||
- username_deny
|
||||
pre_authentication_flow:
|
||||
title: Pre authentication flow
|
||||
description: Flow used before authentication.
|
||||
|
@ -18537,6 +18592,17 @@ definitions:
|
|||
enabled:
|
||||
title: Enabled
|
||||
type: boolean
|
||||
user_matching_mode:
|
||||
title: User matching mode
|
||||
description: How the source determines if an existing user should
|
||||
be authenticated or a new user enrolled.
|
||||
type: string
|
||||
enum:
|
||||
- identifier
|
||||
- email_link
|
||||
- email_deny
|
||||
- username_link
|
||||
- username_deny
|
||||
authentication_flow:
|
||||
title: Authentication flow
|
||||
description: Flow to use when authenticating existing users.
|
||||
|
|
Reference in New Issue