Merge branch 'main' into 5165-password-strength-indicator

* main: (160 commits)
  website: update hackathon with prize pool (#6170)
  web: bump @babel/plugin-transform-runtime from 7.22.6 to 7.22.7 in /web (#6166)
  web: bump @babel/core from 7.22.6 to 7.22.7 in /web (#6165)
  web: bump @babel/plugin-proposal-decorators from 7.22.6 to 7.22.7 in /web (#6167)
  web: bump @babel/preset-env from 7.22.6 to 7.22.7 in /web (#6168)
  website: bump prettier from 2.8.8 to 3.0.0 in /website (#6155)
  web: bump storybook from 7.0.25 to 7.0.26 in /web (#6162)
  core: bump goauthentik.io/api/v3 from 3.2023054.2 to 3.2023054.4 (#6154)
  core: bump golang.org/x/oauth2 from 0.9.0 to 0.10.0 (#6153)
  web: bump @storybook/addon-essentials from 7.0.25 to 7.0.26 in /web (#6158)
  ci: bump actions/setup-node from 3.6.0 to 3.7.0 (#6156)
  web: bump core-js from 3.31.0 to 3.31.1 in /web (#6160)
  web: bump @storybook/addon-links from 7.0.25 to 7.0.26 in /web (#6159)
  web: bump @storybook/web-components-vite from 7.0.25 to 7.0.26 in /web (#6163)
  web: bump lit from 2.7.5 to 2.7.6 in /web (#6161)
  core: bump lxml from 4.9.2 to 4.9.3 (#6151)
  web: bump @babel/core from 7.22.5 to 7.22.6 in /web (#6143)
  web: bump @babel/plugin-transform-runtime from 7.22.5 to 7.22.6 in /web (#6142)
  web: bump @babel/preset-env from 7.22.5 to 7.22.6 in /web (#6144)
  web: bump @babel/plugin-proposal-decorators from 7.22.5 to 7.22.6 in /web (#6141)
  ...
This commit is contained in:
Ken Sternberg 2023-07-06 08:05:05 -07:00
commit 465820b002
196 changed files with 10328 additions and 6791 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2023.5.3 current_version = 2023.5.4
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@ -0,0 +1,17 @@
---
name: Hackathon Idea
about: Propose an idea for the hackathon
title: ""
labels: hackathon
assignees: ""
---
**Describe the idea**
A clear concise description of the idea you want to implement
You're also free to work on existing GitHub issues, whether they be feature requests or bugs, just link the existing GitHub issue here.
<!-- Don't modify below here -->
If you want to help working on this idea or want to contribute in any other way, react to this issue with a :rocket:

View File

@ -24,6 +24,18 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "web:" prefix: "web:"
groups:
sentry:
patterns:
- "@sentry/*"
babel:
patterns:
- "@babel/*"
- "babel-*"
storybook:
patterns:
- "@storybook/*"
- "*storybook*"
- package-ecosystem: npm - package-ecosystem: npm
directory: "/website" directory: "/website"
schedule: schedule:
@ -32,6 +44,10 @@ updates:
open-pull-requests-limit: 10 open-pull-requests-limit: 10
commit-message: commit-message:
prefix: "website:" prefix: "website:"
groups:
docusaurus:
patterns:
- "@docusaurus/*"
- package-ecosystem: pip - package-ecosystem: pip
directory: "/" directory: "/"
schedule: schedule:

19
.github/stale.yml vendored
View File

@ -1,19 +0,0 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
- pr_wanted
- enhancement
- bug/confirmed
- enhancement/confirmed
- question
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
only: issues

View File

@ -218,6 +218,7 @@ jobs:
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }} ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
build-args: | build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
- name: Comment on PR - name: Comment on PR
if: github.event_name == 'pull_request' if: github.event_name == 'pull_request'
@ -262,5 +263,6 @@ jobs:
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}-arm64 ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}-arm64
build-args: | build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
platforms: linux/arm64 platforms: linux/arm64

View File

@ -95,6 +95,7 @@ jobs:
file: ${{ matrix.type }}.Dockerfile file: ${{ matrix.type }}.Dockerfile
build-args: | build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
context: . context: .
@ -119,7 +120,7 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

View File

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -31,7 +31,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -47,7 +47,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -63,7 +63,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -95,7 +95,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

View File

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -29,7 +29,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"
@ -50,7 +50,7 @@ jobs:
- build-docs-only - build-docs-only
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

View File

@ -16,10 +16,5 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
ref: main ref: main
- id: main-state - run: |
run: | git push origin --force main:next
state=$(curl -fsSL -H "Accept: application/vnd.github+json" -H "Authorization: Bearer ${{ github.token }}" "https://api.github.com/repos/${{ github.repository }}/commits/HEAD/state" | jq -r '.state')
echo "state=${state}" >> $GITHUB_OUTPUT
- if: ${{ steps.main-state.outputs.state == 'success' }}
run: |
git push origin next --force

View File

@ -43,6 +43,7 @@ jobs:
ghcr.io/goauthentik/server:latest ghcr.io/goauthentik/server:latest
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
build-args: | build-args: |
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost: build-outpost:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -90,6 +91,7 @@ jobs:
file: ${{ matrix.type }}.Dockerfile file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
build-args: | build-args: |
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost-binary: build-outpost-binary:
timeout-minutes: 120 timeout-minutes: 120
@ -108,7 +110,7 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
cache: "npm" cache: "npm"

33
.github/workflows/repo-stale.yml vendored Normal file
View File

@ -0,0 +1,33 @@
name: 'authentik-repo-stale'
on:
schedule:
- cron: '30 1 * * *'
workflow_dispatch:
permissions:
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- id: generate_token
uses: tibdex/github-app-token@v1
with:
app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
- uses: actions/stale@v8
with:
repo-token: ${{ steps.generate_token.outputs.token }}
days-before-stale: 60
days-before-close: 7
exempt-issue-labels: pinned,security,pr_wanted,enhancement,bug/confirmed,enhancement/confirmed,question
stale-issue-label: wontfix
stale-issue-message: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
# Don't stale PRs, so only apply to PRs with a non-existent label
only-pr-labels: foo

View File

@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
- uses: actions/setup-node@v3.6.0 - uses: actions/setup-node@v3.7.0
with: with:
node-version: "20" node-version: "20"
registry-url: "https://registry.npmjs.org" registry-url: "https://registry.npmjs.org"

1
.gitignore vendored
View File

@ -166,6 +166,7 @@ dmypy.json
# SageMath parsed files # SageMath parsed files
# Environments # Environments
**/.DS_Store
# Spyder project settings # Spyder project settings

27
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,27 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: PDB attach Server",
"type": "python",
"request": "attach",
"connect": {
"host": "localhost",
"port": 6800
},
"justMyCode": true,
"django": true
},
{
"name": "Python: PDB attach Worker",
"type": "python",
"request": "attach",
"connect": {
"host": "localhost",
"port": 6900
},
"justMyCode": true,
"django": true
},
]
}

View File

@ -65,15 +65,18 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
# Stage 6: Run # Stage 6: Run
FROM docker.io/python:3.11.4-slim-bullseye AS final-image FROM docker.io/python:3.11.4-slim-bullseye AS final-image
ARG GIT_BUILD_HASH
ARG VERSION
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
LABEL org.opencontainers.image.url https://goauthentik.io LABEL org.opencontainers.image.url https://goauthentik.io
LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info. LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info.
LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version ${VERSION}
LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
WORKDIR / WORKDIR /
ARG GIT_BUILD_HASH
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
COPY --from=poetry-locker /work/requirements.txt / COPY --from=poetry-locker /work/requirements.txt /
COPY --from=poetry-locker /work/requirements-dev.txt / COPY --from=poetry-locker /work/requirements-dev.txt /
COPY --from=geoip /usr/share/GeoIP /geoip COPY --from=geoip /usr/share/GeoIP /geoip

View File

@ -2,7 +2,7 @@
from os import environ from os import environ
from typing import Optional from typing import Optional
__version__ = "2023.5.3" __version__ = "2023.5.4"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -8,6 +8,7 @@ from rest_framework.viewsets import ViewSet
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.lib.utils.reflection import get_apps from authentik.lib.utils.reflection import get_apps
from authentik.policies.event_matcher.models import model_choices
class AppSerializer(PassiveSerializer): class AppSerializer(PassiveSerializer):
@ -29,3 +30,17 @@ class AppsViewSet(ViewSet):
for app in sorted(get_apps(), key=lambda app: app.name): for app in sorted(get_apps(), key=lambda app: app.name):
data.append({"name": app.name, "label": app.verbose_name}) data.append({"name": app.name, "label": app.verbose_name})
return Response(AppSerializer(data, many=True).data) return Response(AppSerializer(data, many=True).data)
class ModelViewSet(ViewSet):
"""Read-only view list all installed models"""
permission_classes = [IsAdminUser]
@extend_schema(responses={200: AppSerializer(many=True)})
def list(self, request: Request) -> Response:
"""Read-only view list all installed models"""
data = []
for name, label in model_choices():
data.append({"name": name, "label": label})
return Response(AppSerializer(data, many=True).data)

View File

@ -1,5 +1,4 @@
"""authentik administration overview""" """authentik administration overview"""
import os
import platform import platform
from datetime import datetime from datetime import datetime
from sys import version as python_version from sys import version as python_version
@ -34,7 +33,6 @@ class RuntimeDict(TypedDict):
class SystemSerializer(PassiveSerializer): class SystemSerializer(PassiveSerializer):
"""Get system information.""" """Get system information."""
env = SerializerMethodField()
http_headers = SerializerMethodField() http_headers = SerializerMethodField()
http_host = SerializerMethodField() http_host = SerializerMethodField()
http_is_secure = SerializerMethodField() http_is_secure = SerializerMethodField()
@ -43,10 +41,6 @@ class SystemSerializer(PassiveSerializer):
server_time = SerializerMethodField() server_time = SerializerMethodField()
embedded_outpost_host = SerializerMethodField() embedded_outpost_host = SerializerMethodField()
def get_env(self, request: Request) -> dict[str, str]:
"""Get Environment"""
return os.environ.copy()
def get_http_headers(self, request: Request) -> dict[str, str]: def get_http_headers(self, request: Request) -> dict[str, str]:
"""Get HTTP Request headers""" """Get HTTP Request headers"""
headers = {} headers = {}

View File

@ -19,7 +19,7 @@ class WorkerView(APIView):
def get(self, request: Request) -> Response: def get(self, request: Request) -> Response:
"""Get currently connected worker count.""" """Get currently connected worker count."""
count = len(CELERY_APP.control.ping(timeout=0.5)) count = len(CELERY_APP.control.ping(timeout=0.5))
# In debug we run with `CELERY_TASK_ALWAYS_EAGER`, so tasks are ran on the main process # In debug we run with `task_always_eager`, so tasks are ran on the main process
if settings.DEBUG: # pragma: no cover if settings.DEBUG: # pragma: no cover
count += 1 count += 1
return Response({"count": count}) return Response({"count": count})

View File

@ -94,6 +94,11 @@ class TestAdminAPI(TestCase):
response = self.client.get(reverse("authentik_api:apps-list")) response = self.client.get(reverse("authentik_api:apps-list"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_models(self):
"""Test models API"""
response = self.client.get(reverse("authentik_api:models-list"))
self.assertEqual(response.status_code, 200)
@reconcile_app("authentik_outposts") @reconcile_app("authentik_outposts")
def test_system(self): def test_system(self):
"""Test system API""" """Test system API"""

View File

@ -1,7 +1,7 @@
"""API URLs""" """API URLs"""
from django.urls import path from django.urls import path
from authentik.admin.api.meta import AppsViewSet from authentik.admin.api.meta import AppsViewSet, ModelViewSet
from authentik.admin.api.metrics import AdministrationMetricsViewSet from authentik.admin.api.metrics import AdministrationMetricsViewSet
from authentik.admin.api.system import SystemView from authentik.admin.api.system import SystemView
from authentik.admin.api.tasks import TaskViewSet from authentik.admin.api.tasks import TaskViewSet
@ -11,6 +11,7 @@ from authentik.admin.api.workers import WorkerView
api_urlpatterns = [ api_urlpatterns = [
("admin/system_tasks", TaskViewSet, "admin_system_tasks"), ("admin/system_tasks", TaskViewSet, "admin_system_tasks"),
("admin/apps", AppsViewSet, "apps"), ("admin/apps", AppsViewSet, "apps"),
("admin/models", ModelViewSet, "models"),
path( path(
"admin/metrics/", "admin/metrics/",
AdministrationMetricsViewSet.as_view(), AdministrationMetricsViewSet.as_view(),

View File

@ -1,4 +1,5 @@
"""API Authentication""" """API Authentication"""
from hmac import compare_digest
from typing import Any, Optional from typing import Any, Optional
from django.conf import settings from django.conf import settings
@ -78,7 +79,7 @@ def token_secret_key(value: str) -> Optional[User]:
and return the service account for the managed outpost""" and return the service account for the managed outpost"""
from authentik.outposts.apps import MANAGED_OUTPOST from authentik.outposts.apps import MANAGED_OUTPOST
if value != settings.SECRET_KEY: if not compare_digest(value, settings.SECRET_KEY):
return None return None
outposts = Outpost.objects.filter(managed=MANAGED_OUTPOST) outposts = Outpost.objects.filter(managed=MANAGED_OUTPOST)
if not outposts: if not outposts:

View File

@ -10,8 +10,6 @@ API Browser - {{ tenant.branding_title }}
<script src="{% static 'dist/standalone/api-browser/index.js' %}?version={{ version }}" type="module"></script> <script src="{% static 'dist/standalone/api-browser/index.js' %}?version={{ version }}" type="module"></script>
<meta name="theme-color" content="#151515" media="(prefers-color-scheme: light)"> <meta name="theme-color" content="#151515" media="(prefers-color-scheme: light)">
<meta name="theme-color" content="#151515" media="(prefers-color-scheme: dark)"> <meta name="theme-color" content="#151515" media="(prefers-color-scheme: dark)">
<link rel="icon" href="{{ tenant.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
{% endblock %} {% endblock %}
{% block body %} {% block body %}

View File

@ -82,7 +82,10 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
def retrieve_file(self) -> str: def retrieve_file(self) -> str:
"""Get blueprint from path""" """Get blueprint from path"""
try: try:
full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path)) base = Path(CONFIG.y("blueprints_dir"))
full_path = base.joinpath(Path(self.path)).resolve()
if not str(full_path).startswith(str(base.resolve())):
raise BlueprintRetrievalFailed("Invalid blueprint path")
with full_path.open("r", encoding="utf-8") as _file: with full_path.open("r", encoding="utf-8") as _file:
return _file.read() return _file.read()
except (IOError, OSError) as exc: except (IOError, OSError) as exc:

View File

@ -1,34 +1,15 @@
"""authentik managed models tests""" """authentik managed models tests"""
from typing import Callable, Type
from django.apps import apps
from django.test import TestCase from django.test import TestCase
from authentik.blueprints.v1.importer import is_model_allowed from authentik.blueprints.models import BlueprintInstance, BlueprintRetrievalFailed
from authentik.lib.models import SerializerModel from authentik.lib.generators import generate_id
class TestModels(TestCase): class TestModels(TestCase):
"""Test Models""" """Test Models"""
def test_retrieve_file(self):
def serializer_tester_factory(test_model: Type[SerializerModel]) -> Callable: """Test retrieve_file"""
"""Test serializer""" instance = BlueprintInstance.objects.create(name=generate_id(), path="../etc/hosts")
with self.assertRaises(BlueprintRetrievalFailed):
def tester(self: TestModels): instance.retrieve()
if test_model._meta.abstract: # pragma: no cover
return
model_class = test_model()
self.assertTrue(isinstance(model_class, SerializerModel))
self.assertIsNotNone(model_class.serializer)
return tester
for app in apps.get_app_configs():
if not app.label.startswith("authentik"):
continue
for model in app.get_models():
if not is_model_allowed(model):
continue
setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model))

View File

@ -0,0 +1,34 @@
"""authentik managed models tests"""
from typing import Callable, Type
from django.apps import apps
from django.test import TestCase
from authentik.blueprints.v1.importer import is_model_allowed
from authentik.lib.models import SerializerModel
class TestModels(TestCase):
"""Test Models"""
def serializer_tester_factory(test_model: Type[SerializerModel]) -> Callable:
"""Test serializer"""
def tester(self: TestModels):
if test_model._meta.abstract: # pragma: no cover
return
model_class = test_model()
self.assertTrue(isinstance(model_class, SerializerModel))
self.assertIsNotNone(model_class.serializer)
return tester
for app in apps.get_app_configs():
if not app.label.startswith("authentik"):
continue
for model in app.get_models():
if not is_model_allowed(model):
continue
setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model))

View File

@ -1,5 +1,6 @@
"""Groups API Viewset""" """Groups API Viewset"""
from json import loads from json import loads
from typing import Optional
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.http import Http404 from django.http import Http404
@ -52,6 +53,14 @@ class GroupSerializer(ModelSerializer):
num_pk = IntegerField(read_only=True) num_pk = IntegerField(read_only=True)
def validate_parent(self, parent: Optional[Group]):
"""Validate group parent (if set), ensuring the parent isn't itself"""
if not self.instance or not parent:
return parent
if str(parent.group_uuid) == str(self.instance.group_uuid):
raise ValidationError("Cannot set group as parent of itself.")
return parent
class Meta: class Meta:
model = Group model = Group
fields = [ fields = [

View File

@ -68,11 +68,12 @@ from authentik.core.models import (
TokenIntents, TokenIntents,
User, User,
) )
from authentik.events.models import EventAction from authentik.events.models import Event, EventAction
from authentik.flows.exceptions import FlowNonApplicableException from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import FlowToken from authentik.flows.models import FlowToken
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
from authentik.flows.views.executor import QS_KEY_TOKEN from authentik.flows.views.executor import QS_KEY_TOKEN
from authentik.lib.config import CONFIG
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
from authentik.stages.email.tasks import send_mails from authentik.stages.email.tasks import send_mails
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
@ -568,6 +569,58 @@ class UserViewSet(UsedByMixin, ModelViewSet):
send_mails(email_stage, message) send_mails(email_stage, message)
return Response(status=204) return Response(status=204)
@permission_required("authentik_core.impersonate")
@extend_schema(
request=OpenApiTypes.NONE,
responses={
"204": OpenApiResponse(description="Successfully started impersonation"),
"401": OpenApiResponse(description="Access denied"),
},
)
@action(detail=True, methods=["POST"])
def impersonate(self, request: Request, pk: int) -> Response:
"""Impersonate a user"""
if not CONFIG.y_bool("impersonation"):
LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401)
if not request.user.has_perm("impersonate"):
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return Response(status=401)
user_to_be = self.get_object()
request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER] = request.user
request.session[SESSION_KEY_IMPERSONATE_USER] = user_to_be
Event.new(EventAction.IMPERSONATION_STARTED).from_http(request, user_to_be)
return Response(status=201)
@extend_schema(
request=OpenApiTypes.NONE,
responses={
"204": OpenApiResponse(description="Successfully started impersonation"),
},
)
@action(detail=False, methods=["GET"])
def impersonate_end(self, request: Request) -> Response:
"""End Impersonation a user"""
if (
SESSION_KEY_IMPERSONATE_USER not in request.session
or SESSION_KEY_IMPERSONATE_ORIGINAL_USER not in request.session
):
LOGGER.debug("Can't end impersonation", user=request.user)
return Response(status=204)
original_user = request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]
del request.session[SESSION_KEY_IMPERSONATE_USER]
del request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]
Event.new(EventAction.IMPERSONATION_ENDED).from_http(request, original_user)
return Response(status=204)
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet: def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
"""Custom filter_queryset method which ignores guardian, but still supports sorting""" """Custom filter_queryset method which ignores guardian, but still supports sorting"""
for backend in list(self.filter_backends): for backend in list(self.filter_backends):

View File

@ -0,0 +1,40 @@
"""Run worker"""
from sys import exit as sysexit
from tempfile import tempdir
from celery.apps.worker import Worker
from django.core.management.base import BaseCommand
from django.db import close_old_connections
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
class Command(BaseCommand):
"""Run worker"""
def handle(self, **options):
close_old_connections()
if CONFIG.y_bool("remote_debug"):
import debugpy
debugpy.listen(("0.0.0.0", 6900)) # nosec
worker: Worker = CELERY_APP.Worker(
no_color=False,
quiet=True,
optimization="fair",
max_tasks_per_child=1,
autoscale=(3, 1),
task_events=True,
beat=True,
schedule_filename=f"{tempdir}/celerybeat-schedule",
queues=["authentik", "authentik_scheduled", "authentik_events"],
)
for task in CELERY_APP.tasks:
LOGGER.debug("Registered task", task=task)
worker.start()
sysexit(worker.exitcode)

View File

@ -11,7 +11,7 @@ def backport_is_backchannel(apps: Apps, schema_editor: BaseDatabaseSchemaEditor)
for model in BackchannelProvider.__subclasses__(): for model in BackchannelProvider.__subclasses__():
try: try:
for obj in model.objects.all(): for obj in model.objects.only("is_backchannel"):
obj.is_backchannel = True obj.is_backchannel = True
obj.save() obj.save()
except (DatabaseError, InternalError, ProgrammingError): except (DatabaseError, InternalError, ProgrammingError):

View File

@ -8,7 +8,8 @@
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1"> <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">
<title>{% block title %}{% trans title|default:tenant.branding_title %}{% endblock %}</title> <title>{% block title %}{% trans title|default:tenant.branding_title %}{% endblock %}</title>
<link rel="shortcut icon" type="image/png" href="{% static 'dist/assets/icons/icon.png' %}"> <link rel="icon" href="{{ tenant.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
{% block head_before %} {% block head_before %}
{% endblock %} {% endblock %}
<link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}"> <link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}">

View File

@ -6,8 +6,6 @@
<script src="{% static 'dist/admin/AdminInterface.js' %}?version={{ version }}" type="module"></script> <script src="{% static 'dist/admin/AdminInterface.js' %}?version={{ version }}" type="module"></script>
<meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)"> <meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)">
<meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)"> <meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)">
<link rel="icon" href="{{ tenant.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
{% include "base/header_js.html" %} {% include "base/header_js.html" %}
{% endblock %} {% endblock %}

View File

@ -5,8 +5,6 @@
{% block head_before %} {% block head_before %}
{{ block.super }} {{ block.super }}
<link rel="prefetch" href="{{ flow.background_url }}" /> <link rel="prefetch" href="{{ flow.background_url }}" />
<link rel="icon" href="{{ tenant.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
{% if flow.compatibility_mode and not inspector %} {% if flow.compatibility_mode and not inspector %}
<script>ShadyDOM = { force: !navigator.webdriver };</script> <script>ShadyDOM = { force: !navigator.webdriver };</script>
{% endif %} {% endif %}

View File

@ -6,8 +6,6 @@
<script src="{% static 'dist/user/UserInterface.js' %}?version={{ version }}" type="module"></script> <script src="{% static 'dist/user/UserInterface.js' %}?version={{ version }}" type="module"></script>
<meta name="theme-color" content="#1c1e21" media="(prefers-color-scheme: light)"> <meta name="theme-color" content="#1c1e21" media="(prefers-color-scheme: light)">
<meta name="theme-color" content="#1c1e21" media="(prefers-color-scheme: dark)"> <meta name="theme-color" content="#1c1e21" media="(prefers-color-scheme: dark)">
<link rel="icon" href="{{ tenant.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
{% include "base/header_js.html" %} {% include "base/header_js.html" %}
{% endblock %} {% endblock %}

View File

@ -67,3 +67,16 @@ class TestGroupsAPI(APITestCase):
}, },
) )
self.assertEqual(res.status_code, 404) self.assertEqual(res.status_code, 404)
def test_parent_self(self):
"""Test parent"""
group = Group.objects.create(name=generate_id())
self.client.force_login(self.admin)
res = self.client.patch(
reverse("authentik_api:group-detail", kwargs={"pk": group.pk}),
data={
"pk": self.user.pk + 3,
"parent": group.pk,
},
)
self.assertEqual(res.status_code, 400)

View File

@ -1,14 +1,14 @@
"""impersonation tests""" """impersonation tests"""
from json import loads from json import loads
from django.test.testcases import TestCase
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import User from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
class TestImpersonation(TestCase): class TestImpersonation(APITestCase):
"""impersonation tests""" """impersonation tests"""
def setUp(self) -> None: def setUp(self) -> None:
@ -23,10 +23,10 @@ class TestImpersonation(TestCase):
self.other_user.save() self.other_user.save()
self.client.force_login(self.user) self.client.force_login(self.user)
self.client.get( self.client.post(
reverse( reverse(
"authentik_core:impersonate-init", "authentik_api:user-impersonate",
kwargs={"user_id": self.other_user.pk}, kwargs={"pk": self.other_user.pk},
) )
) )
@ -35,7 +35,7 @@ class TestImpersonation(TestCase):
self.assertEqual(response_body["user"]["username"], self.other_user.username) self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], self.user.username) self.assertEqual(response_body["original"]["username"], self.user.username)
self.client.get(reverse("authentik_core:impersonate-end")) self.client.get(reverse("authentik_api:user-impersonate-end"))
response = self.client.get(reverse("authentik_api:user-me")) response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode()) response_body = loads(response.content.decode())
@ -46,9 +46,7 @@ class TestImpersonation(TestCase):
"""test impersonation without permissions""" """test impersonation without permissions"""
self.client.force_login(self.other_user) self.client.force_login(self.other_user)
self.client.get( self.client.get(reverse("authentik_api:user-impersonate", kwargs={"pk": self.user.pk}))
reverse("authentik_core:impersonate-init", kwargs={"user_id": self.user.pk})
)
response = self.client.get(reverse("authentik_api:user-me")) response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode()) response_body = loads(response.content.decode())
@ -58,5 +56,5 @@ class TestImpersonation(TestCase):
"""test un-impersonation without impersonating first""" """test un-impersonation without impersonating first"""
self.client.force_login(self.other_user) self.client.force_login(self.other_user)
response = self.client.get(reverse("authentik_core:impersonate-end")) response = self.client.get(reverse("authentik_api:user-impersonate-end"))
self.assertRedirects(response, reverse("authentik_core:if-user")) self.assertEqual(response.status_code, 204)

