Merge branch 'dev' into web/sidebar-with-live-content-3
* dev: (131 commits) web: Replace calls to `rootInterface()?.tenant?` with a contextual `this.tenant` object (#7778) web: abstract `rootInterface()?.config?.capabilities.includes()` into `.can()` (#7737) web: update some locale details (#8090) web: bump the eslint group in /web with 2 updates (#8082) web: bump rollup from 4.9.2 to 4.9.4 in /web (#8083) core: bump github.com/redis/go-redis/v9 from 9.3.1 to 9.4.0 (#8085) web: bump the eslint group in /tests/wdio with 2 updates (#8086) website: bump @types/react from 18.2.46 to 18.2.47 in /website (#8088) stages/user_login: only set last_ip in session if a binding is given (#8074) providers/oauth2: fix missing nonce in token endpoint not being saved (#8073) core: bump goauthentik.io/api/v3 from 3.2023105.3 to 3.2023105.5 (#8066) providers/oauth2: fix missing nonce in id_token (#8072) rbac: fix error when looking up permissions for now uninstalled apps (#8068) web/flows: fix device picker incorrect foreground color (#8067) translate: Updates for file web/xliff/en.xlf in zh_CN (#8061) translate: Updates for file web/xliff/en.xlf in zh-Hans (#8062) website: bump postcss from 8.4.32 to 8.4.33 in /website (#8063) web: bump the sentry group in /web with 2 updates (#8064) core: bump golang.org/x/sync from 0.5.0 to 0.6.0 (#8065) website/docs: add link to our example flows (#8052) ...
This commit is contained in:
commit
9768684c3c
|
@ -1,5 +1,5 @@
|
|||
[bumpversion]
|
||||
current_version = 2023.10.4
|
||||
current_version = 2023.10.5
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
|
|
|
@ -9,3 +9,4 @@ blueprints/local
|
|||
.git
|
||||
!gen-ts-api/node_modules
|
||||
!gen-ts-api/dist/**
|
||||
!gen-go-api/
|
||||
|
|
|
@ -2,3 +2,4 @@ keypair
|
|||
keypairs
|
||||
hass
|
||||
warmup
|
||||
ontext
|
||||
|
|
|
@ -61,10 +61,6 @@ jobs:
|
|||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup authentik env
|
||||
uses: ./.github/actions/setup
|
||||
with:
|
||||
postgresql_version: ${{ matrix.psql }}
|
||||
- name: checkout stable
|
||||
run: |
|
||||
# Delete all poetry envs
|
||||
|
@ -76,7 +72,7 @@ jobs:
|
|||
git checkout version/$(python -c "from authentik import __version__; print(__version__)")
|
||||
rm -rf .github/ scripts/
|
||||
mv ../.github ../scripts .
|
||||
- name: Setup authentik env (ensure stable deps are installed)
|
||||
- name: Setup authentik env (stable)
|
||||
uses: ./.github/actions/setup
|
||||
with:
|
||||
postgresql_version: ${{ matrix.psql }}
|
||||
|
@ -90,15 +86,20 @@ jobs:
|
|||
git clean -d -fx .
|
||||
git checkout $GITHUB_SHA
|
||||
# Delete previous poetry env
|
||||
rm -rf $(poetry env info --path)
|
||||
rm -rf /home/runner/.cache/pypoetry/virtualenvs/*
|
||||
- name: Setup authentik env (ensure latest deps are installed)
|
||||
uses: ./.github/actions/setup
|
||||
with:
|
||||
postgresql_version: ${{ matrix.psql }}
|
||||
- name: migrate to latest
|
||||
run: |
|
||||
poetry install
|
||||
poetry run python -m lifecycle.migrate
|
||||
- name: run tests
|
||||
env:
|
||||
# Test in the main database that we just migrated from the previous stable version
|
||||
AUTHENTIK_POSTGRESQL__TEST__NAME: authentik
|
||||
run: |
|
||||
poetry run make test
|
||||
test-unittest:
|
||||
name: test-unittest - PostgreSQL ${{ matrix.psql }}
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -248,12 +249,6 @@ jobs:
|
|||
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
- name: Comment on PR
|
||||
if: github.event_name == 'pull_request'
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/comment-pr-instructions
|
||||
with:
|
||||
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
|
||||
build-arm64:
|
||||
needs: ci-core-mark
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -302,3 +297,26 @@ jobs:
|
|||
platforms: linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
pr-comment:
|
||||
needs:
|
||||
- build
|
||||
- build-arm64
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
permissions:
|
||||
# Needed to write comments on PRs
|
||||
pull-requests: write
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: prepare variables
|
||||
uses: ./.github/actions/docker-push-variables
|
||||
id: ev
|
||||
env:
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
- name: Comment on PR
|
||||
uses: ./.github/actions/comment-pr-instructions
|
||||
with:
|
||||
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
|
||||
|
|
|
@ -65,6 +65,7 @@ jobs:
|
|||
- proxy
|
||||
- ldap
|
||||
- radius
|
||||
- rac
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
# Needed to upload contianer images to ghcr.io
|
||||
|
@ -119,6 +120,7 @@ jobs:
|
|||
- proxy
|
||||
- ldap
|
||||
- radius
|
||||
- rac
|
||||
goos: [linux]
|
||||
goarch: [amd64, arm64]
|
||||
steps:
|
||||
|
|
|
@ -27,10 +27,10 @@ jobs:
|
|||
- name: Setup authentik env
|
||||
uses: ./.github/actions/setup
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v2
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v2
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
||||
uses: github/codeql-action/analyze@v3
|
||||
|
|
|
@ -65,6 +65,7 @@ jobs:
|
|||
- proxy
|
||||
- ldap
|
||||
- radius
|
||||
- rac
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
|
|
|
@ -71,7 +71,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
|||
# Stage 4: MaxMind GeoIP
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v6.0 as geoip
|
||||
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City"
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||
ENV GEOIPUPDATE_VERBOSE="true"
|
||||
ENV GEOIPUPDATE_ACCOUNT_ID_FILE="/run/secrets/GEOIPUPDATE_ACCOUNT_ID"
|
||||
ENV GEOIPUPDATE_LICENSE_KEY_FILE="/run/secrets/GEOIPUPDATE_LICENSE_KEY"
|
||||
|
|
7
Makefile
7
Makefile
|
@ -58,7 +58,7 @@ test: ## Run the server tests and produce a coverage report (locally)
|
|||
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
|
||||
isort $(PY_SOURCES)
|
||||
black $(PY_SOURCES)
|
||||
ruff $(PY_SOURCES)
|
||||
ruff --fix $(PY_SOURCES)
|
||||
codespell -w $(CODESPELL_ARGS)
|
||||
|
||||
lint: ## Lint the python and golang sources
|
||||
|
@ -115,8 +115,9 @@ gen-diff: ## (Release) generate the changelog diff between the current schema a
|
|||
npx prettier --write diff.md
|
||||
|
||||
gen-clean:
|
||||
rm -rf web/api/src/
|
||||
rm -rf api/
|
||||
rm -rf gen-go-api/
|
||||
rm -rf gen-ts-api/
|
||||
rm -rf web/node_modules/@goauthentik/api/
|
||||
|
||||
gen-client-ts: ## Build and install the authentik API for Typescript into the authentik UI Application
|
||||
docker run \
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from os import environ
|
||||
from typing import Optional
|
||||
|
||||
__version__ = "2023.10.4"
|
||||
__version__ = "2023.10.5"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from authentik.blueprints.tests import reconcile_app
|
|||
from authentik.core.models import Token, TokenIntents, User, UserTypes
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.models import Outpost
|
||||
from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API
|
||||
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider
|
||||
|
||||
|
@ -49,8 +51,12 @@ class TestAPIAuth(TestCase):
|
|||
with self.assertRaises(AuthenticationFailed):
|
||||
bearer_auth(f"Bearer {token.key}".encode())
|
||||
|
||||
def test_managed_outpost(self):
|
||||
@reconcile_app("authentik_outposts")
|
||||
def test_managed_outpost_fail(self):
|
||||
"""Test managed outpost"""
|
||||
outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||
outpost.user.delete()
|
||||
outpost.delete()
|
||||
with self.assertRaises(AuthenticationFailed):
|
||||
bearer_auth(f"Bearer {settings.SECRET_KEY}".encode())
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from rest_framework.response import Response
|
|||
from rest_framework.views import APIView
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
from authentik.events.context_processors.base import get_context_processors
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
capabilities = Signal()
|
||||
|
@ -30,6 +30,7 @@ class Capabilities(models.TextChoices):
|
|||
|
||||
CAN_SAVE_MEDIA = "can_save_media"
|
||||
CAN_GEO_IP = "can_geo_ip"
|
||||
CAN_ASN = "can_asn"
|
||||
CAN_IMPERSONATE = "can_impersonate"
|
||||
CAN_DEBUG = "can_debug"
|
||||
IS_ENTERPRISE = "is_enterprise"
|
||||
|
@ -68,8 +69,9 @@ class ConfigView(APIView):
|
|||
deb_test = settings.DEBUG or settings.TEST
|
||||
if Path(settings.MEDIA_ROOT).is_mount() or deb_test:
|
||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||
if GEOIP_READER.enabled:
|
||||
caps.append(Capabilities.CAN_GEO_IP)
|
||||
for processor in get_context_processors():
|
||||
if cap := processor.capability():
|
||||
caps.append(cap)
|
||||
if CONFIG.get_bool("impersonation"):
|
||||
caps.append(Capabilities.CAN_IMPERSONATE)
|
||||
if settings.DEBUG: # pragma: no cover
|
||||
|
|
|
@ -3,7 +3,7 @@ from django.utils.translation import gettext_lazy as _
|
|||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, DateTimeField, JSONField
|
||||
from rest_framework.fields import CharField, DateTimeField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ListSerializer, ModelSerializer
|
||||
|
@ -15,7 +15,7 @@ from authentik.blueprints.v1.importer import Importer
|
|||
from authentik.blueprints.v1.oci import OCI_PREFIX
|
||||
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
|
||||
|
||||
class ManagedSerializer:
|
||||
|
@ -28,7 +28,7 @@ class MetadataSerializer(PassiveSerializer):
|
|||
"""Serializer for blueprint metadata"""
|
||||
|
||||
name = CharField()
|
||||
labels = JSONField()
|
||||
labels = JSONDictField()
|
||||
|
||||
|
||||
class BlueprintInstanceSerializer(ModelSerializer):
|
||||
|
|
|
@ -40,7 +40,7 @@ class ManagedAppConfig(AppConfig):
|
|||
meth()
|
||||
self._logger.debug("Successfully reconciled", name=name)
|
||||
except (DatabaseError, ProgrammingError, InternalError) as exc:
|
||||
self._logger.debug("Failed to run reconcile", name=name, exc=exc)
|
||||
self._logger.warning("Failed to run reconcile", name=name, exc=exc)
|
||||
|
||||
|
||||
class AuthentikBlueprintsConfig(ManagedAppConfig):
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import BooleanField, JSONField
|
||||
from rest_framework.fields import BooleanField
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.blueprints.v1.meta.registry import BaseMetaModel, MetaResult, registry
|
||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.blueprints.models import BlueprintInstance
|
||||
|
@ -17,7 +17,7 @@ LOGGER = get_logger()
|
|||
class ApplyBlueprintMetaSerializer(PassiveSerializer):
|
||||
"""Serializer for meta apply blueprint model"""
|
||||
|
||||
identifiers = JSONField(validators=[is_dict])
|
||||
identifiers = JSONDictField()
|
||||
required = BooleanField(default=True)
|
||||
|
||||
# We cannot override `instance` as that will confuse rest_framework
|
||||
|
|
|
@ -14,7 +14,8 @@ from ua_parser import user_agent_parser
|
|||
from authentik.api.authorization import OwnerSuperuserPermissions
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.models import AuthenticatedSession
|
||||
from authentik.events.geo import GEOIP_READER, GeoIPDict
|
||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR, ASNDict
|
||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR, GeoIPDict
|
||||
|
||||
|
||||
class UserAgentDeviceDict(TypedDict):
|
||||
|
@ -59,6 +60,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
|||
current = SerializerMethodField()
|
||||
user_agent = SerializerMethodField()
|
||||
geo_ip = SerializerMethodField()
|
||||
asn = SerializerMethodField()
|
||||
|
||||
def get_current(self, instance: AuthenticatedSession) -> bool:
|
||||
"""Check if session is currently active session"""
|
||||
|
@ -70,8 +72,12 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
|||
return user_agent_parser.Parse(instance.last_user_agent)
|
||||
|
||||
def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover
|
||||
"""Get parsed user agent"""
|
||||
return GEOIP_READER.city_dict(instance.last_ip)
|
||||
"""Get GeoIP Data"""
|
||||
return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.last_ip)
|
||||
|
||||
def get_asn(self, instance: AuthenticatedSession) -> Optional[ASNDict]: # pragma: no cover
|
||||
"""Get ASN Data"""
|
||||
return ASN_CONTEXT_PROCESSOR.asn_dict(instance.last_ip)
|
||||
|
||||
class Meta:
|
||||
model = AuthenticatedSession
|
||||
|
@ -80,6 +86,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
|||
"current",
|
||||
"user_agent",
|
||||
"geo_ip",
|
||||
"asn",
|
||||
"user",
|
||||
"last_ip",
|
||||
"last_user_agent",
|
||||
|
|
|
@ -8,7 +8,7 @@ from django_filters.filterset import FilterSet
|
|||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import CharField, IntegerField, JSONField
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError
|
||||
|
@ -16,7 +16,7 @@ from rest_framework.viewsets import ModelViewSet
|
|||
|
||||
from authentik.api.decorators import permission_required
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.rbac.api.roles import RoleSerializer
|
||||
|
||||
|
@ -24,7 +24,7 @@ from authentik.rbac.api.roles import RoleSerializer
|
|||
class GroupMemberSerializer(ModelSerializer):
|
||||
"""Stripped down user serializer to show relevant users for groups"""
|
||||
|
||||
attributes = JSONField(validators=[is_dict], required=False)
|
||||
attributes = JSONDictField(required=False)
|
||||
uid = CharField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
@ -44,7 +44,7 @@ class GroupMemberSerializer(ModelSerializer):
|
|||
class GroupSerializer(ModelSerializer):
|
||||
"""Group Serializer"""
|
||||
|
||||
attributes = JSONField(validators=[is_dict], required=False)
|
||||
attributes = JSONDictField(required=False)
|
||||
users_obj = ListSerializer(
|
||||
child=GroupMemberSerializer(), read_only=True, source="users", required=False
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ from authentik.core.api.used_by import UsedByMixin
|
|||
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||
from authentik.core.models import PropertyMapping
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
from authentik.policies.api.exec import PolicyTestSerializer
|
||||
|
@ -95,6 +96,7 @@ class PropertyMappingViewSet(
|
|||
"description": subclass.__doc__,
|
||||
"component": subclass().component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||
}
|
||||
)
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
||||
|
|
|
@ -16,6 +16,7 @@ from rest_framework.viewsets import GenericViewSet
|
|||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
|
||||
|
||||
|
@ -113,6 +114,7 @@ class ProviderViewSet(
|
|||
"description": subclass.__doc__,
|
||||
"component": subclass().component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||
}
|
||||
)
|
||||
data.append(
|
||||
|
|
|
@ -32,13 +32,7 @@ from drf_spectacular.utils import (
|
|||
)
|
||||
from guardian.shortcuts import get_anonymous_user, get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import (
|
||||
CharField,
|
||||
IntegerField,
|
||||
JSONField,
|
||||
ListField,
|
||||
SerializerMethodField,
|
||||
)
|
||||
from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import (
|
||||
|
@ -57,7 +51,7 @@ from authentik.admin.api.metrics import CoordinateSerializer
|
|||
from authentik.api.decorators import permission_required
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
|
||||
from authentik.core.api.utils import JSONDictField, LinkSerializer, PassiveSerializer
|
||||
from authentik.core.middleware import (
|
||||
SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
|
||||
SESSION_KEY_IMPERSONATE_USER,
|
||||
|
@ -89,7 +83,7 @@ LOGGER = get_logger()
|
|||
class UserGroupSerializer(ModelSerializer):
|
||||
"""Simplified Group Serializer for user's groups"""
|
||||
|
||||
attributes = JSONField(required=False)
|
||||
attributes = JSONDictField(required=False)
|
||||
parent_name = CharField(source="parent.name", read_only=True)
|
||||
|
||||
class Meta:
|
||||
|
@ -110,7 +104,7 @@ class UserSerializer(ModelSerializer):
|
|||
|
||||
is_superuser = BooleanField(read_only=True)
|
||||
avatar = CharField(read_only=True)
|
||||
attributes = JSONField(validators=[is_dict], required=False)
|
||||
attributes = JSONDictField(required=False)
|
||||
groups = PrimaryKeyRelatedField(
|
||||
allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list
|
||||
)
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
from typing import Any
|
||||
|
||||
from django.db.models import Model
|
||||
from rest_framework.fields import CharField, IntegerField, JSONField
|
||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.plumbing import build_basic_type
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from rest_framework.fields import BooleanField, CharField, IntegerField, JSONField
|
||||
from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
|
||||
|
||||
|
||||
|
@ -13,6 +16,21 @@ def is_dict(value: Any):
|
|||
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
|
||||
|
||||
|
||||
class JSONDictField(JSONField):
|
||||
"""JSON Field which only allows dictionaries"""
|
||||
|
||||
default_validators = [is_dict]
|
||||
|
||||
|
||||
class JSONExtension(OpenApiSerializerFieldExtension):
|
||||
"""Generate API Schema for JSON fields as"""
|
||||
|
||||
target_class = "authentik.core.api.utils.JSONDictField"
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
return build_basic_type(OpenApiTypes.OBJECT)
|
||||
|
||||
|
||||
class PassiveSerializer(Serializer):
|
||||
"""Base serializer class which doesn't implement create/update methods"""
|
||||
|
||||
|
@ -26,7 +44,7 @@ class PassiveSerializer(Serializer):
|
|||
class PropertyMappingPreviewSerializer(PassiveSerializer):
|
||||
"""Preview how the current user is mapped via the property mappings selected in a provider"""
|
||||
|
||||
preview = JSONField(read_only=True)
|
||||
preview = JSONDictField(read_only=True)
|
||||
|
||||
|
||||
class MetaNameSerializer(PassiveSerializer):
|
||||
|
@ -56,6 +74,7 @@ class TypeCreateSerializer(PassiveSerializer):
|
|||
description = CharField(required=True)
|
||||
component = CharField(required=True)
|
||||
model_name = CharField(required=True)
|
||||
requires_enterprise = BooleanField(default=False)
|
||||
|
||||
|
||||
class CacheSerializer(PassiveSerializer):
|
||||
|
|
|
@ -1,22 +1,29 @@
|
|||
"""Channels base classes"""
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.authentication import bearer_auth
|
||||
from authentik.core.models import User
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class AuthJsonConsumer(JsonWebsocketConsumer):
|
||||
class TokenOutpostMiddleware:
|
||||
"""Authorize a client with a token"""
|
||||
|
||||
user: User
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
def connect(self):
|
||||
headers = dict(self.scope["headers"])
|
||||
async def __call__(self, scope, receive, send):
|
||||
scope = dict(scope)
|
||||
await self.auth(scope)
|
||||
return await self.inner(scope, receive, send)
|
||||
|
||||
@database_sync_to_async
|
||||
def auth(self, scope):
|
||||
"""Authenticate request from header"""
|
||||
headers = dict(scope["headers"])
|
||||
if b"authorization" not in headers:
|
||||
LOGGER.warning("WS Request without authorization header")
|
||||
raise DenyConnection()
|
||||
|
@ -32,4 +39,4 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
|
|||
LOGGER.warning("Failed to authenticate", exc=exc)
|
||||
raise DenyConnection()
|
||||
|
||||
self.user = user
|
||||
scope["user"] = user
|
||||
|
|
|
@ -44,6 +44,7 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
|||
if request:
|
||||
req.http_request = request
|
||||
self._context["request"] = req
|
||||
req.context.update(**kwargs)
|
||||
self._context.update(**kwargs)
|
||||
self.dry_run = dry_run
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ from authentik.lib.models import (
|
|||
DomainlessFormattedURLValidator,
|
||||
SerializerModel,
|
||||
)
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.root.install_id import get_install_id
|
||||
|
||||
|
@ -748,12 +747,14 @@ class AuthenticatedSession(ExpiringModel):
|
|||
@staticmethod
|
||||
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
|
||||
"""Create a new session from a http request"""
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
if not hasattr(request, "session") or not request.session.session_key:
|
||||
return None
|
||||
return AuthenticatedSession(
|
||||
session_key=request.session.session_key,
|
||||
user=user,
|
||||
last_ip=get_client_ip(request),
|
||||
last_ip=ClientIPMiddleware.get_client_ip(request),
|
||||
last_user_agent=request.META.get("HTTP_USER_AGENT", ""),
|
||||
expires=request.session.get_expiry_date(),
|
||||
)
|
||||
|
|
|
@ -44,28 +44,14 @@
|
|||
|
||||
{% block body %}
|
||||
<div class="pf-c-background-image">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" class="pf-c-background-image__filter" width="0" height="0">
|
||||
<filter id="image_overlay">
|
||||
<feColorMatrix in="SourceGraphic" type="matrix" values="1.3 0 0 0 0 0 1.3 0 0 0 0 0 1.3 0 0 0 0 0 1 0" />
|
||||
<feComponentTransfer color-interpolation-filters="sRGB" result="duotone">
|
||||
<feFuncR type="table" tableValues="0.086274509803922 0.43921568627451"></feFuncR>
|
||||
<feFuncG type="table" tableValues="0.086274509803922 0.43921568627451"></feFuncG>
|
||||
<feFuncB type="table" tableValues="0.086274509803922 0.43921568627451"></feFuncB>
|
||||
<feFuncA type="table" tableValues="0 1"></feFuncA>
|
||||
</feComponentTransfer>
|
||||
</filter>
|
||||
</svg>
|
||||
</div>
|
||||
<ak-message-container></ak-message-container>
|
||||
<div class="pf-c-login">
|
||||
<div class="pf-c-login stacked">
|
||||
<div class="ak-login-container">
|
||||
<header class="pf-c-login__header">
|
||||
<div class="pf-c-brand ak-brand">
|
||||
<main class="pf-c-login__main">
|
||||
<div class="pf-c-login__main-header pf-c-brand ak-brand">
|
||||
<img src="{{ tenant.branding_logo }}" alt="authentik Logo" />
|
||||
</div>
|
||||
</header>
|
||||
{% block main_container %}
|
||||
<main class="pf-c-login__main">
|
||||
<header class="pf-c-login__main-header">
|
||||
<h1 class="pf-c-title pf-m-3xl">
|
||||
{% block card_title %}
|
||||
|
@ -77,7 +63,6 @@
|
|||
{% endblock %}
|
||||
</div>
|
||||
</main>
|
||||
{% endblock %}
|
||||
<footer class="pf-c-login__footer">
|
||||
<ul class="pf-c-list pf-m-inline">
|
||||
{% for link in footer_links %}
|
||||
|
|
|
@ -22,6 +22,7 @@ class InterfaceView(TemplateView):
|
|||
kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}"
|
||||
kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}"
|
||||
kwargs["build"] = get_build_hash()
|
||||
kwargs["url_kwargs"] = self.kwargs
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
|
@ -20,6 +22,18 @@ from authentik.enterprise.models import License, LicenseKey
|
|||
from authentik.root.install_id import get_install_id
|
||||
|
||||
|
||||
class EnterpriseRequiredMixin:
|
||||
"""Mixin to validate that a valid enterprise license
|
||||
exists before allowing to safe the object"""
|
||||
|
||||
def validate(self, attrs: dict) -> dict:
|
||||
"""Check that a valid license exists"""
|
||||
total = LicenseKey.get_total()
|
||||
if not total.is_valid():
|
||||
raise ValidationError(_("Enterprise is required to create/update this object."))
|
||||
return super().validate(attrs)
|
||||
|
||||
|
||||
class LicenseSerializer(ModelSerializer):
|
||||
"""License Serializer"""
|
||||
|
||||
|
|
|
@ -2,7 +2,11 @@
|
|||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
|
||||
|
||||
class AuthentikEnterpriseConfig(ManagedAppConfig):
|
||||
class EnterpriseConfig(ManagedAppConfig):
|
||||
"""Base app config for all enterprise apps"""
|
||||
|
||||
|
||||
class AuthentikEnterpriseConfig(EnterpriseConfig):
|
||||
"""Enterprise app config"""
|
||||
|
||||
name = "authentik.enterprise"
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Enterprise license policies"""
|
||||
from typing import Optional
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from authentik.core.models import User, UserTypes
|
||||
from authentik.enterprise.models import LicenseKey
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
|
@ -13,10 +15,10 @@ class EnterprisePolicyAccessView(PolicyAccessView):
|
|||
def check_license(self):
|
||||
"""Check license"""
|
||||
if not LicenseKey.get_total().is_valid():
|
||||
return False
|
||||
return PolicyResult(False, _("Enterprise required to access this feature."))
|
||||
if self.request.user.type != UserTypes.INTERNAL:
|
||||
return False
|
||||
return True
|
||||
return PolicyResult(False, _("Feature only accessible for internal users."))
|
||||
return PolicyResult(True)
|
||||
|
||||
def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
|
||||
user = user or self.request.user
|
||||
|
@ -24,7 +26,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
|
|||
request.http_request = self.request
|
||||
result = super().user_has_access(user)
|
||||
enterprise_result = self.check_license()
|
||||
if not enterprise_result:
|
||||
if not enterprise_result.passing:
|
||||
return enterprise_result
|
||||
return result
|
||||
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
"""RAC Provider API Views"""
|
||||
from typing import Optional
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import QuerySet
|
||||
from django.urls import reverse
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from rest_framework.fields import SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
|
||||
from authentik.enterprise.providers.rac.models import Endpoint
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def user_endpoint_cache_key(user_pk: str) -> str:
|
||||
"""Cache key where endpoint list for user is saved"""
|
||||
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
|
||||
|
||||
|
||||
class EndpointSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||
"""Endpoint Serializer"""
|
||||
|
||||
provider_obj = RACProviderSerializer(source="provider", read_only=True)
|
||||
launch_url = SerializerMethodField()
|
||||
|
||||
def get_launch_url(self, endpoint: Endpoint) -> Optional[str]:
|
||||
"""Build actual launch URL (the provider itself does not have one, just
|
||||
individual endpoints)"""
|
||||
try:
|
||||
# pylint: disable=no-member
|
||||
return reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk},
|
||||
)
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return None
|
||||
|
||||
class Meta:
|
||||
model = Endpoint
|
||||
fields = [
|
||||
"pk",
|
||||
"name",
|
||||
"provider",
|
||||
"provider_obj",
|
||||
"protocol",
|
||||
"host",
|
||||
"settings",
|
||||
"property_mappings",
|
||||
"auth_mode",
|
||||
"launch_url",
|
||||
"maximum_connections",
|
||||
]
|
||||
|
||||
|
||||
class EndpointViewSet(UsedByMixin, ModelViewSet):
|
||||
"""Endpoint Viewset"""
|
||||
|
||||
queryset = Endpoint.objects.all()
|
||||
serializer_class = EndpointSerializer
|
||||
filterset_fields = ["name", "provider"]
|
||||
search_fields = ["name", "protocol"]
|
||||
ordering = ["name", "protocol"]
|
||||
|
||||
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
|
||||
"""Custom filter_queryset method which ignores guardian, but still supports sorting"""
|
||||
for backend in list(self.filter_backends):
|
||||
if backend == ObjectFilter:
|
||||
continue
|
||||
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||
return queryset
|
||||
|
||||
def _get_allowed_endpoints(self, queryset: QuerySet) -> list[Endpoint]:
|
||||
endpoints = []
|
||||
for endpoint in queryset:
|
||||
engine = PolicyEngine(endpoint, self.request.user, self.request)
|
||||
engine.build()
|
||||
if engine.passing:
|
||||
endpoints.append(endpoint)
|
||||
return endpoints
|
||||
|
||||
@extend_schema(
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
"search",
|
||||
OpenApiTypes.STR,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="superuser_full_list",
|
||||
location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.BOOL,
|
||||
),
|
||||
],
|
||||
responses={
|
||||
200: EndpointSerializer(many=True),
|
||||
400: OpenApiResponse(description="Bad request"),
|
||||
},
|
||||
)
|
||||
def list(self, request: Request, *args, **kwargs) -> Response:
|
||||
"""List accessible endpoints"""
|
||||
should_cache = request.GET.get("search", "") == ""
|
||||
|
||||
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
|
||||
if superuser_full_list and request.user.is_superuser:
|
||||
return super().list(request)
|
||||
|
||||
queryset = self._filter_queryset_for_list(self.get_queryset())
|
||||
self.paginate_queryset(queryset)
|
||||
|
||||
allowed_endpoints = []
|
||||
if not should_cache:
|
||||
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
||||
if should_cache:
|
||||
allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
|
||||
if not allowed_endpoints:
|
||||
LOGGER.debug("Caching allowed endpoint list")
|
||||
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
||||
cache.set(
|
||||
user_endpoint_cache_key(self.request.user.pk),
|
||||
allowed_endpoints,
|
||||
timeout=86400,
|
||||
)
|
||||
serializer = self.get_serializer(allowed_endpoints, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
|
@ -0,0 +1,36 @@
|
|||
"""RAC Provider API Views"""
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import JSONDictField
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.rac.models import RACPropertyMapping
|
||||
|
||||
|
||||
class RACPropertyMappingSerializer(EnterpriseRequiredMixin, PropertyMappingSerializer):
|
||||
"""RACPropertyMapping Serializer"""
|
||||
|
||||
static_settings = JSONDictField()
|
||||
expression = CharField(allow_blank=True, required=False)
|
||||
|
||||
def validate_expression(self, expression: str) -> str:
|
||||
"""Test Syntax"""
|
||||
if expression == "":
|
||||
return expression
|
||||
return super().validate_expression(expression)
|
||||
|
||||
class Meta:
|
||||
model = RACPropertyMapping
|
||||
fields = PropertyMappingSerializer.Meta.fields + ["static_settings"]
|
||||
|
||||
|
||||
class RACPropertyMappingViewSet(UsedByMixin, ModelViewSet):
|
||||
"""RACPropertyMapping Viewset"""
|
||||
|
||||
queryset = RACPropertyMapping.objects.all()
|
||||
serializer_class = RACPropertyMappingSerializer
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
filterset_fields = ["name", "managed"]
|
|
@ -0,0 +1,32 @@
|
|||
"""RAC Provider API Views"""
|
||||
from rest_framework.fields import CharField, ListField
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.rac.models import RACProvider
|
||||
|
||||
|
||||
class RACProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
|
||||
"""RACProvider Serializer"""
|
||||
|
||||
outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
|
||||
|
||||
class Meta:
|
||||
model = RACProvider
|
||||
fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"]
|
||||
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
|
||||
|
||||
|
||||
class RACProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
"""RACProvider Viewset"""
|
||||
|
||||
queryset = RACProvider.objects.all()
|
||||
serializer_class = RACProviderSerializer
|
||||
filterset_fields = {
|
||||
"application": ["isnull"],
|
||||
"name": ["iexact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
|
@ -0,0 +1,17 @@
|
|||
"""RAC app config"""
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
|
||||
|
||||
class AuthentikEnterpriseProviderRAC(EnterpriseConfig):
|
||||
"""authentik enterprise rac app config"""
|
||||
|
||||
name = "authentik.enterprise.providers.rac"
|
||||
label = "authentik_providers_rac"
|
||||
verbose_name = "authentik Enterprise.Providers.RAC"
|
||||
default = True
|
||||
mountpoint = ""
|
||||
ws_mountpoint = "authentik.enterprise.providers.rac.urls"
|
||||
|
||||
def reconcile_load_rac_signals(self):
|
||||
"""Load rac signals"""
|
||||
self.import_module("authentik.enterprise.providers.rac.signals")
|
|
@ -0,0 +1,163 @@
|
|||
"""RAC Client consumer"""
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.exceptions import ChannelFull, DenyConnection
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from django.http.request import QueryDict
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken, RACProvider
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE
|
||||
from authentik.outposts.models import Outpost, OutpostState, OutpostType
|
||||
|
||||
# Global broadcast group, which messages are sent to when the outpost connects back
|
||||
# to authentik for a specific connection
|
||||
# The `RACClientConsumer` consumer adds itself to this group on connection,
|
||||
# and removes itself once it has been assigned a specific outpost channel
|
||||
RAC_CLIENT_GROUP = "group_enterprise_rac_client"
|
||||
# A group for all connections in a given authentik session ID
|
||||
# A disconnect message is sent to this group when the session expires/is deleted
|
||||
RAC_CLIENT_GROUP_SESSION = "group_enterprise_rac_client_%(session)s"
|
||||
# A group for all connections with a specific token, which in almost all cases
|
||||
# is just one connection, however this is used to disconnect the connection
|
||||
# when the token is deleted
|
||||
RAC_CLIENT_GROUP_TOKEN = "group_enterprise_rac_token_%(token)s" # nosec
|
||||
|
||||
# Step 1: Client connects to this websocket endpoint
|
||||
# Step 2: We prepare all the connection args for Guac
|
||||
# Step 3: Send a websocket message to a single outpost that has this provider assigned
|
||||
# (Currently sending to all of them)
|
||||
# (Should probably do different load balancing algorithms)
|
||||
# Step 4: Outpost creates a websocket connection back to authentik
|
||||
# with /ws/outpost_rac/<our_channel_id>/
|
||||
# Step 5: This consumer transfers data between the two channels
|
||||
|
||||
|
||||
class RACClientConsumer(AsyncWebsocketConsumer):
|
||||
"""RAC client consumer the browser connects to"""
|
||||
|
||||
dest_channel_id: str = ""
|
||||
provider: RACProvider
|
||||
token: ConnectionToken
|
||||
logger: BoundLogger
|
||||
|
||||
async def connect(self):
|
||||
await self.accept("guacamole")
|
||||
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
|
||||
await self.channel_layer.group_add(
|
||||
RAC_CLIENT_GROUP_SESSION % {"session": self.scope["session"].session_key},
|
||||
self.channel_name,
|
||||
)
|
||||
await self.init_outpost_connection()
|
||||
|
||||
async def disconnect(self, code):
|
||||
self.logger.debug("Disconnecting")
|
||||
# Tell the outpost we're disconnecting
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.disconnect",
|
||||
},
|
||||
)
|
||||
|
||||
@database_sync_to_async
|
||||
def init_outpost_connection(self):
|
||||
"""Initialize guac connection settings"""
|
||||
self.token = ConnectionToken.filter_not_expired(
|
||||
token=self.scope["url_route"]["kwargs"]["token"]
|
||||
).first()
|
||||
if not self.token:
|
||||
raise DenyConnection()
|
||||
self.provider = self.token.provider
|
||||
params = self.token.get_settings()
|
||||
self.logger = get_logger().bind(
|
||||
endpoint=self.token.endpoint.name, user=self.scope["user"].username
|
||||
)
|
||||
msg = {
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "init_connection",
|
||||
"dest_channel_id": self.channel_name,
|
||||
"params": params,
|
||||
"protocol": self.token.endpoint.protocol,
|
||||
}
|
||||
query = QueryDict(self.scope["query_string"].decode())
|
||||
for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
|
||||
value = query.get(key, None)
|
||||
if not value:
|
||||
continue
|
||||
msg[key] = str(value)
|
||||
outposts = Outpost.objects.filter(
|
||||
type=OutpostType.RAC,
|
||||
providers__in=[self.provider],
|
||||
)
|
||||
if not outposts.exists():
|
||||
self.logger.warning("Provider has no outpost")
|
||||
raise DenyConnection()
|
||||
for outpost in outposts:
|
||||
# Sort all states for the outpost by connection count
|
||||
states = sorted(
|
||||
OutpostState.for_outpost(outpost),
|
||||
key=lambda state: int(state.args.get("active_connections", 0)),
|
||||
)
|
||||
if len(states) < 1:
|
||||
continue
|
||||
self.logger.debug("Sending out connection broadcast")
|
||||
async_to_sync(self.channel_layer.group_send)(
|
||||
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
|
||||
msg,
|
||||
)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
"""Mirror data received from client to the dest_channel_id
|
||||
which is the channel talking to guacd"""
|
||||
if self.dest_channel_id == "":
|
||||
return
|
||||
if self.token.is_expired:
|
||||
await self.event_disconnect({"reason": "token_expiry"})
|
||||
return
|
||||
try:
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.send",
|
||||
"text_data": text_data,
|
||||
"bytes_data": bytes_data,
|
||||
},
|
||||
)
|
||||
except ChannelFull:
|
||||
pass
|
||||
|
||||
async def event_outpost_connected(self, event: dict):
|
||||
"""Handle event broadcasted from outpost consumer, and check if they
|
||||
created a connection for us"""
|
||||
outpost_channel = event.get("outpost_channel")
|
||||
if event.get("client_channel") != self.channel_name:
|
||||
return
|
||||
if self.dest_channel_id != "":
|
||||
# We've already selected an outpost channel, so tell the other channel to disconnect
|
||||
# This should never happen since we remove ourselves from the broadcast group
|
||||
await self.channel_layer.send(
|
||||
outpost_channel,
|
||||
{
|
||||
"type": "event.disconnect",
|
||||
},
|
||||
)
|
||||
return
|
||||
self.logger.debug("Connected to a single outpost instance")
|
||||
self.dest_channel_id = outpost_channel
|
||||
# Since we have a specific outpost channel now, we can remove
|
||||
# ourselves from the global broadcast group
|
||||
await self.channel_layer.group_discard(RAC_CLIENT_GROUP, self.channel_name)
|
||||
|
||||
async def event_send(self, event: dict):
|
||||
"""Handler called by outpost websocket that sends data to this specific
|
||||
client connection"""
|
||||
if self.token.is_expired:
|
||||
await self.event_disconnect({"reason": "token_expiry"})
|
||||
return
|
||||
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
|
||||
|
||||
async def event_disconnect(self, event: dict):
|
||||
"""Disconnect when the session ends"""
|
||||
self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
|
||||
await self.close()
|
|
@ -0,0 +1,48 @@
|
|||
"""RAC consumer"""
|
||||
from channels.exceptions import ChannelFull
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
from authentik.enterprise.providers.rac.consumer_client import RAC_CLIENT_GROUP
|
||||
|
||||
|
||||
class RACOutpostConsumer(AsyncWebsocketConsumer):
|
||||
"""Consumer the outpost connects to, to send specific data back to a client connection"""
|
||||
|
||||
dest_channel_id: str
|
||||
|
||||
async def connect(self):
|
||||
self.dest_channel_id = self.scope["url_route"]["kwargs"]["channel"]
|
||||
await self.accept()
|
||||
await self.channel_layer.group_send(
|
||||
RAC_CLIENT_GROUP,
|
||||
{
|
||||
"type": "event.outpost.connected",
|
||||
"outpost_channel": self.channel_name,
|
||||
"client_channel": self.dest_channel_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
"""Mirror data received from guacd running in the outpost
|
||||
to the dest_channel_id which is the channel talking to the browser"""
|
||||
try:
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.send",
|
||||
"text_data": text_data,
|
||||
"bytes_data": bytes_data,
|
||||
},
|
||||
)
|
||||
except ChannelFull:
|
||||
pass
|
||||
|
||||
async def event_send(self, event: dict):
|
||||
"""Handler called by client websocket that sends data to this specific
|
||||
outpost connection"""
|
||||
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
|
||||
|
||||
async def event_disconnect(self, event: dict):
|
||||
"""Tell outpost we're about to disconnect"""
|
||||
await self.send(text_data="0.authentik.disconnect")
|
||||
await self.close()
|
|
@ -0,0 +1,11 @@
|
|||
"""RAC Provider Docker Controller"""
|
||||
from authentik.outposts.controllers.docker import DockerController
|
||||
from authentik.outposts.models import DockerServiceConnection, Outpost
|
||||
|
||||
|
||||
class RACDockerController(DockerController):
|
||||
"""RAC Provider Docker Controller"""
|
||||
|
||||
def __init__(self, outpost: Outpost, connection: DockerServiceConnection):
|
||||
super().__init__(outpost, connection)
|
||||
self.deployment_ports = []
|
|
@ -0,0 +1,13 @@
|
|||
"""RAC Provider Kubernetes Controller"""
|
||||
from authentik.outposts.controllers.k8s.service import ServiceReconciler
|
||||
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||
from authentik.outposts.models import KubernetesServiceConnection, Outpost
|
||||
|
||||
|
||||
class RACKubernetesController(KubernetesController):
|
||||
"""RAC Provider Kubernetes Controller"""
|
||||
|
||||
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection):
|
||||
super().__init__(outpost, connection)
|
||||
self.deployment_ports = []
|
||||
del self.reconcilers[ServiceReconciler.reconciler_name()]
|
|
@ -0,0 +1,164 @@
|
|||
# Generated by Django 4.2.8 on 2023-12-29 15:58
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import authentik.core.models
|
||||
import authentik.lib.utils.time
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_policies", "0011_policybinding_failure_result_and_more"),
|
||||
("authentik_core", "0032_group_roles"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="RACPropertyMapping",
|
||||
fields=[
|
||||
(
|
||||
"propertymapping_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.propertymapping",
|
||||
),
|
||||
),
|
||||
("static_settings", models.JSONField(default=dict)),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "RAC Property Mapping",
|
||||
"verbose_name_plural": "RAC Property Mappings",
|
||||
},
|
||||
bases=("authentik_core.propertymapping",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="RACProvider",
|
||||
fields=[
|
||||
(
|
||||
"provider_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.provider",
|
||||
),
|
||||
),
|
||||
("settings", models.JSONField(default=dict)),
|
||||
(
|
||||
"auth_mode",
|
||||
models.TextField(
|
||||
choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt"
|
||||
),
|
||||
),
|
||||
(
|
||||
"connection_expiry",
|
||||
models.TextField(
|
||||
default="hours=8",
|
||||
help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)",
|
||||
validators=[authentik.lib.utils.time.timedelta_string_validator],
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
},
|
||||
bases=("authentik_core.provider",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Endpoint",
|
||||
fields=[
|
||||
(
|
||||
"policybindingmodel_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_policies.policybindingmodel",
|
||||
),
|
||||
),
|
||||
("name", models.TextField()),
|
||||
("host", models.TextField()),
|
||||
(
|
||||
"protocol",
|
||||
models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]),
|
||||
),
|
||||
("settings", models.JSONField(default=dict)),
|
||||
(
|
||||
"auth_mode",
|
||||
models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]),
|
||||
),
|
||||
(
|
||||
"property_mappings",
|
||||
models.ManyToManyField(
|
||||
blank=True, default=None, to="authentik_core.propertymapping"
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_rac.racprovider",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "RAC Endpoint",
|
||||
"verbose_name_plural": "RAC Endpoints",
|
||||
},
|
||||
bases=("authentik_policies.policybindingmodel", models.Model),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="ConnectionToken",
|
||||
fields=[
|
||||
(
|
||||
"expires",
|
||||
models.DateTimeField(default=authentik.core.models.default_token_duration),
|
||||
),
|
||||
("expiring", models.BooleanField(default=True)),
|
||||
(
|
||||
"connection_token_uuid",
|
||||
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||
),
|
||||
("token", models.TextField(default=authentik.core.models.default_token_key)),
|
||||
("settings", models.JSONField(default=dict)),
|
||||
(
|
||||
"endpoint",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_rac.endpoint",
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_rac.racprovider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"session",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 5.0 on 2024-01-03 23:44
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_providers_rac", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="endpoint",
|
||||
name="maximum_connections",
|
||||
field=models.IntegerField(default=1),
|
||||
),
|
||||
]
|
|
@ -0,0 +1,192 @@
|
|||
"""RAC Models"""
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from deepmerge import always_merger
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class Protocols(models.TextChoices):
|
||||
"""Supported protocols"""
|
||||
|
||||
RDP = "rdp"
|
||||
VNC = "vnc"
|
||||
SSH = "ssh"
|
||||
|
||||
|
||||
class AuthenticationMode(models.TextChoices):
|
||||
"""Authentication modes"""
|
||||
|
||||
STATIC = "static"
|
||||
PROMPT = "prompt"
|
||||
|
||||
|
||||
class RACProvider(Provider):
|
||||
"""Remotely access computers/servers via RDP/SSH/VNC."""
|
||||
|
||||
settings = models.JSONField(default=dict)
|
||||
auth_mode = models.TextField(
|
||||
choices=AuthenticationMode.choices, default=AuthenticationMode.PROMPT
|
||||
)
|
||||
connection_expiry = models.TextField(
|
||||
default="hours=8",
|
||||
validators=[timedelta_string_validator],
|
||||
help_text=_(
|
||||
"Determines how long a session lasts. Default of 0 means "
|
||||
"that the sessions lasts until the browser is closed. "
|
||||
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def launch_url(self) -> Optional[str]:
|
||||
"""URL to this provider and initiate authorization for the user.
|
||||
Can return None for providers that are not URL-based"""
|
||||
return "goauthentik.io://providers/rac/launch"
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-provider-rac-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
|
||||
|
||||
return RACProviderSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Provider")
|
||||
verbose_name_plural = _("RAC Providers")
|
||||
|
||||
|
||||
class Endpoint(SerializerModel, PolicyBindingModel):
|
||||
"""Remote-accessible endpoint"""
|
||||
|
||||
name = models.TextField()
|
||||
host = models.TextField()
|
||||
protocol = models.TextField(choices=Protocols.choices)
|
||||
settings = models.JSONField(default=dict)
|
||||
auth_mode = models.TextField(choices=AuthenticationMode.choices)
|
||||
provider = models.ForeignKey("RACProvider", on_delete=models.CASCADE)
|
||||
maximum_connections = models.IntegerField(default=1)
|
||||
|
||||
property_mappings = models.ManyToManyField(
|
||||
"authentik_core.PropertyMapping", default=None, blank=True
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer
|
||||
|
||||
return EndpointSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"RAC Endpoint {self.name}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Endpoint")
|
||||
verbose_name_plural = _("RAC Endpoints")
|
||||
|
||||
|
||||
class RACPropertyMapping(PropertyMapping):
|
||||
"""Configure settings for remote access endpoints."""
|
||||
|
||||
static_settings = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-rac-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.rac.api.property_mappings import (
|
||||
RACPropertyMappingSerializer,
|
||||
)
|
||||
|
||||
return RACPropertyMappingSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Property Mapping")
|
||||
verbose_name_plural = _("RAC Property Mappings")
|
||||
|
||||
|
||||
class ConnectionToken(ExpiringModel):
|
||||
"""Token for a single connection to a specified endpoint"""
|
||||
|
||||
connection_token_uuid = models.UUIDField(default=uuid4, primary_key=True)
|
||||
provider = models.ForeignKey(RACProvider, on_delete=models.CASCADE)
|
||||
endpoint = models.ForeignKey(Endpoint, on_delete=models.CASCADE)
|
||||
token = models.TextField(default=default_token_key)
|
||||
settings = models.JSONField(default=dict)
|
||||
session = models.ForeignKey("authentik_core.AuthenticatedSession", on_delete=models.CASCADE)
|
||||
|
||||
def get_settings(self) -> dict:
|
||||
"""Get settings"""
|
||||
default_settings = {}
|
||||
if ":" in self.endpoint.host:
|
||||
host, _, port = self.endpoint.host.partition(":")
|
||||
default_settings["hostname"] = host
|
||||
default_settings["port"] = str(port)
|
||||
else:
|
||||
default_settings["hostname"] = self.endpoint.host
|
||||
default_settings["client-name"] = "authentik"
|
||||
# default_settings["enable-drive"] = "true"
|
||||
# default_settings["drive-name"] = "authentik"
|
||||
settings = {}
|
||||
always_merger.merge(settings, default_settings)
|
||||
always_merger.merge(settings, self.endpoint.provider.settings)
|
||||
always_merger.merge(settings, self.endpoint.settings)
|
||||
always_merger.merge(settings, self.settings)
|
||||
|
||||
def mapping_evaluator(mappings: QuerySet):
|
||||
for mapping in mappings:
|
||||
mapping: RACPropertyMapping
|
||||
if len(mapping.static_settings) > 0:
|
||||
always_merger.merge(settings, mapping.static_settings)
|
||||
continue
|
||||
try:
|
||||
mapping_settings = mapping.evaluate(
|
||||
self.session.user, None, endpoint=self.endpoint, provider=self.provider
|
||||
)
|
||||
always_merger.merge(settings, mapping_settings)
|
||||
except PropertyMappingExpressionException as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping: '{mapping.name}'",
|
||||
provider=self.provider,
|
||||
mapping=mapping,
|
||||
).set_user(self.session.user).save()
|
||||
LOGGER.warning("Failed to evaluate property mapping", exc=exc)
|
||||
|
||||
mapping_evaluator(
|
||||
RACPropertyMapping.objects.filter(provider__in=[self.provider]).order_by("name")
|
||||
)
|
||||
mapping_evaluator(
|
||||
RACPropertyMapping.objects.filter(endpoint__in=[self.endpoint]).order_by("name")
|
||||
)
|
||||
|
||||
settings["drive-path"] = f"/tmp/connection/{self.token}" # nosec
|
||||
settings["create-drive-path"] = "true"
|
||||
# Ensure all values of the settings dict are strings
|
||||
for key, value in settings.items():
|
||||
if isinstance(value, str):
|
||||
continue
|
||||
# Special case for bools
|
||||
if isinstance(value, bool):
|
||||
settings[key] = str(value).lower()
|
||||
continue
|
||||
settings[key] = str(value)
|
||||
return settings
|
|
@ -0,0 +1,54 @@
|
|||
"""RAC Signals"""
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.contrib.auth.signals import user_logged_out
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import post_save, pre_delete
|
||||
from django.dispatch import receiver
|
||||
from django.http import HttpRequest
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.providers.rac.api.endpoints import user_endpoint_cache_key
|
||||
from authentik.enterprise.providers.rac.consumer_client import (
|
||||
RAC_CLIENT_GROUP_SESSION,
|
||||
RAC_CLIENT_GROUP_TOKEN,
|
||||
)
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint
|
||||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
|
||||
"""Disconnect any open RAC connections"""
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
RAC_CLIENT_GROUP_SESSION
|
||||
% {
|
||||
"session": request.session.session_key,
|
||||
},
|
||||
{"type": "event.disconnect", "reason": "session_logout"},
|
||||
)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=ConnectionToken)
|
||||
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
|
||||
"""Disconnect session when connection token is deleted"""
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
RAC_CLIENT_GROUP_TOKEN
|
||||
% {
|
||||
"token": instance.token,
|
||||
},
|
||||
{"type": "event.disconnect", "reason": "token_delete"},
|
||||
)
|
||||
|
||||
|
||||
@receiver(post_save, sender=Endpoint)
|
||||
def post_save_application(sender: type[Model], instance, created: bool, **_):
|
||||
"""Clear user's application cache upon application creation"""
|
||||
if not created: # pragma: no cover
|
||||
return
|
||||
|
||||
# Delete user endpoint cache
|
||||
keys = cache.keys(user_endpoint_cache_key("*"))
|
||||
cache.delete_many(keys)
|
|
@ -0,0 +1,18 @@
|
|||
{% extends "base/skeleton.html" %}
|
||||
|
||||
{% load static %}
|
||||
|
||||
{% block head %}
|
||||
<script src="{% static 'dist/enterprise/rac/index.js' %}?version={{ version }}" type="module"></script>
|
||||
<meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)">
|
||||
<meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)">
|
||||
<link rel="icon" href="{{ tenant.branding_favicon }}">
|
||||
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
|
||||
{% include "base/header_js.html" %}
|
||||
{% endblock %}
|
||||
|
||||
{% block body %}
|
||||
<ak-rac token="{{ url_kwargs.token }}" endpointName="{{ token.endpoint.name }}">
|
||||
<ak-loading></ak-loading>
|
||||
</ak-rac>
|
||||
{% endblock %}
|
|
@ -0,0 +1,171 @@
|
|||
"""Test Endpoints API"""
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
||||
|
||||
class TestEndpointsAPI(APITestCase):
|
||||
"""Test endpoints API"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
self.provider = RACProvider.objects.create(
|
||||
name=generate_id(),
|
||||
)
|
||||
self.app = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
self.allowed = Endpoint.objects.create(
|
||||
name=f"a-{generate_id()}",
|
||||
host=generate_id(),
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
self.denied = Endpoint.objects.create(
|
||||
name=f"b-{generate_id()}",
|
||||
host=generate_id(),
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
PolicyBinding.objects.create(
|
||||
target=self.denied,
|
||||
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||
order=0,
|
||||
)
|
||||
|
||||
def test_list(self):
|
||||
"""Test list operation without superuser_full_list"""
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(reverse("authentik_api:endpoint-list"))
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"pagination": {
|
||||
"next": 0,
|
||||
"previous": 0,
|
||||
"count": 2,
|
||||
"current": 1,
|
||||
"total_pages": 1,
|
||||
"start_index": 1,
|
||||
"end_index": 2,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"pk": str(self.allowed.pk),
|
||||
"name": self.allowed.name,
|
||||
"provider": self.provider.pk,
|
||||
"provider_obj": {
|
||||
"pk": self.provider.pk,
|
||||
"name": self.provider.name,
|
||||
"authentication_flow": None,
|
||||
"authorization_flow": None,
|
||||
"property_mappings": [],
|
||||
"connection_expiry": "hours=8",
|
||||
"component": "ak-provider-rac-form",
|
||||
"assigned_application_slug": self.app.slug,
|
||||
"assigned_application_name": self.app.name,
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||
"settings": {},
|
||||
"outpost_set": [],
|
||||
},
|
||||
"protocol": "rdp",
|
||||
"host": self.allowed.host,
|
||||
"maximum_connections": 1,
|
||||
"settings": {},
|
||||
"property_mappings": [],
|
||||
"auth_mode": "",
|
||||
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_list_superuser_full_list(self):
|
||||
"""Test list operation with superuser_full_list"""
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:endpoint-list") + "?superuser_full_list=true"
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"pagination": {
|
||||
"next": 0,
|
||||
"previous": 0,
|
||||
"count": 2,
|
||||
"current": 1,
|
||||
"total_pages": 1,
|
||||
"start_index": 1,
|
||||
"end_index": 2,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"pk": str(self.allowed.pk),
|
||||
"name": self.allowed.name,
|
||||
"provider": self.provider.pk,
|
||||
"provider_obj": {
|
||||
"pk": self.provider.pk,
|
||||
"name": self.provider.name,
|
||||
"authentication_flow": None,
|
||||
"authorization_flow": None,
|
||||
"property_mappings": [],
|
||||
"component": "ak-provider-rac-form",
|
||||
"assigned_application_slug": self.app.slug,
|
||||
"assigned_application_name": self.app.name,
|
||||
"connection_expiry": "hours=8",
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||
"settings": {},
|
||||
"outpost_set": [],
|
||||
},
|
||||
"protocol": "rdp",
|
||||
"host": self.allowed.host,
|
||||
"maximum_connections": 1,
|
||||
"settings": {},
|
||||
"property_mappings": [],
|
||||
"auth_mode": "",
|
||||
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
|
||||
},
|
||||
{
|
||||
"pk": str(self.denied.pk),
|
||||
"name": self.denied.name,
|
||||
"provider": self.provider.pk,
|
||||
"provider_obj": {
|
||||
"pk": self.provider.pk,
|
||||
"name": self.provider.name,
|
||||
"authentication_flow": None,
|
||||
"authorization_flow": None,
|
||||
"property_mappings": [],
|
||||
"component": "ak-provider-rac-form",
|
||||
"assigned_application_slug": self.app.slug,
|
||||
"assigned_application_name": self.app.name,
|
||||
"connection_expiry": "hours=8",
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||
"settings": {},
|
||||
"outpost_set": [],
|
||||
},
|
||||
"protocol": "rdp",
|
||||
"host": self.denied.host,
|
||||
"maximum_connections": 1,
|
||||
"settings": {},
|
||||
"property_mappings": [],
|
||||
"auth_mode": "",
|
||||
"launch_url": f"/application/rac/{self.app.slug}/{str(self.denied.pk)}/",
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
|
@ -0,0 +1,144 @@
|
|||
"""Test RAC Models"""
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.enterprise.providers.rac.models import (
|
||||
ConnectionToken,
|
||||
Endpoint,
|
||||
Protocols,
|
||||
RACPropertyMapping,
|
||||
RACProvider,
|
||||
)
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestModels(TransactionTestCase):
|
||||
"""Test RAC Models"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.provider = RACProvider.objects.create(
|
||||
name=generate_id(),
|
||||
)
|
||||
self.app = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
self.endpoint = Endpoint.objects.create(
|
||||
name=generate_id(),
|
||||
host=f"{generate_id()}:1324",
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def test_settings_merge(self):
|
||||
"""Test settings merge"""
|
||||
token = ConnectionToken.objects.create(
|
||||
provider=self.provider,
|
||||
endpoint=self.endpoint,
|
||||
session=AuthenticatedSession.objects.create(
|
||||
user=self.user,
|
||||
session_key=generate_id(),
|
||||
),
|
||||
)
|
||||
path = f"/tmp/connection/{token.token}" # nosec
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
},
|
||||
)
|
||||
# Set settings in provider
|
||||
self.provider.settings = {"level": "provider"}
|
||||
self.provider.save()
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "provider",
|
||||
},
|
||||
)
|
||||
# Set settings in endpoint
|
||||
self.endpoint.settings = {
|
||||
"level": "endpoint",
|
||||
}
|
||||
self.endpoint.save()
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "endpoint",
|
||||
},
|
||||
)
|
||||
# Set settings in token
|
||||
token.settings = {
|
||||
"level": "token",
|
||||
}
|
||||
token.save()
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "token",
|
||||
},
|
||||
)
|
||||
# Set settings in property mapping (provider)
|
||||
mapping = RACPropertyMapping.objects.create(
|
||||
name=generate_id(),
|
||||
expression="""return {
|
||||
"level": "property_mapping_provider"
|
||||
}""",
|
||||
)
|
||||
self.provider.property_mappings.add(mapping)
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "property_mapping_provider",
|
||||
},
|
||||
)
|
||||
# Set settings in property mapping (endpoint)
|
||||
mapping = RACPropertyMapping.objects.create(
|
||||
name=generate_id(),
|
||||
static_settings={
|
||||
"level": "property_mapping_endpoint",
|
||||
"foo": True,
|
||||
"bar": 6,
|
||||
},
|
||||
)
|
||||
self.endpoint.property_mappings.add(mapping)
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "property_mapping_endpoint",
|
||||
"foo": "true",
|
||||
"bar": "6",
|
||||
},
|
||||
)
|
|
@ -0,0 +1,132 @@
|
|||
"""RAC Views tests"""
|
||||
from datetime import timedelta
|
||||
from json import loads
|
||||
from time import mktime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.enterprise.models import License, LicenseKey
|
||||
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.denied import AccessDeniedResponse
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
||||
|
||||
class TestRACViews(APITestCase):
|
||||
"""RAC Views tests"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.flow = create_test_flow()
|
||||
self.provider = RACProvider.objects.create(name=generate_id(), authorization_flow=self.flow)
|
||||
self.app = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
self.endpoint = Endpoint.objects.create(
|
||||
name=generate_id(),
|
||||
host=f"{generate_id()}:1324",
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_no_policy(self):
|
||||
"""Test request"""
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
flow_response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
)
|
||||
body = loads(flow_response.content)
|
||||
next_url = body["to"]
|
||||
final_response = self.client.get(next_url)
|
||||
self.assertEqual(final_response.status_code, 200)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_app_deny(self):
|
||||
"""Test request (deny on app level)"""
|
||||
PolicyBinding.objects.create(
|
||||
target=self.app,
|
||||
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||
order=0,
|
||||
)
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||
)
|
||||
)
|
||||
self.assertIsInstance(response, AccessDeniedResponse)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_endpoint_deny(self):
|
||||
"""Test request (deny on endpoint level)"""
|
||||
PolicyBinding.objects.create(
|
||||
target=self.endpoint,
|
||||
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||
order=0,
|
||||
)
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
flow_response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
)
|
||||
body = loads(flow_response.content)
|
||||
self.assertEqual(body["component"], "ak-stage-access-denied")
|
|
@ -0,0 +1,47 @@
|
|||
"""rac urls"""
|
||||
from channels.auth import AuthMiddleware
|
||||
from channels.sessions import CookieMiddleware
|
||||
from django.urls import path
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
|
||||
from authentik.core.channels import TokenOutpostMiddleware
|
||||
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
|
||||
from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet
|
||||
from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet
|
||||
from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer
|
||||
from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer
|
||||
from authentik.enterprise.providers.rac.views import RACInterface, RACStartView
|
||||
from authentik.root.asgi_middleware import SessionMiddleware
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
"application/rac/<slug:app>/<uuid:endpoint>/",
|
||||
ensure_csrf_cookie(RACStartView.as_view()),
|
||||
name="start",
|
||||
),
|
||||
path(
|
||||
"if/rac/<str:token>/",
|
||||
ensure_csrf_cookie(RACInterface.as_view()),
|
||||
name="if-rac",
|
||||
),
|
||||
]
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path(
|
||||
"ws/rac/<str:token>/",
|
||||
ChannelsLoggingMiddleware(
|
||||
CookieMiddleware(SessionMiddleware(AuthMiddleware(RACClientConsumer.as_asgi())))
|
||||
),
|
||||
),
|
||||
path(
|
||||
"ws/outpost_rac/<str:channel>/",
|
||||
ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
("providers/rac", RACProviderViewSet),
|
||||
("propertymappings/rac", RACPropertyMappingViewSet),
|
||||
("rac/endpoints", EndpointViewSet),
|
||||
]
|
|
@ -0,0 +1,140 @@
|
|||
"""RAC Views"""
|
||||
from typing import Any
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.shortcuts import get_object_or_404, redirect
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
from authentik.core.views.interface import InterfaceView
|
||||
from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint, RACProvider
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.challenge import RedirectChallenge
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
from authentik.flows.planner import FlowPlanner
|
||||
from authentik.flows.stage import RedirectStage
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.lib.utils.urls import redirect_with_qs
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
|
||||
|
||||
class RACStartView(EnterprisePolicyAccessView):
|
||||
"""Start a RAC connection by checking access and creating a connection token"""
|
||||
|
||||
endpoint: Endpoint
|
||||
|
||||
def resolve_provider_application(self):
|
||||
self.application = get_object_or_404(Application, slug=self.kwargs["app"])
|
||||
# Endpoint permissions are validated in the RACFinalStage below
|
||||
self.endpoint = get_object_or_404(Endpoint, pk=self.kwargs["endpoint"])
|
||||
self.provider = RACProvider.objects.get(application=self.application)
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
"""Start flow planner for RAC provider"""
|
||||
planner = FlowPlanner(self.provider.authorization_flow)
|
||||
planner.allow_empty_flows = True
|
||||
try:
|
||||
plan = planner.plan(self.request)
|
||||
except FlowNonApplicableException:
|
||||
raise Http404
|
||||
plan.insert_stage(
|
||||
in_memory_stage(
|
||||
RACFinalStage,
|
||||
application=self.application,
|
||||
endpoint=self.endpoint,
|
||||
provider=self.provider,
|
||||
)
|
||||
)
|
||||
request.session[SESSION_KEY_PLAN] = plan
|
||||
return redirect_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
request.GET,
|
||||
flow_slug=self.provider.authorization_flow.slug,
|
||||
)
|
||||
|
||||
|
||||
class RACInterface(InterfaceView):
|
||||
"""Start RAC connection"""
|
||||
|
||||
template_name = "if/rac.html"
|
||||
token: ConnectionToken
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||
# Early sanity check to ensure token still exists
|
||||
token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first()
|
||||
if not token:
|
||||
return redirect("authentik_core:if-user")
|
||||
self.token = token
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||
kwargs["token"] = self.token
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
|
||||
class RACFinalStage(RedirectStage):
|
||||
"""RAC Connection final stage, set the connection token in the stage"""
|
||||
|
||||
endpoint: Endpoint
|
||||
provider: RACProvider
|
||||
application: Application
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||
self.endpoint = self.executor.current_stage.endpoint
|
||||
self.provider = self.executor.current_stage.provider
|
||||
self.application = self.executor.current_stage.application
|
||||
# Check policies bound to endpoint directly
|
||||
engine = PolicyEngine(self.endpoint, self.request.user, self.request)
|
||||
engine.use_cache = False
|
||||
engine.build()
|
||||
passing = engine.result
|
||||
if not passing.passing:
|
||||
return self.executor.stage_invalid(", ".join(passing.messages))
|
||||
# Check if we're already at the maximum connection limit
|
||||
all_tokens = ConnectionToken.filter_not_expired(
|
||||
endpoint=self.endpoint,
|
||||
).exclude(endpoint__maximum_connections__lte=-1)
|
||||
if all_tokens.count() >= self.endpoint.maximum_connections:
|
||||
msg = [_("Maximum connection limit reached.")]
|
||||
# Check if any other tokens exist for the current user, and inform them
|
||||
# they are already connected
|
||||
if all_tokens.filter(session__user=self.request.user).exists():
|
||||
msg.append(_("(You are already connected in another tab/window)"))
|
||||
return self.executor.stage_invalid(" ".join(msg))
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
|
||||
token = ConnectionToken.objects.create(
|
||||
provider=self.provider,
|
||||
endpoint=self.endpoint,
|
||||
settings=self.executor.plan.context.get("connection_settings", {}),
|
||||
session=AuthenticatedSession.objects.filter(
|
||||
session_key=self.request.session.session_key
|
||||
).first(),
|
||||
expires=now() + timedelta_from_string(self.provider.connection_expiry),
|
||||
expiring=True,
|
||||
)
|
||||
Event.new(
|
||||
EventAction.AUTHORIZE_APPLICATION,
|
||||
authorized_application=self.application,
|
||||
flow=self.executor.plan.flow_pk,
|
||||
endpoint=self.endpoint.name,
|
||||
).from_http(self.request)
|
||||
setattr(
|
||||
self.executor.current_stage,
|
||||
"destination",
|
||||
self.request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_rac:if-rac",
|
||||
kwargs={
|
||||
"token": str(token.token),
|
||||
},
|
||||
)
|
||||
),
|
||||
)
|
||||
return super().get_challenge(*args, **kwargs)
|
|
@ -10,3 +10,7 @@ CELERY_BEAT_SCHEDULE = {
|
|||
"options": {"queue": "authentik_scheduled"},
|
||||
}
|
||||
}
|
||||
|
||||
INSTALLED_APPS = [
|
||||
"authentik.enterprise.providers.rac",
|
||||
]
|
||||
|
|
|
@ -6,6 +6,7 @@ import django_filters
|
|||
from django.db.models.aggregates import Count
|
||||
from django.db.models.fields.json import KeyTextTransform, KeyTransform
|
||||
from django.db.models.functions import ExtractDay, ExtractHour
|
||||
from django.db.models.query_utils import Q
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
|
@ -87,7 +88,12 @@ class EventsFilter(django_filters.FilterSet):
|
|||
we need to remove the dashes that a client may send. We can't use a
|
||||
UUIDField for this, as some models might not have a UUID PK"""
|
||||
value = str(value).replace("-", "")
|
||||
return queryset.filter(context__model__pk=value)
|
||||
query = Q(context__model__pk=value)
|
||||
try:
|
||||
query |= Q(context__model__pk=int(value))
|
||||
except ValueError:
|
||||
pass
|
||||
return queryset.filter(query)
|
||||
|
||||
class Meta:
|
||||
model = Event
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from prometheus_client import Gauge
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.config import CONFIG, ENV_PREFIX
|
||||
|
||||
GAUGE_TASKS = Gauge(
|
||||
"authentik_system_tasks",
|
||||
|
@ -21,3 +22,24 @@ class AuthentikEventsConfig(ManagedAppConfig):
|
|||
def reconcile_load_events_signals(self):
|
||||
"""Load events signals"""
|
||||
self.import_module("authentik.events.signals")
|
||||
|
||||
def reconcile_check_deprecations(self):
|
||||
"""Check for config deprecations"""
|
||||
from authentik.events.models import Event, EventAction
|
||||
|
||||
for key_replace, msg in CONFIG.deprecations.items():
|
||||
key, replace = key_replace
|
||||
key_env = f"{ENV_PREFIX}_{key.replace('.', '__')}".upper()
|
||||
replace_env = f"{ENV_PREFIX}_{replace.replace('.', '__')}".upper()
|
||||
if Event.objects.filter(
|
||||
action=EventAction.CONFIGURATION_ERROR, context__deprecated_option=key
|
||||
).exists():
|
||||
continue
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
deprecated_option=key,
|
||||
deprecated_env=key_env,
|
||||
replacement_option=replace,
|
||||
replacement_env=replace_env,
|
||||
message=msg,
|
||||
).save()
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
"""ASN Enricher"""
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||
|
||||
from django.http import HttpRequest
|
||||
from geoip2.errors import GeoIP2Error
|
||||
from geoip2.models import ASN
|
||||
from sentry_sdk import Hub
|
||||
|
||||
from authentik.events.context_processors.mmdb import MMDBContextProcessor
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.api.v3.config import Capabilities
|
||||
from authentik.events.models import Event
|
||||
|
||||
|
||||
class ASNDict(TypedDict):
|
||||
"""ASN Details"""
|
||||
|
||||
asn: int
|
||||
as_org: str | None
|
||||
network: str | None
|
||||
|
||||
|
||||
class ASNContextProcessor(MMDBContextProcessor):
|
||||
"""ASN Database reader wrapper"""
|
||||
|
||||
def capability(self) -> Optional["Capabilities"]:
|
||||
from authentik.api.v3.config import Capabilities
|
||||
|
||||
return Capabilities.CAN_ASN
|
||||
|
||||
def path(self) -> str | None:
|
||||
return CONFIG.get("events.context_processors.asn")
|
||||
|
||||
def enrich_event(self, event: "Event"):
|
||||
asn = self.asn_dict(event.client_ip)
|
||||
if not asn:
|
||||
return
|
||||
event.context["asn"] = asn
|
||||
|
||||
def enrich_context(self, request: HttpRequest) -> dict:
|
||||
return {
|
||||
"asn": self.asn_dict(ClientIPMiddleware.get_client_ip(request)),
|
||||
}
|
||||
|
||||
def asn(self, ip_address: str) -> Optional[ASN]:
|
||||
"""Wrapper for Reader.asn"""
|
||||
with Hub.current.start_span(
|
||||
op="authentik.events.asn.asn",
|
||||
description=ip_address,
|
||||
):
|
||||
if not self.configured():
|
||||
return None
|
||||
self.check_expired()
|
||||
try:
|
||||
return self.reader.asn(ip_address)
|
||||
except (GeoIP2Error, ValueError):
|
||||
return None
|
||||
|
||||
def asn_to_dict(self, asn: ASN | None) -> ASNDict:
|
||||
"""Convert ASN to dict"""
|
||||
if not asn:
|
||||
return {}
|
||||
asn_dict: ASNDict = {
|
||||
"asn": asn.autonomous_system_number,
|
||||
"as_org": asn.autonomous_system_organization,
|
||||
"network": str(asn.network) if asn.network else None,
|
||||
}
|
||||
return asn_dict
|
||||
|
||||
def asn_dict(self, ip_address: str) -> Optional[ASNDict]:
|
||||
"""Wrapper for self.asn that returns a dict"""
|
||||
asn = self.asn(ip_address)
|
||||
if not asn:
|
||||
return None
|
||||
return self.asn_to_dict(asn)
|
||||
|
||||
|
||||
ASN_CONTEXT_PROCESSOR = ASNContextProcessor()
|
|
@ -0,0 +1,43 @@
|
|||
"""Base event enricher"""
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from django.http import HttpRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.api.v3.config import Capabilities
|
||||
from authentik.events.models import Event
|
||||
|
||||
|
||||
class EventContextProcessor:
|
||||
"""Base event enricher"""
|
||||
|
||||
def capability(self) -> Optional["Capabilities"]:
|
||||
"""Return the capability this context processor provides"""
|
||||
return None
|
||||
|
||||
def configured(self) -> bool:
|
||||
"""Return true if this context processor is configured"""
|
||||
return False
|
||||
|
||||
def enrich_event(self, event: "Event"):
|
||||
"""Modify event"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enrich_context(self, request: HttpRequest) -> dict:
|
||||
"""Modify context"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@cache
|
||||
def get_context_processors() -> list[EventContextProcessor]:
|
||||
"""Get a list of all configured context processors"""
|
||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||
|
||||
processors_types = [ASN_CONTEXT_PROCESSOR, GEOIP_CONTEXT_PROCESSOR]
|
||||
processors = []
|
||||
for _type in processors_types:
|
||||
if _type.configured():
|
||||
processors.append(_type)
|
||||
return processors
|
|
@ -0,0 +1,86 @@
|
|||
"""events GeoIP Reader"""
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||
|
||||
from django.http import HttpRequest
|
||||
from geoip2.errors import GeoIP2Error
|
||||
from geoip2.models import City
|
||||
from sentry_sdk.hub import Hub
|
||||
|
||||
from authentik.events.context_processors.mmdb import MMDBContextProcessor
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.api.v3.config import Capabilities
|
||||
from authentik.events.models import Event
|
||||
|
||||
|
||||
class GeoIPDict(TypedDict):
|
||||
"""GeoIP Details"""
|
||||
|
||||
continent: str
|
||||
country: str
|
||||
lat: float
|
||||
long: float
|
||||
city: str
|
||||
|
||||
|
||||
class GeoIPContextProcessor(MMDBContextProcessor):
|
||||
"""Slim wrapper around GeoIP API"""
|
||||
|
||||
def capability(self) -> Optional["Capabilities"]:
|
||||
from authentik.api.v3.config import Capabilities
|
||||
|
||||
return Capabilities.CAN_GEO_IP
|
||||
|
||||
def path(self) -> str | None:
|
||||
return CONFIG.get("events.context_processors.geoip")
|
||||
|
||||
def enrich_event(self, event: "Event"):
|
||||
city = self.city_dict(event.client_ip)
|
||||
if not city:
|
||||
return
|
||||
event.context["geo"] = city
|
||||
|
||||
def enrich_context(self, request: HttpRequest) -> dict:
|
||||
# Different key `geoip` vs `geo` for legacy reasons
|
||||
return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))}
|
||||
|
||||
def city(self, ip_address: str) -> Optional[City]:
|
||||
"""Wrapper for Reader.city"""
|
||||
with Hub.current.start_span(
|
||||
op="authentik.events.geo.city",
|
||||
description=ip_address,
|
||||
):
|
||||
if not self.configured():
|
||||
return None
|
||||
self.check_expired()
|
||||
try:
|
||||
return self.reader.city(ip_address)
|
||||
except (GeoIP2Error, ValueError):
|
||||
return None
|
||||
|
||||
def city_to_dict(self, city: City | None) -> GeoIPDict:
|
||||
"""Convert City to dict"""
|
||||
if not city:
|
||||
return {}
|
||||
city_dict: GeoIPDict = {
|
||||
"continent": city.continent.code,
|
||||
"country": city.country.iso_code,
|
||||
"lat": city.location.latitude,
|
||||
"long": city.location.longitude,
|
||||
"city": "",
|
||||
}
|
||||
if city.city.name:
|
||||
city_dict["city"] = city.city.name
|
||||
return city_dict
|
||||
|
||||
def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
|
||||
"""Wrapper for self.city that returns a dict"""
|
||||
city = self.city(ip_address)
|
||||
if not city:
|
||||
return None
|
||||
return self.city_to_dict(city)
|
||||
|
||||
|
||||
GEOIP_CONTEXT_PROCESSOR = GeoIPContextProcessor()
|
|
@ -0,0 +1,53 @@
|
|||
"""Common logic for reading MMDB files"""
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from geoip2.database import Reader
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.context_processors.base import EventContextProcessor
|
||||
|
||||
|
||||
class MMDBContextProcessor(EventContextProcessor):
|
||||
"""Common logic for reading MaxMind DB files, including re-loading if the file has changed"""
|
||||
|
||||
def __init__(self):
|
||||
self.reader: Optional[Reader] = None
|
||||
self._last_mtime: float = 0.0
|
||||
self.logger = get_logger()
|
||||
self.open()
|
||||
|
||||
def path(self) -> str | None:
|
||||
"""Get the path to the MMDB file to load"""
|
||||
raise NotImplementedError
|
||||
|
||||
def open(self):
|
||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||
path = self.path()
|
||||
if path == "" or not path:
|
||||
return
|
||||
try:
|
||||
self.reader = Reader(path)
|
||||
self._last_mtime = Path(path).stat().st_mtime
|
||||
self.logger.info("Loaded MMDB database", last_write=self._last_mtime, file=path)
|
||||
except OSError as exc:
|
||||
self.logger.warning("Failed to load MMDB database", path=path, exc=exc)
|
||||
|
||||
def check_expired(self):
|
||||
"""Check if the modification date of the MMDB database has
|
||||
changed, and reload it if so"""
|
||||
path = self.path()
|
||||
if path == "" or not path:
|
||||
return
|
||||
try:
|
||||
mtime = Path(path).stat().st_mtime
|
||||
diff = self._last_mtime < mtime
|
||||
if diff > 0:
|
||||
self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
|
||||
self.open()
|
||||
except OSError as exc:
|
||||
self.logger.warning("Failed to check MMDB age", exc=exc)
|
||||
|
||||
def configured(self) -> bool:
|
||||
"""Return true if this context processor is configured"""
|
||||
return bool(self.reader)
|
|
@ -1,100 +0,0 @@
|
|||
"""events GeoIP Reader"""
|
||||
from os import stat
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from geoip2.database import Reader
|
||||
from geoip2.errors import GeoIP2Error
|
||||
from geoip2.models import City
|
||||
from sentry_sdk.hub import Hub
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class GeoIPDict(TypedDict):
|
||||
"""GeoIP Details"""
|
||||
|
||||
continent: str
|
||||
country: str
|
||||
lat: float
|
||||
long: float
|
||||
city: str
|
||||
|
||||
|
||||
class GeoIPReader:
|
||||
"""Slim wrapper around GeoIP API"""
|
||||
|
||||
def __init__(self):
|
||||
self.__reader: Optional[Reader] = None
|
||||
self.__last_mtime: float = 0.0
|
||||
self.__open()
|
||||
|
||||
def __open(self):
|
||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||
path = CONFIG.get("geoip")
|
||||
if path == "" or not path:
|
||||
return
|
||||
try:
|
||||
self.__reader = Reader(path)
|
||||
self.__last_mtime = stat(path).st_mtime
|
||||
LOGGER.info("Loaded GeoIP database", last_write=self.__last_mtime)
|
||||
except OSError as exc:
|
||||
LOGGER.warning("Failed to load GeoIP database", exc=exc)
|
||||
|
||||
def __check_expired(self):
|
||||
"""Check if the modification date of the GeoIP database has
|
||||
changed, and reload it if so"""
|
||||
path = CONFIG.get("geoip")
|
||||
try:
|
||||
mtime = stat(path).st_mtime
|
||||
diff = self.__last_mtime < mtime
|
||||
if diff > 0:
|
||||
LOGGER.info("Found new GeoIP Database, reopening", diff=diff)
|
||||
self.__open()
|
||||
except OSError as exc:
|
||||
LOGGER.warning("Failed to check GeoIP age", exc=exc)
|
||||
return
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Check if GeoIP is enabled"""
|
||||
return bool(self.__reader)
|
||||
|
||||
def city(self, ip_address: str) -> Optional[City]:
|
||||
"""Wrapper for Reader.city"""
|
||||
with Hub.current.start_span(
|
||||
op="authentik.events.geo.city",
|
||||
description=ip_address,
|
||||
):
|
||||
if not self.enabled:
|
||||
return None
|
||||
self.__check_expired()
|
||||
try:
|
||||
return self.__reader.city(ip_address)
|
||||
except (GeoIP2Error, ValueError):
|
||||
return None
|
||||
|
||||
def city_to_dict(self, city: City) -> GeoIPDict:
|
||||
"""Convert City to dict"""
|
||||
city_dict: GeoIPDict = {
|
||||
"continent": city.continent.code,
|
||||
"country": city.country.iso_code,
|
||||
"lat": city.location.latitude,
|
||||
"long": city.location.longitude,
|
||||
"city": "",
|
||||
}
|
||||
if city.city.name:
|
||||
city_dict["city"] = city.city.name
|
||||
return city_dict
|
||||
|
||||
def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
|
||||
"""Wrapper for self.city that returns a dict"""
|
||||
city = self.city(ip_address)
|
||||
if not city:
|
||||
return None
|
||||
return self.city_to_dict(city)
|
||||
|
||||
|
||||
GEOIP_READER = GeoIPReader()
|
|
@ -20,6 +20,7 @@ from authentik.core.models import (
|
|||
User,
|
||||
UserSourceConnection,
|
||||
)
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken
|
||||
from authentik.events.models import Event, EventAction, Notification
|
||||
from authentik.events.utils import model_to_dict
|
||||
from authentik.flows.models import FlowToken, Stage
|
||||
|
@ -54,6 +55,7 @@ IGNORED_MODELS = (
|
|||
SCIMUser,
|
||||
SCIMGroup,
|
||||
Reputation,
|
||||
ConnectionToken,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from authentik.core.middleware import (
|
|||
SESSION_KEY_IMPERSONATE_USER,
|
||||
)
|
||||
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
from authentik.events.context_processors.base import get_context_processors
|
||||
from authentik.events.utils import (
|
||||
cleanse_dict,
|
||||
get_user,
|
||||
|
@ -36,9 +36,10 @@ from authentik.events.utils import (
|
|||
)
|
||||
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.lib.utils.http import get_client_ip, get_http_session
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import DEFAULT_TENANT
|
||||
|
@ -244,22 +245,16 @@ class Event(SerializerModel, ExpiringModel):
|
|||
self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
|
||||
self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
|
||||
# User 255.255.255.255 as fallback if IP cannot be determined
|
||||
self.client_ip = get_client_ip(request)
|
||||
# Apply GeoIP Data, when enabled
|
||||
self.with_geoip()
|
||||
self.client_ip = ClientIPMiddleware.get_client_ip(request)
|
||||
# Enrich event data
|
||||
for processor in get_context_processors():
|
||||
processor.enrich_event(self)
|
||||
# If there's no app set, we get it from the requests too
|
||||
if not self.app:
|
||||
self.app = Event._get_app_from_request(request)
|
||||
self.save()
|
||||
return self
|
||||
|
||||
def with_geoip(self): # pragma: no cover
|
||||
"""Apply GeoIP Data, when enabled"""
|
||||
city = GEOIP_READER.city_dict(self.client_ip)
|
||||
if not city:
|
||||
return
|
||||
self.context["geo"] = city
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if self._state.adding:
|
||||
LOGGER.info(
|
||||
|
@ -466,7 +461,7 @@ class NotificationTransport(SerializerModel):
|
|||
}
|
||||
mail = TemplateEmailMessage(
|
||||
subject=subject_prefix + context["title"],
|
||||
to=[notification.user.email],
|
||||
to=[f"{notification.user.name} <{notification.user.email}>"],
|
||||
language=notification.user.locale(),
|
||||
template_name="email/event_notification.html",
|
||||
template_context=context,
|
||||
|
|
|
@ -45,9 +45,14 @@ def get_login_event(request: HttpRequest) -> Optional[Event]:
|
|||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
|
||||
def on_user_logged_out(sender, request: HttpRequest, user: User, **kwargs):
|
||||
"""Log successfully logout"""
|
||||
Event.new(EventAction.LOGOUT).from_http(request, user=user)
|
||||
# Check if this even comes from the user_login stage's middleware, which will set an extra
|
||||
# argument
|
||||
event = Event.new(EventAction.LOGOUT)
|
||||
if "event_extra" in kwargs:
|
||||
event.context.update(kwargs["event_extra"])
|
||||
event.from_http(request, user=user)
|
||||
|
||||
|
||||
@receiver(user_write)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Event API tests"""
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
@ -11,6 +12,9 @@ from authentik.events.models import (
|
|||
NotificationSeverity,
|
||||
TransportMode,
|
||||
)
|
||||
from authentik.events.utils import model_to_dict
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
|
||||
|
||||
class TestEventsAPI(APITestCase):
|
||||
|
@ -20,6 +24,25 @@ class TestEventsAPI(APITestCase):
|
|||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_filter_model_pk_int(self):
|
||||
"""Test event list with context_model_pk and integer PKs"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
)
|
||||
event = Event.new(EventAction.MODEL_CREATED, model=model_to_dict(provider))
|
||||
event.save()
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:event-list"),
|
||||
data={
|
||||
"context_model_pk": provider.pk,
|
||||
"context_model_app": "authentik_providers_oauth2",
|
||||
"context_model_name": "oauth2provider",
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["pagination"]["count"], 1)
|
||||
|
||||
def test_top_n(self):
|
||||
"""Test top_per_user"""
|
||||
event = Event.new(EventAction.AUTHORIZE_APPLICATION)
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
"""Test ASN Wrapper"""
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.events.context_processors.asn import ASNContextProcessor
|
||||
|
||||
|
||||
class TestASN(TestCase):
|
||||
"""Test ASN Wrapper"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.reader = ASNContextProcessor()
|
||||
|
||||
def test_simple(self):
|
||||
"""Test simple asn wrapper"""
|
||||
# IPs from
|
||||
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-ASN-Test.json
|
||||
self.assertEqual(
|
||||
self.reader.asn_dict("1.0.0.1"),
|
||||
{
|
||||
"asn": 15169,
|
||||
"as_org": "Google Inc.",
|
||||
"network": "1.0.0.0/24",
|
||||
},
|
||||
)
|
|
@ -1,14 +1,14 @@
|
|||
"""Test GeoIP Wrapper"""
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.events.geo import GeoIPReader
|
||||
from authentik.events.context_processors.geoip import GeoIPContextProcessor
|
||||
|
||||
|
||||
class TestGeoIP(TestCase):
|
||||
"""Test GeoIP Wrapper"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.reader = GeoIPReader()
|
||||
self.reader = GeoIPContextProcessor()
|
||||
|
||||
def test_simple(self):
|
||||
"""Test simple city wrapper"""
|
|
@ -17,12 +17,13 @@ from django.db.models.base import Model
|
|||
from django.http.request import HttpRequest
|
||||
from django.utils import timezone
|
||||
from django.views.debug import SafeExceptionReporterFilter
|
||||
from geoip2.models import City
|
||||
from geoip2.models import ASN, City
|
||||
from guardian.utils import get_anonymous_user
|
||||
|
||||
from authentik.blueprints.v1.common import YAMLTag
|
||||
from authentik.core.models import User
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||
from authentik.policies.types import PolicyRequest
|
||||
|
||||
# Special keys which are *not* cleaned, even when the default filter
|
||||
|
@ -123,7 +124,9 @@ def sanitize_item(value: Any) -> Any:
|
|||
if isinstance(value, (HttpRequest, WSGIRequest)):
|
||||
return ...
|
||||
if isinstance(value, City):
|
||||
return GEOIP_READER.city_to_dict(value)
|
||||
return GEOIP_CONTEXT_PROCESSOR.city_to_dict(value)
|
||||
if isinstance(value, ASN):
|
||||
return ASN_CONTEXT_PROCESSOR.asn_to_dict(value)
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
if isinstance(value, Exception):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Generated by Django 4.2.6 on 2023-10-28 14:24
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import migrations
|
||||
from django.db import migrations, models
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
|
@ -31,4 +31,19 @@ class Migration(migrations.Migration):
|
|||
|
||||
operations = [
|
||||
migrations.RunPython(set_oobe_flow_authentication),
|
||||
migrations.AlterField(
|
||||
model_name="flow",
|
||||
name="authentication",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("none", "None"),
|
||||
("require_authenticated", "Require Authenticated"),
|
||||
("require_unauthenticated", "Require Unauthenticated"),
|
||||
("require_superuser", "Require Superuser"),
|
||||
("require_outpost", "Require Outpost"),
|
||||
],
|
||||
default="none",
|
||||
help_text="Required level of authentication and authorization to access a flow.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -31,6 +31,7 @@ class FlowAuthenticationRequirement(models.TextChoices):
|
|||
REQUIRE_AUTHENTICATED = "require_authenticated"
|
||||
REQUIRE_UNAUTHENTICATED = "require_unauthenticated"
|
||||
REQUIRE_SUPERUSER = "require_superuser"
|
||||
REQUIRE_OUTPOST = "require_outpost"
|
||||
|
||||
|
||||
class NotConfiguredAction(models.TextChoices):
|
||||
|
|
|
@ -23,6 +23,7 @@ from authentik.flows.models import (
|
|||
)
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
LOGGER = get_logger()
|
||||
PLAN_CONTEXT_PENDING_USER = "pending_user"
|
||||
|
@ -141,6 +142,10 @@ class FlowPlanner:
|
|||
and not request.user.is_superuser
|
||||
):
|
||||
raise FlowNonApplicableException()
|
||||
if self.flow.authentication == FlowAuthenticationRequirement.REQUIRE_OUTPOST:
|
||||
outpost_user = ClientIPMiddleware.get_outpost_user(request)
|
||||
if not outpost_user:
|
||||
raise FlowNonApplicableException()
|
||||
|
||||
def plan(
|
||||
self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
|
||||
|
|
|
@ -8,6 +8,7 @@ from django.test import RequestFactory, TestCase
|
|||
from django.urls import reverse
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
|
||||
from authentik.blueprints.tests import reconcile_app
|
||||
from authentik.core.models import User
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
||||
|
@ -15,9 +16,12 @@ from authentik.flows.markers import ReevaluateMarker, StageMarker
|
|||
from authentik.flows.models import FlowAuthenticationRequirement, FlowDesignation, FlowStageBinding
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key
|
||||
from authentik.lib.tests.utils import dummy_get_response
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.models import Outpost
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.policies.types import PolicyResult
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.dummy.models import DummyStage
|
||||
|
||||
POLICY_RETURN_FALSE = PropertyMock(return_value=PolicyResult(False))
|
||||
|
@ -68,6 +72,34 @@ class TestFlowPlanner(TestCase):
|
|||
planner.allow_empty_flows = True
|
||||
planner.plan(request)
|
||||
|
||||
@reconcile_app("authentik_outposts")
|
||||
def test_authentication_outpost(self):
|
||||
"""Test flow authentication (outpost)"""
|
||||
flow = create_test_flow()
|
||||
flow.authentication = FlowAuthenticationRequirement.REQUIRE_OUTPOST
|
||||
request = self.request_factory.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
request.user = AnonymousUser()
|
||||
with self.assertRaises(FlowNonApplicableException):
|
||||
planner = FlowPlanner(flow)
|
||||
planner.allow_empty_flows = True
|
||||
planner.plan(request)
|
||||
|
||||
outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||
request = self.request_factory.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
HTTP_X_AUTHENTIK_OUTPOST_TOKEN=outpost.token.key,
|
||||
HTTP_X_AUTHENTIK_REMOTE_IP="1.2.3.4",
|
||||
)
|
||||
request.user = AnonymousUser()
|
||||
middleware = ClientIPMiddleware(dummy_get_response)
|
||||
middleware(request)
|
||||
|
||||
planner = FlowPlanner(flow)
|
||||
planner.allow_empty_flows = True
|
||||
planner.plan(request)
|
||||
|
||||
@patch(
|
||||
"authentik.policies.engine.PolicyEngine.result",
|
||||
POLICY_RETURN_FALSE,
|
||||
|
|
|
@ -35,6 +35,7 @@ REDIS_ENV_KEYS = [
|
|||
]
|
||||
|
||||
DEPRECATIONS = {
|
||||
"geoip": "events.context_processors.geoip",
|
||||
"redis.broker_url": "broker.url",
|
||||
"redis.broker_transport_options": "broker.transport_options",
|
||||
"redis.cache_timeout": "cache.timeout",
|
||||
|
@ -112,6 +113,8 @@ class ConfigLoader:
|
|||
|
||||
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
||||
|
||||
deprecations: dict[tuple[str, str], str] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.__config = {}
|
||||
|
@ -138,9 +141,9 @@ class ConfigLoader:
|
|||
self.update_from_file(env_file)
|
||||
self.update_from_env()
|
||||
self.update(self.__config, kwargs)
|
||||
self.check_deprecations()
|
||||
self.deprecations = self.check_deprecations()
|
||||
|
||||
def check_deprecations(self):
|
||||
def check_deprecations(self) -> dict[str, str]:
|
||||
"""Warn if any deprecated configuration options are used"""
|
||||
|
||||
def _pop_deprecated_key(current_obj, dot_parts, index):
|
||||
|
@ -153,8 +156,10 @@ class ConfigLoader:
|
|||
current_obj.pop(dot_part)
|
||||
return value
|
||||
|
||||
deprecation_replacements = {}
|
||||
for deprecation, replacement in DEPRECATIONS.items():
|
||||
if self.get(deprecation, default=UNSET) is not UNSET:
|
||||
if self.get(deprecation, default=UNSET) is UNSET:
|
||||
continue
|
||||
message = (
|
||||
f"'{deprecation}' has been deprecated in favor of '{replacement}'! "
|
||||
+ "Please update your configuration."
|
||||
|
@ -163,15 +168,11 @@ class ConfigLoader:
|
|||
"warning",
|
||||
message,
|
||||
)
|
||||
try:
|
||||
from authentik.events.models import Event, EventAction
|
||||
|
||||
Event.new(EventAction.CONFIGURATION_ERROR, message=message).save()
|
||||
except ImportError:
|
||||
continue
|
||||
deprecation_replacements[(deprecation, replacement)] = message
|
||||
|
||||
deprecated_attr = _pop_deprecated_key(self.__config, deprecation.split("."), 0)
|
||||
self.set(replacement, deprecated_attr.value)
|
||||
self.set(replacement, deprecated_attr)
|
||||
return deprecation_replacements
|
||||
|
||||
def log(self, level: str, message: str, **kwargs):
|
||||
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
|
||||
|
@ -317,7 +318,9 @@ class ConfigLoader:
|
|||
|
||||
def set(self, path: str, value: Any, sep="."):
|
||||
"""Set value using same syntax as get()"""
|
||||
set_path_in_dict(self.raw, path, Attr(value), sep=sep)
|
||||
if not isinstance(value, Attr):
|
||||
value = Attr(value)
|
||||
set_path_in_dict(self.raw, path, value, sep=sep)
|
||||
|
||||
|
||||
CONFIG = ConfigLoader()
|
||||
|
|
|
@ -8,6 +8,8 @@ postgresql:
|
|||
password: "env://POSTGRES_PASSWORD"
|
||||
use_pgbouncer: false
|
||||
use_pgpool: false
|
||||
test:
|
||||
name: test_authentik
|
||||
|
||||
listen:
|
||||
listen_http: 0.0.0.0:9000
|
||||
|
@ -106,7 +108,10 @@ cookie_domain: null
|
|||
disable_update_check: false
|
||||
disable_startup_analytics: false
|
||||
avatars: env://AUTHENTIK_AUTHENTIK__AVATARS?gravatar,initials
|
||||
events:
|
||||
context_processors:
|
||||
geoip: "/geoip/GeoLite2-City.mmdb"
|
||||
asn: "/geoip/GeoLite2-ASN.mmdb"
|
||||
|
||||
footer_links: []
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ from django.test import RequestFactory, TestCase
|
|||
|
||||
from authentik.core.models import Token, TokenIntents, UserTypes
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip
|
||||
from authentik.lib.views import bad_request_message
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
|
||||
class TestHTTP(TestCase):
|
||||
|
@ -22,12 +22,12 @@ class TestHTTP(TestCase):
|
|||
def test_normal(self):
|
||||
"""Test normal request"""
|
||||
request = self.factory.get("/")
|
||||
self.assertEqual(get_client_ip(request), "127.0.0.1")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||
|
||||
def test_forward_for(self):
|
||||
"""Test x-forwarded-for request"""
|
||||
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
|
||||
self.assertEqual(get_client_ip(request), "127.0.0.2")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
|
||||
|
||||
def test_fake_outpost(self):
|
||||
"""Test faked IP which is overridden by an outpost"""
|
||||
|
@ -38,28 +38,28 @@ class TestHTTP(TestCase):
|
|||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
OUTPOST_REMOTE_IP_HEADER: "1.2.3.4",
|
||||
OUTPOST_TOKEN_HEADER: "abc",
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||
ClientIPMiddleware.outpost_token_header: "abc",
|
||||
},
|
||||
)
|
||||
self.assertEqual(get_client_ip(request), "127.0.0.1")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||
# Invalid, user doesn't have permissions
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
OUTPOST_REMOTE_IP_HEADER: "1.2.3.4",
|
||||
OUTPOST_TOKEN_HEADER: token.key,
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||
ClientIPMiddleware.outpost_token_header: token.key,
|
||||
},
|
||||
)
|
||||
self.assertEqual(get_client_ip(request), "127.0.0.1")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||
# Valid
|
||||
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
self.user.save()
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
OUTPOST_REMOTE_IP_HEADER: "1.2.3.4",
|
||||
OUTPOST_TOKEN_HEADER: token.key,
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||
ClientIPMiddleware.outpost_token_header: token.key,
|
||||
},
|
||||
)
|
||||
self.assertEqual(get_client_ip(request), "1.2.3.4")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "1.2.3.4")
|
||||
|
|
|
@ -1,89 +1,39 @@
|
|||
"""http helpers"""
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from django.http import HttpRequest
|
||||
from requests.sessions import Session
|
||||
from sentry_sdk.hub import Hub
|
||||
from django.conf import settings
|
||||
from requests.sessions import PreparedRequest, Session
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik import get_full_version
|
||||
|
||||
OUTPOST_REMOTE_IP_HEADER = "HTTP_X_AUTHENTIK_REMOTE_IP"
|
||||
OUTPOST_TOKEN_HEADER = "HTTP_X_AUTHENTIK_OUTPOST_TOKEN" # nosec
|
||||
DEFAULT_IP = "255.255.255.255"
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def _get_client_ip_from_meta(meta: dict[str, Any]) -> str:
|
||||
"""Attempt to get the client's IP by checking common HTTP Headers.
|
||||
Returns none if no IP Could be found
|
||||
|
||||
No additional validation is done here as requests are expected to only arrive here
|
||||
via the go proxy, which deals with validating these headers for us"""
|
||||
headers = (
|
||||
"HTTP_X_FORWARDED_FOR",
|
||||
"REMOTE_ADDR",
|
||||
)
|
||||
for _header in headers:
|
||||
if _header in meta:
|
||||
ips: list[str] = meta.get(_header).split(",")
|
||||
return ips[0].strip()
|
||||
return DEFAULT_IP
|
||||
|
||||
|
||||
def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
|
||||
"""Get the actual remote IP when set by an outpost. Only
|
||||
allowed when the request is authenticated, by an outpost internal service account"""
|
||||
from authentik.core.models import Token, TokenIntents, UserTypes
|
||||
|
||||
if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
|
||||
return None
|
||||
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
|
||||
token = (
|
||||
Token.filter_not_expired(
|
||||
key=request.META.get(OUTPOST_TOKEN_HEADER), intent=TokenIntents.INTENT_API
|
||||
)
|
||||
.select_related("user")
|
||||
.first()
|
||||
)
|
||||
if not token:
|
||||
LOGGER.warning("Attempted remote-ip override without token", fake_ip=fake_ip)
|
||||
return None
|
||||
user = token.user
|
||||
if user.type != UserTypes.INTERNAL_SERVICE_ACCOUNT:
|
||||
LOGGER.warning(
|
||||
"Remote-IP override: user doesn't have permission",
|
||||
user=user,
|
||||
fake_ip=fake_ip,
|
||||
)
|
||||
return None
|
||||
# Update sentry scope to include correct IP
|
||||
user = Hub.current.scope._user
|
||||
if not user:
|
||||
user = {}
|
||||
user["ip_address"] = fake_ip
|
||||
Hub.current.scope.set_user(user)
|
||||
return fake_ip
|
||||
|
||||
|
||||
def get_client_ip(request: Optional[HttpRequest]) -> str:
|
||||
"""Attempt to get the client's IP by checking common HTTP Headers.
|
||||
Returns none if no IP Could be found"""
|
||||
if not request:
|
||||
return DEFAULT_IP
|
||||
override = _get_outpost_override_ip(request)
|
||||
if override:
|
||||
return override
|
||||
return _get_client_ip_from_meta(request.META)
|
||||
|
||||
|
||||
def authentik_user_agent() -> str:
|
||||
"""Get a common user agent"""
|
||||
return f"authentik@{get_full_version()}"
|
||||
|
||||
|
||||
class DebugSession(Session):
|
||||
"""requests session which logs http requests and responses"""
|
||||
|
||||
def send(self, req: PreparedRequest, *args, **kwargs):
|
||||
request_id = str(uuid4())
|
||||
LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
|
||||
resp = super().send(req, *args, **kwargs)
|
||||
LOGGER.debug(
|
||||
"HTTP response received",
|
||||
uid=request_id,
|
||||
status=resp.status_code,
|
||||
body=resp.text,
|
||||
headers=resp.headers,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
def get_http_session() -> Session:
|
||||
"""Get a requests session with common headers"""
|
||||
session = Session()
|
||||
session = DebugSession() if settings.DEBUG else Session()
|
||||
session.headers["User-Agent"] = authentik_user_agent()
|
||||
return session
|
||||
|
|
|
@ -9,16 +9,17 @@ from rest_framework.fields import BooleanField, CharField, DateTimeField
|
|||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import JSONField, ModelSerializer, ValidationError
|
||||
from rest_framework.serializers import ModelSerializer, ValidationError
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik import get_build_hash
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.providers.rac.models import RACProvider
|
||||
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
|
||||
from authentik.outposts.models import (
|
||||
Outpost,
|
||||
OutpostConfig,
|
||||
|
@ -34,7 +35,7 @@ from authentik.providers.radius.models import RadiusProvider
|
|||
class OutpostSerializer(ModelSerializer):
|
||||
"""Outpost Serializer"""
|
||||
|
||||
config = JSONField(validators=[is_dict], source="_config")
|
||||
config = JSONDictField(source="_config")
|
||||
# Need to set allow_empty=True for the embedded outpost with no providers
|
||||
# is checked for other providers in the API Viewset
|
||||
providers = PrimaryKeyRelatedField(
|
||||
|
@ -47,12 +48,23 @@ class OutpostSerializer(ModelSerializer):
|
|||
source="service_connection", read_only=True
|
||||
)
|
||||
|
||||
def validate_name(self, name: str) -> str:
|
||||
"""Validate name (especially for embedded outpost)"""
|
||||
if not self.instance:
|
||||
return name
|
||||
if self.instance.managed == MANAGED_OUTPOST and name != MANAGED_OUTPOST_NAME:
|
||||
raise ValidationError("Embedded outpost's name cannot be changed")
|
||||
if self.instance.name == MANAGED_OUTPOST_NAME:
|
||||
self.instance.managed = MANAGED_OUTPOST
|
||||
return name
|
||||
|
||||
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
|
||||
"""Check that all providers match the type of the outpost"""
|
||||
type_map = {
|
||||
OutpostType.LDAP: LDAPProvider,
|
||||
OutpostType.PROXY: ProxyProvider,
|
||||
OutpostType.RADIUS: RadiusProvider,
|
||||
OutpostType.RAC: RACProvider,
|
||||
None: Provider,
|
||||
}
|
||||
for provider in providers:
|
||||
|
@ -95,7 +107,7 @@ class OutpostSerializer(ModelSerializer):
|
|||
class OutpostDefaultConfigSerializer(PassiveSerializer):
|
||||
"""Global default outpost config"""
|
||||
|
||||
config = JSONField(read_only=True)
|
||||
config = JSONDictField(read_only=True)
|
||||
|
||||
|
||||
class OutpostHealthSerializer(PassiveSerializer):
|
||||
|
|
|
@ -15,6 +15,7 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
|
|||
["outpost", "uid", "version"],
|
||||
)
|
||||
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
|
||||
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
|
||||
|
||||
|
||||
class AuthentikOutpostConfig(ManagedAppConfig):
|
||||
|
@ -35,14 +36,17 @@ class AuthentikOutpostConfig(ManagedAppConfig):
|
|||
DockerServiceConnection,
|
||||
KubernetesServiceConnection,
|
||||
Outpost,
|
||||
OutpostConfig,
|
||||
OutpostType,
|
||||
)
|
||||
|
||||
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
|
||||
outpost.managed = MANAGED_OUTPOST
|
||||
outpost.save()
|
||||
return
|
||||
outpost, updated = Outpost.objects.update_or_create(
|
||||
defaults={
|
||||
"name": "authentik Embedded Outpost",
|
||||
"type": OutpostType.PROXY,
|
||||
"name": MANAGED_OUTPOST_NAME,
|
||||
},
|
||||
managed=MANAGED_OUTPOST,
|
||||
)
|
||||
|
@ -51,10 +55,4 @@ class AuthentikOutpostConfig(ManagedAppConfig):
|
|||
outpost.service_connection = KubernetesServiceConnection.objects.first()
|
||||
elif DockerServiceConnection.objects.exists():
|
||||
outpost.service_connection = DockerServiceConnection.objects.first()
|
||||
outpost.config = OutpostConfig(
|
||||
kubernetes_disabled_components=[
|
||||
"deployment",
|
||||
"secret",
|
||||
]
|
||||
)
|
||||
outpost.save()
|
||||
|
|
|
@ -6,16 +6,18 @@ from typing import Any, Optional
|
|||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from dacite.core import from_dict
|
||||
from dacite.data import Data
|
||||
from django.http.request import QueryDict
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.channels import AuthJsonConsumer
|
||||
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
||||
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
||||
|
||||
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
|
||||
OUTPOST_GROUP_INSTANCE = "group_outpost_%(outpost_pk)s_%(instance)s"
|
||||
|
||||
|
||||
class WebsocketMessageInstruction(IntEnum):
|
||||
|
@ -42,25 +44,23 @@ class WebsocketMessage:
|
|||
args: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OutpostConsumer(AuthJsonConsumer):
|
||||
class OutpostConsumer(JsonWebsocketConsumer):
|
||||
"""Handler for Outposts that connect over websockets for health checks and live updates"""
|
||||
|
||||
outpost: Optional[Outpost] = None
|
||||
logger: BoundLogger
|
||||
|
||||
last_uid: Optional[str] = None
|
||||
instance_uid: Optional[str] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.logger = get_logger()
|
||||
|
||||
def connect(self):
|
||||
super().connect()
|
||||
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
||||
user = self.scope["user"]
|
||||
outpost = (
|
||||
get_objects_for_user(self.user, "authentik_outposts.view_outpost")
|
||||
.filter(pk=uuid)
|
||||
.first()
|
||||
get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
|
||||
)
|
||||
if not outpost:
|
||||
raise DenyConnection()
|
||||
|
@ -71,13 +71,19 @@ class OutpostConsumer(AuthJsonConsumer):
|
|||
self.logger.warning("runtime error during accept", exc=exc)
|
||||
raise DenyConnection()
|
||||
self.outpost = outpost
|
||||
self.last_uid = self.channel_name
|
||||
query = QueryDict(self.scope["query_string"].decode())
|
||||
self.instance_uid = query.get("instance_uuid", self.channel_name)
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
OUTPOST_GROUP_INSTANCE
|
||||
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||
self.channel_name,
|
||||
)
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid,
|
||||
uid=self.instance_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).inc()
|
||||
|
||||
|
@ -86,34 +92,37 @@ class OutpostConsumer(AuthJsonConsumer):
|
|||
async_to_sync(self.channel_layer.group_discard)(
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
if self.outpost and self.last_uid:
|
||||
if self.instance_uid:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
OUTPOST_GROUP_INSTANCE
|
||||
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||
self.channel_name,
|
||||
)
|
||||
if self.outpost and self.instance_uid:
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid,
|
||||
uid=self.instance_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).dec()
|
||||
|
||||
def receive_json(self, content: Data, **kwargs):
|
||||
msg = from_dict(WebsocketMessage, content)
|
||||
uid = msg.args.get("uuid", self.channel_name)
|
||||
self.last_uid = uid
|
||||
|
||||
if not self.outpost:
|
||||
raise DenyConnection()
|
||||
|
||||
state = OutpostState.for_instance_uid(self.outpost, uid)
|
||||
state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
|
||||
state.last_seen = datetime.now()
|
||||
state.hostname = msg.args.pop("hostname", "")
|
||||
|
||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||
state.version = msg.args.pop("version", None)
|
||||
state.build_hash = msg.args.pop("buildHash", "")
|
||||
state.args = msg.args
|
||||
state.args.update(msg.args)
|
||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||
return
|
||||
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid or "",
|
||||
uid=self.instance_uid or "",
|
||||
version=state.version or "",
|
||||
).set_to_current_time()
|
||||
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
|
||||
|
|
|
@ -43,6 +43,10 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
|
|||
self.api = AppsV1Api(controller.client)
|
||||
self.outpost = self.controller.outpost
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return self.is_embedded
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "deployment"
|
||||
|
|
|
@ -24,6 +24,10 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
|
|||
super().__init__(controller)
|
||||
self.api = CoreV1Api(controller.client)
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return self.is_embedded
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "secret"
|
||||
|
|
|
@ -77,7 +77,10 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
|
|||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return (not self._crd_exists()) or (self.is_embedded)
|
||||
if not self._crd_exists():
|
||||
self.logger.debug("CRD doesn't exist")
|
||||
return True
|
||||
return self.is_embedded
|
||||
|
||||
def _crd_exists(self) -> bool:
|
||||
"""Check if the Prometheus ServiceMonitor exists"""
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""k8s utils"""
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from kubernetes.client.models.v1_container_port import V1ContainerPort
|
||||
from kubernetes.client.models.v1_service_port import V1ServicePort
|
||||
|
@ -37,9 +38,12 @@ def compare_port(
|
|||
|
||||
|
||||
def compare_ports(
|
||||
current: list[V1ServicePort | V1ContainerPort], reference: list[V1ServicePort | V1ContainerPort]
|
||||
current: Optional[list[V1ServicePort | V1ContainerPort]],
|
||||
reference: Optional[list[V1ServicePort | V1ContainerPort]],
|
||||
):
|
||||
"""Compare ports of a list"""
|
||||
if not current or not reference:
|
||||
raise NeedsRecreate()
|
||||
if len(current) != len(reference):
|
||||
raise NeedsRecreate()
|
||||
for port in reference:
|
||||
|
|
|
@ -81,7 +81,10 @@ class KubernetesController(BaseController):
|
|||
def up(self):
|
||||
try:
|
||||
for reconcile_key in self.reconcile_order:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
reconciler.up()
|
||||
|
||||
except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc:
|
||||
|
@ -95,7 +98,10 @@ class KubernetesController(BaseController):
|
|||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||
continue
|
||||
with capture_logs() as logs:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
reconciler.up()
|
||||
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
||||
return all_logs
|
||||
|
@ -105,7 +111,10 @@ class KubernetesController(BaseController):
|
|||
def down(self):
|
||||
try:
|
||||
for reconcile_key in self.reconcile_order:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
self.logger.debug("Tearing down object", name=reconcile_key)
|
||||
reconciler.down()
|
||||
|
||||
|
@ -120,7 +129,10 @@ class KubernetesController(BaseController):
|
|||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||
continue
|
||||
with capture_logs() as logs:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
reconciler.down()
|
||||
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
||||
return all_logs
|
||||
|
@ -130,7 +142,10 @@ class KubernetesController(BaseController):
|
|||
def get_static_deployment(self) -> str:
|
||||
documents = []
|
||||
for reconcile_key in self.reconcile_order:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
if reconciler.noop:
|
||||
continue
|
||||
documents.append(reconciler.get_reference_object().to_dict())
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# Generated by Django 4.2.6 on 2023-10-14 19:23
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_outposts", "0020_alter_outpost_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="outpost",
|
||||
name="type",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("proxy", "Proxy"),
|
||||
("ldap", "Ldap"),
|
||||
("radius", "Radius"),
|
||||
("rac", "Rac"),
|
||||
],
|
||||
default="proxy",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -90,11 +90,12 @@ class OutpostModel(Model):
|
|||
|
||||
|
||||
class OutpostType(models.TextChoices):
|
||||
"""Outpost types, currently only the reverse proxy is available"""
|
||||
"""Outpost types"""
|
||||
|
||||
PROXY = "proxy"
|
||||
LDAP = "ldap"
|
||||
RADIUS = "radius"
|
||||
RAC = "rac"
|
||||
|
||||
|
||||
def default_outpost_config(host: Optional[str] = None):
|
||||
|
@ -459,7 +460,7 @@ class OutpostState:
|
|||
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
|
||||
"""Get state for a single instance"""
|
||||
key = f"{outpost.state_cache_prefix}/{uid}"
|
||||
default_data = {"uid": uid, "channel_ids": []}
|
||||
default_data = {"uid": uid}
|
||||
data = cache.get(key, default_data)
|
||||
if isinstance(data, str):
|
||||
cache.delete(key)
|
||||
|
|
|
@ -17,6 +17,8 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
|
|||
from structlog.stdlib import get_logger
|
||||
from yaml import safe_load
|
||||
|
||||
from authentik.enterprise.providers.rac.controllers.docker import RACDockerController
|
||||
from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
|
@ -71,6 +73,11 @@ def controller_for_outpost(outpost: Outpost) -> Optional[type[BaseController]]:
|
|||
return RadiusDockerController
|
||||
if isinstance(service_connection, KubernetesServiceConnection):
|
||||
return RadiusKubernetesController
|
||||
if outpost.type == OutpostType.RAC:
|
||||
if isinstance(service_connection, DockerServiceConnection):
|
||||
return RACDockerController
|
||||
if isinstance(service_connection, KubernetesServiceConnection):
|
||||
return RACKubernetesController
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -2,11 +2,13 @@
|
|||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.blueprints.tests import reconcile_app
|
||||
from authentik.core.models import PropertyMapping
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.outposts.api.outposts import OutpostSerializer
|
||||
from authentik.outposts.models import OutpostType, default_outpost_config
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.models import Outpost, OutpostType, default_outpost_config
|
||||
from authentik.providers.ldap.models import LDAPProvider
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
|
||||
|
@ -22,7 +24,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_outpost_validaton(self):
|
||||
@reconcile_app("authentik_outposts")
|
||||
def test_managed_name_change(self):
|
||||
"""Test name change for embedded outpost"""
|
||||
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||
self.assertIsNotNone(embedded_outpost)
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
|
||||
{"name": "foo"},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content, {"name": ["Embedded outpost's name cannot be changed"]}
|
||||
)
|
||||
|
||||
@reconcile_app("authentik_outposts")
|
||||
def test_managed_without_managed(self):
|
||||
"""Test name change for embedded outpost"""
|
||||
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||
self.assertIsNotNone(embedded_outpost)
|
||||
embedded_outpost.managed = ""
|
||||
embedded_outpost.save()
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
|
||||
{"name": "foo"},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
embedded_outpost.refresh_from_db()
|
||||
self.assertEqual(embedded_outpost.managed, MANAGED_OUTPOST)
|
||||
|
||||
def test_outpost_validation(self):
|
||||
"""Test Outpost validation"""
|
||||
valid = OutpostSerializer(
|
||||
data={
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Websocket tests"""
|
||||
from dataclasses import asdict
|
||||
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.routing import URLRouter
|
||||
from channels.testing import WebsocketCommunicator
|
||||
from django.test import TransactionTestCase
|
||||
|
@ -35,6 +36,7 @@ class TestOutpostWS(TransactionTestCase):
|
|||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns), f"/ws/outpost/{self.outpost.pk}/"
|
||||
)
|
||||
with self.assertRaises(DenyConnection):
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertFalse(connected)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Outpost Websocket URLS"""
|
||||
from django.urls import path
|
||||
|
||||
from authentik.core.channels import TokenOutpostMiddleware
|
||||
from authentik.outposts.api.outposts import OutpostViewSet
|
||||
from authentik.outposts.api.service_connections import (
|
||||
DockerServiceConnectionViewSet,
|
||||
|
@ -11,7 +12,10 @@ from authentik.outposts.consumer import OutpostConsumer
|
|||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path("ws/outpost/<uuid:pk>/", ChannelsLoggingMiddleware(OutpostConsumer.as_asgi())),
|
||||
path(
|
||||
"ws/outpost/<uuid:pk>/",
|
||||
ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
"""Serializer for policy execution"""
|
||||
from rest_framework.fields import BooleanField, CharField, DictField, JSONField, ListField
|
||||
from rest_framework.fields import BooleanField, CharField, DictField, ListField
|
||||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.models import User
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ class PolicyTestSerializer(PassiveSerializer):
|
|||
"""Test policy execution for a user with context"""
|
||||
|
||||
user = PrimaryKeyRelatedField(queryset=User.objects.all())
|
||||
context = JSONField(required=False, validators=[is_dict])
|
||||
context = JSONDictField(required=False)
|
||||
|
||||
|
||||
class PolicyTestResultSerializer(PassiveSerializer):
|
||||
|
|
|
@ -7,9 +7,9 @@ from structlog.stdlib import get_logger
|
|||
|
||||
from authentik.flows.planner import PLAN_CONTEXT_SSO
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
LOGGER = get_logger()
|
||||
if TYPE_CHECKING:
|
||||
|
@ -49,7 +49,7 @@ class PolicyEvaluator(BaseEvaluator):
|
|||
"""Update context based on http request"""
|
||||
# update website/docs/expressions/_objects.md
|
||||
# update website/docs/expressions/_functions.md
|
||||
self._context["ak_client_ip"] = ip_address(get_client_ip(request))
|
||||
self._context["ak_client_ip"] = ip_address(ClientIPMiddleware.get_client_ip(request))
|
||||
self._context["http_request"] = request
|
||||
|
||||
def handle_error(self, exc: Exception, expression_source: str):
|
||||
|
|
|
@ -47,6 +47,7 @@ class ReputationSerializer(ModelSerializer):
|
|||
"identifier",
|
||||
"ip",
|
||||
"ip_geo_data",
|
||||
"ip_asn_data",
|
||||
"score",
|
||||
"updated",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 4.2.7 on 2023-12-05 22:20
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_policies_reputation", "0005_reputation_expires_reputation_expiring"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="reputation",
|
||||
name="ip_asn_data",
|
||||
field=models.JSONField(default=dict),
|
||||
),
|
||||
]
|
|
@ -13,9 +13,9 @@ from structlog import get_logger
|
|||
from authentik.core.models import ExpiringModel
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
from authentik.policies.models import Policy
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
|
||||
|
@ -44,7 +44,7 @@ class ReputationPolicy(Policy):
|
|||
return "ak-policy-reputation-form"
|
||||
|
||||
def passes(self, request: PolicyRequest) -> PolicyResult:
|
||||
remote_ip = get_client_ip(request.http_request)
|
||||
remote_ip = ClientIPMiddleware.get_client_ip(request.http_request)
|
||||
query = Q()
|
||||
if self.check_ip:
|
||||
query |= Q(ip=remote_ip)
|
||||
|
@ -76,6 +76,7 @@ class Reputation(ExpiringModel, SerializerModel):
|
|||
identifier = models.TextField()
|
||||
ip = models.GenericIPAddressField()
|
||||
ip_geo_data = models.JSONField(default=dict)
|
||||
ip_asn_data = models.JSONField(default=dict)
|
||||
score = models.BigIntegerField(default=0)
|
||||
|
||||
expires = models.DateTimeField(default=reputation_expiry)
|
||||
|
|
|
@ -7,9 +7,9 @@ from structlog.stdlib import get_logger
|
|||
|
||||
from authentik.core.signals import login_failed
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
from authentik.policies.reputation.models import CACHE_KEY_PREFIX
|
||||
from authentik.policies.reputation.tasks import save_reputation
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.identification.signals import identification_failed
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
@ -18,7 +18,7 @@ CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_reputation")
|
|||
|
||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
"""Update score for IP and User"""
|
||||
remote_ip = get_client_ip(request)
|
||||
remote_ip = ClientIPMiddleware.get_client_ip(request)
|
||||
|
||||
try:
|
||||
# We only update the cache here, as its faster than writing to the DB
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
from django.core.cache import cache
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
|
@ -26,7 +27,8 @@ def save_reputation(self: MonitoredTask):
|
|||
ip=score["ip"],
|
||||
identifier=score["identifier"],
|
||||
)
|
||||
rep.ip_geo_data = GEOIP_READER.city_dict(score["ip"]) or {}
|
||||
rep.ip_geo_data = GEOIP_CONTEXT_PROCESSOR.city_dict(score["ip"]) or {}
|
||||
rep.ip_asn_data = ASN_CONTEXT_PROCESSOR.asn_dict(score["ip"]) or {}
|
||||
rep.score = score["score"]
|
||||
objects_to_update.append(rep)
|
||||
Reputation.objects.bulk_update(objects_to_update, ["score", "ip_geo_data"])
|
||||
|
|
|
@ -8,8 +8,7 @@ from django.db.models import Model
|
|||
from django.http import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.geo import GEOIP_READER
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
from authentik.events.context_processors.base import get_context_processors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.core.models import User
|
||||
|
@ -39,12 +38,8 @@ class PolicyRequest:
|
|||
def set_http_request(self, request: HttpRequest): # pragma: no cover
|
||||
"""Load data from HTTP request, including geoip when enabled"""
|
||||
self.http_request = request
|
||||
if not GEOIP_READER.enabled:
|
||||
return
|
||||
client_ip = get_client_ip(request)
|
||||
if not client_ip:
|
||||
return
|
||||
self.context["geoip"] = GEOIP_READER.city(client_ip)
|
||||
for processor in get_context_processors():
|
||||
self.context.update(processor.enrich_context(request))
|
||||
|
||||
@property
|
||||
def should_cache(self) -> bool:
|
||||
|
|
|
@ -7,8 +7,8 @@ GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"
|
|||
GRANT_TYPE_PASSWORD = "password" # nosec
|
||||
GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
CLIENT_ASSERTION_TYPE = "client_assertion_type"
|
||||
CLIENT_ASSERTION = "client_assertion"
|
||||
CLIENT_ASSERTION_TYPE = "client_assertion_type"
|
||||
CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
|
||||
PROMPT_NONE = "none"
|
||||
|
@ -18,9 +18,9 @@ PROMPT_LOGIN = "login"
|
|||
SCOPE_OPENID = "openid"
|
||||
SCOPE_OPENID_PROFILE = "profile"
|
||||
SCOPE_OPENID_EMAIL = "email"
|
||||
SCOPE_OFFLINE_ACCESS = "offline_access"
|
||||
|
||||
# https://www.iana.org/assignments/oauth-parameters/\
|
||||
# oauth-parameters.xhtml#pkce-code-challenge-method
|
||||
# https://www.iana.org/assignments/oauth-parameters/auth-parameters.xhtml#pkce-code-challenge-method
|
||||
PKCE_METHOD_PLAIN = "plain"
|
||||
PKCE_METHOD_S256 = "S256"
|
||||
|
||||
|
@ -36,6 +36,12 @@ SCOPE_GITHUB_USER_READ = "read:user"
|
|||
SCOPE_GITHUB_USER_EMAIL = "user:email"
|
||||
# Read info about teams
|
||||
SCOPE_GITHUB_ORG_READ = "read:org"
|
||||
SCOPE_GITHUB = {
|
||||
SCOPE_GITHUB_USER,
|
||||
SCOPE_GITHUB_USER_READ,
|
||||
SCOPE_GITHUB_USER_EMAIL,
|
||||
SCOPE_GITHUB_ORG_READ,
|
||||
}
|
||||
|
||||
ACR_AUTHENTIK_DEFAULT = "goauthentik.io/providers/oauth2/default"
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue