Merge branch 'main' into web/issue-4880-multi-select-limitations

* main: (30 commits)
  outposts/proxy: better Redis error message (#8044)
  translate: Updates for file web/xliff/en.xlf in fr (#8046)
  web: bump the eslint group in /tests/wdio with 2 updates (#8041)
  web: bump the storybook group in /web with 7 updates (#8042)
  web: bump the eslint group in /web with 2 updates (#8043)
  web: bump @types/guacamole-common-js from 1.3.2 to 1.5.2 in /web (#8030)
  translate: Updates for file web/xliff/en.xlf in zh_CN (#8038)
  translate: Updates for file web/xliff/en.xlf in zh-Hans (#8039)
  website: bump clsx from 2.0.0 to 2.1.0 in /website (#8033)
  core: bump golang from 1.21.3-bookworm to 1.21.5-bookworm (#8027)
  web: bump the babel group in /web with 4 updates (#8028)
  web: bump the esbuild group in /web with 2 updates (#8029)
  web: bump rollup from 4.9.1 to 4.9.2 in /web (#8031)
  tests/e2e: fix tests to work without docker network_mode host (#8035)
  website/docs: fix typo (#8015)
  web: bump API Client version (#8025)
  enterprise/providers: Add RAC [AUTH-15] (#7291)
  outposts: disable deployment and secret reconciler for embedded outpost in code instead of in config (#8021)
  providers/proxy: use access token (#8022)
  website/integrations: Add custom Group/Role mapping documentation for Grafana (#7453)
  ...
This commit is contained in:
Ken Sternberg 2024-01-02 15:55:35 -08:00
commit 73bba0498f
146 changed files with 8337 additions and 1523 deletions

View File

@ -9,3 +9,4 @@ blueprints/local
.git .git
!gen-ts-api/node_modules !gen-ts-api/node_modules
!gen-ts-api/dist/** !gen-ts-api/dist/**
!gen-go-api/

View File

@ -2,3 +2,4 @@ keypair
keypairs keypairs
hass hass
warmup warmup
ontext

View File

@ -249,12 +249,6 @@ jobs:
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max
- name: Comment on PR
if: github.event_name == 'pull_request'
continue-on-error: true
uses: ./.github/actions/comment-pr-instructions
with:
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
build-arm64: build-arm64:
needs: ci-core-mark needs: ci-core-mark
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -303,3 +297,26 @@ jobs:
platforms: linux/arm64 platforms: linux/arm64
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max
pr-comment:
needs:
- build
- build-arm64
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request' }}
permissions:
# Needed to write comments on PRs
pull-requests: write
timeout-minutes: 120
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
- name: Comment on PR
uses: ./.github/actions/comment-pr-instructions
with:
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}

View File

@ -65,6 +65,7 @@ jobs:
- proxy - proxy
- ldap - ldap
- radius - radius
- rac
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
# Needed to upload contianer images to ghcr.io # Needed to upload contianer images to ghcr.io
@ -119,6 +120,7 @@ jobs:
- proxy - proxy
- ldap - ldap
- radius - radius
- rac
goos: [linux] goos: [linux]
goarch: [amd64, arm64] goarch: [amd64, arm64]
steps: steps:

View File

@ -65,6 +65,7 @@ jobs:
- proxy - proxy
- ldap - ldap
- radius - radius
- rac
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5

View File

@ -58,7 +58,7 @@ test: ## Run the server tests and produce a coverage report (locally)
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors. lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
isort $(PY_SOURCES) isort $(PY_SOURCES)
black $(PY_SOURCES) black $(PY_SOURCES)
ruff $(PY_SOURCES) ruff --fix $(PY_SOURCES)
codespell -w $(CODESPELL_ARGS) codespell -w $(CODESPELL_ARGS)
lint: ## Lint the python and golang sources lint: ## Lint the python and golang sources

View File

@ -40,7 +40,7 @@ class ManagedAppConfig(AppConfig):
meth() meth()
self._logger.debug("Successfully reconciled", name=name) self._logger.debug("Successfully reconciled", name=name)
except (DatabaseError, ProgrammingError, InternalError) as exc: except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to run reconcile", name=name, exc=exc) self._logger.warning("Failed to run reconcile", name=name, exc=exc)
class AuthentikBlueprintsConfig(ManagedAppConfig): class AuthentikBlueprintsConfig(ManagedAppConfig):

View File

@ -1,22 +1,29 @@
"""Channels base classes""" """Channels base classes"""
from channels.db import database_sync_to_async
from channels.exceptions import DenyConnection from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer
from rest_framework.exceptions import AuthenticationFailed from rest_framework.exceptions import AuthenticationFailed
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.api.authentication import bearer_auth from authentik.api.authentication import bearer_auth
from authentik.core.models import User
LOGGER = get_logger() LOGGER = get_logger()
class AuthJsonConsumer(JsonWebsocketConsumer): class TokenOutpostMiddleware:
"""Authorize a client with a token""" """Authorize a client with a token"""
user: User def __init__(self, inner):
self.inner = inner
def connect(self): async def __call__(self, scope, receive, send):
headers = dict(self.scope["headers"]) scope = dict(scope)
await self.auth(scope)
return await self.inner(scope, receive, send)
@database_sync_to_async
def auth(self, scope):
"""Authenticate request from header"""
headers = dict(scope["headers"])
if b"authorization" not in headers: if b"authorization" not in headers:
LOGGER.warning("WS Request without authorization header") LOGGER.warning("WS Request without authorization header")
raise DenyConnection() raise DenyConnection()
@ -32,4 +39,4 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
LOGGER.warning("Failed to authenticate", exc=exc) LOGGER.warning("Failed to authenticate", exc=exc)
raise DenyConnection() raise DenyConnection()
self.user = user scope["user"] = user

View File

@ -22,6 +22,7 @@ class InterfaceView(TemplateView):
kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}" kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}"
kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}" kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}"
kwargs["build"] = get_build_hash() kwargs["build"] = get_build_hash()
kwargs["url_kwargs"] = self.kwargs
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)

View File

@ -1,6 +1,8 @@
"""Enterprise license policies""" """Enterprise license policies"""
from typing import Optional from typing import Optional
from django.utils.translation import gettext_lazy as _
from authentik.core.models import User, UserTypes from authentik.core.models import User, UserTypes
from authentik.enterprise.models import LicenseKey from authentik.enterprise.models import LicenseKey
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
@ -13,10 +15,10 @@ class EnterprisePolicyAccessView(PolicyAccessView):
def check_license(self): def check_license(self):
"""Check license""" """Check license"""
if not LicenseKey.get_total().is_valid(): if not LicenseKey.get_total().is_valid():
return False return PolicyResult(False, _("Enterprise required to access this feature."))
if self.request.user.type != UserTypes.INTERNAL: if self.request.user.type != UserTypes.INTERNAL:
return False return PolicyResult(False, _("Feature only accessible for internal users."))
return True return PolicyResult(True)
def user_has_access(self, user: Optional[User] = None) -> PolicyResult: def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
user = user or self.request.user user = user or self.request.user
@ -24,7 +26,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
request.http_request = self.request request.http_request = self.request
result = super().user_has_access(user) result = super().user_has_access(user)
enterprise_result = self.check_license() enterprise_result = self.check_license()
if not enterprise_result: if not enterprise_result.passing:
return enterprise_result return enterprise_result
return result return result

View File

@ -0,0 +1,133 @@
"""RAC Provider API Views"""
from typing import Optional
from django.core.cache import cache
from django.db.models import QuerySet
from django.urls import reverse
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework.fields import SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.core.api.used_by import UsedByMixin
from authentik.core.models import Provider
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
from authentik.enterprise.providers.rac.models import Endpoint
from authentik.policies.engine import PolicyEngine
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger()
def user_endpoint_cache_key(user_pk: str) -> str:
"""Cache key where endpoint list for user is saved"""
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
class EndpointSerializer(ModelSerializer):
"""Endpoint Serializer"""
provider_obj = RACProviderSerializer(source="provider", read_only=True)
launch_url = SerializerMethodField()
def get_launch_url(self, endpoint: Endpoint) -> Optional[str]:
"""Build actual launch URL (the provider itself does not have one, just
individual endpoints)"""
try:
# pylint: disable=no-member
return reverse(
"authentik_providers_rac:start",
kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk},
)
except Provider.application.RelatedObjectDoesNotExist:
return None
class Meta:
model = Endpoint
fields = [
"pk",
"name",
"provider",
"provider_obj",
"protocol",
"host",
"settings",
"property_mappings",
"auth_mode",
"launch_url",
]
class EndpointViewSet(UsedByMixin, ModelViewSet):
"""Endpoint Viewset"""
queryset = Endpoint.objects.all()
serializer_class = EndpointSerializer
filterset_fields = ["name", "provider"]
search_fields = ["name", "protocol"]
ordering = ["name", "protocol"]
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
"""Custom filter_queryset method which ignores guardian, but still supports sorting"""
for backend in list(self.filter_backends):
if backend == ObjectFilter:
continue
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
def _get_allowed_endpoints(self, queryset: QuerySet) -> list[Endpoint]:
endpoints = []
for endpoint in queryset:
engine = PolicyEngine(endpoint, self.request.user, self.request)
engine.build()
if engine.passing:
endpoints.append(endpoint)
return endpoints
@extend_schema(
parameters=[
OpenApiParameter(
"search",
OpenApiTypes.STR,
),
OpenApiParameter(
name="superuser_full_list",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.BOOL,
),
],
responses={
200: EndpointSerializer(many=True),
400: OpenApiResponse(description="Bad request"),
},
)
def list(self, request: Request, *args, **kwargs) -> Response:
"""List accessible endpoints"""
should_cache = request.GET.get("search", "") == ""
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
if superuser_full_list and request.user.is_superuser:
return super().list(request)
queryset = self._filter_queryset_for_list(self.get_queryset())
self.paginate_queryset(queryset)
allowed_endpoints = []
if not should_cache:
allowed_endpoints = self._get_allowed_endpoints(queryset)
if should_cache:
allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
if not allowed_endpoints:
LOGGER.debug("Caching allowed endpoint list")
allowed_endpoints = self._get_allowed_endpoints(queryset)
cache.set(
user_endpoint_cache_key(self.request.user.pk),
allowed_endpoints,
timeout=86400,
)
serializer = self.get_serializer(allowed_endpoints, many=True)
return self.get_paginated_response(serializer.data)

View File

@ -0,0 +1,35 @@
"""RAC Provider API Views"""
from rest_framework.fields import CharField
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField
from authentik.enterprise.providers.rac.models import RACPropertyMapping
class RACPropertyMappingSerializer(PropertyMappingSerializer):
"""RACPropertyMapping Serializer"""
static_settings = JSONDictField()
expression = CharField(allow_blank=True, required=False)
def validate_expression(self, expression: str) -> str:
"""Test Syntax"""
if expression == "":
return expression
return super().validate_expression(expression)
class Meta:
model = RACPropertyMapping
fields = PropertyMappingSerializer.Meta.fields + ["static_settings"]
class RACPropertyMappingViewSet(UsedByMixin, ModelViewSet):
"""RACPropertyMapping Viewset"""
queryset = RACPropertyMapping.objects.all()
serializer_class = RACPropertyMappingSerializer
search_fields = ["name"]
ordering = ["name"]
filterset_fields = ["name", "managed"]

View File

@ -0,0 +1,31 @@
"""RAC Provider API Views"""
from rest_framework.fields import CharField, ListField
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.providers.rac.models import RACProvider
class RACProviderSerializer(ProviderSerializer):
"""RACProvider Serializer"""
outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
class Meta:
model = RACProvider
fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
class RACProviderViewSet(UsedByMixin, ModelViewSet):
"""RACProvider Viewset"""
queryset = RACProvider.objects.all()
serializer_class = RACProviderSerializer
filterset_fields = {
"application": ["isnull"],
"name": ["iexact"],
}
search_fields = ["name"]
ordering = ["name"]

View File

@ -0,0 +1,17 @@
"""RAC app config"""
from authentik.blueprints.apps import ManagedAppConfig
class AuthentikEnterpriseProviderRAC(ManagedAppConfig):
"""authentik enterprise rac app config"""
name = "authentik.enterprise.providers.rac"
label = "authentik_providers_rac"
verbose_name = "authentik Enterprise.Providers.RAC"
default = True
mountpoint = ""
ws_mountpoint = "authentik.enterprise.providers.rac.urls"
def reconcile_load_rac_signals(self):
"""Load rac signals"""
self.import_module("authentik.enterprise.providers.rac.signals")

View File

@ -0,0 +1,163 @@
"""RAC Client consumer"""
from asgiref.sync import async_to_sync
from channels.db import database_sync_to_async
from channels.exceptions import ChannelFull, DenyConnection
from channels.generic.websocket import AsyncWebsocketConsumer
from django.http.request import QueryDict
from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.providers.rac.models import ConnectionToken, RACProvider
from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE
from authentik.outposts.models import Outpost, OutpostState, OutpostType
# Global broadcast group, which messages are sent to when the outpost connects back
# to authentik for a specific connection
# The `RACClientConsumer` consumer adds itself to this group on connection,
# and removes itself once it has been assigned a specific outpost channel
RAC_CLIENT_GROUP = "group_enterprise_rac_client"
# A group for all connections in a given authentik session ID
# A disconnect message is sent to this group when the session expires/is deleted
RAC_CLIENT_GROUP_SESSION = "group_enterprise_rac_client_%(session)s"
# A group for all connections with a specific token, which in almost all cases
# is just one connection, however this is used to disconnect the connection
# when the token is deleted
RAC_CLIENT_GROUP_TOKEN = "group_enterprise_rac_token_%(token)s" # nosec
# Step 1: Client connects to this websocket endpoint
# Step 2: We prepare all the connection args for Guac
# Step 3: Send a websocket message to a single outpost that has this provider assigned
# (Currently sending to all of them)
# (Should probably do different load balancing algorithms)
# Step 4: Outpost creates a websocket connection back to authentik
# with /ws/outpost_rac/<our_channel_id>/
# Step 5: This consumer transfers data between the two channels
class RACClientConsumer(AsyncWebsocketConsumer):
"""RAC client consumer the browser connects to"""
dest_channel_id: str = ""
provider: RACProvider
token: ConnectionToken
logger: BoundLogger
async def connect(self):
await self.accept("guacamole")
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
await self.channel_layer.group_add(
RAC_CLIENT_GROUP_SESSION % {"session": self.scope["session"].session_key},
self.channel_name,
)
await self.init_outpost_connection()
async def disconnect(self, code):
self.logger.debug("Disconnecting")
# Tell the outpost we're disconnecting
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.disconnect",
},
)
@database_sync_to_async
def init_outpost_connection(self):
"""Initialize guac connection settings"""
self.token = ConnectionToken.filter_not_expired(
token=self.scope["url_route"]["kwargs"]["token"]
).first()
if not self.token:
raise DenyConnection()
self.provider = self.token.provider
params = self.token.get_settings()
self.logger = get_logger().bind(
endpoint=self.token.endpoint.name, user=self.scope["user"].username
)
msg = {
"type": "event.provider.specific",
"sub_type": "init_connection",
"dest_channel_id": self.channel_name,
"params": params,
"protocol": self.token.endpoint.protocol,
}
query = QueryDict(self.scope["query_string"].decode())
for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
value = query.get(key, None)
if not value:
continue
msg[key] = str(value)
outposts = Outpost.objects.filter(
type=OutpostType.RAC,
providers__in=[self.provider],
)
if not outposts.exists():
self.logger.warning("Provider has no outpost")
raise DenyConnection()
for outpost in outposts:
# Sort all states for the outpost by connection count
states = sorted(
OutpostState.for_outpost(outpost),
key=lambda state: int(state.args.get("active_connections", 0)),
)
if len(states) < 1:
continue
self.logger.debug("Sending out connection broadcast")
async_to_sync(self.channel_layer.group_send)(
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
msg,
)
async def receive(self, text_data=None, bytes_data=None):
"""Mirror data received from client to the dest_channel_id
which is the channel talking to guacd"""
if self.dest_channel_id == "":
return
if self.token.is_expired:
await self.event_disconnect({"reason": "token_expiry"})
return
try:
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.send",
"text_data": text_data,
"bytes_data": bytes_data,
},
)
except ChannelFull:
pass
async def event_outpost_connected(self, event: dict):
"""Handle event broadcasted from outpost consumer, and check if they
created a connection for us"""
outpost_channel = event.get("outpost_channel")
if event.get("client_channel") != self.channel_name:
return
if self.dest_channel_id != "":
# We've already selected an outpost channel, so tell the other channel to disconnect
# This should never happen since we remove ourselves from the broadcast group
await self.channel_layer.send(
outpost_channel,
{
"type": "event.disconnect",
},
)
return
self.logger.debug("Connected to a single outpost instance")
self.dest_channel_id = outpost_channel
# Since we have a specific outpost channel now, we can remove
# ourselves from the global broadcast group
await self.channel_layer.group_discard(RAC_CLIENT_GROUP, self.channel_name)
async def event_send(self, event: dict):
"""Handler called by outpost websocket that sends data to this specific
client connection"""
if self.token.is_expired:
await self.event_disconnect({"reason": "token_expiry"})
return
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
async def event_disconnect(self, event: dict):
"""Disconnect when the session ends"""
self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
await self.close()

View File

@ -0,0 +1,48 @@
"""RAC consumer"""
from channels.exceptions import ChannelFull
from channels.generic.websocket import AsyncWebsocketConsumer
from authentik.enterprise.providers.rac.consumer_client import RAC_CLIENT_GROUP
class RACOutpostConsumer(AsyncWebsocketConsumer):
"""Consumer the outpost connects to, to send specific data back to a client connection"""
dest_channel_id: str
async def connect(self):
self.dest_channel_id = self.scope["url_route"]["kwargs"]["channel"]
await self.accept()
await self.channel_layer.group_send(
RAC_CLIENT_GROUP,
{
"type": "event.outpost.connected",
"outpost_channel": self.channel_name,
"client_channel": self.dest_channel_id,
},
)
async def receive(self, text_data=None, bytes_data=None):
"""Mirror data received from guacd running in the outpost
to the dest_channel_id which is the channel talking to the browser"""
try:
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.send",
"text_data": text_data,
"bytes_data": bytes_data,
},
)
except ChannelFull:
pass
async def event_send(self, event: dict):
"""Handler called by client websocket that sends data to this specific
outpost connection"""
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
async def event_disconnect(self, event: dict):
"""Tell outpost we're about to disconnect"""
await self.send(text_data="0.authentik.disconnect")
await self.close()

View File

@ -0,0 +1,11 @@
"""RAC Provider Docker Controller"""
from authentik.outposts.controllers.docker import DockerController
from authentik.outposts.models import DockerServiceConnection, Outpost
class RACDockerController(DockerController):
"""RAC Provider Docker Controller"""
def __init__(self, outpost: Outpost, connection: DockerServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []

View File

@ -0,0 +1,13 @@
"""RAC Provider Kubernetes Controller"""
from authentik.outposts.controllers.k8s.service import ServiceReconciler
from authentik.outposts.controllers.kubernetes import KubernetesController
from authentik.outposts.models import KubernetesServiceConnection, Outpost
class RACKubernetesController(KubernetesController):
"""RAC Provider Kubernetes Controller"""
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []
del self.reconcilers[ServiceReconciler.reconciler_name()]

View File

@ -0,0 +1,164 @@
# Generated by Django 4.2.8 on 2023-12-29 15:58
import uuid
import django.db.models.deletion
from django.db import migrations, models
import authentik.core.models
import authentik.lib.utils.time
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_policies", "0011_policybinding_failure_result_and_more"),
("authentik_core", "0032_group_roles"),
]
operations = [
migrations.CreateModel(
name="RACPropertyMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
("static_settings", models.JSONField(default=dict)),
],
options={
"verbose_name": "RAC Property Mapping",
"verbose_name_plural": "RAC Property Mappings",
},
bases=("authentik_core.propertymapping",),
),
migrations.CreateModel(
name="RACProvider",
fields=[
(
"provider_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.provider",
),
),
("settings", models.JSONField(default=dict)),
(
"auth_mode",
models.TextField(
choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt"
),
),
(
"connection_expiry",
models.TextField(
default="hours=8",
help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)",
validators=[authentik.lib.utils.time.timedelta_string_validator],
),
),
],
options={
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
},
bases=("authentik_core.provider",),
),
migrations.CreateModel(
name="Endpoint",
fields=[
(
"policybindingmodel_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_policies.policybindingmodel",
),
),
("name", models.TextField()),
("host", models.TextField()),
(
"protocol",
models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]),
),
("settings", models.JSONField(default=dict)),
(
"auth_mode",
models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]),
),
(
"property_mappings",
models.ManyToManyField(
blank=True, default=None, to="authentik_core.propertymapping"
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.racprovider",
),
),
],
options={
"verbose_name": "RAC Endpoint",
"verbose_name_plural": "RAC Endpoints",
},
bases=("authentik_policies.policybindingmodel", models.Model),
),
migrations.CreateModel(
name="ConnectionToken",
fields=[
(
"expires",
models.DateTimeField(default=authentik.core.models.default_token_duration),
),
("expiring", models.BooleanField(default=True)),
(
"connection_token_uuid",
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
),
("token", models.TextField(default=authentik.core.models.default_token_key)),
("settings", models.JSONField(default=dict)),
(
"endpoint",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.endpoint",
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.racprovider",
),
),
(
"session",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_core.authenticatedsession",
),
),
],
options={
"abstract": False,
},
),
]

View File

@ -0,0 +1,191 @@
"""RAC Models"""
from typing import Optional
from uuid import uuid4
from deepmerge import always_merger
from django.db import models
from django.db.models import QuerySet
from django.utils.translation import gettext as _
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key
from authentik.events.models import Event, EventAction
from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_string_validator
from authentik.policies.models import PolicyBindingModel
LOGGER = get_logger()
class Protocols(models.TextChoices):
"""Supported protocols"""
RDP = "rdp"
VNC = "vnc"
SSH = "ssh"
class AuthenticationMode(models.TextChoices):
"""Authentication modes"""
STATIC = "static"
PROMPT = "prompt"
class RACProvider(Provider):
"""Remotely access computers/servers"""
settings = models.JSONField(default=dict)
auth_mode = models.TextField(
choices=AuthenticationMode.choices, default=AuthenticationMode.PROMPT
)
connection_expiry = models.TextField(
default="hours=8",
validators=[timedelta_string_validator],
help_text=_(
"Determines how long a session lasts. Default of 0 means "
"that the sessions lasts until the browser is closed. "
"(Format: hours=-1;minutes=-2;seconds=-3)"
),
)
@property
def launch_url(self) -> Optional[str]:
"""URL to this provider and initiate authorization for the user.
Can return None for providers that are not URL-based"""
return "goauthentik.io://providers/rac/launch"
@property
def component(self) -> str:
return "ak-provider-rac-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
return RACProviderSerializer
class Meta:
verbose_name = _("RAC Provider")
verbose_name_plural = _("RAC Providers")
class Endpoint(SerializerModel, PolicyBindingModel):
"""Remote-accessible endpoint"""
name = models.TextField()
host = models.TextField()
protocol = models.TextField(choices=Protocols.choices)
settings = models.JSONField(default=dict)
auth_mode = models.TextField(choices=AuthenticationMode.choices)
provider = models.ForeignKey("RACProvider", on_delete=models.CASCADE)
property_mappings = models.ManyToManyField(
"authentik_core.PropertyMapping", default=None, blank=True
)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer
return EndpointSerializer
def __str__(self):
return f"RAC Endpoint {self.name}"
class Meta:
verbose_name = _("RAC Endpoint")
verbose_name_plural = _("RAC Endpoints")
class RACPropertyMapping(PropertyMapping):
"""Configure settings for remote access endpoints."""
static_settings = models.JSONField(default=dict)
@property
def component(self) -> str:
return "ak-property-mapping-rac-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.rac.api.property_mappings import (
RACPropertyMappingSerializer,
)
return RACPropertyMappingSerializer
class Meta:
verbose_name = _("RAC Property Mapping")
verbose_name_plural = _("RAC Property Mappings")
class ConnectionToken(ExpiringModel):
"""Token for a single connection to a specified endpoint"""
connection_token_uuid = models.UUIDField(default=uuid4, primary_key=True)
provider = models.ForeignKey(RACProvider, on_delete=models.CASCADE)
endpoint = models.ForeignKey(Endpoint, on_delete=models.CASCADE)
token = models.TextField(default=default_token_key)
settings = models.JSONField(default=dict)
session = models.ForeignKey("authentik_core.AuthenticatedSession", on_delete=models.CASCADE)
def get_settings(self) -> dict:
"""Get settings"""
default_settings = {}
if ":" in self.endpoint.host:
host, _, port = self.endpoint.host.partition(":")
default_settings["hostname"] = host
default_settings["port"] = str(port)
else:
default_settings["hostname"] = self.endpoint.host
default_settings["client-name"] = "authentik"
# default_settings["enable-drive"] = "true"
# default_settings["drive-name"] = "authentik"
settings = {}
always_merger.merge(settings, default_settings)
always_merger.merge(settings, self.endpoint.provider.settings)
always_merger.merge(settings, self.endpoint.settings)
always_merger.merge(settings, self.settings)
def mapping_evaluator(mappings: QuerySet):
for mapping in mappings:
mapping: RACPropertyMapping
if len(mapping.static_settings) > 0:
always_merger.merge(settings, mapping.static_settings)
continue
try:
mapping_settings = mapping.evaluate(
self.session.user, None, endpoint=self.endpoint, provider=self.provider
)
always_merger.merge(settings, mapping_settings)
except PropertyMappingExpressionException as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping: '{mapping.name}'",
provider=self.provider,
mapping=mapping,
).set_user(self.session.user).save()
LOGGER.warning("Failed to evaluate property mapping", exc=exc)
mapping_evaluator(
RACPropertyMapping.objects.filter(provider__in=[self.provider]).order_by("name")
)
mapping_evaluator(
RACPropertyMapping.objects.filter(endpoint__in=[self.endpoint]).order_by("name")
)
settings["drive-path"] = f"/tmp/connection/{self.token}" # nosec
settings["create-drive-path"] = "true"
# Ensure all values of the settings dict are strings
for key, value in settings.items():
if isinstance(value, str):
continue
# Special case for bools
if isinstance(value, bool):
settings[key] = str(value).lower()
continue
settings[key] = str(value)
return settings

View File

@ -0,0 +1,54 @@
"""RAC Signals"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models import Model
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import User
from authentik.enterprise.providers.rac.api.endpoints import user_endpoint_cache_key
from authentik.enterprise.providers.rac.consumer_client import (
RAC_CLIENT_GROUP_SESSION,
RAC_CLIENT_GROUP_TOKEN,
)
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint
@receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Disconnect any open RAC connections"""
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_SESSION
% {
"session": request.session.session_key,
},
{"type": "event.disconnect", "reason": "session_logout"},
)
@receiver(pre_delete, sender=ConnectionToken)
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
"""Disconnect session when connection token is deleted"""
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_TOKEN
% {
"token": instance.token,
},
{"type": "event.disconnect", "reason": "token_delete"},
)
@receiver(post_save, sender=Endpoint)
def post_save_application(sender: type[Model], instance, created: bool, **_):
"""Clear user's application cache upon application creation"""
if not created: # pragma: no cover
return
# Delete user endpoint cache
keys = cache.keys(user_endpoint_cache_key("*"))
cache.delete_many(keys)

View File

@ -0,0 +1,18 @@
{% extends "base/skeleton.html" %}
{% load static %}
{% block head %}
<script src="{% static 'dist/enterprise/rac/index.js' %}?version={{ version }}" type="module"></script>
<meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)">
<meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)">
<link rel="icon" href="{{ tenant.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
{% include "base/header_js.html" %}
{% endblock %}
{% block body %}
<ak-rac token="{{ url_kwargs.token }}" endpointName="{{ token.endpoint.name }}">
<ak-loading></ak-loading>
</ak-rac>
{% endblock %}

View File

@ -0,0 +1,168 @@
"""Test Endpoints API"""
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
class TestEndpointsAPI(APITestCase):
"""Test endpoints API"""
def setUp(self) -> None:
self.user = create_test_admin_user()
self.provider = RACProvider.objects.create(
name=generate_id(),
)
self.app = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
self.allowed = Endpoint.objects.create(
name=f"a-{generate_id()}",
host=generate_id(),
protocol=Protocols.RDP,
provider=self.provider,
)
self.denied = Endpoint.objects.create(
name=f"b-{generate_id()}",
host=generate_id(),
protocol=Protocols.RDP,
provider=self.provider,
)
PolicyBinding.objects.create(
target=self.denied,
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
order=0,
)
def test_list(self):
"""Test list operation without superuser_full_list"""
self.client.force_login(self.user)
response = self.client.get(reverse("authentik_api:endpoint-list"))
self.assertJSONEqual(
response.content.decode(),
{
"pagination": {
"next": 0,
"previous": 0,
"count": 2,
"current": 1,
"total_pages": 1,
"start_index": 1,
"end_index": 2,
},
"results": [
{
"pk": str(self.allowed.pk),
"name": self.allowed.name,
"provider": self.provider.pk,
"provider_obj": {
"pk": self.provider.pk,
"name": self.provider.name,
"authentication_flow": None,
"authorization_flow": None,
"property_mappings": [],
"connection_expiry": "hours=8",
"component": "ak-provider-rac-form",
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",
"settings": {},
"outpost_set": [],
},
"protocol": "rdp",
"host": self.allowed.host,
"settings": {},
"property_mappings": [],
"auth_mode": "",
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
},
],
},
)
def test_list_superuser_full_list(self):
"""Test list operation with superuser_full_list"""
self.client.force_login(self.user)
response = self.client.get(
reverse("authentik_api:endpoint-list") + "?superuser_full_list=true"
)
self.assertJSONEqual(
response.content.decode(),
{
"pagination": {
"next": 0,
"previous": 0,
"count": 2,
"current": 1,
"total_pages": 1,
"start_index": 1,
"end_index": 2,
},
"results": [
{
"pk": str(self.allowed.pk),
"name": self.allowed.name,
"provider": self.provider.pk,
"provider_obj": {
"pk": self.provider.pk,
"name": self.provider.name,
"authentication_flow": None,
"authorization_flow": None,
"property_mappings": [],
"component": "ak-provider-rac-form",
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"connection_expiry": "hours=8",
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",
"settings": {},
"outpost_set": [],
},
"protocol": "rdp",
"host": self.allowed.host,
"settings": {},
"property_mappings": [],
"auth_mode": "",
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
},
{
"pk": str(self.denied.pk),
"name": self.denied.name,
"provider": self.provider.pk,
"provider_obj": {
"pk": self.provider.pk,
"name": self.provider.name,
"authentication_flow": None,
"authorization_flow": None,
"property_mappings": [],
"component": "ak-provider-rac-form",
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"connection_expiry": "hours=8",
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",
"settings": {},
"outpost_set": [],
},
"protocol": "rdp",
"host": self.denied.host,
"settings": {},
"property_mappings": [],
"auth_mode": "",
"launch_url": f"/application/rac/{self.app.slug}/{str(self.denied.pk)}/",
},
],
},
)

View File

@ -0,0 +1,144 @@
"""Test RAC Models"""
from django.test import TransactionTestCase
from authentik.core.models import Application, AuthenticatedSession
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.providers.rac.models import (
ConnectionToken,
Endpoint,
Protocols,
RACPropertyMapping,
RACProvider,
)
from authentik.lib.generators import generate_id
class TestModels(TransactionTestCase):
"""Test RAC Models"""
def setUp(self):
self.user = create_test_admin_user()
self.provider = RACProvider.objects.create(
name=generate_id(),
)
self.app = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
self.endpoint = Endpoint.objects.create(
name=generate_id(),
host=f"{generate_id()}:1324",
protocol=Protocols.RDP,
provider=self.provider,
)
def test_settings_merge(self):
"""Test settings merge"""
token = ConnectionToken.objects.create(
provider=self.provider,
endpoint=self.endpoint,
session=AuthenticatedSession.objects.create(
user=self.user,
session_key=generate_id(),
),
)
path = f"/tmp/connection/{token.token}" # nosec
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
},
)
# Set settings in provider
self.provider.settings = {"level": "provider"}
self.provider.save()
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "provider",
},
)
# Set settings in endpoint
self.endpoint.settings = {
"level": "endpoint",
}
self.endpoint.save()
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "endpoint",
},
)
# Set settings in token
token.settings = {
"level": "token",
}
token.save()
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "token",
},
)
# Set settings in property mapping (provider)
mapping = RACPropertyMapping.objects.create(
name=generate_id(),
expression="""return {
"level": "property_mapping_provider"
}""",
)
self.provider.property_mappings.add(mapping)
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "property_mapping_provider",
},
)
# Set settings in property mapping (endpoint)
mapping = RACPropertyMapping.objects.create(
name=generate_id(),
static_settings={
"level": "property_mapping_endpoint",
"foo": True,
"bar": 6,
},
)
self.endpoint.property_mappings.add(mapping)
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "property_mapping_endpoint",
"foo": "true",
"bar": "6",
},
)

View File

@ -0,0 +1,132 @@
"""RAC Views tests"""
from datetime import timedelta
from json import loads
from time import mktime
from unittest.mock import MagicMock, patch
from django.urls import reverse
from django.utils.timezone import now
from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.enterprise.models import License, LicenseKey
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
from authentik.lib.generators import generate_id
from authentik.policies.denied import AccessDeniedResponse
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
class TestRACViews(APITestCase):
"""RAC Views tests"""
def setUp(self):
self.user = create_test_admin_user()
self.flow = create_test_flow()
self.provider = RACProvider.objects.create(name=generate_id(), authorization_flow=self.flow)
self.app = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
self.endpoint = Endpoint.objects.create(
name=generate_id(),
host=f"{generate_id()}:1324",
protocol=Protocols.RDP,
provider=self.provider,
)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_no_policy(self):
"""Test request"""
License.objects.create(key=generate_id())
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertEqual(response.status_code, 302)
flow_response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
body = loads(flow_response.content)
next_url = body["to"]
final_response = self.client.get(next_url)
self.assertEqual(final_response.status_code, 200)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_app_deny(self):
"""Test request (deny on app level)"""
PolicyBinding.objects.create(
target=self.app,
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
order=0,
)
License.objects.create(key=generate_id())
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertIsInstance(response, AccessDeniedResponse)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_endpoint_deny(self):
"""Test request (deny on endpoint level)"""
PolicyBinding.objects.create(
target=self.endpoint,
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
order=0,
)
License.objects.create(key=generate_id())
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertEqual(response.status_code, 302)
flow_response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
body = loads(flow_response.content)
self.assertEqual(body["component"], "ak-stage-access-denied")

View File

@ -0,0 +1,47 @@
"""rac urls"""
from channels.auth import AuthMiddleware
from channels.sessions import CookieMiddleware
from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from authentik.core.channels import TokenOutpostMiddleware
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet
from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet
from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer
from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer
from authentik.enterprise.providers.rac.views import RACInterface, RACStartView
from authentik.root.asgi_middleware import SessionMiddleware
from authentik.root.middleware import ChannelsLoggingMiddleware
urlpatterns = [
path(
"application/rac/<slug:app>/<uuid:endpoint>/",
ensure_csrf_cookie(RACStartView.as_view()),
name="start",
),
path(
"if/rac/<str:token>/",
ensure_csrf_cookie(RACInterface.as_view()),
name="if-rac",
),
]
websocket_urlpatterns = [
path(
"ws/rac/<str:token>/",
ChannelsLoggingMiddleware(
CookieMiddleware(SessionMiddleware(AuthMiddleware(RACClientConsumer.as_asgi())))
),
),
path(
"ws/outpost_rac/<str:channel>/",
ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())),
),
]
api_urlpatterns = [
("providers/rac", RACProviderViewSet),
("propertymappings/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet),
]

View File

@ -0,0 +1,115 @@
"""RAC Views"""
from typing import Any
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, redirect
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.models import Application, AuthenticatedSession
from authentik.core.views.interface import InterfaceView
from authentik.enterprise.policy import EnterprisePolicyAccessView
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint, RACProvider
from authentik.flows.challenge import RedirectChallenge
from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import in_memory_stage
from authentik.flows.planner import FlowPlanner
from authentik.flows.stage import RedirectStage
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.utils.time import timedelta_from_string
from authentik.lib.utils.urls import redirect_with_qs
from authentik.policies.engine import PolicyEngine
class RACStartView(EnterprisePolicyAccessView):
"""Start a RAC connection by checking access and creating a connection token"""
endpoint: Endpoint
def resolve_provider_application(self):
self.application = get_object_or_404(Application, slug=self.kwargs["app"])
# Endpoint permissions are validated in the RACFinalStage below
self.endpoint = get_object_or_404(Endpoint, pk=self.kwargs["endpoint"])
self.provider = RACProvider.objects.get(application=self.application)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Start flow planner for RAC provider"""
planner = FlowPlanner(self.provider.authorization_flow)
planner.allow_empty_flows = True
try:
plan = planner.plan(self.request)
except FlowNonApplicableException:
raise Http404
plan.insert_stage(
in_memory_stage(
RACFinalStage,
endpoint=self.endpoint,
provider=self.provider,
)
)
request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
request.GET,
flow_slug=self.provider.authorization_flow.slug,
)
class RACInterface(InterfaceView):
"""Start RAC connection"""
template_name = "if/rac.html"
token: ConnectionToken
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# Early sanity check to ensure token still exists
token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first()
if not token:
return redirect("authentik_core:if-user")
self.token = token
return super().dispatch(request, *args, **kwargs)
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs["token"] = self.token
return super().get_context_data(**kwargs)
class RACFinalStage(RedirectStage):
"""RAC Connection final stage, set the connection token in the stage"""
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
endpoint: Endpoint = self.executor.current_stage.endpoint
engine = PolicyEngine(endpoint, self.request.user, self.request)
engine.use_cache = False
engine.build()
passing = engine.result
if not passing.passing:
return self.executor.stage_invalid(", ".join(passing.messages))
return super().dispatch(request, *args, **kwargs)
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
endpoint: Endpoint = self.executor.current_stage.endpoint
provider: RACProvider = self.executor.current_stage.provider
token = ConnectionToken.objects.create(
provider=provider,
endpoint=endpoint,
settings=self.executor.plan.context.get("connection_settings", {}),
session=AuthenticatedSession.objects.filter(
session_key=self.request.session.session_key
).first(),
expires=now() + timedelta_from_string(provider.connection_expiry),
expiring=True,
)
setattr(
self.executor.current_stage,
"destination",
self.request.build_absolute_uri(
reverse(
"authentik_providers_rac:if-rac",
kwargs={
"token": str(token.token),
},
)
),
)
return super().get_challenge(*args, **kwargs)

View File

@ -10,3 +10,7 @@ CELERY_BEAT_SCHEDULE = {
"options": {"queue": "authentik_scheduled"}, "options": {"queue": "authentik_scheduled"},
} }
} }
INSTALLED_APPS = [
"authentik.enterprise.providers.rac",
]

View File

@ -6,6 +6,7 @@ import django_filters
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.fields.json import KeyTextTransform, KeyTransform from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.functions import ExtractDay, ExtractHour from django.db.models.functions import ExtractDay, ExtractHour
from django.db.models.query_utils import Q
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema from drf_spectacular.utils import OpenApiParameter, extend_schema
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
@ -87,7 +88,12 @@ class EventsFilter(django_filters.FilterSet):
we need to remove the dashes that a client may send. We can't use a we need to remove the dashes that a client may send. We can't use a
UUIDField for this, as some models might not have a UUID PK""" UUIDField for this, as some models might not have a UUID PK"""
value = str(value).replace("-", "") value = str(value).replace("-", "")
return queryset.filter(context__model__pk=value) query = Q(context__model__pk=value)
try:
query |= Q(context__model__pk=int(value))
except ValueError:
pass
return queryset.filter(query)
class Meta: class Meta:
model = Event model = Event

View File

@ -1,4 +1,5 @@
"""Event API tests""" """Event API tests"""
from json import loads
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
@ -11,6 +12,9 @@ from authentik.events.models import (
NotificationSeverity, NotificationSeverity,
TransportMode, TransportMode,
) )
from authentik.events.utils import model_to_dict
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import OAuth2Provider
class TestEventsAPI(APITestCase): class TestEventsAPI(APITestCase):
@ -20,6 +24,25 @@ class TestEventsAPI(APITestCase):
self.user = create_test_admin_user() self.user = create_test_admin_user()
self.client.force_login(self.user) self.client.force_login(self.user)
def test_filter_model_pk_int(self):
"""Test event list with context_model_pk and integer PKs"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
)
event = Event.new(EventAction.MODEL_CREATED, model=model_to_dict(provider))
event.save()
response = self.client.get(
reverse("authentik_api:event-list"),
data={
"context_model_pk": provider.pk,
"context_model_app": "authentik_providers_oauth2",
"context_model_name": "oauth2provider",
},
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["pagination"]["count"], 1)
def test_top_n(self): def test_top_n(self):
"""Test top_per_user""" """Test top_per_user"""
event = Event.new(EventAction.AUTHORIZE_APPLICATION) event = Event.new(EventAction.AUTHORIZE_APPLICATION)

View File

@ -17,8 +17,9 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField, PassiveSerializer from authentik.core.api.utils import JSONDictField, PassiveSerializer
from authentik.core.models import Provider from authentik.core.models import Provider
from authentik.enterprise.providers.rac.models import RACProvider
from authentik.outposts.api.service_connections import ServiceConnectionSerializer from authentik.outposts.api.service_connections import ServiceConnectionSerializer
from authentik.outposts.apps import MANAGED_OUTPOST from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
from authentik.outposts.models import ( from authentik.outposts.models import (
Outpost, Outpost,
OutpostConfig, OutpostConfig,
@ -47,12 +48,23 @@ class OutpostSerializer(ModelSerializer):
source="service_connection", read_only=True source="service_connection", read_only=True
) )
def validate_name(self, name: str) -> str:
"""Validate name (especially for embedded outpost)"""
if not self.instance:
return name
if self.instance.managed == MANAGED_OUTPOST and name != MANAGED_OUTPOST_NAME:
raise ValidationError("Embedded outpost's name cannot be changed")
if self.instance.name == MANAGED_OUTPOST_NAME:
self.instance.managed = MANAGED_OUTPOST
return name
def validate_providers(self, providers: list[Provider]) -> list[Provider]: def validate_providers(self, providers: list[Provider]) -> list[Provider]:
"""Check that all providers match the type of the outpost""" """Check that all providers match the type of the outpost"""
type_map = { type_map = {
OutpostType.LDAP: LDAPProvider, OutpostType.LDAP: LDAPProvider,
OutpostType.PROXY: ProxyProvider, OutpostType.PROXY: ProxyProvider,
OutpostType.RADIUS: RadiusProvider, OutpostType.RADIUS: RadiusProvider,
OutpostType.RAC: RACProvider,
None: Provider, None: Provider,
} }
for provider in providers: for provider in providers:

View File

@ -15,6 +15,7 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
["outpost", "uid", "version"], ["outpost", "uid", "version"],
) )
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded" MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
class AuthentikOutpostConfig(ManagedAppConfig): class AuthentikOutpostConfig(ManagedAppConfig):
@ -35,14 +36,17 @@ class AuthentikOutpostConfig(ManagedAppConfig):
DockerServiceConnection, DockerServiceConnection,
KubernetesServiceConnection, KubernetesServiceConnection,
Outpost, Outpost,
OutpostConfig,
OutpostType, OutpostType,
) )
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
outpost.managed = MANAGED_OUTPOST
outpost.save()
return
outpost, updated = Outpost.objects.update_or_create( outpost, updated = Outpost.objects.update_or_create(
defaults={ defaults={
"name": "authentik Embedded Outpost",
"type": OutpostType.PROXY, "type": OutpostType.PROXY,
"name": MANAGED_OUTPOST_NAME,
}, },
managed=MANAGED_OUTPOST, managed=MANAGED_OUTPOST,
) )
@ -51,10 +55,4 @@ class AuthentikOutpostConfig(ManagedAppConfig):
outpost.service_connection = KubernetesServiceConnection.objects.first() outpost.service_connection = KubernetesServiceConnection.objects.first()
elif DockerServiceConnection.objects.exists(): elif DockerServiceConnection.objects.exists():
outpost.service_connection = DockerServiceConnection.objects.first() outpost.service_connection = DockerServiceConnection.objects.first()
outpost.config = OutpostConfig(
kubernetes_disabled_components=[
"deployment",
"secret",
]
)
outpost.save() outpost.save()

View File

@ -6,16 +6,18 @@ from typing import Any, Optional
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.exceptions import DenyConnection from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer
from dacite.core import from_dict from dacite.core import from_dict
from dacite.data import Data from dacite.data import Data
from django.http.request import QueryDict
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.core.channels import AuthJsonConsumer
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s" OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
OUTPOST_GROUP_INSTANCE = "group_outpost_%(outpost_pk)s_%(instance)s"
class WebsocketMessageInstruction(IntEnum): class WebsocketMessageInstruction(IntEnum):
@ -42,25 +44,23 @@ class WebsocketMessage:
args: dict[str, Any] = field(default_factory=dict) args: dict[str, Any] = field(default_factory=dict)
class OutpostConsumer(AuthJsonConsumer): class OutpostConsumer(JsonWebsocketConsumer):
"""Handler for Outposts that connect over websockets for health checks and live updates""" """Handler for Outposts that connect over websockets for health checks and live updates"""
outpost: Optional[Outpost] = None outpost: Optional[Outpost] = None
logger: BoundLogger logger: BoundLogger
last_uid: Optional[str] = None instance_uid: Optional[str] = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.logger = get_logger() self.logger = get_logger()
def connect(self): def connect(self):
super().connect()
uuid = self.scope["url_route"]["kwargs"]["pk"] uuid = self.scope["url_route"]["kwargs"]["pk"]
user = self.scope["user"]
outpost = ( outpost = (
get_objects_for_user(self.user, "authentik_outposts.view_outpost") get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
.filter(pk=uuid)
.first()
) )
if not outpost: if not outpost:
raise DenyConnection() raise DenyConnection()
@ -71,13 +71,19 @@ class OutpostConsumer(AuthJsonConsumer):
self.logger.warning("runtime error during accept", exc=exc) self.logger.warning("runtime error during accept", exc=exc)
raise DenyConnection() raise DenyConnection()
self.outpost = outpost self.outpost = outpost
self.last_uid = self.channel_name query = QueryDict(self.scope["query_string"].decode())
self.instance_uid = query.get("instance_uuid", self.channel_name)
async_to_sync(self.channel_layer.group_add)( async_to_sync(self.channel_layer.group_add)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
) )
async_to_sync(self.channel_layer.group_add)(
OUTPOST_GROUP_INSTANCE
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
self.channel_name,
)
GAUGE_OUTPOSTS_CONNECTED.labels( GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name, outpost=self.outpost.name,
uid=self.last_uid, uid=self.instance_uid,
expected=self.outpost.config.kubernetes_replicas, expected=self.outpost.config.kubernetes_replicas,
).inc() ).inc()
@ -86,34 +92,37 @@ class OutpostConsumer(AuthJsonConsumer):
async_to_sync(self.channel_layer.group_discard)( async_to_sync(self.channel_layer.group_discard)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
) )
if self.outpost and self.last_uid: if self.instance_uid:
async_to_sync(self.channel_layer.group_discard)(
OUTPOST_GROUP_INSTANCE
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
self.channel_name,
)
if self.outpost and self.instance_uid:
GAUGE_OUTPOSTS_CONNECTED.labels( GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name, outpost=self.outpost.name,
uid=self.last_uid, uid=self.instance_uid,
expected=self.outpost.config.kubernetes_replicas, expected=self.outpost.config.kubernetes_replicas,
).dec() ).dec()
def receive_json(self, content: Data, **kwargs): def receive_json(self, content: Data, **kwargs):
msg = from_dict(WebsocketMessage, content) msg = from_dict(WebsocketMessage, content)
uid = msg.args.get("uuid", self.channel_name)
self.last_uid = uid
if not self.outpost: if not self.outpost:
raise DenyConnection() raise DenyConnection()
state = OutpostState.for_instance_uid(self.outpost, uid) state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
state.last_seen = datetime.now() state.last_seen = datetime.now()
state.hostname = msg.args.pop("hostname", "") state.hostname = msg.args.pop("hostname", "")
if msg.instruction == WebsocketMessageInstruction.HELLO: if msg.instruction == WebsocketMessageInstruction.HELLO:
state.version = msg.args.pop("version", None) state.version = msg.args.pop("version", None)
state.build_hash = msg.args.pop("buildHash", "") state.build_hash = msg.args.pop("buildHash", "")
state.args = msg.args state.args.update(msg.args)
elif msg.instruction == WebsocketMessageInstruction.ACK: elif msg.instruction == WebsocketMessageInstruction.ACK:
return return
GAUGE_OUTPOSTS_LAST_UPDATE.labels( GAUGE_OUTPOSTS_LAST_UPDATE.labels(
outpost=self.outpost.name, outpost=self.outpost.name,
uid=self.last_uid or "", uid=self.instance_uid or "",
version=state.version or "", version=state.version or "",
).set_to_current_time() ).set_to_current_time()
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5) state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)

View File

@ -43,6 +43,10 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
self.api = AppsV1Api(controller.client) self.api = AppsV1Api(controller.client)
self.outpost = self.controller.outpost self.outpost = self.controller.outpost
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod @staticmethod
def reconciler_name() -> str: def reconciler_name() -> str:
return "deployment" return "deployment"

View File

@ -24,6 +24,10 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
super().__init__(controller) super().__init__(controller)
self.api = CoreV1Api(controller.client) self.api = CoreV1Api(controller.client)
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod @staticmethod
def reconciler_name() -> str: def reconciler_name() -> str:
return "secret" return "secret"

View File

@ -77,7 +77,10 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
@property @property
def noop(self) -> bool: def noop(self) -> bool:
return (not self._crd_exists()) or (self.is_embedded) if not self._crd_exists():
self.logger.debug("CRD doesn't exist")
return True
return self.is_embedded
def _crd_exists(self) -> bool: def _crd_exists(self) -> bool:
"""Check if the Prometheus ServiceMonitor exists""" """Check if the Prometheus ServiceMonitor exists"""

View File

@ -1,5 +1,6 @@
"""k8s utils""" """k8s utils"""
from pathlib import Path from pathlib import Path
from typing import Optional
from kubernetes.client.models.v1_container_port import V1ContainerPort from kubernetes.client.models.v1_container_port import V1ContainerPort
from kubernetes.client.models.v1_service_port import V1ServicePort from kubernetes.client.models.v1_service_port import V1ServicePort
@ -37,9 +38,12 @@ def compare_port(
def compare_ports( def compare_ports(
current: list[V1ServicePort | V1ContainerPort], reference: list[V1ServicePort | V1ContainerPort] current: Optional[list[V1ServicePort | V1ContainerPort]],
reference: Optional[list[V1ServicePort | V1ContainerPort]],
): ):
"""Compare ports of a list""" """Compare ports of a list"""
if not current or not reference:
raise NeedsRecreate()
if len(current) != len(reference): if len(current) != len(reference):
raise NeedsRecreate() raise NeedsRecreate()
for port in reference: for port in reference:

View File

@ -81,7 +81,10 @@ class KubernetesController(BaseController):
def up(self): def up(self):
try: try:
for reconcile_key in self.reconcile_order: for reconcile_key in self.reconcile_order:
reconciler = self.reconcilers[reconcile_key](self) reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
reconciler.up() reconciler.up()
except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc: except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc:
@ -95,7 +98,10 @@ class KubernetesController(BaseController):
all_logs += [f"{reconcile_key.title()}: Disabled"] all_logs += [f"{reconcile_key.title()}: Disabled"]
continue continue
with capture_logs() as logs: with capture_logs() as logs:
reconciler = self.reconcilers[reconcile_key](self) reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
reconciler.up() reconciler.up()
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs] all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
return all_logs return all_logs
@ -105,7 +111,10 @@ class KubernetesController(BaseController):
def down(self): def down(self):
try: try:
for reconcile_key in self.reconcile_order: for reconcile_key in self.reconcile_order:
reconciler = self.reconcilers[reconcile_key](self) reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
self.logger.debug("Tearing down object", name=reconcile_key) self.logger.debug("Tearing down object", name=reconcile_key)
reconciler.down() reconciler.down()
@ -120,7 +129,10 @@ class KubernetesController(BaseController):
all_logs += [f"{reconcile_key.title()}: Disabled"] all_logs += [f"{reconcile_key.title()}: Disabled"]
continue continue
with capture_logs() as logs: with capture_logs() as logs:
reconciler = self.reconcilers[reconcile_key](self) reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
reconciler.down() reconciler.down()
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs] all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
return all_logs return all_logs
@ -130,7 +142,10 @@ class KubernetesController(BaseController):
def get_static_deployment(self) -> str: def get_static_deployment(self) -> str:
documents = [] documents = []
for reconcile_key in self.reconcile_order: for reconcile_key in self.reconcile_order:
reconciler = self.reconcilers[reconcile_key](self) reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
if reconciler.noop: if reconciler.noop:
continue continue
documents.append(reconciler.get_reference_object().to_dict()) documents.append(reconciler.get_reference_object().to_dict())

View File

@ -0,0 +1,25 @@
# Generated by Django 4.2.6 on 2023-10-14 19:23
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_outposts", "0020_alter_outpost_type"),
]
operations = [
migrations.AlterField(
model_name="outpost",
name="type",
field=models.TextField(
choices=[
("proxy", "Proxy"),
("ldap", "Ldap"),
("radius", "Radius"),
("rac", "Rac"),
],
default="proxy",
),
),
]

View File

@ -90,11 +90,12 @@ class OutpostModel(Model):
class OutpostType(models.TextChoices): class OutpostType(models.TextChoices):
"""Outpost types, currently only the reverse proxy is available""" """Outpost types"""
PROXY = "proxy" PROXY = "proxy"
LDAP = "ldap" LDAP = "ldap"
RADIUS = "radius" RADIUS = "radius"
RAC = "rac"
def default_outpost_config(host: Optional[str] = None): def default_outpost_config(host: Optional[str] = None):
@ -459,7 +460,7 @@ class OutpostState:
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState": def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
"""Get state for a single instance""" """Get state for a single instance"""
key = f"{outpost.state_cache_prefix}/{uid}" key = f"{outpost.state_cache_prefix}/{uid}"
default_data = {"uid": uid, "channel_ids": []} default_data = {"uid": uid}
data = cache.get(key, default_data) data = cache.get(key, default_data)
if isinstance(data, str): if isinstance(data, str):
cache.delete(key) cache.delete(key)

View File

@ -17,6 +17,8 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from yaml import safe_load from yaml import safe_load
from authentik.enterprise.providers.rac.controllers.docker import RACDockerController
from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController
from authentik.events.monitored_tasks import ( from authentik.events.monitored_tasks import (
MonitoredTask, MonitoredTask,
TaskResult, TaskResult,
@ -71,6 +73,11 @@ def controller_for_outpost(outpost: Outpost) -> Optional[type[BaseController]]:
return RadiusDockerController return RadiusDockerController
if isinstance(service_connection, KubernetesServiceConnection): if isinstance(service_connection, KubernetesServiceConnection):
return RadiusKubernetesController return RadiusKubernetesController
if outpost.type == OutpostType.RAC:
if isinstance(service_connection, DockerServiceConnection):
return RACDockerController
if isinstance(service_connection, KubernetesServiceConnection):
return RACKubernetesController
return None return None

View File

@ -2,11 +2,13 @@
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.blueprints.tests import reconcile_app
from authentik.core.models import PropertyMapping from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.outposts.api.outposts import OutpostSerializer from authentik.outposts.api.outposts import OutpostSerializer
from authentik.outposts.models import OutpostType, default_outpost_config from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import Outpost, OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
@ -22,7 +24,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
self.user = create_test_admin_user() self.user = create_test_admin_user()
self.client.force_login(self.user) self.client.force_login(self.user)
def test_outpost_validaton(self): @reconcile_app("authentik_outposts")
def test_managed_name_change(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content, {"name": ["Embedded outpost's name cannot be changed"]}
)
@reconcile_app("authentik_outposts")
def test_managed_without_managed(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
embedded_outpost.managed = ""
embedded_outpost.save()
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 200)
embedded_outpost.refresh_from_db()
self.assertEqual(embedded_outpost.managed, MANAGED_OUTPOST)
def test_outpost_validation(self):
"""Test Outpost validation""" """Test Outpost validation"""
valid = OutpostSerializer( valid = OutpostSerializer(
data={ data={

View File

@ -1,6 +1,7 @@
"""Websocket tests""" """Websocket tests"""
from dataclasses import asdict from dataclasses import asdict
from channels.exceptions import DenyConnection
from channels.routing import URLRouter from channels.routing import URLRouter
from channels.testing import WebsocketCommunicator from channels.testing import WebsocketCommunicator
from django.test import TransactionTestCase from django.test import TransactionTestCase
@ -35,6 +36,7 @@ class TestOutpostWS(TransactionTestCase):
communicator = WebsocketCommunicator( communicator = WebsocketCommunicator(
URLRouter(websocket.websocket_urlpatterns), f"/ws/outpost/{self.outpost.pk}/" URLRouter(websocket.websocket_urlpatterns), f"/ws/outpost/{self.outpost.pk}/"
) )
with self.assertRaises(DenyConnection):
connected, _ = await communicator.connect() connected, _ = await communicator.connect()
self.assertFalse(connected) self.assertFalse(connected)

View File

@ -1,6 +1,7 @@
"""Outpost Websocket URLS""" """Outpost Websocket URLS"""
from django.urls import path from django.urls import path
from authentik.core.channels import TokenOutpostMiddleware
from authentik.outposts.api.outposts import OutpostViewSet from authentik.outposts.api.outposts import OutpostViewSet
from authentik.outposts.api.service_connections import ( from authentik.outposts.api.service_connections import (
DockerServiceConnectionViewSet, DockerServiceConnectionViewSet,
@ -11,7 +12,10 @@ from authentik.outposts.consumer import OutpostConsumer
from authentik.root.middleware import ChannelsLoggingMiddleware from authentik.root.middleware import ChannelsLoggingMiddleware
websocket_urlpatterns = [ websocket_urlpatterns = [
path("ws/outpost/<uuid:pk>/", ChannelsLoggingMiddleware(OutpostConsumer.as_asgi())), path(
"ws/outpost/<uuid:pk>/",
ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())),
),
] ]
api_urlpatterns = [ api_urlpatterns = [

View File

@ -40,10 +40,9 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}", f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
) )
CONFIG.set("error_reporting.sample_rate", 0) CONFIG.set("error_reporting.sample_rate", 0)
sentry_init( CONFIG.set("error_reporting.environment", "testing")
environment="testing", CONFIG.set("error_reporting.send_pii", True)
send_default_pii=True, sentry_init()
)
@classmethod @classmethod
def add_arguments(cls, parser: ArgumentParser): def add_arguments(cls, parser: ArgumentParser):

View File

@ -99,7 +99,9 @@ class OAuthSourceSerializer(SourceSerializer):
]: ]:
if getattr(provider_type, url, None) is None: if getattr(provider_type, url, None) is None:
if url not in attrs: if url not in attrs:
raise ValidationError(f"{url} is required for provider {provider_type.name}") raise ValidationError(
f"{url} is required for provider {provider_type.verbose_name}"
)
return attrs return attrs
class Meta: class Meta:

View File

@ -104,8 +104,8 @@ class AppleType(SourceType):
callback_view = AppleOAuth2Callback callback_view = AppleOAuth2Callback
redirect_view = AppleOAuthRedirect redirect_view = AppleOAuthRedirect
name = "Apple" verbose_name = "Apple"
slug = "apple" name = "apple"
authorization_url = "https://appleid.apple.com/auth/authorize" authorization_url = "https://appleid.apple.com/auth/authorize"
access_token_url = "https://appleid.apple.com/auth/token" # nosec access_token_url = "https://appleid.apple.com/auth/token" # nosec

View File

@ -43,8 +43,8 @@ class AzureADType(SourceType):
callback_view = AzureADOAuthCallback callback_view = AzureADOAuthCallback
redirect_view = AzureADOAuthRedirect redirect_view = AzureADOAuthRedirect
name = "Azure AD" verbose_name = "Azure AD"
slug = "azuread" name = "azuread"
urls_customizable = True urls_customizable = True

View File

@ -36,8 +36,8 @@ class DiscordType(SourceType):
callback_view = DiscordOAuth2Callback callback_view = DiscordOAuth2Callback
redirect_view = DiscordOAuthRedirect redirect_view = DiscordOAuthRedirect
name = "Discord" verbose_name = "Discord"
slug = "discord" name = "discord"
authorization_url = "https://discord.com/api/oauth2/authorize" authorization_url = "https://discord.com/api/oauth2/authorize"
access_token_url = "https://discord.com/api/oauth2/token" # nosec access_token_url = "https://discord.com/api/oauth2/token" # nosec

View File

@ -48,8 +48,8 @@ class FacebookType(SourceType):
callback_view = FacebookOAuth2Callback callback_view = FacebookOAuth2Callback
redirect_view = FacebookOAuthRedirect redirect_view = FacebookOAuthRedirect
name = "Facebook" verbose_name = "Facebook"
slug = "facebook" name = "facebook"
authorization_url = "https://www.facebook.com/v7.0/dialog/oauth" authorization_url = "https://www.facebook.com/v7.0/dialog/oauth"
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec

View File

@ -68,8 +68,8 @@ class GitHubType(SourceType):
callback_view = GitHubOAuth2Callback callback_view = GitHubOAuth2Callback
redirect_view = GitHubOAuthRedirect redirect_view = GitHubOAuthRedirect
name = "GitHub" verbose_name = "GitHub"
slug = "github" name = "github"
urls_customizable = True urls_customizable = True

View File

@ -34,8 +34,8 @@ class GoogleType(SourceType):
callback_view = GoogleOAuth2Callback callback_view = GoogleOAuth2Callback
redirect_view = GoogleOAuthRedirect redirect_view = GoogleOAuthRedirect
name = "Google" verbose_name = "Google"
slug = "google" name = "google"
authorization_url = "https://accounts.google.com/o/oauth2/auth" authorization_url = "https://accounts.google.com/o/oauth2/auth"
access_token_url = "https://oauth2.googleapis.com/token" # nosec access_token_url = "https://oauth2.googleapis.com/token" # nosec

View File

@ -63,7 +63,7 @@ class MailcowType(SourceType):
callback_view = MailcowOAuth2Callback callback_view = MailcowOAuth2Callback
redirect_view = MailcowOAuthRedirect redirect_view = MailcowOAuthRedirect
name = "Mailcow" verbose_name = "Mailcow"
slug = "mailcow" name = "mailcow"
urls_customizable = True urls_customizable = True

View File

@ -42,7 +42,7 @@ class OpenIDConnectType(SourceType):
callback_view = OpenIDConnectOAuth2Callback callback_view = OpenIDConnectOAuth2Callback
redirect_view = OpenIDConnectOAuthRedirect redirect_view = OpenIDConnectOAuthRedirect
name = "OpenID Connect" verbose_name = "OpenID Connect"
slug = "openidconnect" name = "openidconnect"
urls_customizable = True urls_customizable = True

View File

@ -42,7 +42,7 @@ class OktaType(SourceType):
callback_view = OktaOAuth2Callback callback_view = OktaOAuth2Callback
redirect_view = OktaOAuthRedirect redirect_view = OktaOAuthRedirect
name = "Okta" verbose_name = "Okta"
slug = "okta" name = "okta"
urls_customizable = True urls_customizable = True

View File

@ -43,8 +43,8 @@ class PatreonType(SourceType):
callback_view = PatreonOAuthCallback callback_view = PatreonOAuthCallback
redirect_view = PatreonOAuthRedirect redirect_view = PatreonOAuthRedirect
name = "Patreon" verbose_name = "Patreon"
slug = "patreon" name = "patreon"
authorization_url = "https://www.patreon.com/oauth2/authorize" authorization_url = "https://www.patreon.com/oauth2/authorize"
access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec

View File

@ -51,8 +51,8 @@ class RedditType(SourceType):
callback_view = RedditOAuth2Callback callback_view = RedditOAuth2Callback
redirect_view = RedditOAuthRedirect redirect_view = RedditOAuthRedirect
name = "Reddit" verbose_name = "Reddit"
slug = "reddit" name = "reddit"
authorization_url = "https://www.reddit.com/api/v1/authorize" authorization_url = "https://www.reddit.com/api/v1/authorize"
access_token_url = "https://www.reddit.com/api/v1/access_token" # nosec access_token_url = "https://www.reddit.com/api/v1/access_token" # nosec

View File

@ -28,7 +28,7 @@ class SourceType:
callback_view = OAuthCallback callback_view = OAuthCallback
redirect_view = OAuthRedirect redirect_view = OAuthRedirect
name: str = "default" name: str = "default"
slug: str = "default" verbose_name: str = "Default source type"
urls_customizable = False urls_customizable = False
@ -41,7 +41,7 @@ class SourceType:
def icon_url(self) -> str: def icon_url(self) -> str:
"""Get Icon URL for login""" """Get Icon URL for login"""
return static(f"authentik/sources/{self.slug}.svg") return static(f"authentik/sources/{self.name}.svg")
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge: def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
"""Allow types to return custom challenges""" """Allow types to return custom challenges"""
@ -77,20 +77,20 @@ class SourceTypeRegistry:
def get_name_tuple(self): def get_name_tuple(self):
"""Get list of tuples of all registered names""" """Get list of tuples of all registered names"""
return [(x.slug, x.name) for x in self.__sources] return [(x.name, x.verbose_name) for x in self.__sources]
def find_type(self, type_name: str) -> Type[SourceType]: def find_type(self, type_name: str) -> Type[SourceType]:
"""Find type based on source""" """Find type based on source"""
found_type = None found_type = None
for src_type in self.__sources: for src_type in self.__sources:
if src_type.slug == type_name: if src_type.name == type_name:
return src_type return src_type
if not found_type: if not found_type:
found_type = SourceType found_type = SourceType
LOGGER.warning( LOGGER.warning(
"no matching type found, using default", "no matching type found, using default",
wanted=type_name, wanted=type_name,
have=[x.slug for x in self.__sources], have=[x.name for x in self.__sources],
) )
return found_type return found_type

View File

@ -49,8 +49,8 @@ class TwitchType(SourceType):
callback_view = TwitchOAuth2Callback callback_view = TwitchOAuth2Callback
redirect_view = TwitchOAuthRedirect redirect_view = TwitchOAuthRedirect
name = "Twitch" verbose_name = "Twitch"
slug = "twitch" name = "twitch"
authorization_url = "https://id.twitch.tv/oauth2/authorize" authorization_url = "https://id.twitch.tv/oauth2/authorize"
access_token_url = "https://id.twitch.tv/oauth2/token" # nosec access_token_url = "https://id.twitch.tv/oauth2/token" # nosec

View File

@ -66,8 +66,8 @@ class TwitterType(SourceType):
callback_view = TwitterOAuthCallback callback_view = TwitterOAuthCallback
redirect_view = TwitterOAuthRedirect redirect_view = TwitterOAuthRedirect
name = "Twitter" verbose_name = "Twitter"
slug = "twitter" name = "twitter"
authorization_url = "https://twitter.com/i/oauth2/authorize" authorization_url = "https://twitter.com/i/oauth2/authorize"
access_token_url = "https://api.twitter.com/2/oauth2/token" # nosec access_token_url = "https://api.twitter.com/2/oauth2/token" # nosec

View File

@ -2779,6 +2779,117 @@
} }
} }
}, },
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_providers_rac.racprovider"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"attrs": {
"$ref": "#/$defs/model_authentik_providers_rac.racprovider"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_providers_rac.racprovider"
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_providers_rac.endpoint"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"attrs": {
"$ref": "#/$defs/model_authentik_providers_rac.endpoint"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_providers_rac.endpoint"
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_providers_rac.racpropertymapping"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"attrs": {
"$ref": "#/$defs/model_authentik_providers_rac.racpropertymapping"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_providers_rac.racpropertymapping"
}
}
},
{ {
"type": "object", "type": "object",
"required": [ "required": [
@ -3296,7 +3407,8 @@
"enum": [ "enum": [
"proxy", "proxy",
"ldap", "ldap",
"radius" "radius",
"rac"
], ],
"title": "Type" "title": "Type"
}, },
@ -3476,7 +3588,8 @@
"authentik.tenants", "authentik.tenants",
"authentik.blueprints", "authentik.blueprints",
"authentik.core", "authentik.core",
"authentik.enterprise" "authentik.enterprise",
"authentik.enterprise.providers.rac"
], ],
"title": "App", "title": "App",
"description": "Match events created by selected application. When left empty, all applications are matched." "description": "Match events created by selected application. When left empty, all applications are matched."
@ -3561,7 +3674,10 @@
"authentik_core.user", "authentik_core.user",
"authentik_core.application", "authentik_core.application",
"authentik_core.token", "authentik_core.token",
"authentik_enterprise.license" "authentik_enterprise.license",
"authentik_providers_rac.racprovider",
"authentik_providers_rac.endpoint",
"authentik_providers_rac.racpropertymapping"
], ],
"title": "Model", "title": "Model",
"description": "Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched." "description": "Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched."
@ -8758,6 +8874,123 @@
}, },
"required": [] "required": []
}, },
"model_authentik_providers_rac.racprovider": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"authentication_flow": {
"type": "integer",
"title": "Authentication flow",
"description": "Flow used for authentication when the associated application is accessed by an un-authenticated user."
},
"authorization_flow": {
"type": "integer",
"title": "Authorization flow",
"description": "Flow used when authorizing this provider."
},
"property_mappings": {
"type": "array",
"items": {
"type": "integer"
},
"title": "Property mappings"
},
"settings": {
"type": "object",
"additionalProperties": true,
"title": "Settings"
},
"connection_expiry": {
"type": "string",
"minLength": 1,
"title": "Connection expiry",
"description": "Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)"
}
},
"required": []
},
"model_authentik_providers_rac.endpoint": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"provider": {
"type": "integer",
"title": "Provider"
},
"protocol": {
"type": "string",
"enum": [
"rdp",
"vnc",
"ssh"
],
"title": "Protocol"
},
"host": {
"type": "string",
"minLength": 1,
"title": "Host"
},
"settings": {
"type": "object",
"additionalProperties": true,
"title": "Settings"
},
"property_mappings": {
"type": "array",
"items": {
"type": "integer"
},
"title": "Property mappings"
},
"auth_mode": {
"type": "string",
"enum": [
"static",
"prompt"
],
"title": "Auth mode"
}
},
"required": []
},
"model_authentik_providers_rac.racpropertymapping": {
"type": "object",
"properties": {
"managed": {
"type": [
"string",
"null"
],
"minLength": 1,
"title": "Managed by authentik",
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
},
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"expression": {
"type": "string",
"title": "Expression"
},
"static_settings": {
"type": "object",
"additionalProperties": true,
"title": "Static settings"
}
},
"required": []
},
"model_authentik_blueprints.metaapplyblueprint": { "model_authentik_blueprints.metaapplyblueprint": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@ -0,0 +1,32 @@
version: 1
metadata:
labels:
blueprints.goauthentik.io/system: "true"
name: System - RAC Provider - Mappings
entries:
- identifiers:
managed: goauthentik.io/providers/rac/rdp-default
model: authentik_providers_rac.racpropertymapping
attrs:
name: "authentik default RAC Mapping: RDP Default settings"
static_settings:
resize-method: "display-update"
enable-wallpaper: "true"
enable-font-smoothing: "true"
- identifiers:
managed: goauthentik.io/providers/rac/rdp-high-fidelity
model: authentik_providers_rac.racpropertymapping
attrs:
name: "authentik default RAC Mapping: RDP High Fidelity"
static_settings:
enable-theming: "true"
enable-full-window-drag: "true"
enable-desktop-composition: "true"
enable-menu-animations: "true"
- identifiers:
managed: goauthentik.io/providers/rac/ssh-default
model: authentik_providers_rac.racpropertymapping
attrs:
name: "authentik default RAC Mapping: SSH Default settings"
static_settings:
terminal-type: "xterm-256color"

93
cmd/rac/main.go Normal file
View File

@ -0,0 +1,93 @@
package main
import (
"fmt"
"net/url"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"goauthentik.io/internal/common"
"goauthentik.io/internal/debug"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ak/healthcheck"
"goauthentik.io/internal/outpost/rac"
)
const helpMessage = `authentik RAC
Required environment variables:
- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company")
- AUTHENTIK_TOKEN: Token to authenticate with
- AUTHENTIK_INSECURE: Skip SSL Certificate verification`
var rootCmd = &cobra.Command{
Long: helpMessage,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
log.SetLevel(log.DebugLevel)
log.SetFormatter(&log.JSONFormatter{
FieldMap: log.FieldMap{
log.FieldKeyMsg: "event",
log.FieldKeyTime: "timestamp",
},
DisableHTMLEscape: true,
})
},
Run: func(cmd *cobra.Command, args []string) {
debug.EnableDebugServer()
akURL, found := os.LookupEnv("AUTHENTIK_HOST")
if !found {
fmt.Println("env AUTHENTIK_HOST not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akToken, found := os.LookupEnv("AUTHENTIK_TOKEN")
if !found {
fmt.Println("env AUTHENTIK_TOKEN not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akURLActual, err := url.Parse(akURL)
if err != nil {
fmt.Println(err)
fmt.Println(helpMessage)
os.Exit(1)
}
ex := common.Init()
defer common.Defer()
go func() {
for {
<-ex
os.Exit(0)
}
}()
ac := ak.NewAPIController(*akURLActual, akToken)
if ac == nil {
os.Exit(1)
}
defer ac.Shutdown()
ac.Server = rac.NewServer(ac)
err = ac.Start()
if err != nil {
log.WithError(err).Panic("Failed to run server")
}
for {
<-ex
}
},
}
func main() {
rootCmd.AddCommand(healthcheck.Command)
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}

15
go.mod
View File

@ -10,7 +10,7 @@ require (
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.6 github.com/go-ldap/ldap/v3 v3.4.6
github.com/go-openapi/runtime v0.26.2 github.com/go-openapi/runtime v0.26.2
github.com/go-openapi/strfmt v0.21.10 github.com/go-openapi/strfmt v0.22.0
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.5.0 github.com/google/uuid v1.5.0
github.com/gorilla/handlers v1.5.2 github.com/gorilla/handlers v1.5.2
@ -22,12 +22,13 @@ require (
github.com/mitchellh/mapstructure v1.5.0 github.com/mitchellh/mapstructure v1.5.0
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484 github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
github.com/pires/go-proxyproto v0.7.0 github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.17.0 github.com/prometheus/client_golang v1.18.0
github.com/redis/go-redis/v9 v9.3.1 github.com/redis/go-redis/v9 v9.3.1
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.0 github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
goauthentik.io/api/v3 v3.2023105.2 github.com/wwt/guac v1.3.2
goauthentik.io/api/v3 v3.2023105.3
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.15.0 golang.org/x/oauth2 v0.15.0
golang.org/x/sync v0.5.0 golang.org/x/sync v0.5.0
@ -60,14 +61,14 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect github.com/oklog/ulid v1.3.1 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/common v0.45.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect github.com/prometheus/procfs v0.12.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
go.mongodb.org/mongo-driver v1.13.1 // indirect go.mongodb.org/mongo-driver v1.13.1 // indirect
go.opentelemetry.io/otel v1.17.0 // indirect go.opentelemetry.io/otel v1.17.0 // indirect

36
go.sum
View File

@ -116,8 +116,8 @@ github.com/go-openapi/spec v0.20.6/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6
github.com/go-openapi/spec v0.20.11 h1:J/TzFDLTt4Rcl/l1PmyErvkqlJDncGvPTMnCI39I4gY= github.com/go-openapi/spec v0.20.11 h1:J/TzFDLTt4Rcl/l1PmyErvkqlJDncGvPTMnCI39I4gY=
github.com/go-openapi/spec v0.20.11/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6VaaBKcWA= github.com/go-openapi/spec v0.20.11/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6VaaBKcWA=
github.com/go-openapi/strfmt v0.21.3/go.mod h1:k+RzNO0Da+k3FrrynSNN8F7n/peCmQQqbbXjtDfvmGg= github.com/go-openapi/strfmt v0.21.3/go.mod h1:k+RzNO0Da+k3FrrynSNN8F7n/peCmQQqbbXjtDfvmGg=
github.com/go-openapi/strfmt v0.21.10 h1:JIsly3KXZB/Qf4UzvzJpg4OELH/0ASDQsyk//TTBDDk= github.com/go-openapi/strfmt v0.22.0 h1:Ew9PnEYc246TwrEspvBdDHS4BVKXy/AOVsfqGDgAcaI=
github.com/go-openapi/strfmt v0.21.10/go.mod h1:vNDMwbilnl7xKiO/Ve/8H8Bb2JIInBnH+lqiw6QWgis= github.com/go-openapi/strfmt v0.22.0/go.mod h1:HzJ9kokGIju3/K6ap8jL+OlGAbjpSv27135Yr9OivU4=
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
@ -195,6 +195,7 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY=
github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
@ -210,6 +211,7 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@ -223,8 +225,8 @@ github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN
github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k=
github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
@ -247,21 +249,22 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac h1:jWKYCNlX4J5s8M0nHYkh7Y7c9gRVDEb3mq51j5J0F5M= github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac h1:jWKYCNlX4J5s8M0nHYkh7Y7c9gRVDEb3mq51j5J0F5M=
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac/go.mod h1:hoLfEwdY11HjRfKFH6KqnPsfxlo3BP6bJehpDv8t6sQ= github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac/go.mod h1:hoLfEwdY11HjRfKFH6KqnPsfxlo3BP6bJehpDv8t6sQ=
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q= github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk=
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY= github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM= github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM=
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY=
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds= github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
@ -269,8 +272,10 @@ github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyh
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@ -281,6 +286,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/wwt/guac v1.3.2 h1:sH6OFGa/1tBs7ieWBVlZe7t6F5JAOWBry/tqQL/Vup4=
github.com/wwt/guac v1.3.2/go.mod h1:eKm+NrnK7A88l4UBEcYNpZQGMpZRryYKoz4D/0/n1C0=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g=
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
@ -309,8 +316,8 @@ go.opentelemetry.io/otel/trace v1.17.0 h1:/SWhSRHmDPOImIAetP1QAeMnZYiQXrTy4fMMYO
go.opentelemetry.io/otel/trace v1.17.0/go.mod h1:I/4vKTgFclIsXRVucpH25X0mpFSczM7aHeaz0ZBLWjY= go.opentelemetry.io/otel/trace v1.17.0/go.mod h1:I/4vKTgFclIsXRVucpH25X0mpFSczM7aHeaz0ZBLWjY=
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
goauthentik.io/api/v3 v3.2023105.2 h1:ZUblqN5LidnCSlEZ/L19h7OnwppnAA3m5AGC7wUN0Ew= goauthentik.io/api/v3 v3.2023105.3 h1:x0pMJIKkbN198OOssqA94h8bO6ft9gwG8bpZqZL7WVg=
goauthentik.io/api/v3 v3.2023105.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw= goauthentik.io/api/v3 v3.2023105.3/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -414,6 +421,7 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -159,8 +159,8 @@ func (a *APIController) AddRefreshHandler(handler func()) {
a.refreshHandlers = append(a.refreshHandlers, handler) a.refreshHandlers = append(a.refreshHandlers, handler)
} }
func (a *APIController) AddWSHandler(handler WSHandler) { func (a *APIController) Token() string {
a.wsHandlers = append(a.wsHandlers, handler) return a.token
} }
func (a *APIController) OnRefresh() error { func (a *APIController) OnRefresh() error {
@ -182,7 +182,7 @@ func (a *APIController) OnRefresh() error {
return err return err
} }
func (a *APIController) getWebsocketArgs() map[string]interface{} { func (a *APIController) getWebsocketPingArgs() map[string]interface{} {
args := map[string]interface{}{ args := map[string]interface{}{
"version": constants.VERSION, "version": constants.VERSION,
"buildHash": constants.BUILD("tagged"), "buildHash": constants.BUILD("tagged"),

View File

@ -18,6 +18,8 @@ import (
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
pathTemplate := "%s://%s/ws/outpost/%s/?%s" pathTemplate := "%s://%s/ws/outpost/%s/?%s"
query := akURL.Query()
query.Set("instance_uuid", ac.instanceUUID.String())
scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws") scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
authHeader := fmt.Sprintf("Bearer %s", ac.token) authHeader := fmt.Sprintf("Bearer %s", ac.token)
@ -45,7 +47,7 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
// Send hello message with our version // Send hello message with our version
msg := websocketMessage{ msg := websocketMessage{
Instruction: WebsocketInstructionHello, Instruction: WebsocketInstructionHello,
Args: ac.getWebsocketArgs(), Args: ac.getWebsocketPingArgs(),
} }
err = ws.WriteJSON(msg) err = ws.WriteJSON(msg)
if err != nil { if err != nil {
@ -53,7 +55,7 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
return err return err
} }
ac.lastWsReconnect = time.Now() ac.lastWsReconnect = time.Now()
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket") ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Info("Successfully connected websocket")
return nil return nil
} }
@ -157,23 +159,19 @@ func (ac *APIController) startWSHandler() {
func (ac *APIController) startWSHealth() { func (ac *APIController) startWSHealth() {
ticker := time.NewTicker(time.Second * 10) ticker := time.NewTicker(time.Second * 10)
for ; true; <-ticker.C { for ; true; <-ticker.C {
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: ac.getWebsocketArgs(),
}
if ac.wsConn == nil { if ac.wsConn == nil {
go ac.reconnectWS() go ac.reconnectWS()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} }
err := ac.wsConn.WriteJSON(aliveMsg) err := ac.SendWSHello(map[string]interface{}{})
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
if err != nil { if err != nil {
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error") ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
go ac.reconnectWS() go ac.reconnectWS()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} else { } else {
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
ConnectionStatus.With(prometheus.Labels{ ConnectionStatus.With(prometheus.Labels{
"outpost_name": ac.Outpost.Name, "outpost_name": ac.Outpost.Name,
"outpost_type": ac.Server.Type(), "outpost_type": ac.Server.Type(),
@ -202,3 +200,20 @@ func (ac *APIController) startIntervalUpdater() {
} }
} }
} }
func (a *APIController) AddWSHandler(handler WSHandler) {
a.wsHandlers = append(a.wsHandlers, handler)
}
func (a *APIController) SendWSHello(args map[string]interface{}) error {
allArgs := a.getWebsocketPingArgs()
for key, value := range args {
allArgs[key] = value
}
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: allArgs,
}
err := a.wsConn.WriteJSON(aliveMsg)
return err
}

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"beryju.io/ldap" "beryju.io/ldap"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ldap/constants" "goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/utils" "goauthentik.io/internal/outpost/ldap/utils"
@ -49,8 +50,8 @@ func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
constants.OCPosixAccount, constants.OCPosixAccount,
constants.OCAKUser, constants.OCAKUser,
}, },
"uidNumber": {pi.GetUidNumber(u)}, "uidNumber": {pi.GetUserUidNumber(u)},
"gidNumber": {pi.GetUidNumber(u)}, "gidNumber": {pi.GetUserGidNumber(u)},
"homeDirectory": {fmt.Sprintf("/home/%s", u.Username)}, "homeDirectory": {fmt.Sprintf("/home/%s", u.Username)},
"sn": {u.Name}, "sn": {u.Name},
}) })

View File

@ -4,6 +4,7 @@ import (
"strconv" "strconv"
"beryju.io/ldap" "beryju.io/ldap"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ldap/constants" "goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/server"
@ -50,7 +51,7 @@ func FromAPIGroup(g api.Group, si server.LDAPServerInstance) *LDAPGroup {
DN: si.GetGroupDN(g.Name), DN: si.GetGroupDN(g.Name),
CN: g.Name, CN: g.Name,
Uid: string(g.Pk), Uid: string(g.Pk),
GidNumber: si.GetGidNumber(g), GidNumber: si.GetGroupGidNumber(g),
Member: si.UsersForGroup(g), Member: si.UsersForGroup(g),
IsVirtualGroup: false, IsVirtualGroup: false,
IsSuperuser: *g.IsSuperuser, IsSuperuser: *g.IsSuperuser,
@ -63,7 +64,7 @@ func FromAPIUser(u api.User, si server.LDAPServerInstance) *LDAPGroup {
DN: si.GetVirtualGroupDN(u.Username), DN: si.GetVirtualGroupDN(u.Username),
CN: u.Username, CN: u.Username,
Uid: u.Uid, Uid: u.Uid,
GidNumber: si.GetUidNumber(u), GidNumber: si.GetUserGidNumber(u),
Member: []string{si.GetUserDN(u.Username)}, Member: []string{si.GetUserDN(u.Username)},
IsVirtualGroup: true, IsVirtualGroup: true,
IsSuperuser: false, IsSuperuser: false,

View File

@ -3,6 +3,7 @@ package server
import ( import (
"beryju.io/ldap" "beryju.io/ldap"
"github.com/go-openapi/strfmt" "github.com/go-openapi/strfmt"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ldap/flags" "goauthentik.io/internal/outpost/ldap/flags"
) )
@ -28,8 +29,9 @@ type LDAPServerInstance interface {
GetGroupDN(string) string GetGroupDN(string) string
GetVirtualGroupDN(string) string GetVirtualGroupDN(string) string
GetUidNumber(api.User) string GetUserUidNumber(api.User) string
GetGidNumber(api.Group) string GetUserGidNumber(api.User) string
GetGroupGidNumber(api.Group) string
UsersForGroup(api.Group) []string UsersForGroup(api.Group) []string

View File

@ -35,7 +35,7 @@ func (pi *ProviderInstance) GetVirtualGroupDN(group string) string {
return fmt.Sprintf("cn=%s,%s", group, pi.VirtualGroupDN) return fmt.Sprintf("cn=%s,%s", group, pi.VirtualGroupDN)
} }
func (pi *ProviderInstance) GetUidNumber(user api.User) string { func (pi *ProviderInstance) GetUserUidNumber(user api.User) string {
uidNumber, ok := user.GetAttributes()["uidNumber"].(string) uidNumber, ok := user.GetAttributes()["uidNumber"].(string)
if ok { if ok {
@ -45,7 +45,17 @@ func (pi *ProviderInstance) GetUidNumber(user api.User) string {
return strconv.FormatInt(int64(pi.uidStartNumber+user.Pk), 10) return strconv.FormatInt(int64(pi.uidStartNumber+user.Pk), 10)
} }
func (pi *ProviderInstance) GetGidNumber(group api.Group) string { func (pi *ProviderInstance) GetUserGidNumber(user api.User) string {
gidNumber, ok := user.GetAttributes()["gidNumber"].(string)
if ok {
return gidNumber
}
return pi.GetUserUidNumber(user)
}
func (pi *ProviderInstance) GetGroupGidNumber(group api.Group) string {
gidNumber, ok := group.GetAttributes()["gidNumber"].(string) gidNumber, ok := group.GetAttributes()["gidNumber"].(string)
if ok { if ok {

View File

@ -31,16 +31,11 @@ func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Co
return nil, err return nil, err
} }
// Extract the ID Token from OAuth2 token. jwt := oauth2Token.AccessToken
rawIDToken, ok := oauth2Token.Extra("id_token").(string) a.log.WithField("jwt", jwt).Trace("access_token")
if !ok {
return nil, fmt.Errorf("missing id_token")
}
a.log.WithField("id_token", rawIDToken).Trace("id_token")
// Parse and verify ID Token payload. // Parse and verify ID Token payload.
idToken, err := a.tokenVerifier.Verify(ctx, rawIDToken) idToken, err := a.tokenVerifier.Verify(ctx, jwt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -53,6 +48,6 @@ func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Co
if claims.Proxy == nil { if claims.Proxy == nil {
claims.Proxy = &ProxyClaims{} claims.Proxy = &ProxyClaims{}
} }
claims.RawToken = rawIDToken claims.RawToken = jwt
return claims, nil return claims, nil
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/config" "goauthentik.io/internal/config"
"goauthentik.io/internal/outpost/proxyv2/codecs" "goauthentik.io/internal/outpost/proxyv2/codecs"
@ -40,7 +41,7 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
// New default RedisStore // New default RedisStore
rs, err := redisstore.NewRedisStore(context.Background(), client) rs, err := redisstore.NewRedisStore(context.Background(), client)
if err != nil { if err != nil {
panic(err) a.log.WithError(err).Panic("failed to connect to redis")
} }
rs.KeyPrefix(RedisKeyPrefix) rs.KeyPrefix(RedisKeyPrefix)

View File

@ -0,0 +1,124 @@
package connection
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"github.com/wwt/guac"
"goauthentik.io/internal/config"
"goauthentik.io/internal/constants"
"goauthentik.io/internal/outpost/ak"
)
const guacAddr = "0.0.0.0:4822"
type Connection struct {
log *log.Entry
st *guac.SimpleTunnel
ac *ak.APIController
ws *websocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
OnError func(error)
closing bool
}
func NewConnection(ac *ak.APIController, forChannel string, cfg *guac.Config) (*Connection, error) {
ctx, canc := context.WithCancel(context.Background())
c := &Connection{
ac: ac,
log: log.WithField("connection", forChannel),
ctx: ctx,
ctxCancel: canc,
OnError: func(err error) {},
closing: false,
}
err := c.initGuac(cfg)
if err != nil {
return nil, err
}
err = c.initSocket(forChannel)
if err != nil {
_ = c.st.Close()
return nil, err
}
c.initMirror()
return c, nil
}
func (c *Connection) initSocket(forChannel string) error {
pathTemplate := "%s://%s/ws/outpost_rac/%s/"
scheme := strings.ReplaceAll(c.ac.Client.GetConfig().Scheme, "http", "ws")
authHeader := fmt.Sprintf("Bearer %s", c.ac.Token())
header := http.Header{
"Authorization": []string{authHeader},
"User-Agent": []string{constants.OutpostUserAgent()},
}
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Get().AuthentikInsecure,
},
}
url := fmt.Sprintf(pathTemplate, scheme, c.ac.Client.GetConfig().Host, forChannel)
ws, _, err := dialer.Dial(url, header)
if err != nil {
c.log.WithError(err).Warning("failed to connect websocket")
return err
}
c.ws = ws
return nil
}
func (c *Connection) initGuac(cfg *guac.Config) error {
addr, err := net.ResolveTCPAddr("tcp", guacAddr)
if err != nil {
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
stream := guac.NewStream(conn, guac.SocketTimeout)
err = stream.Handshake(cfg)
if err != nil {
return err
}
st := guac.NewSimpleTunnel(stream)
c.st = st
return nil
}
func (c *Connection) initMirror() {
go c.wsToGuacd()
go c.guacdToWs()
}
func (c *Connection) onError(err error) {
if c.closing {
return
}
c.closing = true
e := c.st.Close()
if e != nil {
c.log.WithError(e).Warning("failed to close guacd connection")
}
c.log.WithError(err).Info("removing connection")
c.ctxCancel()
c.OnError(err)
}

View File

@ -0,0 +1,103 @@
package connection
import (
"bytes"
"fmt"
"github.com/gorilla/websocket"
"github.com/wwt/guac"
)
var (
internalOpcodeIns = []byte(fmt.Sprint(len(guac.InternalDataOpcode), ".", guac.InternalDataOpcode))
authentikOpcode = []byte("0.authentik.")
)
// MessageReader wraps a websocket connection and only permits Reading
type MessageReader interface {
// ReadMessage should return a single complete message to send to guac
ReadMessage() (int, []byte, error)
}
func (c *Connection) wsToGuacd() {
w := c.st.AcquireWriter()
for {
select {
default:
_, data, e := c.ws.ReadMessage()
if e != nil {
c.log.WithError(e).Trace("Error reading message from ws")
c.onError(e)
return
}
if bytes.HasPrefix(data, internalOpcodeIns) {
if bytes.HasPrefix(data, authentikOpcode) {
switch string(bytes.Replace(data, authentikOpcode, []byte{}, 1)) {
case "disconnect":
_, e := w.Write([]byte(guac.NewInstruction("disconnect").String()))
c.onError(e)
return
}
}
// messages starting with the InternalDataOpcode are never sent to guacd
continue
}
if _, e = w.Write(data); e != nil {
c.log.WithError(e).Trace("Failed writing to guacd")
c.onError(e)
return
}
case <-c.ctx.Done():
return
}
}
}
// MessageWriter wraps a websocket connection and only permits Writing
type MessageWriter interface {
// WriteMessage writes one or more complete guac commands to the websocket
WriteMessage(int, []byte) error
}
func (c *Connection) guacdToWs() {
r := c.st.AcquireReader()
buf := bytes.NewBuffer(make([]byte, 0, guac.MaxGuacMessage*2))
for {
select {
default:
ins, e := r.ReadSome()
if e != nil {
c.log.WithError(e).Trace("Error reading from guacd")
c.onError(e)
return
}
if bytes.HasPrefix(ins, internalOpcodeIns) {
// messages starting with the InternalDataOpcode are never sent to the websocket
continue
}
if _, e = buf.Write(ins); e != nil {
c.log.WithError(e).Trace("Failed to buffer guacd to ws")
c.onError(e)
return
}
// if the buffer has more data in it or we've reached the max buffer size, send the data and reset
if !r.Available() || buf.Len() >= guac.MaxGuacMessage {
if e = c.ws.WriteMessage(1, buf.Bytes()); e != nil {
if e == websocket.ErrCloseSent {
return
}
c.log.WithError(e).Trace("Failed sending message to ws")
c.onError(e)
return
}
buf.Reset()
}
case <-c.ctx.Done():
return
}
}
}

View File

@ -0,0 +1,26 @@
package rac
import (
"os"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ak"
)
const (
guacdPath = "/opt/guacamole/sbin/guacd"
guacdDefaultArgs = " -b 0.0.0.0 -f"
)
func (rs *RACServer) startGuac() error {
guacdArgs := strings.Split(guacdDefaultArgs, " ")
guacdArgs = append(guacdArgs, "-L", rs.ac.Outpost.Config[ak.ConfigLogLevel].(string))
rs.guacd = exec.Command(guacdPath, guacdArgs...)
rs.guacd.Env = os.Environ()
rs.guacd.Stdout = rs.log.WithField("logger", "authentik.outpost.rac.guacd").WriterLevel(log.InfoLevel)
rs.guacd.Stderr = rs.log.WithField("logger", "authentik.outpost.rac.guacd").WriterLevel(log.InfoLevel)
rs.log.Info("starting guacd")
return rs.guacd.Start()
}

View File

@ -0,0 +1,28 @@
package metrics
import (
"net/http"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/config"
"goauthentik.io/internal/utils/sentry"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func RunServer() {
m := mux.NewRouter()
l := log.WithField("logger", "authentik.outpost.metrics")
m.Use(sentry.SentryNoSampleMiddleware)
m.HandleFunc("/outpost.goauthentik.io/ping", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(204)
})
m.Path("/metrics").Handler(promhttp.Handler())
listen := config.Get().Listen.Metrics
l.WithField("listen", listen).Info("Starting Metrics server")
err := http.ListenAndServe(listen, m)
if err != nil {
l.WithError(err).Warning("Failed to start metrics listener")
}
}

126
internal/outpost/rac/rac.go Normal file
View File

@ -0,0 +1,126 @@
package rac
import (
"context"
"os/exec"
"strconv"
"sync"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
"github.com/wwt/guac"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/rac/connection"
"goauthentik.io/internal/outpost/rac/metrics"
)
type RACServer struct {
log *log.Entry
ac *ak.APIController
guacd *exec.Cmd
connm sync.RWMutex
conns map[string]connection.Connection
}
func NewServer(ac *ak.APIController) *RACServer {
rs := &RACServer{
log: log.WithField("logger", "authentik.outpost.rac"),
ac: ac,
connm: sync.RWMutex{},
conns: map[string]connection.Connection{},
}
ac.AddWSHandler(rs.wsHandler)
return rs
}
type WSMessage struct {
ConnID string `mapstructure:"conn_id"`
DestChannelID string `mapstructure:"dest_channel_id"`
Params map[string]string `mapstructure:"params"`
Protocol string `mapstructure:"protocol"`
OptimalScreenWidth string `mapstructure:"screen_width"`
OptimalScreenHeight string `mapstructure:"screen_height"`
OptimalScreenDPI string `mapstructure:"screen_dpi"`
}
func parseIntOrZero(input string) int {
x, err := strconv.Atoi(input)
if err != nil {
return 0
}
return x
}
func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) {
wsm := WSMessage{}
err := mapstructure.Decode(args, &wsm)
if err != nil {
rs.log.WithError(err).Warning("invalid ws message")
return
}
config := guac.NewGuacamoleConfiguration()
config.Protocol = wsm.Protocol
config.Parameters = wsm.Params
config.OptimalScreenWidth = parseIntOrZero(wsm.OptimalScreenWidth)
config.OptimalScreenHeight = parseIntOrZero(wsm.OptimalScreenHeight)
config.OptimalResolution = parseIntOrZero(wsm.OptimalScreenDPI)
config.AudioMimetypes = []string{
"audio/L8",
"audio/L16",
}
cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config)
if err != nil {
rs.log.WithError(err).Warning("failed to setup connection")
return
}
cc.OnError = func(err error) {
rs.connm.Lock()
delete(rs.conns, wsm.ConnID)
_ = rs.ac.SendWSHello(map[string]interface{}{
"active_connections": len(rs.conns),
})
rs.connm.Unlock()
}
rs.connm.Lock()
rs.conns[wsm.ConnID] = *cc
_ = rs.ac.SendWSHello(map[string]interface{}{
"active_connections": len(rs.conns),
})
rs.connm.Unlock()
}
func (rs *RACServer) Start() error {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
metrics.RunServer()
}()
go func() {
defer wg.Done()
err := rs.startGuac()
if err != nil {
panic(err)
}
}()
wg.Wait()
return nil
}
func (rs *RACServer) Stop() error {
if rs.guacd != nil {
return rs.guacd.Process.Kill()
}
return nil
}
func (rs *RACServer) TimerFlowCacheExpiry(context.Context) {}
func (rs *RACServer) Type() string {
return "rac"
}
func (rs *RACServer) Refresh() error {
return nil
}

View File

@ -34,6 +34,11 @@ func (ws *WebServer) configureStatic() {
}) })
indexLessRouter.PathPrefix("/if/admin/assets").Handler(http.StripPrefix("/if/admin", distFs)) indexLessRouter.PathPrefix("/if/admin/assets").Handler(http.StripPrefix("/if/admin", distFs))
indexLessRouter.PathPrefix("/if/user/assets").Handler(http.StripPrefix("/if/user", distFs)) indexLessRouter.PathPrefix("/if/user/assets").Handler(http.StripPrefix("/if/user", distFs))
indexLessRouter.PathPrefix("/if/rac/{app_slug}/assets").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
web.DisableIndex(http.StripPrefix(fmt.Sprintf("/if/rac/%s", vars["app_slug"]), distFs)).ServeHTTP(rw, r)
})
indexLessRouter.PathPrefix("/media/").Handler(http.StripPrefix("/media", fs)) indexLessRouter.PathPrefix("/media/").Handler(http.StripPrefix("/media", fs))

Binary file not shown.

Binary file not shown.

1065
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -145,7 +145,12 @@ geoip2 = "*"
gunicorn = "*" gunicorn = "*"
kubernetes = "*" kubernetes = "*"
ldap3 = "*" ldap3 = "*"
lxml = "*" lxml = [
# 5.0.0 works with libxml2 2.11.x, which is standard on brew
{ version = "5.0.0", platform = "darwin" },
# 4.9.x works with previous libxml2 versions, which is what we get on linux
{ version = "4.9.4", platform = "linux" },
]
opencontainers = { extras = ["reggie"], version = "*" } opencontainers = { extras = ["reggie"], version = "*" }
packaging = "*" packaging = "*"
paramiko = "*" paramiko = "*"

38
rac.Dockerfile Normal file
View File

@ -0,0 +1,38 @@
# syntax=docker/dockerfile:1
# Stage 1: Build
FROM docker.io/golang:1.21.5-bookworm AS builder
WORKDIR /go/src/goauthentik.io
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
--mount=type=bind,target=/go/src/goauthentik.io/gen-go-api,src=./gen-go-api \
--mount=type=cache,target=/go/pkg/mod \
go mod download
ENV CGO_ENABLED=0
COPY . .
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
go build -o /go/rac ./cmd/rac
# Stage 2: Run
FROM ghcr.io/beryju/guacd:1.5.3
ARG GIT_BUILD_HASH
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
LABEL org.opencontainers.image.url https://goauthentik.io
LABEL org.opencontainers.image.description goauthentik.io RAC outpost, see https://goauthentik.io for more info.
LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version ${VERSION}
LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
COPY --from=builder /go/rac /
HEALTHCHECK --interval=5s --retries=20 --start-period=3s CMD [ "/rac", "healthcheck" ]
USER 1000
ENTRYPOINT ["/rac"]

1232
schema.yml

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,6 @@
"""LDAP and Outpost e2e tests""" """LDAP and Outpost e2e tests"""
from dataclasses import asdict from dataclasses import asdict
from sys import platform
from time import sleep from time import sleep
from unittest.case import skipUnless
from docker.client import DockerClient, from_env from docker.client import DockerClient, from_env
from docker.models.containers import Container from docker.models.containers import Container
@ -14,13 +12,13 @@ from authentik.blueprints.tests import apply_blueprint, reconcile_app
from authentik.core.models import Application, User from authentik.core.models import Application, User
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.lib.generators import generate_id
from authentik.outposts.apps import MANAGED_OUTPOST from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import Outpost, OutpostConfig, OutpostType from authentik.outposts.models import Outpost, OutpostConfig, OutpostType
from authentik.providers.ldap.models import APIAccessMode, LDAPProvider from authentik.providers.ldap.models import APIAccessMode, LDAPProvider
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderLDAP(SeleniumTestCase): class TestProviderLDAP(SeleniumTestCase):
"""LDAP and Outpost e2e tests""" """LDAP and Outpost e2e tests"""
@ -37,7 +35,10 @@ class TestProviderLDAP(SeleniumTestCase):
container = client.containers.run( container = client.containers.run(
image=self.get_container_image("ghcr.io/goauthentik/dev-ldap"), image=self.get_container_image("ghcr.io/goauthentik/dev-ldap"),
detach=True, detach=True,
network_mode="host", ports={
"3389": "3389",
"6636": "6636",
},
environment={ environment={
"AUTHENTIK_HOST": self.live_server_url, "AUTHENTIK_HOST": self.live_server_url,
"AUTHENTIK_TOKEN": outpost.token.key, "AUTHENTIK_TOKEN": outpost.token.key,
@ -51,15 +52,15 @@ class TestProviderLDAP(SeleniumTestCase):
self.user.save() self.user.save()
ldap: LDAPProvider = LDAPProvider.objects.create( ldap: LDAPProvider = LDAPProvider.objects.create(
name="ldap_provider", name=generate_id(),
authorization_flow=Flow.objects.get(slug="default-authentication-flow"), authorization_flow=Flow.objects.get(slug="default-authentication-flow"),
search_group=self.user.ak_groups.first(), search_group=self.user.ak_groups.first(),
search_mode=APIAccessMode.CACHED, search_mode=APIAccessMode.CACHED,
) )
# we need to create an application to actually access the ldap # we need to create an application to actually access the ldap
Application.objects.create(name="ldap", slug="ldap", provider=ldap) Application.objects.create(name=generate_id(), slug=generate_id(), provider=ldap)
outpost: Outpost = Outpost.objects.create( outpost: Outpost = Outpost.objects.create(
name="ldap_outpost", name=generate_id(),
type=OutpostType.LDAP, type=OutpostType.LDAP,
_config=asdict(OutpostConfig(log_level="debug")), _config=asdict(OutpostConfig(log_level="debug")),
) )

View File

@ -1,8 +1,6 @@
"""test OAuth Provider flow""" """test OAuth Provider flow"""
from sys import platform
from time import sleep from time import sleep
from typing import Any, Optional from typing import Any, Optional
from unittest.case import skipUnless
from docker.types import Healthcheck from docker.types import Healthcheck
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
@ -18,7 +16,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2Github(SeleniumTestCase): class TestProviderOAuth2Github(SeleniumTestCase):
"""test OAuth Provider flow""" """test OAuth Provider flow"""
@ -32,7 +29,9 @@ class TestProviderOAuth2Github(SeleniumTestCase):
return { return {
"image": "grafana/grafana:7.1.0", "image": "grafana/grafana:7.1.0",
"detach": True, "detach": True,
"network_mode": "host", "ports": {
"3000": "3000",
},
"auto_remove": True, "auto_remove": True,
"healthcheck": Healthcheck( "healthcheck": Healthcheck(
test=["CMD", "wget", "--spider", "http://localhost:3000"], test=["CMD", "wget", "--spider", "http://localhost:3000"],

View File

@ -1,8 +1,6 @@
"""test OAuth2 OpenID Provider flow""" """test OAuth2 OpenID Provider flow"""
from sys import platform
from time import sleep from time import sleep
from typing import Any, Optional from typing import Any, Optional
from unittest.case import skipUnless
from docker.types import Healthcheck from docker.types import Healthcheck
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
@ -24,7 +22,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2OAuth(SeleniumTestCase): class TestProviderOAuth2OAuth(SeleniumTestCase):
"""test OAuth with OAuth Provider flow""" """test OAuth with OAuth Provider flow"""
@ -38,13 +35,15 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
return { return {
"image": "grafana/grafana:7.1.0", "image": "grafana/grafana:7.1.0",
"detach": True, "detach": True,
"network_mode": "host",
"auto_remove": True, "auto_remove": True,
"healthcheck": Healthcheck( "healthcheck": Healthcheck(
test=["CMD", "wget", "--spider", "http://localhost:3000"], test=["CMD", "wget", "--spider", "http://localhost:3000"],
interval=5 * 1_000 * 1_000_000, interval=5 * 1_000 * 1_000_000,
start_period=1 * 1_000 * 1_000_000, start_period=1 * 1_000 * 1_000_000,
), ),
"ports": {
"3000": "3000",
},
"environment": { "environment": {
"GF_AUTH_GENERIC_OAUTH_ENABLED": "true", "GF_AUTH_GENERIC_OAUTH_ENABLED": "true",
"GF_AUTH_GENERIC_OAUTH_CLIENT_ID": self.client_id, "GF_AUTH_GENERIC_OAUTH_CLIENT_ID": self.client_id,

View File

@ -1,8 +1,6 @@
"""test OAuth2 OpenID Provider flow""" """test OAuth2 OpenID Provider flow"""
from json import loads from json import loads
from sys import platform
from time import sleep from time import sleep
from unittest.case import skipUnless
from docker import DockerClient, from_env from docker import DockerClient, from_env
from docker.models.containers import Container from docker.models.containers import Container
@ -25,7 +23,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2OIDC(SeleniumTestCase): class TestProviderOAuth2OIDC(SeleniumTestCase):
"""test OAuth with OpenID Provider flow""" """test OAuth with OpenID Provider flow"""
@ -36,13 +33,15 @@ class TestProviderOAuth2OIDC(SeleniumTestCase):
super().setUp() super().setUp()
def setup_client(self) -> Container: def setup_client(self) -> Container:
"""Setup client saml-sp container which we test SAML against""" """Setup client oidc-test-client container which we test OIDC against"""
sleep(1) sleep(1)
client: DockerClient = from_env() client: DockerClient = from_env()
container = client.containers.run( container = client.containers.run(
image="ghcr.io/beryju/oidc-test-client:1.3", image="ghcr.io/beryju/oidc-test-client:1.3",
detach=True, detach=True,
network_mode="host", ports={
"9009": "9009",
},
environment={ environment={
"OIDC_CLIENT_ID": self.client_id, "OIDC_CLIENT_ID": self.client_id,
"OIDC_CLIENT_SECRET": self.client_secret, "OIDC_CLIENT_SECRET": self.client_secret,

View File

@ -1,8 +1,6 @@
"""test OAuth2 OpenID Provider flow""" """test OAuth2 OpenID Provider flow"""
from json import loads from json import loads
from sys import platform
from time import sleep from time import sleep
from unittest.case import skipUnless
from docker import DockerClient, from_env from docker import DockerClient, from_env
from docker.models.containers import Container from docker.models.containers import Container
@ -25,7 +23,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2OIDCImplicit(SeleniumTestCase): class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
"""test OAuth with OpenID Provider flow""" """test OAuth with OpenID Provider flow"""
@ -36,13 +33,15 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
super().setUp() super().setUp()
def setup_client(self) -> Container: def setup_client(self) -> Container:
"""Setup client saml-sp container which we test SAML against""" """Setup client oidc-test-client container which we test OIDC against"""
sleep(1) sleep(1)
client: DockerClient = from_env() client: DockerClient = from_env()
container = client.containers.run( container = client.containers.run(
image="ghcr.io/beryju/oidc-test-client:1.3", image="ghcr.io/beryju/oidc-test-client:1.3",
detach=True, detach=True,
network_mode="host", ports={
"9009": "9009",
},
environment={ environment={
"OIDC_CLIENT_ID": self.client_id, "OIDC_CLIENT_ID": self.client_id,
"OIDC_CLIENT_SECRET": self.client_secret, "OIDC_CLIENT_SECRET": self.client_secret,

View File

@ -21,7 +21,6 @@ from authentik.providers.proxy.models import ProxyProvider
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderProxy(SeleniumTestCase): class TestProviderProxy(SeleniumTestCase):
"""Proxy and Outpost e2e tests""" """Proxy and Outpost e2e tests"""
@ -36,7 +35,9 @@ class TestProviderProxy(SeleniumTestCase):
return { return {
"image": "traefik/whoami:latest", "image": "traefik/whoami:latest",
"detach": True, "detach": True,
"network_mode": "host", "ports": {
"80": "80",
},
"auto_remove": True, "auto_remove": True,
} }
@ -46,7 +47,9 @@ class TestProviderProxy(SeleniumTestCase):
container = client.containers.run( container = client.containers.run(
image=self.get_container_image("ghcr.io/goauthentik/dev-proxy"), image=self.get_container_image("ghcr.io/goauthentik/dev-proxy"),
detach=True, detach=True,
network_mode="host", ports={
"9000": "9000",
},
environment={ environment={
"AUTHENTIK_HOST": self.live_server_url, "AUTHENTIK_HOST": self.live_server_url,
"AUTHENTIK_TOKEN": outpost.token.key, "AUTHENTIK_TOKEN": outpost.token.key,
@ -78,7 +81,7 @@ class TestProviderProxy(SeleniumTestCase):
authorization_flow=Flow.objects.get( authorization_flow=Flow.objects.get(
slug="default-provider-authorization-implicit-consent" slug="default-provider-authorization-implicit-consent"
), ),
internal_host="http://localhost", internal_host=f"http://{self.host}",
external_host="http://localhost:9000", external_host="http://localhost:9000",
) )
# Ensure OAuth2 Params are set # Ensure OAuth2 Params are set
@ -145,7 +148,7 @@ class TestProviderProxy(SeleniumTestCase):
authorization_flow=Flow.objects.get( authorization_flow=Flow.objects.get(
slug="default-provider-authorization-implicit-consent" slug="default-provider-authorization-implicit-consent"
), ),
internal_host="http://localhost", internal_host=f"http://{self.host}",
external_host="http://localhost:9000", external_host="http://localhost:9000",
basic_auth_enabled=True, basic_auth_enabled=True,
basic_auth_user_attribute="basic-username", basic_auth_user_attribute="basic-username",

View File

@ -1,8 +1,6 @@
"""Radius e2e tests""" """Radius e2e tests"""
from dataclasses import asdict from dataclasses import asdict
from sys import platform
from time import sleep from time import sleep
from unittest.case import skipUnless
from docker.client import DockerClient, from_env from docker.client import DockerClient, from_env
from docker.models.containers import Container from docker.models.containers import Container
@ -19,7 +17,6 @@ from authentik.providers.radius.models import RadiusProvider
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderRadius(SeleniumTestCase): class TestProviderRadius(SeleniumTestCase):
"""Radius Outpost e2e tests""" """Radius Outpost e2e tests"""
@ -40,7 +37,7 @@ class TestProviderRadius(SeleniumTestCase):
container = client.containers.run( container = client.containers.run(
image=self.get_container_image("ghcr.io/goauthentik/dev-radius"), image=self.get_container_image("ghcr.io/goauthentik/dev-radius"),
detach=True, detach=True,
network_mode="host", ports={"1812/udp": "1812/udp"},
environment={ environment={
"AUTHENTIK_HOST": self.live_server_url, "AUTHENTIK_HOST": self.live_server_url,
"AUTHENTIK_TOKEN": outpost.token.key, "AUTHENTIK_TOKEN": outpost.token.key,

View File

@ -1,8 +1,6 @@
"""test SAML Provider flow""" """test SAML Provider flow"""
from json import loads from json import loads
from sys import platform
from time import sleep from time import sleep
from unittest.case import skipUnless
from docker import DockerClient, from_env from docker import DockerClient, from_env
from docker.models.containers import Container from docker.models.containers import Container
@ -20,7 +18,6 @@ from authentik.sources.saml.processors.constants import SAML_BINDING_POST
from tests.e2e.utils import SeleniumTestCase, retry from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderSAML(SeleniumTestCase): class TestProviderSAML(SeleniumTestCase):
"""test SAML Provider flow""" """test SAML Provider flow"""
@ -41,7 +38,9 @@ class TestProviderSAML(SeleniumTestCase):
container = client.containers.run( container = client.containers.run(
image="ghcr.io/beryju/saml-test-sp:1.1", image="ghcr.io/beryju/saml-test-sp:1.1",
detach=True, detach=True,
network_mode="host", ports={
"9009": "9009",
},
environment={ environment={
"SP_ENTITY_ID": provider.issuer, "SP_ENTITY_ID": provider.issuer,
"SP_SSO_BINDING": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", "SP_SSO_BINDING": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",

View File

@ -0,0 +1,141 @@
"""test OAuth Source"""
from time import sleep
from typing import Any, Optional
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.support import expected_conditions as ec
from selenium.webdriver.support.wait import WebDriverWait
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import User
from authentik.flows.models import Flow
from authentik.lib.generators import generate_id, generate_key
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.stages.identification.models import IdentificationStage
from tests.e2e.utils import SeleniumTestCase, retry
class OAuth1Callback(OAuthCallback):
"""OAuth1 Callback with custom getters"""
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("id")
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("screen_name"),
"email": info.get("email"),
"name": info.get("name"),
}
@registry.register()
class OAUth1Type(SourceType):
"""OAuth1 Type definition"""
callback_view = OAuth1Callback
verbose_name = "OAuth1"
name = "oauth1"
request_token_url = "http://localhost:5001/oauth/request_token" # nosec
access_token_url = "http://localhost:5001/oauth/access_token" # nosec
authorization_url = "http://localhost:5001/oauth/authorize"
profile_url = "http://localhost:5001/api/me"
urls_customizable = False
class TestSourceOAuth1(SeleniumTestCase):
"""Test OAuth1 Source"""
def setUp(self) -> None:
self.client_id = generate_id()
self.client_secret = generate_key()
self.source_slug = generate_id()
super().setUp()
def get_container_specs(self) -> Optional[dict[str, Any]]:
return {
"image": "ghcr.io/beryju/oauth1-test-server:v1.1",
"detach": True,
"ports": {"5000": "5001"},
"auto_remove": True,
"environment": {
"OAUTH1_CLIENT_ID": self.client_id,
"OAUTH1_CLIENT_SECRET": self.client_secret,
"OAUTH1_REDIRECT_URI": self.url(
"authentik_sources_oauth:oauth-client-callback",
source_slug=self.source_slug,
),
},
}
def create_objects(self):
"""Create required objects"""
# Bootstrap all needed objects
authentication_flow = Flow.objects.get(slug="default-source-authentication")
enrollment_flow = Flow.objects.get(slug="default-source-enrollment")
source = OAuthSource.objects.create( # nosec
name=generate_id(),
slug=self.source_slug,
authentication_flow=authentication_flow,
enrollment_flow=enrollment_flow,
provider_type="oauth1",
consumer_key=self.client_id,
consumer_secret=self.client_secret,
)
ident_stage = IdentificationStage.objects.first()
ident_stage.sources.set([source])
ident_stage.save()
@retry()
@apply_blueprint(
"default/flow-default-authentication-flow.yaml",
"default/flow-default-invalidation-flow.yaml",
)
@apply_blueprint(
"default/flow-default-source-authentication.yaml",
"default/flow-default-source-enrollment.yaml",
"default/flow-default-source-pre-authentication.yaml",
)
def test_oauth_enroll(self):
"""test OAuth Source With With OIDC"""
self.create_objects()
self.driver.get(self.live_server_url)
flow_executor = self.get_shadow_root("ak-flow-executor")
identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor)
wait = WebDriverWait(identification_stage, self.wait_timeout)
wait.until(
ec.presence_of_element_located(
(By.CSS_SELECTOR, ".pf-c-login__main-footer-links-item > button")
)
)
identification_stage.find_element(
By.CSS_SELECTOR, ".pf-c-login__main-footer-links-item > button"
).click()
# Now we should be at the IDP, wait for the login field
self.wait.until(ec.presence_of_element_located((By.NAME, "username")))
self.driver.find_element(By.NAME, "username").send_keys("example-user")
self.driver.find_element(By.NAME, "username").send_keys(Keys.ENTER)
sleep(2)
# Wait until we're logged in
self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "[name='confirm']")))
self.driver.find_element(By.CSS_SELECTOR, "[name='confirm']").click()
# Wait until we've loaded the user info page
sleep(2)
# Wait until we've logged in
self.wait_for_url(self.if_user_url("/library"))
self.driver.get(self.if_user_url("/settings"))
self.assert_user(User(username="example-user", name="test name", email="foo@example.com"))

Some files were not shown because too many files have changed in this diff Show More