View File

@ -8,7 +8,7 @@ from authentik.core.api.utils import PassiveSerializer
from authentik.flows.challenge import Challenge from authentik.flows.challenge import Challenge
@dataclass @dataclass(slots=True)
class UILoginButton: class UILoginButton:
"""Dataclass for Source's ui_login_button""" """Dataclass for Source's ui_login_button"""

View File

@ -16,7 +16,7 @@ from authentik.core.api.providers import ProviderViewSet
from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSet from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSet
from authentik.core.api.tokens import TokenViewSet from authentik.core.api.tokens import TokenViewSet
from authentik.core.api.users import UserViewSet from authentik.core.api.users import UserViewSet
from authentik.core.views import apps, impersonate from authentik.core.views import apps
from authentik.core.views.debug import AccessDeniedView from authentik.core.views.debug import AccessDeniedView
from authentik.core.views.interface import FlowInterfaceView, InterfaceView from authentik.core.views.interface import FlowInterfaceView, InterfaceView
from authentik.core.views.session import EndSessionView from authentik.core.views.session import EndSessionView
@ -38,17 +38,6 @@ urlpatterns = [
apps.RedirectToAppLaunch.as_view(), apps.RedirectToAppLaunch.as_view(),
name="application-launch", name="application-launch",
), ),
# Impersonation
path(
"-/impersonation/<int:user_id>/",
impersonate.ImpersonateInitView.as_view(),
name="impersonate-init",
),
path(
"-/impersonation/end/",
impersonate.ImpersonateEndView.as_view(),
name="impersonate-end",
),
# Interfaces # Interfaces
path( path(
"if/admin/", "if/admin/",

View File

@ -1,60 +0,0 @@
"""authentik impersonation views"""
from django.http import HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, redirect
from django.views import View
from structlog.stdlib import get_logger
from authentik.core.middleware import (
SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
SESSION_KEY_IMPERSONATE_USER,
)
from authentik.core.models import User
from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG
LOGGER = get_logger()
class ImpersonateInitView(View):
"""Initiate Impersonation"""
def get(self, request: HttpRequest, user_id: int) -> HttpResponse:
"""Impersonation handler, checks permissions"""
if not CONFIG.y_bool("impersonation"):
LOGGER.debug("User attempted to impersonate", user=request.user)
return HttpResponse("Unauthorized", status=401)
if not request.user.has_perm("impersonate"):
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return HttpResponse("Unauthorized", status=401)
user_to_be = get_object_or_404(User, pk=user_id)
request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER] = request.user
request.session[SESSION_KEY_IMPERSONATE_USER] = user_to_be
Event.new(EventAction.IMPERSONATION_STARTED).from_http(request, user_to_be)
return redirect("authentik_core:if-user")
class ImpersonateEndView(View):
"""End User impersonation"""
def get(self, request: HttpRequest) -> HttpResponse:
"""End Impersonation handler"""
if (
SESSION_KEY_IMPERSONATE_USER not in request.session
or SESSION_KEY_IMPERSONATE_ORIGINAL_USER not in request.session
):
LOGGER.debug("Can't end impersonation", user=request.user)
return redirect("authentik_core:if-user")
original_user = request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]
del request.session[SESSION_KEY_IMPERSONATE_USER]
del request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]
Event.new(EventAction.IMPERSONATION_ENDED).from_http(request, original_user)
return redirect("authentik_core:root-redirect")

View File

@ -41,6 +41,7 @@ class TaskResult:
def with_error(self, exc: Exception) -> "TaskResult": def with_error(self, exc: Exception) -> "TaskResult":
"""Since errors might not always be pickle-able, set the traceback""" """Since errors might not always be pickle-able, set the traceback"""
# TODO: Mark exception somehow so that is rendered as <pre> in frontend
self.messages.append(exception_to_string(exc)) self.messages.append(exception_to_string(exc))
return self return self
@ -69,8 +70,10 @@ class TaskInfo:
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")) return cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*"))
@staticmethod @staticmethod
def by_name(name: str) -> Optional["TaskInfo"]: def by_name(name: str) -> Optional["TaskInfo"] | Optional[list["TaskInfo"]]:
"""Get TaskInfo Object by name""" """Get TaskInfo Object by name"""
if "*" in name:
return cache.get_many(cache.keys(CACHE_KEY_PREFIX + name)).values()
return cache.get(CACHE_KEY_PREFIX + name, None) return cache.get(CACHE_KEY_PREFIX + name, None)
def delete(self): def delete(self):

View File

@ -23,7 +23,8 @@ class DiagramElement:
style: list[str] = field(default_factory=lambda: ["[", "]"]) style: list[str] = field(default_factory=lambda: ["[", "]"])
def __str__(self) -> str: def __str__(self) -> str:
element = f'{self.identifier}{self.style[0]}"{self.description}"{self.style[1]}' description = self.description.replace('"', "#quot;")
element = f'{self.identifier}{self.style[0]}"{description}"{self.style[1]}'
if self.action is not None: if self.action is not None:
if self.action != "": if self.action != "":
element = f"--{self.action}--> {element}" element = f"--{self.action}--> {element}"

View File

@ -154,7 +154,7 @@ class AutosubmitChallenge(Challenge):
"""Autosubmit challenge used to send and navigate a POST request""" """Autosubmit challenge used to send and navigate a POST request"""
url = CharField() url = CharField()
attrs = DictField(child=CharField()) attrs = DictField(child=CharField(allow_blank=True), allow_empty=True)
title = CharField(required=False) title = CharField(required=False)
component = CharField(default="ak-stage-autosubmit") component = CharField(default="ak-stage-autosubmit")

View File

@ -30,7 +30,7 @@ class StageMarker:
return binding return binding
@dataclass @dataclass(slots=True)
class ReevaluateMarker(StageMarker): class ReevaluateMarker(StageMarker):
"""Reevaluate Marker, forces stage's policies to be evaluated again.""" """Reevaluate Marker, forces stage's policies to be evaluated again."""

View File

@ -45,7 +45,7 @@ def cache_key(flow: Flow, user: Optional[User] = None) -> str:
return prefix return prefix
@dataclass @dataclass(slots=True)
class FlowPlan: class FlowPlan:
"""This data-class is the output of a FlowPlanner. It holds a flat list """This data-class is the output of a FlowPlanner. It holds a flat list
of all Stages that should be run.""" of all Stages that should be run."""

View File

@ -204,12 +204,12 @@ class ChallengeStageView(StageView):
for field, errors in response.errors.items(): for field, errors in response.errors.items():
for error in errors: for error in errors:
full_errors.setdefault(field, []) full_errors.setdefault(field, [])
full_errors[field].append( field_error = {
{
"string": str(error), "string": str(error),
"code": error.code,
} }
) if hasattr(error, "code"):
field_error["code"] = error.code
full_errors[field].append(field_error)
challenge_response.initial_data["response_errors"] = full_errors challenge_response.initial_data["response_errors"] = full_errors
if not challenge_response.is_valid(): if not challenge_response.is_valid():
self.logger.error( self.logger.error(

View File

@ -0,0 +1,28 @@
"""flow views tests"""
from django.test import TestCase
from authentik.flows.challenge import AutosubmitChallenge, ChallengeTypes
class TestChallenges(TestCase):
"""Test generic challenges"""
def test_autosubmit_blank(self):
"""Test blank autosubmit"""
challenge = AutosubmitChallenge(
data={
"type": ChallengeTypes.NATIVE.value,
"url": "http://localhost",
"attrs": {},
}
)
self.assertTrue(challenge.is_valid(raise_exception=True))
# Test with an empty value
challenge = AutosubmitChallenge(
data={
"type": ChallengeTypes.NATIVE.value,
"url": "http://localhost",
"attrs": {"foo": ""},
}
)
self.assertTrue(challenge.is_valid(raise_exception=True))

View File

@ -26,6 +26,7 @@ redis:
cache_timeout_reputation: 300 cache_timeout_reputation: 300
debug: false debug: false
remote_debug: false
log_level: info log_level: info

View File

@ -28,7 +28,7 @@ class WebsocketMessageInstruction(IntEnum):
TRIGGER_UPDATE = 2 TRIGGER_UPDATE = 2
@dataclass @dataclass(slots=True)
class WebsocketMessage: class WebsocketMessage:
"""Complete Websocket Message that is being sent""" """Complete Websocket Message that is being sent"""

View File

@ -6,7 +6,7 @@ from rest_framework.viewsets import ModelViewSet
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.policies.api.policies import PolicySerializer from authentik.policies.api.policies import PolicySerializer
from authentik.policies.event_matcher.models import EventMatcherPolicy, app_choices from authentik.policies.event_matcher.models import EventMatcherPolicy, app_choices, model_choices
class EventMatcherPolicySerializer(PolicySerializer): class EventMatcherPolicySerializer(PolicySerializer):
@ -15,15 +15,30 @@ class EventMatcherPolicySerializer(PolicySerializer):
app = ChoiceField( app = ChoiceField(
choices=app_choices(), choices=app_choices(),
required=False, required=False,
allow_blank=True, allow_null=True,
help_text=_( help_text=_(
"Match events created by selected application. When left empty, " "Match events created by selected application. When left empty, "
"all applications are matched." "all applications are matched."
), ),
) )
model = ChoiceField(
choices=model_choices(),
required=False,
allow_null=True,
help_text=_(
"Match events created by selected model. "
"When left empty, all models are matched. When an app is selected, "
"all the application's models are matched."
),
)
def validate(self, attrs: dict) -> dict: def validate(self, attrs: dict) -> dict:
if attrs["action"] == "" and attrs["client_ip"] == "" and attrs["app"] == "": if (
attrs["action"] == ""
and attrs["client_ip"] == ""
and attrs["app"] == ""
and attrs["model"] == ""
):
raise ValidationError(_("At least one criteria must be set.")) raise ValidationError(_("At least one criteria must be set."))
return super().validate(attrs) return super().validate(attrs)
@ -33,6 +48,7 @@ class EventMatcherPolicySerializer(PolicySerializer):
"action", "action",
"client_ip", "client_ip",
"app", "app",
"model",
] ]

View File

@ -0,0 +1,21 @@
# Generated by Django 4.1.7 on 2023-05-29 15:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_event_matcher", "0021_alter_eventmatcherpolicy_app"),
]
operations = [
migrations.AddField(
model_name="eventmatcherpolicy",
name="model",
field=models.TextField(
blank=True,
default="",
help_text="Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched.",
),
),
]

View File

@ -0,0 +1,103 @@
# Generated by Django 4.1.7 on 2023-06-21 12:45
from django.apps.registry import Apps
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def replace_defaults(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
eventmatcherpolicy = apps.get_model("authentik_policies_event_matcher", "eventmatcherpolicy")
for policy in eventmatcherpolicy.objects.using(db_alias).all():
changed = False
if policy.action == "":
policy.action = None
changed = True
if policy.app == "":
policy.app = None
changed = True
if policy.client_ip == "":
policy.client_ip = None
changed = True
if policy.model == "":
policy.model = None
changed = True
if not changed:
continue
policy.save()
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_event_matcher", "0022_eventmatcherpolicy_model"),
]
operations = [
migrations.AlterField(
model_name="eventmatcherpolicy",
name="action",
field=models.TextField(
choices=[
("login", "Login"),
("login_failed", "Login Failed"),
("logout", "Logout"),
("user_write", "User Write"),
("suspicious_request", "Suspicious Request"),
("password_set", "Password Set"),
("secret_view", "Secret View"),
("secret_rotate", "Secret Rotate"),
("invitation_used", "Invite Used"),
("authorize_application", "Authorize Application"),
("source_linked", "Source Linked"),
("impersonation_started", "Impersonation Started"),
("impersonation_ended", "Impersonation Ended"),
("flow_execution", "Flow Execution"),
("policy_execution", "Policy Execution"),
("policy_exception", "Policy Exception"),
("property_mapping_exception", "Property Mapping Exception"),
("system_task_execution", "System Task Execution"),
("system_task_exception", "System Task Exception"),
("system_exception", "System Exception"),
("configuration_error", "Configuration Error"),
("model_created", "Model Created"),
("model_updated", "Model Updated"),
("model_deleted", "Model Deleted"),
("email_sent", "Email Sent"),
("update_available", "Update Available"),
("custom_", "Custom Prefix"),
],
default=None,
help_text="Match created events with this action type. When left empty, all action types will be matched.",
null=True,
),
),
migrations.AlterField(
model_name="eventmatcherpolicy",
name="app",
field=models.TextField(
default=None,
help_text="Match events created by selected application. When left empty, all applications are matched.",
null=True,
),
),
migrations.AlterField(
model_name="eventmatcherpolicy",
name="client_ip",
field=models.TextField(
default=None,
help_text="Matches Event's Client IP (strict matching, for network matching use an Expression Policy)",
null=True,
),
),
migrations.AlterField(
model_name="eventmatcherpolicy",
name="model",
field=models.TextField(
default=None,
help_text="Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched.",
null=True,
),
),
migrations.RunPython(replace_defaults),
]

View File

@ -1,13 +1,19 @@
"""Event Matcher models""" """Event Matcher models"""
from itertools import chain
from django.apps import apps from django.apps import apps
from django.db import models from django.db import models
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from structlog.stdlib import get_logger
from authentik.blueprints.v1.importer import is_model_allowed
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.policies.models import Policy from authentik.policies.models import Policy
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger()
def app_choices() -> list[tuple[str, str]]: def app_choices() -> list[tuple[str, str]]:
"""Get a list of all installed applications that create events. """Get a list of all installed applications that create events.
@ -19,27 +25,50 @@ def app_choices() -> list[tuple[str, str]]:
return choices return choices
def model_choices() -> list[tuple[str, str]]:
"""Get a list of all installed models
Returns a list of tuples containing (dotted.model.path, name)"""
choices = []
for model in apps.get_models():
if not is_model_allowed(model):
continue
name = f"{model._meta.app_label}.{model._meta.model_name}"
choices.append((name, model._meta.verbose_name))
return choices
class EventMatcherPolicy(Policy): class EventMatcherPolicy(Policy):
"""Passes when Event matches selected criteria.""" """Passes when Event matches selected criteria."""
action = models.TextField( action = models.TextField(
choices=EventAction.choices, choices=EventAction.choices,
blank=True, null=True,
default=None,
help_text=_( help_text=_(
"Match created events with this action type. " "Match created events with this action type. "
"When left empty, all action types will be matched." "When left empty, all action types will be matched."
), ),
) )
app = models.TextField( app = models.TextField(
blank=True, null=True,
default="", default=None,
help_text=_( help_text=_(
"Match events created by selected application. " "Match events created by selected application. "
"When left empty, all applications are matched." "When left empty, all applications are matched."
), ),
) )
model = models.TextField(
null=True,
default=None,
help_text=_(
"Match events created by selected model. "
"When left empty, all models are matched. When an app is selected, "
"all the application's models are matched."
),
)
client_ip = models.TextField( client_ip = models.TextField(
blank=True, null=True,
default=None,
help_text=_( help_text=_(
"Matches Event's Client IP (strict matching, " "Matches Event's Client IP (strict matching, "
"for network matching use an Expression Policy)" "for network matching use an Expression Policy)"
@ -60,13 +89,55 @@ class EventMatcherPolicy(Policy):
if "event" not in request.context: if "event" not in request.context:
return PolicyResult(False) return PolicyResult(False)
event: Event = request.context["event"] event: Event = request.context["event"]
if event.action == self.action: matches: list[PolicyResult] = []
return PolicyResult(True, "Action matched.") messages = []
if event.client_ip == self.client_ip: checks = [
return PolicyResult(True, "Client IP matched.") self.passes_action,
if event.app == self.app: self.passes_client_ip,
return PolicyResult(True, "App matched.") self.passes_app,
return PolicyResult(False) self.passes_model,
]
for checker in checks:
result = checker(request, event)
if result is None:
continue
LOGGER.info(
"Event matcher check result",
checker=checker.__name__,
result=result,
)
matches.append(result)
passing = any(x.passing for x in matches)
messages = chain(*[x.messages for x in matches])
result = PolicyResult(passing, *messages)
result.source_results = matches
return result
def passes_action(self, request: PolicyRequest, event: Event) -> PolicyResult | None:
"""Check if `self.action` matches"""
if self.action is None:
return None
return PolicyResult(self.action == event.action, "Action matched.")
def passes_client_ip(self, request: PolicyRequest, event: Event) -> PolicyResult | None:
"""Check if `self.client_ip` matches"""
if self.client_ip is None:
return None
return PolicyResult(self.client_ip == event.client_ip, "Client IP matched.")
def passes_app(self, request: PolicyRequest, event: Event) -> PolicyResult | None:
"""Check if `self.app` matches"""
if self.app is None:
return None
return PolicyResult(self.app == event.app, "App matched.")
def passes_model(self, request: PolicyRequest, event: Event) -> PolicyResult | None:
"""Check if `self.model` is set, and pass if it matches the event's model"""
if self.model is None:
return None
event_model_info = event.context.get("model", {})
event_model = f"{event_model_info.get('app')}.{event_model_info.get('model_name')}"
return PolicyResult(event_model == self.model, "Model matched.")
class Meta(Policy.PolicyMeta): class Meta(Policy.PolicyMeta):
verbose_name = _("Event Matcher Policy") verbose_name = _("Event Matcher Policy")

View File

@ -42,6 +42,22 @@ class TestEventMatcherPolicy(TestCase):
self.assertTrue(response.passing) self.assertTrue(response.passing)
self.assertTupleEqual(response.messages, ("App matched.",)) self.assertTupleEqual(response.messages, ("App matched.",))
def test_match_model(self):
"""Test match model"""
event = Event.new(EventAction.LOGIN)
event.context = {
"model": {
"app": "foo",
"model_name": "bar",
}
}
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(model="foo.bar")
response = policy.passes(request)
self.assertTrue(response.passing)
self.assertTupleEqual(response.messages, ("Model matched.",))
def test_drop(self): def test_drop(self):
"""Test drop event""" """Test drop event"""
event = Event.new(EventAction.LOGIN) event = Event.new(EventAction.LOGIN)
@ -52,6 +68,19 @@ class TestEventMatcherPolicy(TestCase):
response = policy.passes(request) response = policy.passes(request)
self.assertFalse(response.passing) self.assertFalse(response.passing)
def test_drop_multiple(self):
"""Test drop event"""
event = Event.new(EventAction.LOGIN)
event.app = "foo"
event.client_ip = "1.2.3.4"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.5", app="bar"
)
response = policy.passes(request)
self.assertFalse(response.passing)
def test_invalid(self): def test_invalid(self):
"""Test passing event""" """Test passing event"""
request = PolicyRequest(get_anonymous_user()) request = PolicyRequest(get_anonymous_user())

View File

@ -132,9 +132,9 @@ class TestPolicyProcess(TestCase):
) )
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
http_request = self.factory.get(reverse("authentik_core:impersonate-end")) http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
http_request.user = self.user http_request.user = self.user
http_request.resolver_match = resolve(reverse("authentik_core:impersonate-end")) http_request.resolver_match = resolve(reverse("authentik_api:user-impersonate-end"))
request = PolicyRequest(self.user) request = PolicyRequest(self.user)
request.set_http_request(http_request) request.set_http_request(http_request)

View File

@ -19,7 +19,7 @@ LOGGER = get_logger()
CACHE_PREFIX = "goauthentik.io/policies/" CACHE_PREFIX = "goauthentik.io/policies/"
@dataclass @dataclass(slots=True)
class PolicyRequest: class PolicyRequest:
"""Data-class to hold policy request data""" """Data-class to hold policy request data"""
@ -27,14 +27,14 @@ class PolicyRequest:
http_request: Optional[HttpRequest] http_request: Optional[HttpRequest]
obj: Optional[Model] obj: Optional[Model]
context: dict[str, Any] context: dict[str, Any]
debug: bool = False debug: bool
def __init__(self, user: User): def __init__(self, user: User):
super().__init__()
self.user = user self.user = user
self.http_request = None self.http_request = None
self.obj = None self.obj = None
self.context = {} self.context = {}
self.debug = False
def set_http_request(self, request: HttpRequest): # pragma: no cover def set_http_request(self, request: HttpRequest): # pragma: no cover
"""Load data from HTTP request, including geoip when enabled""" """Load data from HTTP request, including geoip when enabled"""
@ -67,7 +67,7 @@ class PolicyRequest:
return text + ">" return text + ">"
@dataclass @dataclass(slots=True)
class PolicyResult: class PolicyResult:
"""Result from evaluating a policy.""" """Result from evaluating a policy."""
@ -81,7 +81,6 @@ class PolicyResult:
log_messages: Optional[list[dict]] log_messages: Optional[list[dict]]
def __init__(self, passing: bool, *messages: str): def __init__(self, passing: bool, *messages: str):
super().__init__()
self.passing = passing self.passing = passing
self.messages = messages self.messages = messages
self.raw_result = None self.raw_result = None

View File

@ -29,6 +29,7 @@ class LDAPProviderSerializer(ProviderSerializer):
"outpost_set", "outpost_set",
"search_mode", "search_mode",
"bind_mode", "bind_mode",
"mfa_support",
] ]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs extra_kwargs = ProviderSerializer.Meta.extra_kwargs
@ -99,6 +100,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
"gid_start_number", "gid_start_number",
"search_mode", "search_mode",
"bind_mode", "bind_mode",
"mfa_support",
] ]

View File

@ -0,0 +1,37 @@
# Generated by Django 4.1.7 on 2023-06-19 17:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_ldap", "0002_ldapprovider_bind_mode"),
]
operations = [
migrations.AddField(
model_name="ldapprovider",
name="mfa_support",
field=models.BooleanField(
default=True,
help_text="When enabled, code-based multi-factor authentication can be used by appending a semicolon and the TOTP code to the password. This should only be enabled if all users that will bind to this provider have a TOTP device configured, as otherwise a password may incorrectly be rejected if it contains a semicolon.",
verbose_name="MFA Support",
),
),
migrations.AlterField(
model_name="ldapprovider",
name="gid_start_number",
field=models.IntegerField(
default=4000,
help_text="The start for gidNumbers, this number is added to a number generated from the group.pk to make sure that the numbers aren't too low for POSIX groups. Default is 4000 to ensure that we don't collide with local groups or users primary groups gidNumber",
),
),
migrations.AlterField(
model_name="ldapprovider",
name="uid_start_number",
field=models.IntegerField(
default=2000,
help_text="The start for uidNumbers, this number is added to the user.pk to make sure that the numbers aren't too low for POSIX users. Default is 2000 to ensure that we don't collide with local users uidNumber",
),
),
]

View File

@ -50,7 +50,7 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
uid_start_number = models.IntegerField( uid_start_number = models.IntegerField(
default=2000, default=2000,
help_text=_( help_text=_(
"The start for uidNumbers, this number is added to the user.Pk to make sure that the " "The start for uidNumbers, this number is added to the user.pk to make sure that the "
"numbers aren't too low for POSIX users. Default is 2000 to ensure that we don't " "numbers aren't too low for POSIX users. Default is 2000 to ensure that we don't "
"collide with local users uidNumber" "collide with local users uidNumber"
), ),
@ -60,7 +60,7 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
default=4000, default=4000,
help_text=_( help_text=_(
"The start for gidNumbers, this number is added to a number generated from the " "The start for gidNumbers, this number is added to a number generated from the "
"group.Pk to make sure that the numbers aren't too low for POSIX groups. Default " "group.pk to make sure that the numbers aren't too low for POSIX groups. Default "
"is 4000 to ensure that we don't collide with local groups or users " "is 4000 to ensure that we don't collide with local groups or users "
"primary groups gidNumber" "primary groups gidNumber"
), ),
@ -69,6 +69,17 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
bind_mode = models.TextField(default=APIAccessMode.DIRECT, choices=APIAccessMode.choices) bind_mode = models.TextField(default=APIAccessMode.DIRECT, choices=APIAccessMode.choices)
search_mode = models.TextField(default=APIAccessMode.DIRECT, choices=APIAccessMode.choices) search_mode = models.TextField(default=APIAccessMode.DIRECT, choices=APIAccessMode.choices)
mfa_support = models.BooleanField(
default=True,
verbose_name="MFA Support",
help_text=_(
"When enabled, code-based multi-factor authentication can be used by appending a "
"semicolon and the TOTP code to the password. This should only be enabled if all "
"users that will bind to this provider have a TOTP device configured, as otherwise "
"a password may incorrectly be rejected if it contains a semicolon."
),
)
@property @property
def launch_url(self) -> Optional[str]: def launch_url(self) -> Optional[str]:
"""LDAP never has a launch URL""" """LDAP never has a launch URL"""

View File

@ -19,6 +19,11 @@ SCOPE_OPENID = "openid"
SCOPE_OPENID_PROFILE = "profile" SCOPE_OPENID_PROFILE = "profile"
SCOPE_OPENID_EMAIL = "email" SCOPE_OPENID_EMAIL = "email"
# https://www.iana.org/assignments/oauth-parameters/\
# oauth-parameters.xhtml#pkce-code-challenge-method
PKCE_METHOD_PLAIN = "plain"
PKCE_METHOD_S256 = "S256"
TOKEN_TYPE = "Bearer" # nosec TOKEN_TYPE = "Bearer" # nosec
SCOPE_AUTHENTIK_API = "goauthentik.io/api" SCOPE_AUTHENTIK_API = "goauthentik.io/api"

View File

@ -41,7 +41,7 @@ class SubModes(models.TextChoices):
) )
@dataclass @dataclass(slots=True)
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
class IDToken: class IDToken:
"""The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be """The primary extension that OpenID Connect makes to OAuth 2.0 to enable End-Users to be

View File

@ -35,6 +35,8 @@ from authentik.lib.views import bad_request_message
from authentik.policies.types import PolicyRequest from authentik.policies.types import PolicyRequest
from authentik.policies.views import PolicyAccessView, RequestValidationError from authentik.policies.views import PolicyAccessView, RequestValidationError
from authentik.providers.oauth2.constants import ( from authentik.providers.oauth2.constants import (
PKCE_METHOD_PLAIN,
PKCE_METHOD_S256,
PROMPT_CONSENT, PROMPT_CONSENT,
PROMPT_LOGIN, PROMPT_LOGIN,
PROMPT_NONE, PROMPT_NONE,
@ -74,7 +76,7 @@ SESSION_KEY_LAST_LOGIN_UID = "authentik/providers/oauth2/last_login_uid"
ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN} ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN}
@dataclass @dataclass(slots=True)
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
class OAuthAuthorizationParams: class OAuthAuthorizationParams:
"""Parameters required to authorize an OAuth Client""" """Parameters required to authorize an OAuth Client"""
@ -254,7 +256,10 @@ class OAuthAuthorizationParams:
def check_code_challenge(self): def check_code_challenge(self):
"""PKCE validation of the transformation method.""" """PKCE validation of the transformation method."""
if self.code_challenge and self.code_challenge_method not in ["plain", "S256"]: if self.code_challenge and self.code_challenge_method not in [
PKCE_METHOD_PLAIN,
PKCE_METHOD_S256,
]:
raise AuthorizeError( raise AuthorizeError(
self.redirect_uri, self.redirect_uri,
"invalid_request", "invalid_request",

View File

@ -14,7 +14,7 @@ from authentik.providers.oauth2.utils import TokenResponse, authenticate_provide
LOGGER = get_logger() LOGGER = get_logger()
@dataclass @dataclass(slots=True)
class TokenIntrospectionParams: class TokenIntrospectionParams:
"""Parameters for Token Introspection""" """Parameters for Token Introspection"""

View File

@ -17,6 +17,8 @@ from authentik.providers.oauth2.constants import (
GRANT_TYPE_IMPLICIT, GRANT_TYPE_IMPLICIT,
GRANT_TYPE_PASSWORD, GRANT_TYPE_PASSWORD,
GRANT_TYPE_REFRESH_TOKEN, GRANT_TYPE_REFRESH_TOKEN,
PKCE_METHOD_PLAIN,
PKCE_METHOD_S256,
SCOPE_OPENID, SCOPE_OPENID,
) )
from authentik.providers.oauth2.models import ( from authentik.providers.oauth2.models import (
@ -109,6 +111,7 @@ class ProviderInfoView(View):
"request_parameter_supported": False, "request_parameter_supported": False,
"claims_supported": self.get_claims(provider), "claims_supported": self.get_claims(provider),
"claims_parameter_supported": False, "claims_parameter_supported": False,
"code_challenge_methods_supported": [PKCE_METHOD_PLAIN, PKCE_METHOD_S256],
} }
def get_claims(self, provider: OAuth2Provider) -> list[str]: def get_claims(self, provider: OAuth2Provider) -> list[str]:

View File

@ -39,6 +39,7 @@ from authentik.providers.oauth2.constants import (
GRANT_TYPE_DEVICE_CODE, GRANT_TYPE_DEVICE_CODE,
GRANT_TYPE_PASSWORD, GRANT_TYPE_PASSWORD,
GRANT_TYPE_REFRESH_TOKEN, GRANT_TYPE_REFRESH_TOKEN,
PKCE_METHOD_S256,
TOKEN_TYPE, TOKEN_TYPE,
) )
from authentik.providers.oauth2.errors import DeviceCodeError, TokenError, UserAuthError from authentik.providers.oauth2.errors import DeviceCodeError, TokenError, UserAuthError
@ -58,7 +59,7 @@ from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_ME
LOGGER = get_logger() LOGGER = get_logger()
@dataclass @dataclass(slots=True)
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
class TokenParams: class TokenParams:
"""Token params""" """Token params"""
@ -221,7 +222,7 @@ class TokenParams:
# Validate PKCE parameters. # Validate PKCE parameters.
if self.code_verifier: if self.code_verifier:
if self.authorization_code.code_challenge_method == "S256": if self.authorization_code.code_challenge_method == PKCE_METHOD_S256:
new_code_challenge = ( new_code_challenge = (
urlsafe_b64encode(sha256(self.code_verifier.encode("ascii")).digest()) urlsafe_b64encode(sha256(self.code_verifier.encode("ascii")).digest())
.decode("utf-8") .decode("utf-8")

View File

@ -14,7 +14,7 @@ from authentik.providers.oauth2.utils import TokenResponse, authenticate_provide
LOGGER = get_logger() LOGGER = get_logger()
@dataclass @dataclass(slots=True)
class TokenRevocationParams: class TokenRevocationParams:
"""Parameters for Token Revocation""" """Parameters for Token Revocation"""

View File

@ -31,7 +31,7 @@ ERROR_SIGNATURE_REQUIRED_BUT_ABSENT = (
ERROR_FAILED_TO_VERIFY = "Failed to verify signature" ERROR_FAILED_TO_VERIFY = "Failed to verify signature"
@dataclass @dataclass(slots=True)
class AuthNRequest: class AuthNRequest:
"""AuthNRequest Dataclass""" """AuthNRequest Dataclass"""

View File

@ -12,7 +12,7 @@ from authentik.providers.saml.utils.encoding import decode_base64_and_inflate
from authentik.sources.saml.processors.constants import NS_SAML_PROTOCOL from authentik.sources.saml.processors.constants import NS_SAML_PROTOCOL
@dataclass @dataclass(slots=True)
class LogoutRequest: class LogoutRequest:
"""Logout Request""" """Logout Request"""

View File

@ -35,7 +35,7 @@ def format_pem_certificate(unformatted_cert: str) -> str:
return "\n".join(lines) return "\n".join(lines)
@dataclass @dataclass(slots=True)
class ServiceProviderMetadata: class ServiceProviderMetadata:
"""SP Metadata Dataclass""" """SP Metadata Dataclass"""

View File

@ -130,11 +130,7 @@ class LivenessProbe(bootsteps.StartStopStep):
HEARTBEAT_FILE.touch() HEARTBEAT_FILE.touch()
# Using a string here means the worker doesn't have to serialize CELERY_APP.config_from_object(settings.CELERY)
# the configuration object to child processes.
# - namespace='CELERY' means all celery-related configuration keys
# should have a `CELERY_` prefix.
CELERY_APP.config_from_object(settings, namespace="CELERY")
# Load task modules from all registered Django app configs. # Load task modules from all registered Django app configs.
CELERY_APP.autodiscover_tasks() CELERY_APP.autodiscover_tasks()

View File

@ -182,13 +182,13 @@ REST_FRAMEWORK = {
}, },
} }
REDIS_PROTOCOL_PREFIX = "redis://" _redis_protocol_prefix = "redis://"
REDIS_CELERY_TLS_REQUIREMENTS = "" _redis_celery_tls_requirements = ""
if CONFIG.y_bool("redis.tls", False): if CONFIG.y_bool("redis.tls", False):
REDIS_PROTOCOL_PREFIX = "rediss://" _redis_protocol_prefix = "rediss://"
REDIS_CELERY_TLS_REQUIREMENTS = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}" _redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}"
_redis_url = ( _redis_url = (
f"{REDIS_PROTOCOL_PREFIX}:" f"{_redis_protocol_prefix}:"
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:" f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
f"{int(CONFIG.y('redis.port'))}" f"{int(CONFIG.y('redis.port'))}"
) )
@ -326,12 +326,11 @@ USE_TZ = True
LOCALE_PATHS = ["./locale"] LOCALE_PATHS = ["./locale"]
# Celery settings CELERY = {
# Add a 10 minute timeout to all Celery tasks. "task_soft_time_limit": 600,
CELERY_TASK_SOFT_TIME_LIMIT = 600 "worker_max_tasks_per_child": 50,
CELERY_WORKER_MAX_TASKS_PER_CHILD = 50 "worker_concurrency": 2,
CELERY_WORKER_CONCURRENCY = 2 "beat_schedule": {
CELERY_BEAT_SCHEDULE = {
"clean_expired_models": { "clean_expired_models": {
"task": "authentik.core.tasks.clean_expired_models", "task": "authentik.core.tasks.clean_expired_models",
"schedule": crontab(minute="2-59/5"), "schedule": crontab(minute="2-59/5"),
@ -342,11 +341,12 @@ CELERY_BEAT_SCHEDULE = {
"schedule": crontab(minute="9-59/5"), "schedule": crontab(minute="9-59/5"),
"options": {"queue": "authentik_scheduled"}, "options": {"queue": "authentik_scheduled"},
}, },
},
"task_create_missing_queues": True,
"task_default_queue": "authentik",
"broker_url": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
"result_backend": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
} }
CELERY_TASK_CREATE_MISSING_QUEUES = True
CELERY_TASK_DEFAULT_QUEUE = "authentik"
CELERY_BROKER_URL = f"{_redis_url}/{CONFIG.y('redis.db')}{REDIS_CELERY_TLS_REQUIREMENTS}"
CELERY_RESULT_BACKEND = f"{_redis_url}/{CONFIG.y('redis.db')}{REDIS_CELERY_TLS_REQUIREMENTS}"
# Sentry integration # Sentry integration
env = get_env() env = get_env()
@ -455,7 +455,7 @@ _DISALLOWED_ITEMS = [
"INSTALLED_APPS", "INSTALLED_APPS",
"MIDDLEWARE", "MIDDLEWARE",
"AUTHENTICATION_BACKENDS", "AUTHENTICATION_BACKENDS",
"CELERY_BEAT_SCHEDULE", "CELERY",
] ]
@ -466,7 +466,7 @@ def _update_settings(app_path: str):
INSTALLED_APPS.extend(getattr(settings_module, "INSTALLED_APPS", [])) INSTALLED_APPS.extend(getattr(settings_module, "INSTALLED_APPS", []))
MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
CELERY_BEAT_SCHEDULE.update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
for _attr in dir(settings_module): for _attr in dir(settings_module):
if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:
globals()[_attr] = getattr(settings_module, _attr) globals()[_attr] = getattr(settings_module, _attr)
@ -482,7 +482,7 @@ for _app in INSTALLED_APPS:
_update_settings("data.user_settings") _update_settings("data.user_settings")
if DEBUG: if DEBUG:
CELERY_TASK_ALWAYS_EAGER = True CELERY["task_always_eager"] = True
os.environ[ENV_GIT_HASH_KEY] = "dev" os.environ[ENV_GIT_HASH_KEY] = "dev"
INSTALLED_APPS.append("silk") INSTALLED_APPS.append("silk")
SILKY_PYTHON_PROFILER = True SILKY_PYTHON_PROFILER = True

View File

@ -30,7 +30,7 @@ class PytestTestRunner: # pragma: no cover
self.args.append(f"--randomly-seed={kwargs['randomly_seed']}") self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
settings.TEST = True settings.TEST = True
settings.CELERY_TASK_ALWAYS_EAGER = True settings.CELERY["task_always_eager"] = True
CONFIG.y_set("avatars", "none") CONFIG.y_set("avatars", "none")
CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb") CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb")
CONFIG.y_set("blueprints_dir", "./blueprints") CONFIG.y_set("blueprints_dir", "./blueprints")

View File

@ -118,10 +118,9 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
"""Get source's sync status""" """Get source's sync status"""
source = self.get_object() source = self.get_object()
results = [] results = []
for sync_class in SYNC_CLASSES: tasks = TaskInfo.by_name(f"ldap_sync:{source.slug}:*")
sync_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower() if tasks:
task = TaskInfo.by_name(f"ldap_sync:{source.slug}:{sync_name}") for task in tasks:
if task:
results.append(task) results.append(task)
return Response(TaskSerializer(results, many=True).data) return Response(TaskSerializer(results, many=True).data)
@ -143,7 +142,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
source = self.get_object() source = self.get_object()
all_objects = {} all_objects = {}
for sync_class in SYNC_CLASSES: for sync_class in SYNC_CLASSES:
class_name = sync_class.__name__.replace("LDAPSynchronizer", "").lower() class_name = sync_class.name()
all_objects.setdefault(class_name, []) all_objects.setdefault(class_name, [])
for obj in sync_class(source).get_objects(size_limit=10): for obj in sync_class(source).get_objects(size_limit=10):
obj: dict obj: dict

View File

@ -2,9 +2,8 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.lib.utils.reflection import class_to_path
from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.models import LDAPSource
from authentik.sources.ldap.tasks import SYNC_CLASSES, ldap_sync from authentik.sources.ldap.tasks import ldap_sync_single
LOGGER = get_logger() LOGGER = get_logger()
@ -21,7 +20,4 @@ class Command(BaseCommand):
if not source: if not source:
LOGGER.warning("Source does not exist", slug=source_slug) LOGGER.warning("Source does not exist", slug=source_slug)
continue continue
for sync_class in SYNC_CLASSES: ldap_sync_single(source)
LOGGER.info("Starting sync", cls=sync_class)
# pylint: disable=no-value-for-parameter
ldap_sync(source.pk, class_to_path(sync_class))

View File

@ -151,7 +151,7 @@ class LDAPSource(Source):
servers.append(Server(server, **server_kwargs)) servers.append(Server(server, **server_kwargs))
else: else:
servers = [Server(self.server_uri, **server_kwargs)] servers = [Server(self.server_uri, **server_kwargs)]
return ServerPool(servers, RANDOM, active=True, exhaust=True) return ServerPool(servers, RANDOM, active=5, exhaust=True)
def connection( def connection(
self, server_kwargs: Optional[dict] = None, connection_kwargs: Optional[dict] = None self, server_kwargs: Optional[dict] = None, connection_kwargs: Optional[dict] = None

View File

@ -4,7 +4,7 @@ from re import split
from typing import Optional from typing import Optional
from ldap3 import BASE from ldap3 import BASE
from ldap3.core.exceptions import LDAPAttributeError from ldap3.core.exceptions import LDAPAttributeError, LDAPUnwillingToPerformResult
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.models import User from authentik.core.models import User
@ -69,7 +69,7 @@ class LDAPPasswordChanger:
attributes=["pwdProperties"], attributes=["pwdProperties"],
) )
root_attrs = list(root_attrs)[0] root_attrs = list(root_attrs)[0]
except (LDAPAttributeError, KeyError, IndexError): except (LDAPAttributeError, LDAPUnwillingToPerformResult, KeyError, IndexError):
return False return False
raw_pwd_properties = root_attrs.get("attributes", {}).get("pwdProperties", None) raw_pwd_properties = root_attrs.get("attributes", {}).get("pwdProperties", None)
if not raw_pwd_properties: if not raw_pwd_properties:
@ -92,7 +92,7 @@ class LDAPPasswordChanger:
return return
try: try:
self._connection.extend.microsoft.modify_password(user_dn, password) self._connection.extend.microsoft.modify_password(user_dn, password)
except LDAPAttributeError: except (LDAPAttributeError, LDAPUnwillingToPerformResult):
self._connection.extend.standard.modify_password(user_dn, new_password=password) self._connection.extend.standard.modify_password(user_dn, new_password=password)
def _ad_check_password_existing(self, password: str, user_dn: str) -> bool: def _ad_check_password_existing(self, password: str, user_dn: str) -> bool:

View File

@ -12,13 +12,9 @@ from authentik.core.models import User
from authentik.core.signals import password_changed from authentik.core.signals import password_changed
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.lib.utils.reflection import class_to_path
from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.models import LDAPSource
from authentik.sources.ldap.password import LDAPPasswordChanger from authentik.sources.ldap.password import LDAPPasswordChanger
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.tasks import ldap_sync_single
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tasks import ldap_sync
from authentik.stages.prompt.signals import password_validate from authentik.stages.prompt.signals import password_validate
LOGGER = get_logger() LOGGER = get_logger()
@ -35,12 +31,7 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
# and the mappings are created with an m2m event # and the mappings are created with an m2m event
if not instance.property_mappings.exists() or not instance.property_mappings_group.exists(): if not instance.property_mappings.exists() or not instance.property_mappings_group.exists():
return return
for sync_class in [ ldap_sync_single.delay(instance.pk)
UserLDAPSynchronizer,
GroupLDAPSynchronizer,
MembershipLDAPSynchronizer,
]:
ldap_sync.delay(instance.pk, class_to_path(sync_class))
@receiver(password_validate) @receiver(password_validate)
@ -66,8 +57,8 @@ def ldap_sync_password(sender, user: User, password: str, **_):
if not sources.exists(): if not sources.exists():
return return
source = sources.first() source = sources.first()
changer = LDAPPasswordChanger(source)
try: try:
changer = LDAPPasswordChanger(source)
changer.change_password(user, password) changer.change_password(user, password)
except LDAPOperationResult as exc: except LDAPOperationResult as exc:
LOGGER.warning("failed to set LDAP password", exc=exc) LOGGER.warning("failed to set LDAP password", exc=exc)

View File

@ -1,9 +1,10 @@
"""Sync LDAP Users and groups into authentik""" """Sync LDAP Users and groups into authentik"""
from typing import Any, Generator from typing import Any, Generator
from django.conf import settings
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from ldap3 import Connection from ldap3 import DEREF_ALWAYS, SUBTREE, Connection
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.core.exceptions import PropertyMappingExpressionException from authentik.core.exceptions import PropertyMappingExpressionException
@ -29,6 +30,24 @@ class BaseLDAPSynchronizer:
self._messages = [] self._messages = []
self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__) self._logger = get_logger().bind(source=source, syncer=self.__class__.__name__)
@staticmethod
def name() -> str:
"""UI name for the type of object this class synchronizes"""
raise NotImplementedError
def sync_full(self):
"""Run full sync, this function should only be used in tests"""
if not settings.TEST: # noqa
raise RuntimeError(
f"{self.__class__.__name__}.sync_full() should only be used in tests"
)
for page in self.get_objects():
self.sync(page)
def sync(self, page_data: list) -> int:
"""Sync function, implemented in subclass"""
raise NotImplementedError()
@property @property
def messages(self) -> list[str]: def messages(self) -> list[str]:
"""Get all UI messages""" """Get all UI messages"""
@ -60,9 +79,47 @@ class BaseLDAPSynchronizer:
"""Get objects from LDAP, implemented in subclass""" """Get objects from LDAP, implemented in subclass"""
raise NotImplementedError() raise NotImplementedError()
def sync(self) -> int: # pylint: disable=too-many-arguments
"""Sync function, implemented in subclass""" def search_paginator(
raise NotImplementedError() self,
search_base,
search_filter,
search_scope=SUBTREE,
dereference_aliases=DEREF_ALWAYS,
attributes=None,
size_limit=0,
time_limit=0,
types_only=False,
get_operational_attributes=False,
controls=None,
paged_size=5,
paged_criticality=False,
):
"""Search in pages, returns each page"""
cookie = True
while cookie:
self._connection.search(
search_base,
search_filter,
search_scope,
dereference_aliases,
attributes,
size_limit,
time_limit,
types_only,
get_operational_attributes,
controls,
paged_size,
paged_criticality,
None if cookie is True else cookie,
)
try:
cookie = self._connection.result["controls"]["1.2.840.113556.1.4.319"]["value"][
"cookie"
]
except KeyError:
cookie = None
yield self._connection.response
def _flatten(self, value: Any) -> Any: def _flatten(self, value: Any) -> Any:
"""Flatten `value` if its a list""" """Flatten `value` if its a list"""

View File

@ -13,8 +13,12 @@ from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchroniz
class GroupLDAPSynchronizer(BaseLDAPSynchronizer): class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
"""Sync LDAP Users and groups into authentik""" """Sync LDAP Users and groups into authentik"""
@staticmethod
def name() -> str:
return "groups"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
return self._connection.extend.standard.paged_search( return self.search_paginator(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,
search_scope=SUBTREE, search_scope=SUBTREE,
@ -22,13 +26,13 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
**kwargs, **kwargs,
) )
def sync(self) -> int: def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Groups and create authentik_core.Group instances""" """Iterate over all LDAP Groups and create authentik_core.Group instances"""
if not self._source.sync_groups: if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source") self.message("Group syncing is disabled for this Source")
return -1 return -1
group_count = 0 group_count = 0
for group in self.get_objects(): for group in page_data:
if "attributes" not in group: if "attributes" not in group:
continue continue
attributes = group.get("attributes", {}) attributes = group.get("attributes", {})

View File

@ -19,8 +19,12 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
super().__init__(source) super().__init__(source)
self.group_cache: dict[str, Group] = {} self.group_cache: dict[str, Group] = {}
@staticmethod
def name() -> str:
return "membership"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
return self._connection.extend.standard.paged_search( return self.search_paginator(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,
search_scope=SUBTREE, search_scope=SUBTREE,
@ -32,13 +36,13 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
**kwargs, **kwargs,
) )
def sync(self) -> int: def sync(self, page_data: list) -> int:
"""Iterate over all Users and assign Groups using memberOf Field""" """Iterate over all Users and assign Groups using memberOf Field"""
if not self._source.sync_groups: if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source") self.message("Group syncing is disabled for this Source")
return -1 return -1
membership_count = 0 membership_count = 0
for group in self.get_objects(): for group in page_data:
if "attributes" not in group: if "attributes" not in group:
continue continue
members = group.get("attributes", {}).get(self._source.group_membership_field, []) members = group.get("attributes", {}).get(self._source.group_membership_field, [])

View File

@ -15,8 +15,12 @@ from authentik.sources.ldap.sync.vendor.ms_ad import MicrosoftActiveDirectory
class UserLDAPSynchronizer(BaseLDAPSynchronizer): class UserLDAPSynchronizer(BaseLDAPSynchronizer):
"""Sync LDAP Users into authentik""" """Sync LDAP Users into authentik"""
@staticmethod
def name() -> str:
return "users"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
return self._connection.extend.standard.paged_search( return self.search_paginator(
search_base=self.base_dn_users, search_base=self.base_dn_users,
search_filter=self._source.user_object_filter, search_filter=self._source.user_object_filter,
search_scope=SUBTREE, search_scope=SUBTREE,
@ -24,13 +28,13 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
**kwargs, **kwargs,
) )
def sync(self) -> int: def sync(self, page_data: list) -> int:
"""Iterate over all LDAP Users and create authentik_core.User instances""" """Iterate over all LDAP Users and create authentik_core.User instances"""
if not self._source.sync_users: if not self._source.sync_users:
self.message("User syncing is disabled for this Source") self.message("User syncing is disabled for this Source")
return -1 return -1
user_count = 0 user_count = 0
for user in self.get_objects(): for user in page_data:
if "attributes" not in user: if "attributes" not in user:
continue continue
attributes = user.get("attributes", {}) attributes = user.get("attributes", {})

View File

@ -11,6 +11,10 @@ from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
class FreeIPA(BaseLDAPSynchronizer): class FreeIPA(BaseLDAPSynchronizer):
"""FreeIPA-specific LDAP""" """FreeIPA-specific LDAP"""
@staticmethod
def name() -> str:
return "freeipa"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
yield None yield None

View File

@ -42,6 +42,10 @@ class UserAccountControl(IntFlag):
class MicrosoftActiveDirectory(BaseLDAPSynchronizer): class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
"""Microsoft-specific LDAP""" """Microsoft-specific LDAP"""
@staticmethod
def name() -> str:
return "microsoft_ad"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
yield None yield None

View File

@ -1,4 +1,8 @@
"""LDAP Sync tasks""" """LDAP Sync tasks"""
from uuid import uuid4
from celery import chain, group
from django.core.cache import cache
from ldap3.core.exceptions import LDAPException from ldap3.core.exceptions import LDAPException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -8,6 +12,7 @@ from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import class_to_path, path_to_class from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
from authentik.sources.ldap.models import LDAPSource from authentik.sources.ldap.models import LDAPSource
from authentik.sources.ldap.sync.base import BaseLDAPSynchronizer
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
@ -18,14 +23,43 @@ SYNC_CLASSES = [
GroupLDAPSynchronizer, GroupLDAPSynchronizer,
MembershipLDAPSynchronizer, MembershipLDAPSynchronizer,
] ]
CACHE_KEY_PREFIX = "goauthentik.io/sources/ldap/page/"
@CELERY_APP.task() @CELERY_APP.task()
def ldap_sync_all(): def ldap_sync_all():
"""Sync all sources""" """Sync all sources"""
for source in LDAPSource.objects.filter(enabled=True): for source in LDAPSource.objects.filter(enabled=True):
for sync_class in SYNC_CLASSES: ldap_sync_single(source)
ldap_sync.delay(source.pk, class_to_path(sync_class))
@CELERY_APP.task()
def ldap_sync_single(source: LDAPSource):
"""Sync a single source"""
task = chain(
# User and group sync can happen at once, they have no dependencies on each other
group(
ldap_sync_paginator(source, UserLDAPSynchronizer)
+ ldap_sync_paginator(source, GroupLDAPSynchronizer),
),
# Membership sync needs to run afterwards
group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer),
),
)
task()
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:
"""Return a list of task signatures with LDAP pagination data"""
sync_inst: BaseLDAPSynchronizer = sync(source)
signatures = []
for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page)
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
signatures.append(page_sync)
return signatures
@CELERY_APP.task( @CELERY_APP.task(
@ -34,7 +68,7 @@ def ldap_sync_all():
soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
) )
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str): def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source""" """Synchronization of an LDAP Source"""
self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours")) self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours"))
try: try:
@ -43,11 +77,16 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str):
# Because the source couldn't be found, we don't have a UID # Because the source couldn't be found, we don't have a UID
# to set the state with # to set the state with
return return
sync = path_to_class(sync_class) sync: type[BaseLDAPSynchronizer] = path_to_class(sync_class)
self.set_uid(f"{source.slug}:{sync.__name__.replace('LDAPSynchronizer', '').lower()}") uid = page_cache_key.replace(CACHE_KEY_PREFIX, "")
self.set_uid(f"{source.slug}:{sync.name()}:{uid}")
try: try:
sync_inst = sync(source) sync_inst: BaseLDAPSynchronizer = sync(source)
count = sync_inst.sync() page = cache.get(page_cache_key)
if not page:
return
cache.touch(page_cache_key)
count = sync_inst.sync(page)
messages = sync_inst.messages messages = sync_inst.messages
messages.append(f"Synced {count} objects.") messages.append(f"Synced {count} objects.")
self.set_status( self.set_status(
@ -56,6 +95,7 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str):
messages, messages,
) )
) )
cache.delete(page_cache_key)
except LDAPException as exc: except LDAPException as exc:
# No explicit event is created here as .set_status with an error will do that # No explicit event is created here as .set_status with an error will do that
LOGGER.warning(exception_to_string(exc)) LOGGER.warning(exception_to_string(exc))

View File

@ -43,7 +43,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=raw_conn) connection = MagicMock(return_value=raw_conn)
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync_full()
user = User.objects.get(username="user0_sn") user = User.objects.get(username="user0_sn")
# auth_user_by_bind = Mock(return_value=user) # auth_user_by_bind = Mock(return_value=user)
@ -71,7 +71,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync_full()
user = User.objects.get(username="user0_sn") user = User.objects.get(username="user0_sn")
auth_user_by_bind = Mock(return_value=user) auth_user_by_bind = Mock(return_value=user)
@ -98,7 +98,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync_full()
user = User.objects.get(username="user0_sn") user = User.objects.get(username="user0_sn")
auth_user_by_bind = Mock(return_value=user) auth_user_by_bind = Mock(return_value=user)

View File

@ -51,7 +51,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync_full()
self.assertFalse(User.objects.filter(username="user0_sn").exists()) self.assertFalse(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists())
events = Event.objects.filter( events = Event.objects.filter(
@ -87,7 +87,7 @@ class LDAPSyncTests(TestCase):
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync_full()
user = User.objects.filter(username="user0_sn").first() user = User.objects.filter(username="user0_sn").first()
self.assertEqual(user.attributes["foo"], "bar") self.assertEqual(user.attributes["foo"], "bar")
self.assertFalse(user.is_active) self.assertFalse(user.is_active)
@ -106,7 +106,7 @@ class LDAPSyncTests(TestCase):
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists())
@ -128,9 +128,9 @@ class LDAPSyncTests(TestCase):
self.source.sync_parent_group = parent_group self.source.sync_parent_group = parent_group
self.source.save() self.source.save()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source)
group_sync.sync() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync.sync() membership_sync.sync_full()
group: Group = Group.objects.filter(name="test-group").first() group: Group = Group.objects.filter(name="test-group").first()
self.assertIsNotNone(group) self.assertIsNotNone(group)
self.assertEqual(group.parent, parent_group) self.assertEqual(group.parent, parent_group)
@ -152,9 +152,9 @@ class LDAPSyncTests(TestCase):
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source)
group_sync.sync() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync.sync() membership_sync.sync_full()
group = Group.objects.filter(name="group1") group = Group.objects.filter(name="group1")
self.assertTrue(group.exists()) self.assertTrue(group.exists())
@ -177,11 +177,11 @@ class LDAPSyncTests(TestCase):
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
self.source.save() self.source.save()
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync() user_sync.sync_full()
group_sync = GroupLDAPSynchronizer(self.source) group_sync = GroupLDAPSynchronizer(self.source)
group_sync.sync() group_sync.sync_full()
membership_sync = MembershipLDAPSynchronizer(self.source) membership_sync = MembershipLDAPSynchronizer(self.source)
membership_sync.sync() membership_sync.sync_full()
# Test if membership mapping based on memberUid works. # Test if membership mapping based on memberUid works.
posix_group = Group.objects.filter(name="group-posix").first() posix_group = Group.objects.filter(name="group-posix").first()
self.assertTrue(posix_group.users.filter(name="user-posix").exists()) self.assertTrue(posix_group.users.filter(name="user-posix").exists())

View File

@ -1,6 +1,8 @@
"""OpenID Type tests""" """OpenID Type tests"""
from django.test import TestCase from django.test import RequestFactory, TestCase
from requests_mock import Mocker
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
@ -24,9 +26,10 @@ class TestTypeOpenID(TestCase):
slug="test", slug="test",
provider_type="openidconnect", provider_type="openidconnect",
authorization_url="", authorization_url="",
profile_url="", profile_url="http://localhost/userinfo",
consumer_key="", consumer_key="",
) )
self.factory = RequestFactory()
def test_enroll_context(self): def test_enroll_context(self):
"""Test OpenID Enrollment context""" """Test OpenID Enrollment context"""
@ -34,3 +37,19 @@ class TestTypeOpenID(TestCase):
self.assertEqual(ak_context["username"], OPENID_USER["nickname"]) self.assertEqual(ak_context["username"], OPENID_USER["nickname"])
self.assertEqual(ak_context["email"], OPENID_USER["email"]) self.assertEqual(ak_context["email"], OPENID_USER["email"])
self.assertEqual(ak_context["name"], OPENID_USER["name"]) self.assertEqual(ak_context["name"], OPENID_USER["name"])
@Mocker()
def test_userinfo(self, mock: Mocker):
"""Test userinfo API call"""
mock.get("http://localhost/userinfo", json=OPENID_USER)
token = generate_id()
OpenIDConnectOAuth2Callback(request=self.factory.get("/")).get_client(
self.source
).get_profile_info(
{
"token_type": "foo",
"access_token": token,
}
)
self.assertEqual(mock.last_request.query, "")
self.assertEqual(mock.last_request.headers["Authorization"], f"foo {token}")

View File

@ -20,7 +20,7 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
class OpenIDConnectOAuth2Callback(OAuthCallback): class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback""" """OpenIDConnect OAuth2 Callback"""
client_class: UserprofileHeaderAuthClient client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "") return info.get("sub", "")

View File

@ -133,6 +133,12 @@ def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -
device = WebAuthnDevice.objects.filter(credential_id=credential_id).first() device = WebAuthnDevice.objects.filter(credential_id=credential_id).first()
if not device: if not device:
raise ValidationError("Invalid device") raise ValidationError("Invalid device")
# We can only check the device's user if the user we're given isn't anonymous
# as this validation is also used for password-less login where webauthn is the very first
# step done by a user. Only if this validation happens at a later stage we can check
# that the device belongs to the user
if not user.is_anonymous and device.user != user:
raise ValidationError("Invalid device")
stage: AuthenticatorValidateStage = stage_view.executor.current_stage stage: AuthenticatorValidateStage = stage_view.executor.current_stage

View File

@ -37,9 +37,9 @@ from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_ME
COOKIE_NAME_MFA = "authentik_mfa" COOKIE_NAME_MFA = "authentik_mfa"
SESSION_KEY_STAGES = "authentik/stages/authenticator_validate/stages" PLAN_CONTEXT_STAGES = "goauthentik.io/stages/authenticator_validate/stages"
SESSION_KEY_SELECTED_STAGE = "authentik/stages/authenticator_validate/selected_stage" PLAN_CONTEXT_SELECTED_STAGE = "goauthentik.io/stages/authenticator_validate/selected_stage"
SESSION_KEY_DEVICE_CHALLENGES = "authentik/stages/authenticator_validate/device_challenges" PLAN_CONTEXT_DEVICE_CHALLENGES = "goauthentik.io/stages/authenticator_validate/device_challenges"
class SelectableStageSerializer(PassiveSerializer): class SelectableStageSerializer(PassiveSerializer):
@ -73,8 +73,8 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
component = CharField(default="ak-stage-authenticator-validate") component = CharField(default="ak-stage-authenticator-validate")
def _challenge_allowed(self, classes: list): def _challenge_allowed(self, classes: list):
device_challenges: list[dict] = self.stage.request.session.get( device_challenges: list[dict] = self.stage.executor.plan.context.get(
SESSION_KEY_DEVICE_CHALLENGES, [] PLAN_CONTEXT_DEVICE_CHALLENGES, []
) )
if not any(x["device_class"] in classes for x in device_challenges): if not any(x["device_class"] in classes for x in device_challenges):
raise ValidationError("No compatible device class allowed") raise ValidationError("No compatible device class allowed")
@ -104,7 +104,9 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
"""Check which challenge the user has selected. Actual logic only used for SMS stage.""" """Check which challenge the user has selected. Actual logic only used for SMS stage."""
# First check if the challenge is valid # First check if the challenge is valid
allowed = False allowed = False
for device_challenge in self.stage.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []): for device_challenge in self.stage.executor.plan.context.get(
PLAN_CONTEXT_DEVICE_CHALLENGES, []
):
if device_challenge.get("device_class", "") == challenge.get( if device_challenge.get("device_class", "") == challenge.get(
"device_class", "" "device_class", ""
) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""): ) and device_challenge.get("device_uid", "") == challenge.get("device_uid", ""):
@ -122,11 +124,11 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse):
def validate_selected_stage(self, stage_pk: str) -> str: def validate_selected_stage(self, stage_pk: str) -> str:
"""Check that the selected stage is valid""" """Check that the selected stage is valid"""
stages = self.stage.request.session.get(SESSION_KEY_STAGES, []) stages = self.stage.executor.plan.context.get(PLAN_CONTEXT_STAGES, [])
if not any(str(stage.pk) == stage_pk for stage in stages): if not any(str(stage.pk) == stage_pk for stage in stages):
raise ValidationError("Selected stage is invalid") raise ValidationError("Selected stage is invalid")
self.stage.logger.debug("Setting selected stage to ", stage=stage_pk) self.stage.logger.debug("Setting selected stage to ", stage=stage_pk)
self.stage.request.session[SESSION_KEY_SELECTED_STAGE] = stage_pk self.stage.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = stage_pk
return stage_pk return stage_pk
def validate(self, attrs: dict): def validate(self, attrs: dict):
@ -231,7 +233,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
else: else:
self.logger.debug("No pending user, continuing") self.logger.debug("No pending user, continuing")
return self.executor.stage_ok() return self.executor.stage_ok()
self.request.session[SESSION_KEY_DEVICE_CHALLENGES] = challenges self.executor.plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = challenges
# No allowed devices # No allowed devices
if len(challenges) < 1: if len(challenges) < 1:
@ -264,23 +266,23 @@ class AuthenticatorValidateStageView(ChallengeStageView):
if stage.configuration_stages.count() == 1: if stage.configuration_stages.count() == 1:
next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk) next_stage = Stage.objects.get_subclass(pk=stage.configuration_stages.first().pk)
self.logger.debug("Single stage configured, auto-selecting", stage=next_stage) self.logger.debug("Single stage configured, auto-selecting", stage=next_stage)
self.request.session[SESSION_KEY_SELECTED_STAGE] = next_stage self.executor.plan.context[PLAN_CONTEXT_SELECTED_STAGE] = next_stage
# Because that normal execution only happens on post, we directly inject it here and # Because that normal execution only happens on post, we directly inject it here and
# return it # return it
self.executor.plan.insert_stage(next_stage) self.executor.plan.insert_stage(next_stage)
return self.executor.stage_ok() return self.executor.stage_ok()
stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses() stages = Stage.objects.filter(pk__in=stage.configuration_stages.all()).select_subclasses()
self.request.session[SESSION_KEY_STAGES] = stages self.executor.plan.context[PLAN_CONTEXT_STAGES] = stages
return super().get(self.request, *args, **kwargs) return super().get(self.request, *args, **kwargs)
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
res = super().post(request, *args, **kwargs) res = super().post(request, *args, **kwargs)
if ( if (
SESSION_KEY_SELECTED_STAGE in self.request.session PLAN_CONTEXT_SELECTED_STAGE in self.executor.plan.context
and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE and self.executor.current_stage.not_configured_action == NotConfiguredAction.CONFIGURE
): ):
self.logger.debug("Got selected stage in session, running that") self.logger.debug("Got selected stage in context, running that")
stage_pk = self.request.session.get(SESSION_KEY_SELECTED_STAGE) stage_pk = self.executor.plan.context.get(PLAN_CONTEXT_SELECTED_STAGE)
# Because the foreign key to stage.configuration_stage points to # Because the foreign key to stage.configuration_stage points to
# a base stage class, we need to do another lookup # a base stage class, we need to do another lookup
stage = Stage.objects.get_subclass(pk=stage_pk) stage = Stage.objects.get_subclass(pk=stage_pk)
@ -291,8 +293,8 @@ class AuthenticatorValidateStageView(ChallengeStageView):
return res return res
def get_challenge(self) -> AuthenticatorValidationChallenge: def get_challenge(self) -> AuthenticatorValidationChallenge:
challenges = self.request.session.get(SESSION_KEY_DEVICE_CHALLENGES, []) challenges = self.executor.plan.context.get(PLAN_CONTEXT_DEVICE_CHALLENGES, [])
stages = self.request.session.get(SESSION_KEY_STAGES, []) stages = self.executor.plan.context.get(PLAN_CONTEXT_STAGES, [])
stage_challenges = [] stage_challenges = []
for stage in stages: for stage in stages:
serializer = SelectableStageSerializer( serializer = SelectableStageSerializer(
@ -307,6 +309,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
stage_challenges.append(serializer.data) stage_challenges.append(serializer.data)
return AuthenticatorValidationChallenge( return AuthenticatorValidationChallenge(
data={ data={
"component": "ak-stage-authenticator-validate",
"type": ChallengeTypes.NATIVE.value, "type": ChallengeTypes.NATIVE.value,
"device_challenges": challenges, "device_challenges": challenges,
"configuration_stages": stage_challenges, "configuration_stages": stage_challenges,
@ -386,8 +389,3 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device": webauthn_device, "device": webauthn_device,
} }
return self.set_valid_mfa_cookie(response.device) return self.set_valid_mfa_cookie(response.device)
def cleanup(self):
self.request.session.pop(SESSION_KEY_STAGES, None)
self.request.session.pop(SESSION_KEY_SELECTED_STAGE, None)
self.request.session.pop(SESSION_KEY_DEVICE_CHALLENGES, None)

View File

@ -1,26 +1,19 @@
"""Test validator stage""" """Test validator stage"""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from django.contrib.sessions.middleware import SessionMiddleware
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.urls.base import reverse from django.urls.base import reverse
from rest_framework.exceptions import ValidationError
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import FlowDesignation, FlowStageBinding, NotConfiguredAction from authentik.flows.models import FlowDesignation, FlowStageBinding, NotConfiguredAction
from authentik.flows.planner import FlowPlan from authentik.flows.planner import FlowPlan
from authentik.flows.stage import StageView
from authentik.flows.tests import FlowTestCase from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import SESSION_KEY_PLAN, FlowExecutorView from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id, generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.lib.tests.utils import dummy_get_response
from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice
from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_validate.stage import ( from authentik.stages.authenticator_validate.stage import PLAN_CONTEXT_DEVICE_CHALLENGES
SESSION_KEY_DEVICE_CHALLENGES,
AuthenticatorValidationChallengeResponse,
)
from authentik.stages.identification.models import IdentificationStage, UserFields from authentik.stages.identification.models import IdentificationStage, UserFields
@ -48,6 +41,10 @@ class AuthenticatorValidateStageTests(FlowTestCase):
FlowStageBinding.objects.create(target=flow, stage=conf_stage, order=0) FlowStageBinding.objects.create(target=flow, stage=conf_stage, order=0)
FlowStageBinding.objects.create(target=flow, stage=stage, order=1) FlowStageBinding.objects.create(target=flow, stage=stage, order=1)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 200)
response = self.client.post( response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
{"uid_field": self.user.username}, {"uid_field": self.user.username},
@ -68,6 +65,67 @@ class AuthenticatorValidateStageTests(FlowTestCase):
show_source_labels=False, show_source_labels=False,
) )
def test_not_configured_action_multiple(self):
"""Test not_configured_action"""
conf_stage = IdentificationStage.objects.create(
name=generate_id(),
user_fields=[
UserFields.USERNAME,
],
)
conf_stage2 = IdentificationStage.objects.create(
name=generate_id(),
user_fields=[
UserFields.USERNAME,
],
)
stage = AuthenticatorValidateStage.objects.create(
name=generate_id(),
not_configured_action=NotConfiguredAction.CONFIGURE,
)
stage.configuration_stages.set([conf_stage, conf_stage2])
flow = create_test_flow()
FlowStageBinding.objects.create(target=flow, stage=conf_stage, order=0)
FlowStageBinding.objects.create(target=flow, stage=stage, order=1)
# Get initial identification stage
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 200)
# Answer initial identification stage
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
{"uid_field": self.user.username},
)
self.assertEqual(response.status_code, 302)
# Get list of all configuration stages
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 200)
# Select stage
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
{"selected_stage": conf_stage.pk},
)
self.assertEqual(response.status_code, 302)
# get actual identification stage response
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow,
component="ak-stage-identification",
password_fields=False,
primary_action="Continue",
user_fields=["username"],
sources=[],
show_source_labels=False,
)
def test_stage_validation(self): def test_stage_validation(self):
"""Test serializer validation""" """Test serializer validation"""
self.client.force_login(self.user) self.client.force_login(self.user)
@ -86,12 +144,17 @@ class AuthenticatorValidateStageTests(FlowTestCase):
def test_validate_selected_challenge(self): def test_validate_selected_challenge(self):
"""Test validate_selected_challenge""" """Test validate_selected_challenge"""
# Prepare request with session flow = create_test_flow()
request = self.request_factory.get("/") stage = AuthenticatorValidateStage.objects.create(
name=generate_id(),
not_configured_action=NotConfiguredAction.CONFIGURE,
device_classes=[DeviceClasses.STATIC, DeviceClasses.TOTP],
)
middleware = SessionMiddleware(dummy_get_response) session = self.client.session
middleware.process_request(request) plan = FlowPlan(flow_pk=flow.pk.hex)
request.session[SESSION_KEY_DEVICE_CHALLENGES] = [ plan.append_stage(stage)
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
{ {
"device_class": "static", "device_class": "static",
"device_uid": "1", "device_uid": "1",
@ -101,23 +164,43 @@ class AuthenticatorValidateStageTests(FlowTestCase):
"device_uid": "2", "device_uid": "2",
}, },
] ]
request.session.save() session[SESSION_KEY_PLAN] = plan
session.save()
res = AuthenticatorValidationChallengeResponse() response = self.client.post(
res.stage = StageView(FlowExecutorView()) reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
res.stage.request = request data={
with self.assertRaises(ValidationError): "selected_challenge": {
res.validate_selected_challenge(
{
"device_class": "baz", "device_class": "baz",
"device_uid": "quox", "device_uid": "quox",
"challenge": {},
} }
},
) )
res.validate_selected_challenge( self.assertStageResponse(
{ response,
flow,
response_errors={
"selected_challenge": [{"string": "invalid challenge selected", "code": "invalid"}]
},
component="ak-stage-authenticator-validate",
)
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
data={
"selected_challenge": {
"device_class": "static", "device_class": "static",
"device_uid": "1", "device_uid": "1",
} "challenge": {},
},
},
)
self.assertStageResponse(
response,
flow,
response_errors={"non_field_errors": [{"string": "Empty response", "code": "invalid"}]},
component="ak-stage-authenticator-validate",
) )
@patch( @patch(

View File

@ -22,7 +22,7 @@ from authentik.stages.authenticator_validate.challenge import (
) )
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_validate.stage import ( from authentik.stages.authenticator_validate.stage import (
SESSION_KEY_DEVICE_CHALLENGES, PLAN_CONTEXT_DEVICE_CHALLENGES,
AuthenticatorValidateStageView, AuthenticatorValidateStageView,
) )
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
@ -211,14 +211,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
plan.append_stage(stage) plan.append_stage(stage)
plan.append_stage(UserLoginStage(name=generate_id())) plan.append_stage(UserLoginStage(name=generate_id()))
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
session[SESSION_KEY_PLAN] = plan plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
session[SESSION_KEY_DEVICE_CHALLENGES] = [
{ {
"device_class": device.__class__.__name__.lower().replace("device", ""), "device_class": device.__class__.__name__.lower().replace("device", ""),
"device_uid": device.pk, "device_uid": device.pk,
"challenge": {}, "challenge": {},
} }
] ]
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
) )
@ -283,14 +283,14 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
plan = FlowPlan(flow_pk=flow.pk.hex) plan = FlowPlan(flow_pk=flow.pk.hex)
plan.append_stage(stage) plan.append_stage(stage)
plan.append_stage(UserLoginStage(name=generate_id())) plan.append_stage(UserLoginStage(name=generate_id()))
session[SESSION_KEY_PLAN] = plan plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
session[SESSION_KEY_DEVICE_CHALLENGES] = [
{ {
"device_class": device.__class__.__name__.lower().replace("device", ""), "device_class": device.__class__.__name__.lower().replace("device", ""),
"device_uid": device.pk, "device_uid": device.pk,
"challenge": {}, "challenge": {},
} }
] ]
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
) )

View File

@ -124,6 +124,7 @@ class UserWriteStageView(StageView):
connection: UserSourceConnection = self.executor.plan.context[ connection: UserSourceConnection = self.executor.plan.context[
PLAN_CONTEXT_SOURCES_CONNECTION PLAN_CONTEXT_SOURCES_CONNECTION
] ]
if connection.source.name not in user.attributes[USER_ATTRIBUTE_SOURCES]:
user.attributes[USER_ATTRIBUTE_SOURCES].append(connection.source.name) user.attributes[USER_ATTRIBUTE_SOURCES].append(connection.source.name)
def get(self, request: HttpRequest) -> HttpResponse: def get(self, request: HttpRequest) -> HttpResponse:

View File

@ -97,6 +97,47 @@ class TestUserWriteStage(FlowTestCase):
self.assertEqual(user_qs.first().attributes["foo"], "bar") self.assertEqual(user_qs.first().attributes["foo"], "bar")
self.assertNotIn("some_ignored_attribute", user_qs.first().attributes) self.assertNotIn("some_ignored_attribute", user_qs.first().attributes)
def test_user_update_source(self):
"""Test update of existing user with a source"""
new_password = generate_key()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = User.objects.create(
username="unittest",
email="test@goauthentik.io",
attributes={
USER_ATTRIBUTE_SOURCES: [
self.source.name,
]
},
)
plan.context[PLAN_CONTEXT_SOURCES_CONNECTION] = UserSourceConnection(source=self.source)
plan.context[PLAN_CONTEXT_PROMPT] = {
"username": "test-user-new",
"password": new_password,
"attributes.some.custom-attribute": "test",
"attributes": {
"foo": "bar",
},
"some_ignored_attribute": "bar",
}
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"])
self.assertTrue(user_qs.exists())
self.assertTrue(user_qs.first().check_password(new_password))
self.assertEqual(user_qs.first().attributes["some"]["custom-attribute"], "test")
self.assertEqual(user_qs.first().attributes["foo"], "bar")
self.assertEqual(user_qs.first().attributes[USER_ATTRIBUTE_SOURCES], [self.source.name])
self.assertNotIn("some_ignored_attribute", user_qs.first().attributes)
@patch( @patch(
"authentik.flows.views.executor.to_stage_response", "authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK, TO_STAGE_RESPONSE_MOCK,

View File

@ -3155,9 +3155,12 @@
"description": "When this option is enabled, all executions of this policy will be logged. By default, only execution errors are logged." "description": "When this option is enabled, all executions of this policy will be logged. By default, only execution errors are logged."
}, },
"action": { "action": {
"type": "string", "type": [
"null",
"string"
],
"enum": [ "enum": [
"", null,
"login", "login",
"login_failed", "login_failed",
"logout", "logout",
@ -3190,14 +3193,21 @@
"description": "Match created events with this action type. When left empty, all action types will be matched." "description": "Match created events with this action type. When left empty, all action types will be matched."
}, },
"client_ip": { "client_ip": {
"type": "string", "type": [
"string",
"null"
],
"minLength": 1,
"title": "Client ip", "title": "Client ip",
"description": "Matches Event's Client IP (strict matching, for network matching use an Expression Policy)" "description": "Matches Event's Client IP (strict matching, for network matching use an Expression Policy)"
}, },
"app": { "app": {
"type": "string", "type": [
"null",
"string"
],
"enum": [ "enum": [
"", null,
"authentik.admin", "authentik.admin",
"authentik.api", "authentik.api",
"authentik.crypto", "authentik.crypto",
@ -3249,6 +3259,87 @@
], ],
"title": "App", "title": "App",
"description": "Match events created by selected application. When left empty, all applications are matched." "description": "Match events created by selected application. When left empty, all applications are matched."
},
"model": {
"type": [
"null",
"string"
],
"enum": [
null,
"authentik_crypto.certificatekeypair",
"authentik_events.event",
"authentik_events.notificationtransport",
"authentik_events.notification",
"authentik_events.notificationrule",
"authentik_events.notificationwebhookmapping",
"authentik_flows.flow",
"authentik_flows.flowstagebinding",
"authentik_outposts.dockerserviceconnection",
"authentik_outposts.kubernetesserviceconnection",
"authentik_outposts.outpost",
"authentik_policies_dummy.dummypolicy",
"authentik_policies_event_matcher.eventmatcherpolicy",
"authentik_policies_expiry.passwordexpirypolicy",
"authentik_policies_expression.expressionpolicy",
"authentik_policies_password.passwordpolicy",
"authentik_policies_reputation.reputationpolicy",
"authentik_policies_reputation.reputation",
"authentik_policies.policybinding",
"authentik_providers_ldap.ldapprovider",
"authentik_providers_oauth2.scopemapping",
"authentik_providers_oauth2.oauth2provider",
"authentik_providers_oauth2.authorizationcode",
"authentik_providers_oauth2.accesstoken",
"authentik_providers_oauth2.refreshtoken",
"authentik_providers_proxy.proxyprovider",
"authentik_providers_radius.radiusprovider",
"authentik_providers_saml.samlprovider",
"authentik_providers_saml.samlpropertymapping",
"authentik_providers_scim.scimprovider",
"authentik_providers_scim.scimmapping",
"authentik_sources_ldap.ldapsource",
"authentik_sources_ldap.ldappropertymapping",
"authentik_sources_oauth.oauthsource",
"authentik_sources_oauth.useroauthsourceconnection",
"authentik_sources_plex.plexsource",
"authentik_sources_plex.plexsourceconnection",
"authentik_sources_saml.samlsource",
"authentik_sources_saml.usersamlsourceconnection",
"authentik_stages_authenticator_duo.authenticatorduostage",
"authentik_stages_authenticator_duo.duodevice",
"authentik_stages_authenticator_sms.authenticatorsmsstage",
"authentik_stages_authenticator_sms.smsdevice",
"authentik_stages_authenticator_static.authenticatorstaticstage",
"authentik_stages_authenticator_totp.authenticatortotpstage",
"authentik_stages_authenticator_validate.authenticatorvalidatestage",
"authentik_stages_authenticator_webauthn.authenticatewebauthnstage",
"authentik_stages_authenticator_webauthn.webauthndevice",
"authentik_stages_captcha.captchastage",
"authentik_stages_consent.consentstage",
"authentik_stages_consent.userconsent",
"authentik_stages_deny.denystage",
"authentik_stages_dummy.dummystage",
"authentik_stages_email.emailstage",
"authentik_stages_identification.identificationstage",
"authentik_stages_invitation.invitationstage",
"authentik_stages_invitation.invitation",
"authentik_stages_password.passwordstage",
"authentik_stages_prompt.prompt",
"authentik_stages_prompt.promptstage",
"authentik_stages_user_delete.userdeletestage",
"authentik_stages_user_login.userloginstage",
"authentik_stages_user_logout.userlogoutstage",
"authentik_stages_user_write.userwritestage",
"authentik_tenants.tenant",
"authentik_blueprints.blueprintinstance",
"authentik_core.group",
"authentik_core.user",
"authentik_core.application",
"authentik_core.token"
],
"title": "Model",
"description": "Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched."
} }
}, },
"required": [] "required": []
@ -3542,14 +3633,14 @@
"minimum": -2147483648, "minimum": -2147483648,
"maximum": 2147483647, "maximum": 2147483647,
"title": "Uid start number", "title": "Uid start number",
"description": "The start for uidNumbers, this number is added to the user.Pk to make sure that the numbers aren't too low for POSIX users. Default is 2000 to ensure that we don't collide with local users uidNumber" "description": "The start for uidNumbers, this number is added to the user.pk to make sure that the numbers aren't too low for POSIX users. Default is 2000 to ensure that we don't collide with local users uidNumber"
}, },
"gid_start_number": { "gid_start_number": {
"type": "integer", "type": "integer",
"minimum": -2147483648, "minimum": -2147483648,
"maximum": 2147483647, "maximum": 2147483647,
"title": "Gid start number", "title": "Gid start number",
"description": "The start for gidNumbers, this number is added to a number generated from the group.Pk to make sure that the numbers aren't too low for POSIX groups. Default is 4000 to ensure that we don't collide with local groups or users primary groups gidNumber" "description": "The start for gidNumbers, this number is added to a number generated from the group.pk to make sure that the numbers aren't too low for POSIX groups. Default is 4000 to ensure that we don't collide with local groups or users primary groups gidNumber"
}, },
"search_mode": { "search_mode": {
"type": "string", "type": "string",
@ -3566,6 +3657,11 @@
"cached" "cached"
], ],
"title": "Bind mode" "title": "Bind mode"
},
"mfa_support": {
"type": "boolean",
"title": "MFA Support",
"description": "When enabled, code-based multi-factor authentication can be used by appending a semicolon and the TOTP code to the password. This should only be enabled if all users that will bind to this provider have a TOTP device configured, as otherwise a password may incorrectly be rejected if it contains a semicolon."
} }
}, },
"required": [] "required": []

View File

@ -0,0 +1,36 @@
# This file is used for development and debugging, and should not be used for production instances
version: '3.5'
services:
flower:
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.5.4}
restart: unless-stopped
command: worker-status
environment:
AUTHENTIK_REDIS__HOST: redis
AUTHENTIK_POSTGRESQL__HOST: postgresql
AUTHENTIK_POSTGRESQL__USER: ${PG_USER:-authentik}
AUTHENTIK_POSTGRESQL__NAME: ${PG_DB:-authentik}
AUTHENTIK_POSTGRESQL__PASSWORD: ${PG_PASS}
env_file:
- .env
ports:
- "9001:9000"
depends_on:
- postgresql
- redis
server:
environment:
AUTHENTIK_REMOTE_DEBUG: "true"
PYDEVD_THREAD_DUMP_ON_WARN_EVALUATION_TIMEOUT: "true"
ports:
- 6800:6800
worker:
environment:
CELERY_RDB_HOST: "0.0.0.0"
CELERY_RDBSIG: "1"
AUTHENTIK_REMOTE_DEBUG: "true"
PYDEVD_THREAD_DUMP_ON_WARN_EVALUATION_TIMEOUT: "true"
ports:
- 6900:6900

View File

@ -32,7 +32,7 @@ services:
volumes: volumes:
- redis:/data - redis:/data
server: server:
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.5.3} image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.5.4}
restart: unless-stopped restart: unless-stopped
command: server command: server
environment: environment:
@ -53,7 +53,7 @@ services:
- postgresql - postgresql
- redis - redis
worker: worker:
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.5.3} image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2023.5.4}
restart: unless-stopped restart: unless-stopped
command: worker command: worker
environment: environment:

22
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/Netflix/go-env v0.0.0-20210215222557-e437a7e7f9fb github.com/Netflix/go-env v0.0.0-20210215222557-e437a7e7f9fb
github.com/coreos/go-oidc v2.2.1+incompatible github.com/coreos/go-oidc v2.2.1+incompatible
github.com/garyburd/redigo v1.6.4 github.com/garyburd/redigo v1.6.4
github.com/getsentry/sentry-go v0.21.0 github.com/getsentry/sentry-go v0.22.0
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.5 github.com/go-ldap/ldap/v3 v3.4.5
github.com/go-openapi/runtime v0.26.0 github.com/go-openapi/runtime v0.26.0
@ -22,14 +22,14 @@ require (
github.com/jellydator/ttlcache/v3 v3.0.1 github.com/jellydator/ttlcache/v3 v3.0.1
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484 github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
github.com/pires/go-proxyproto v0.7.0 github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.15.1 github.com/prometheus/client_golang v1.16.0
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
goauthentik.io/api/v3 v3.2023052.1 goauthentik.io/api/v3 v3.2023054.4
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.8.0 golang.org/x/oauth2 v0.10.0
golang.org/x/sync v0.2.0 golang.org/x/sync v0.3.0
gopkg.in/boj/redistore.v1 v1.0.0-20160128113310-fc113767cd6b gopkg.in/boj/redistore.v1 v1.0.0-20160128113310-fc113767cd6b
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
layeh.com/radius v0.0.0-20210819152912-ad72663a72ab layeh.com/radius v0.0.0-20210819152912-ad72663a72ab
@ -67,18 +67,18 @@ require (
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.42.0 // indirect github.com/prometheus/common v0.42.0 // indirect
github.com/prometheus/procfs v0.9.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
go.mongodb.org/mongo-driver v1.11.3 // indirect go.mongodb.org/mongo-driver v1.11.3 // indirect
go.opentelemetry.io/otel v1.14.0 // indirect go.opentelemetry.io/otel v1.14.0 // indirect
go.opentelemetry.io/otel/trace v1.14.0 // indirect go.opentelemetry.io/otel/trace v1.14.0 // indirect
golang.org/x/crypto v0.7.0 // indirect golang.org/x/crypto v0.11.0 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.12.0 // indirect
golang.org/x/sys v0.8.0 // indirect golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.9.0 // indirect golang.org/x/text v0.11.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/square/go-jose.v2 v2.5.1 // indirect gopkg.in/square/go-jose.v2 v2.5.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

1496
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -29,4 +29,4 @@ func UserAgent() string {
return fmt.Sprintf("authentik@%s", FullVersion()) return fmt.Sprintf("authentik@%s", FullVersion())
} }
const VERSION = "2023.5.3" const VERSION = "2023.5.4"

View File

@ -14,5 +14,3 @@ const (
HeaderAuthentikRemoteIP = "X-authentik-remote-ip" HeaderAuthentikRemoteIP = "X-authentik-remote-ip"
HeaderAuthentikOutpostToken = "X-authentik-outpost-token" HeaderAuthentikOutpostToken = "X-authentik-outpost-token"
) )
const CodePasswordSeparator = ";"

Some files were not shown because too many files have changed in this diff Show More