Compare commits

..

1 Commits

Author SHA1 Message Date
Jens Langhammer 7d1efd7450
root: replace celery queues with priority
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2023-10-18 12:50:24 +02:00
324 changed files with 9221 additions and 16186 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 2023.10.6
current_version = 2023.8.3
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@ -2,39 +2,36 @@ name: "Setup authentik testing environment"
description: "Setup authentik testing environment"
inputs:
postgresql_version:
postgresql_tag:
description: "Optional postgresql image tag"
default: "12"
runs:
using: "composite"
steps:
- name: Install poetry & deps
- name: Install poetry
shell: bash
run: |
pipx install poetry || true
sudo apt-get update
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext
sudo apt update
sudo apt install -y libpq-dev openssl libxmlsec1-dev pkg-config gettext
- name: Setup python and restore poetry
uses: actions/setup-python@v4
uses: actions/setup-python@v3
with:
python-version-file: 'pyproject.toml'
python-version: "3.11"
cache: "poetry"
- name: Setup node
uses: actions/setup-node@v3
with:
node-version-file: web/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: web/package-lock.json
- name: Setup go
uses: actions/setup-go@v4
with:
go-version-file: "go.mod"
- name: Setup dependencies
shell: bash
run: |
export PSQL_TAG=${{ inputs.postgresql_version }}
export PSQL_TAG=${{ inputs.postgresql_tag }}
docker-compose -f .github/actions/setup/docker-compose.yml up -d
poetry env use python3.11
poetry install
cd web && npm ci
- name: Generate config

View File

@ -11,7 +11,6 @@ on:
pull_request:
branches:
- main
- version-*
env:
POSTGRES_DB: authentik
@ -48,38 +47,25 @@ jobs:
- name: run migrations
run: poetry run python -m lifecycle.migrate
test-migrations-from-stable:
name: test-migrations-from-stable - PostgreSQL ${{ matrix.psql }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
psql:
- 12-alpine
- 15-alpine
- 16-alpine
continue-on-error: true
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup authentik env
uses: ./.github/actions/setup
with:
postgresql_version: ${{ matrix.psql }}
- name: checkout stable
run: |
# Delete all poetry envs
rm -rf /home/runner/.cache/pypoetry
# Copy current, latest config to local
cp authentik/lib/default.yml local.env.yml
cp -R .github ..
cp -R scripts ..
git checkout version/$(python -c "from authentik import __version__; print(__version__)")
git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
rm -rf .github/ scripts/
mv ../.github ../scripts .
- name: Setup authentik env (ensure stable deps are installed)
uses: ./.github/actions/setup
with:
postgresql_version: ${{ matrix.psql }}
- name: run migrations to stable
run: poetry run python -m lifecycle.migrate
- name: checkout current code
@ -89,13 +75,9 @@ jobs:
git reset --hard HEAD
git clean -d -fx .
git checkout $GITHUB_SHA
# Delete previous poetry env
rm -rf $(poetry env info --path)
poetry install
- name: Setup authentik env (ensure latest deps are installed)
uses: ./.github/actions/setup
with:
postgresql_version: ${{ matrix.psql }}
- name: migrate to latest
run: poetry run python -m lifecycle.migrate
test-unittest:
@ -114,7 +96,7 @@ jobs:
- name: Setup authentik env
uses: ./.github/actions/setup
with:
postgresql_version: ${{ matrix.psql }}
postgresql_tag: ${{ matrix.psql }}
- name: run unittest
run: |
poetry run make test
@ -203,9 +185,6 @@ jobs:
build:
needs: ci-core-mark
runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
timeout-minutes: 120
steps:
- uses: actions/checkout@v4
@ -256,9 +235,6 @@ jobs:
build-arm64:
needs: ci-core-mark
runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
timeout-minutes: 120
steps:
- uses: actions/checkout@v4

View File

@ -9,7 +9,6 @@ on:
pull_request:
branches:
- main
- version-*
jobs:
lint-golint:
@ -66,9 +65,6 @@ jobs:
- ldap
- radius
runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
steps:
- uses: actions/checkout@v4
with:
@ -128,9 +124,9 @@ jobs:
- uses: actions/setup-go@v4
with:
go-version-file: "go.mod"
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: web/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: web/package-lock.json
- name: Generate API

View File

@ -9,7 +9,6 @@ on:
pull_request:
branches:
- main
- version-*
jobs:
lint-eslint:
@ -22,9 +21,9 @@ jobs:
- tests/wdio
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: ${{ matrix.project }}/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: ${{ matrix.project }}/package-lock.json
- working-directory: ${{ matrix.project }}/
@ -38,9 +37,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: web/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: web/package-lock.json
- working-directory: web/
@ -60,9 +59,9 @@ jobs:
- tests/wdio
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: ${{ matrix.project }}/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: ${{ matrix.project }}/package-lock.json
- working-directory: ${{ matrix.project }}/
@ -76,9 +75,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: web/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: web/package-lock.json
- working-directory: web/
@ -108,9 +107,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: web/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: web/package-lock.json
- working-directory: web/

View File

@ -9,16 +9,15 @@ on:
pull_request:
branches:
- main
- version-*
jobs:
lint-prettier:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: website/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/
@ -30,9 +29,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: website/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/
@ -51,9 +50,9 @@ jobs:
- build-docs-only
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: website/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/

View File

@ -6,7 +6,6 @@ on:
workflow_dispatch:
permissions:
# Needed to be able to push to the next branch
contents: write
jobs:

View File

@ -7,9 +7,6 @@ on:
jobs:
build-server:
runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
steps:
- uses: actions/checkout@v4
- name: Set up QEMU
@ -30,10 +27,8 @@ jobs:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: make empty clients
run: |
mkdir -p ./gen-ts-api
mkdir -p ./gen-go-api
- name: make empty ts client
run: mkdir -p ./gen-ts-client
- name: Build Docker Image
uses: docker/build-push-action@v5
with:
@ -55,9 +50,6 @@ jobs:
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost:
runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
strategy:
fail-fast: false
matrix:
@ -77,10 +69,6 @@ jobs:
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
- name: make empty clients
run: |
mkdir -p ./gen-ts-api
mkdir -p ./gen-go-api
- name: Docker Login Registry
uses: docker/login-action@v3
with:
@ -105,16 +93,12 @@ jobs:
ghcr.io/goauthentik/${{ matrix.type }}:latest
file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64
context: .
build-args: |
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost-binary:
timeout-minutes: 120
runs-on: ubuntu-latest
permissions:
# Needed to upload binaries to the release
contents: write
strategy:
fail-fast: false
matrix:
@ -129,9 +113,9 @@ jobs:
- uses: actions/setup-go@v4
with:
go-version-file: "go.mod"
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: web/package.json
node-version: "20"
cache: "npm"
cache-dependency-path: web/package-lock.json
- name: Build web

View File

@ -16,7 +16,6 @@ jobs:
echo "PG_PASS=$(openssl rand -base64 32)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env
docker buildx install
mkdir -p ./gen-ts-api
docker build -t testing:latest .
echo "AUTHENTIK_IMAGE=testing" >> .env
echo "AUTHENTIK_TAG=latest" >> .env

View File

@ -6,8 +6,8 @@ on:
workflow_dispatch:
permissions:
# Needed to update issues and PRs
issues: write
pull-requests: write
jobs:
stale:

View File

@ -17,9 +17,9 @@ jobs:
- uses: actions/checkout@v4
with:
token: ${{ steps.generate_token.outputs.token }}
- uses: actions/setup-node@v4
- uses: actions/setup-node@v3
with:
node-version-file: web/package.json
node-version: "20"
registry-url: "https://registry.npmjs.org"
- name: Generate API Client
run: make gen-client-ts

View File

@ -1,7 +1,5 @@
# syntax=docker/dockerfile:1
# Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/node:21 as website-builder
FROM --platform=${BUILDPLATFORM} docker.io/node:20 as website-builder
ENV NODE_ENV=production
@ -9,7 +7,7 @@ WORKDIR /work/website
RUN --mount=type=bind,target=/work/website/package.json,src=./website/package.json \
--mount=type=bind,target=/work/website/package-lock.json,src=./website/package-lock.json \
--mount=type=cache,id=npm-website,sharing=shared,target=/root/.npm \
--mount=type=cache,target=/root/.npm \
npm ci --include=dev
COPY ./website /work/website/
@ -19,7 +17,7 @@ COPY ./SECURITY.md /work/
RUN npm run build-docs-only
# Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/node:21 as web-builder
FROM --platform=${BUILDPLATFORM} docker.io/node:20 as web-builder
ENV NODE_ENV=production
@ -27,7 +25,7 @@ WORKDIR /work/web
RUN --mount=type=bind,target=/work/web/package.json,src=./web/package.json \
--mount=type=bind,target=/work/web/package-lock.json,src=./web/package-lock.json \
--mount=type=cache,id=npm-web,sharing=shared,target=/root/.npm \
--mount=type=cache,target=/root/.npm \
npm ci --include=dev
COPY ./web /work/web/
@ -37,14 +35,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build
# Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.21.4-bookworm AS go-builder
ARG TARGETOS
ARG TARGETARCH
ARG TARGETVARIANT
ARG GOOS=$TARGETOS
ARG GOARCH=$TARGETARCH
FROM docker.io/golang:1.21.3-bookworm AS go-builder
WORKDIR /go/src/goauthentik.io
@ -64,12 +55,12 @@ COPY ./go.sum /go/src/goauthentik.io/go.sum
ENV CGO_ENABLED=0
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server
RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \
go build -o /go/authentik ./cmd/server
# Stage 4: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v6.0 as geoip
FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City"
ENV GEOIPUPDATE_VERBOSE="true"
@ -91,9 +82,7 @@ ENV VENV_PATH="/ak-root/venv" \
POETRY_VIRTUALENVS_CREATE=false \
PATH="/ak-root/venv/bin:$PATH"
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && \
# Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev

View File

@ -56,9 +56,9 @@ test: ## Run the server tests and produce a coverage report (locally)
coverage report
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
isort $(PY_SOURCES)
black $(PY_SOURCES)
ruff $(PY_SOURCES)
isort authentik $(PY_SOURCES)
black authentik $(PY_SOURCES)
ruff authentik $(PY_SOURCES)
codespell -w $(CODESPELL_ARGS)
lint: ## Lint the python and golang sources

View File

@ -2,7 +2,7 @@
from os import environ
from typing import Optional
__version__ = "2023.10.6"
__version__ = "2023.8.3"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -1,12 +1,13 @@
"""authentik admin settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"admin_latest_version": {
"task": "authentik.admin.tasks.update_latest_version",
"schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}
}

View File

@ -21,9 +21,7 @@ _other_urls = []
for _authentik_app in get_apps():
try:
api_urls = import_module(f"{_authentik_app.name}.urls")
except ModuleNotFoundError:
continue
except ImportError as exc:
except (ModuleNotFoundError, ImportError) as exc:
LOGGER.warning("Could not import app's URLs", app_name=_authentik_app.name, exc=exc)
continue
if not hasattr(api_urls, "api_urlpatterns"):

View File

@ -40,7 +40,7 @@ class ManagedAppConfig(AppConfig):
meth()
self._logger.debug("Successfully reconciled", name=name)
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.warning("Failed to run reconcile", name=name, exc=exc)
self._logger.debug("Failed to run reconcile", name=name, exc=exc)
class AuthentikBlueprintsConfig(ManagedAppConfig):

View File

@ -1,17 +1,18 @@
"""blueprint Settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"blueprints_v1_discover": {
"task": "authentik.blueprints.v1.tasks.blueprints_discovery",
"schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
"blueprints_v1_cleanup": {
"task": "authentik.blueprints.v1.tasks.clear_failed_blueprints",
"schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
}

View File

@ -584,17 +584,12 @@ class EntryInvalidError(SentryIgnoredException):
entry_model: Optional[str]
entry_id: Optional[str]
validation_error: Optional[ValidationError]
serializer: Optional[Serializer] = None
def __init__(
self, *args: object, validation_error: Optional[ValidationError] = None, **kwargs
) -> None:
def __init__(self, *args: object, validation_error: Optional[ValidationError] = None) -> None:
super().__init__(*args)
self.entry_model = None
self.entry_id = None
self.validation_error = validation_error
for key, value in kwargs.items():
setattr(self, key, value)
@staticmethod
def from_entry(

View File

@ -255,10 +255,7 @@ class Importer:
try:
full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import))
except ValueError as exc:
raise EntryInvalidError.from_entry(
exc,
entry,
) from exc
raise EntryInvalidError.from_entry(exc, entry) from exc
always_merger.merge(full_data, updated_identifiers)
serializer_kwargs["data"] = full_data
@ -275,7 +272,6 @@ class Importer:
f"Serializer errors {serializer.errors}",
validation_error=exc,
entry=entry,
serializer=serializer,
) from exc
return serializer
@ -304,14 +300,12 @@ class Importer:
)
return False
# Validate each single entry
serializer = None
try:
serializer = self._validate_single(entry)
except EntryInvalidError as exc:
# For deleting objects we don't need the serializer to be valid
if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT:
serializer = exc.serializer
else:
continue
self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc)
if raise_errors:
raise exc

View File

@ -75,14 +75,14 @@ class BlueprintEventHandler(FileSystemEventHandler):
return
if event.is_directory:
return
root = Path(CONFIG.get("blueprints_dir")).absolute()
path = Path(event.src_path).absolute()
rel_path = str(path.relative_to(root))
if isinstance(event, FileCreatedEvent):
LOGGER.debug("new blueprint file created, starting discovery", path=rel_path)
blueprints_discovery.delay(rel_path)
LOGGER.debug("new blueprint file created, starting discovery")
blueprints_discovery.delay()
if isinstance(event, FileModifiedEvent):
for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True):
path = Path(event.src_path)
root = Path(CONFIG.get("blueprints_dir")).absolute()
rel_path = str(path.relative_to(root))
for instance in BlueprintInstance.objects.filter(path=rel_path):
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
apply_blueprint.delay(instance.pk.hex)
@ -98,32 +98,39 @@ def blueprints_find_dict():
return blueprints
def blueprints_find() -> list[BlueprintFile]:
def blueprints_find():
"""Find blueprints and return valid ones"""
blueprints = []
root = Path(CONFIG.get("blueprints_dir"))
for path in root.rglob("**/*.yaml"):
rel_path = path.relative_to(root)
# Check if any part in the path starts with a dot and assume a hidden file
if any(part for part in path.parts if part.startswith(".")):
continue
LOGGER.debug("found blueprint", path=str(path))
with open(path, "r", encoding="utf-8") as blueprint_file:
try:
raw_blueprint = load(blueprint_file.read(), BlueprintLoader)
except YAMLError as exc:
raw_blueprint = None
LOGGER.warning("failed to parse blueprint", exc=exc, path=str(rel_path))
LOGGER.warning("failed to parse blueprint", exc=exc, path=str(path))
if not raw_blueprint:
continue
metadata = raw_blueprint.get("metadata", None)
version = raw_blueprint.get("version", 1)
if version != 1:
LOGGER.warning("invalid blueprint version", version=version, path=str(rel_path))
LOGGER.warning("invalid blueprint version", version=version, path=str(path))
continue
file_hash = sha512(path.read_bytes()).hexdigest()
blueprint = BlueprintFile(str(rel_path), version, file_hash, int(path.stat().st_mtime))
blueprint = BlueprintFile(
str(path.relative_to(root)), version, file_hash, int(path.stat().st_mtime)
)
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
blueprints.append(blueprint)
LOGGER.debug(
"parsed & loaded blueprint",
hash=file_hash,
path=str(path),
)
return blueprints
@ -131,12 +138,10 @@ def blueprints_find() -> list[BlueprintFile]:
throws=(DatabaseError, ProgrammingError, InternalError), base=MonitoredTask, bind=True
)
@prefill_task
def blueprints_discovery(self: MonitoredTask, path: Optional[str] = None):
def blueprints_discovery(self: MonitoredTask):
"""Find blueprints and check if they need to be created in the database"""
count = 0
for blueprint in blueprints_find():
if path and blueprint.path != path:
continue
check_blueprint_v1_file(blueprint)
count += 1
self.set_status(
@ -166,11 +171,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
metadata={},
)
instance.save()
LOGGER.info(
"Creating new blueprint instance from file", instance=instance, path=instance.path
)
if instance.last_applied_hash != blueprint.hash:
LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path)
apply_blueprint.delay(str(instance.pk))

View File

@ -98,7 +98,6 @@ class ApplicationSerializer(ModelSerializer):
class ApplicationViewSet(UsedByMixin, ModelViewSet):
"""Application Viewset"""
# pylint: disable=no-member
queryset = Application.objects.all().prefetch_related("provider")
serializer_class = ApplicationSerializer
search_fields = [

View File

@ -139,7 +139,6 @@ class UserAccountSerializer(PassiveSerializer):
class GroupViewSet(UsedByMixin, ModelViewSet):
"""Group Viewset"""
# pylint: disable=no-member
queryset = Group.objects.all().select_related("parent").prefetch_related("users")
serializer_class = GroupSerializer
search_fields = ["name", "is_superuser"]

View File

@ -38,7 +38,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
managed = ReadOnlyField()
component = SerializerMethodField()
icon = ReadOnlyField(source="icon_url")
icon = ReadOnlyField(source="get_icon")
def get_component(self, obj: Source) -> str:
"""Get object component so that we know how to edit the object"""

View File

@ -171,11 +171,6 @@ class UserSerializer(ModelSerializer):
raise ValidationError("Setting a user to internal service account is not allowed.")
return user_type
def validate(self, attrs: dict) -> dict:
if self.instance and self.instance.type == UserTypes.INTERNAL_SERVICE_ACCOUNT:
raise ValidationError("Can't modify internal service account users")
return super().validate(attrs)
class Meta:
model = User
fields = [

View File

@ -44,7 +44,6 @@ class PropertyMappingEvaluator(BaseEvaluator):
if request:
req.http_request = request
self._context["request"] = req
req.context.update(**kwargs)
self._context.update(**kwargs)
self.dry_run = dry_run

View File

@ -17,15 +17,9 @@ class Command(BaseCommand):
"""Run worker"""
def add_arguments(self, parser):
parser.add_argument(
"-b",
"--beat",
action="store_false",
help="When set, this worker will _not_ run Beat (scheduled) tasks",
)
parser.add_argument("-b", "--beat", action="store_true")
def handle(self, **options):
LOGGER.debug("Celery options", **options)
close_old_connections()
if CONFIG.get_bool("remote_debug"):
import debugpy
@ -39,7 +33,6 @@ class Command(BaseCommand):
task_events=True,
beat=options.get("beat", True),
schedule_filename=f"{tempdir}/celerybeat-schedule",
queues=["authentik", "authentik_scheduled", "authentik_events"],
)
for task in CELERY_APP.tasks:
LOGGER.debug("Registered task", task=task)

View File

@ -97,7 +97,6 @@ class SourceFlowManager:
if self.request.user.is_authenticated:
new_connection.user = self.request.user
new_connection = self.update_connection(new_connection, **kwargs)
# pylint: disable=no-member
new_connection.save()
return Action.LINK, new_connection

View File

@ -13,6 +13,7 @@
{% block head_before %}
{% endblock %}
<link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}">
<link rel="stylesheet" type="text/css" href="{% static 'dist/theme-dark.css' %}" media="(prefers-color-scheme: dark)">
<link rel="stylesheet" type="text/css" href="{% static 'dist/custom.css' %}" data-inject>
<script src="{% static 'dist/poly.js' %}?version={{ version }}" type="module"></script>
<script src="{% static 'dist/standalone/loading/index.js' %}?version={{ version }}" type="module"></script>

View File

@ -16,8 +16,8 @@ You've logged out of {{ application }}.
{% block card %}
<form method="POST" class="pf-c-form">
<p>
{% blocktrans with application=application.name branding_title=tenant.branding_title %}
You've logged out of {{ application }}. You can go back to the overview to launch another application, or log out of your {{ branding_title }} account.
{% blocktrans with application=application.name %}
You've logged out of {{ application }}. You can go back to the overview to launch another application, or log out of your authentik account.
{% endblocktrans %}
</p>

View File

@ -6,7 +6,6 @@
{% block head_before %}
<link rel="prefetch" href="/static/dist/assets/images/flow_background.jpg" />
<link rel="stylesheet" type="text/css" href="{% static 'dist/patternfly.min.css' %}">
<link rel="stylesheet" type="text/css" href="{% static 'dist/theme-dark.css' %}" media="(prefers-color-scheme: dark)">
{% include "base/header_js.html" %}
{% endblock %}

View File

@ -1,10 +1,13 @@
"""authentik crypto app config"""
from datetime import datetime
from typing import Optional
from typing import TYPE_CHECKING, Optional
from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.generators import generate_id
if TYPE_CHECKING:
from authentik.crypto.models import CertificateKeyPair
MANAGED_KEY = "goauthentik.io/crypto/jwt-managed"
@ -20,37 +23,33 @@ class AuthentikCryptoConfig(ManagedAppConfig):
"""Load crypto tasks"""
self.import_module("authentik.crypto.tasks")
def _create_update_cert(self):
def _create_update_cert(self, cert: Optional["CertificateKeyPair"] = None):
from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair
common_name = "authentik Internal JWT Certificate"
builder = CertificateBuilder(common_name)
builder = CertificateBuilder("authentik Internal JWT Certificate")
builder.build(
subject_alt_names=["goauthentik.io"],
validity_days=360,
)
CertificateKeyPair.objects.update_or_create(
managed=MANAGED_KEY,
defaults={
"name": common_name,
"certificate_data": builder.certificate,
"key_data": builder.private_key,
},
)
if not cert:
cert = CertificateKeyPair()
builder.cert = cert
builder.cert.managed = MANAGED_KEY
builder.save()
def reconcile_managed_jwt_cert(self):
"""Ensure managed JWT certificate"""
from authentik.crypto.models import CertificateKeyPair
cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter(
managed=MANAGED_KEY
).first()
now = datetime.now()
if not cert or (
now < cert.certificate.not_valid_before or now > cert.certificate.not_valid_after
):
certs = CertificateKeyPair.objects.filter(managed=MANAGED_KEY)
if not certs.exists():
self._create_update_cert()
return
cert: CertificateKeyPair = certs.first()
now = datetime.now()
if now < cert.certificate.not_valid_before or now > cert.certificate.not_valid_after:
self._create_update_cert(cert)
def reconcile_self_signed(self):
"""Create self-signed keypair"""
@ -62,10 +61,4 @@ class AuthentikCryptoConfig(ManagedAppConfig):
return
builder = CertificateBuilder(name)
builder.build(subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"])
CertificateKeyPair.objects.get_or_create(
name=name,
defaults={
"certificate_data": builder.certificate,
"key_data": builder.private_key,
},
)
builder.save()

View File

@ -1,12 +1,13 @@
"""Crypto task Settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"crypto_certificate_discovery": {
"task": "authentik.crypto.tasks.certificate_discovery",
"schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
}

View File

@ -136,9 +136,6 @@ class LicenseKey:
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if LicenseUsage.objects.filter(record_date__gte=threshold).exists():
return
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),

View File

@ -1,12 +1,13 @@
"""Enterprise additional settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"enterprise_calculate_license": {
"task": "authentik.enterprise.tasks.calculate_license",
"schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/2"),
"options": {"queue": "authentik_scheduled"},
"schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/8"),
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}
}

View File

@ -6,4 +6,5 @@ from authentik.root.celery import CELERY_APP
@CELERY_APP.task()
def calculate_license():
"""Calculate licensing status"""
LicenseKey.get_total().record_usage()
total = LicenseKey.get_total()
total.record_usage()

View File

@ -27,7 +27,6 @@ from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel
from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.scim.models import SCIMGroup, SCIMUser
from authentik.stages.authenticator_static.models import StaticToken
@ -53,13 +52,11 @@ IGNORED_MODELS = (
RefreshToken,
SCIMUser,
SCIMGroup,
Reputation,
)
def should_log_model(model: Model) -> bool:
"""Return true if operation on `model` should be logged"""
# Check for silk by string so this comparison doesn't fail when silk isn't installed
if model.__module__.startswith("silk"):
return False
return model.__class__ not in IGNORED_MODELS
@ -96,30 +93,21 @@ class AuditMiddleware:
of models"""
get_response: Callable[[HttpRequest], HttpResponse]
anonymous_user: User = None
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
def _ensure_fallback_user(self):
"""Defer fetching anonymous user until we have to"""
if self.anonymous_user:
return
from guardian.shortcuts import get_anonymous_user
self.anonymous_user = get_anonymous_user()
def connect(self, request: HttpRequest):
"""Connect signal for automatic logging"""
self._ensure_fallback_user()
user = getattr(request, "user", self.anonymous_user)
if not user.is_authenticated:
user = self.anonymous_user
if not hasattr(request, "user"):
return
if not getattr(request.user, "is_authenticated", False):
return
if not hasattr(request, "request_id"):
return
post_save_handler = partial(self.post_save_handler, user=user, request=request)
pre_delete_handler = partial(self.pre_delete_handler, user=user, request=request)
m2m_changed_handler = partial(self.m2m_changed_handler, user=user, request=request)
post_save_handler = partial(self.post_save_handler, user=request.user, request=request)
pre_delete_handler = partial(self.pre_delete_handler, user=request.user, request=request)
m2m_changed_handler = partial(self.m2m_changed_handler, user=request.user, request=request)
post_save.connect(
post_save_handler,
dispatch_uid=request.request_id,

View File

@ -217,7 +217,6 @@ class Event(SerializerModel, ExpiringModel):
"path": request.path,
"method": request.method,
"args": cleanse_dict(QueryDict(request.META.get("QUERY_STRING", ""))),
"user_agent": request.META.get("HTTP_USER_AGENT", ""),
}
# Special case for events created during flow execution
# since they keep the http query within a wrapped query

View File

@ -1,12 +1,13 @@
"""Event Settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"events_notification_cleanup": {
"task": "authentik.events.tasks.notification_cleanup",
"schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
}

View File

@ -13,7 +13,6 @@ from authentik.events.tasks import event_notification_handler, gdpr_cleanup
from authentik.flows.models import Stage
from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.config import CONFIG
from authentik.stages.invitation.models import Invitation
from authentik.stages.invitation.signals import invitation_used
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
@ -93,5 +92,4 @@ def event_post_save_notification(sender, instance: Event, **_):
@receiver(pre_delete, sender=User)
def event_user_pre_delete_cleanup(sender, instance: User, **_):
"""If gdpr_compliance is enabled, remove all the user's events"""
if CONFIG.get_bool("gdpr_compliance", True):
gdpr_cleanup.delay(instance.pk)

View File

@ -20,6 +20,7 @@ from authentik.events.monitored_tasks import (
TaskResultStatus,
prefill_task,
)
from authentik.lib.config import CONFIG
from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.root.celery import CELERY_APP
@ -89,7 +90,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
user.pk,
str(trigger.pk),
],
queue="authentik_events",
priority=CONFIG.get_int("worker.priority.events"),
)
if transport.send_once:
break

View File

@ -53,15 +53,7 @@ class TestEvents(TestCase):
"""Test plain from_http"""
event = Event.new("unittest").from_http(self.factory.get("/"))
self.assertEqual(
event.context,
{
"http_request": {
"args": {},
"method": "GET",
"path": "/",
"user_agent": "",
}
},
event.context, {"http_request": {"args": {}, "method": "GET", "path": "/"}}
)
def test_from_http_clean_querystring(self):
@ -75,7 +67,6 @@ class TestEvents(TestCase):
"args": {"token": SafeExceptionReporterFilter.cleansed_substitute},
"method": "GET",
"path": "/",
"user_agent": "",
}
},
)
@ -92,7 +83,6 @@ class TestEvents(TestCase):
"args": {"token": SafeExceptionReporterFilter.cleansed_substitute},
"method": "GET",
"path": "/",
"user_agent": "",
}
},
)

View File

@ -5,13 +5,12 @@ from dataclasses import asdict, is_dataclass
from datetime import date, datetime, time, timedelta
from enum import Enum
from pathlib import Path
from types import GeneratorType, NoneType
from types import GeneratorType
from typing import Any, Optional
from uuid import UUID
from django.contrib.auth.models import AnonymousUser
from django.core.handlers.wsgi import WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
from django.db import models
from django.db.models.base import Model
from django.http.request import HttpRequest
@ -154,20 +153,7 @@ def sanitize_item(value: Any) -> Any:
return value.isoformat()
if isinstance(value, timedelta):
return str(value.total_seconds())
if callable(value):
return {
"type": "callable",
"name": value.__name__,
"module": value.__module__,
}
# List taken from the stdlib's JSON encoder (_make_iterencode, encoder.py:415)
if isinstance(value, (bool, int, float, NoneType, list, tuple, dict)):
return value
try:
return DjangoJSONEncoder().default(value)
except TypeError:
return str(value)
return str(value)
def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:

View File

@ -1,34 +0,0 @@
# Generated by Django 4.2.6 on 2023-10-28 14:24
from django.apps.registry import Apps
from django.db import migrations
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def set_oobe_flow_authentication(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from guardian.shortcuts import get_anonymous_user
Flow = apps.get_model("authentik_flows", "Flow")
User = apps.get_model("authentik_core", "User")
db_alias = schema_editor.connection.alias
users = User.objects.using(db_alias).exclude(username="akadmin")
try:
users = users.exclude(pk=get_anonymous_user().pk)
# pylint: disable=broad-except
except Exception: # nosec
pass
if users.exists():
Flow.objects.filter(slug="initial-setup").update(authentication="require_superuser")
class Migration(migrations.Migration):
dependencies = [
("authentik_flows", "0026_alter_flow_options"),
]
operations = [
migrations.RunPython(set_oobe_flow_authentication),
]

View File

@ -167,11 +167,7 @@ class ChallengeStageView(StageView):
stage_type=self.__class__.__name__, method="get_challenge"
).time(),
):
try:
challenge = self.get_challenge(*args, **kwargs)
except StageInvalidException as exc:
self.logger.debug("Got StageInvalidException", exc=exc)
return self.executor.stage_invalid()
with Hub.current.start_span(
op="authentik.flow.stage._get_challenge",
description=self.__class__.__name__,

View File

@ -114,3 +114,8 @@ web:
worker:
concurrency: 2
priority:
default: 4
scheduled: 9
sync: 9
events: 8

View File

@ -18,7 +18,7 @@ from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, is_dict
from authentik.core.models import Provider
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import (
Outpost,
OutpostConfig,
@ -47,16 +47,6 @@ class OutpostSerializer(ModelSerializer):
source="service_connection", read_only=True
)
def validate_name(self, name: str) -> str:
"""Validate name (especially for embedded outpost)"""
if not self.instance:
return name
if self.instance.managed == MANAGED_OUTPOST and name != MANAGED_OUTPOST_NAME:
raise ValidationError("Embedded outpost's name cannot be changed")
if self.instance.name == MANAGED_OUTPOST_NAME:
self.instance.managed = MANAGED_OUTPOST
return name
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
"""Check that all providers match the type of the outpost"""
type_map = {

View File

@ -15,7 +15,6 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
["outpost", "uid", "version"],
)
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
class AuthentikOutpostConfig(ManagedAppConfig):
@ -36,17 +35,14 @@ class AuthentikOutpostConfig(ManagedAppConfig):
DockerServiceConnection,
KubernetesServiceConnection,
Outpost,
OutpostConfig,
OutpostType,
)
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
outpost.managed = MANAGED_OUTPOST
outpost.save()
return
outpost, updated = Outpost.objects.update_or_create(
defaults={
"name": "authentik Embedded Outpost",
"type": OutpostType.PROXY,
"name": MANAGED_OUTPOST_NAME,
},
managed=MANAGED_OUTPOST,
)
@ -55,4 +51,10 @@ class AuthentikOutpostConfig(ManagedAppConfig):
outpost.service_connection = KubernetesServiceConnection.objects.first()
elif DockerServiceConnection.objects.exists():
outpost.service_connection = DockerServiceConnection.objects.first()
outpost.config = OutpostConfig(
kubernetes_disabled_components=[
"deployment",
"secret",
]
)
outpost.save()

View File

@ -43,10 +43,6 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
self.api = AppsV1Api(controller.client)
self.outpost = self.controller.outpost
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod
def reconciler_name() -> str:
return "deployment"

View File

@ -24,10 +24,6 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
super().__init__(controller)
self.api = CoreV1Api(controller.client)
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod
def reconciler_name() -> str:
return "secret"

View File

@ -77,10 +77,7 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
@property
def noop(self) -> bool:
if not self._crd_exists():
self.logger.debug("CRD doesn't exist")
return True
return self.is_embedded
return (not self._crd_exists()) or (self.is_embedded)
def _crd_exists(self) -> bool:
"""Check if the Prometheus ServiceMonitor exists"""

View File

@ -344,21 +344,11 @@ class Outpost(SerializerModel, ManagedModel):
user_created = False
if not user:
user: User = User.objects.create(username=self.user_identifier)
user_created = True
attrs = {
"type": UserTypes.INTERNAL_SERVICE_ACCOUNT,
"name": f"Outpost {self.name} Service-Account",
"path": USER_PATH_OUTPOSTS,
}
dirty = False
for key, value in attrs.items():
if getattr(user, key) != value:
dirty = True
setattr(user, key, value)
if user.has_usable_password():
user.set_unusable_password()
dirty = True
if dirty:
user_created = True
user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
user.name = f"Outpost {self.name} Service-Account"
user.path = USER_PATH_OUTPOSTS
user.save()
if user_created:
self.build_user_permissions(user)

View File

@ -1,27 +1,28 @@
"""Outposts Settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"outposts_controller": {
"task": "authentik.outposts.tasks.outpost_controller_all",
"schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
"outposts_service_connection_check": {
"task": "authentik.outposts.tasks.outpost_service_connection_monitor",
"schedule": crontab(minute="3-59/15"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
"outpost_token_ensurer": {
"task": "authentik.outposts.tasks.outpost_token_ensurer",
"schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
"outpost_connection_discovery": {
"task": "authentik.outposts.tasks.outpost_connection_discovery",
"schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
}

View File

@ -2,13 +2,11 @@
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.blueprints.tests import reconcile_app
from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.outposts.api.outposts import OutpostSerializer
from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import Outpost, OutpostType, default_outpost_config
from authentik.outposts.models import OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider
@ -24,36 +22,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
self.user = create_test_admin_user()
self.client.force_login(self.user)
@reconcile_app("authentik_outposts")
def test_managed_name_change(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content, {"name": ["Embedded outpost's name cannot be changed"]}
)
@reconcile_app("authentik_outposts")
def test_managed_without_managed(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
embedded_outpost.managed = ""
embedded_outpost.save()
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 200)
embedded_outpost.refresh_from_db()
self.assertEqual(embedded_outpost.managed, MANAGED_OUTPOST)
def test_outpost_validation(self):
def test_outpost_validaton(self):
"""Test Outpost validation"""
valid = OutpostSerializer(
data={

View File

@ -1,10 +1,12 @@
"""Reputation Settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
CELERY_BEAT_SCHEDULE = {
"policies_reputation_save": {
"task": "authentik.policies.reputation.tasks.save_reputation",
"schedule": crontab(minute="1-59/5"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
}

View File

@ -1,27 +0,0 @@
# Generated by Django 5.0 on 2023-12-22 23:20
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_oauth2", "0016_alter_refreshtoken_token"),
]
operations = [
migrations.AddField(
model_name="accesstoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="authorizationcode",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="refreshtoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
]

View File

@ -296,7 +296,6 @@ class BaseGrantModel(models.Model):
revoked = models.BooleanField(default=False)
_scope = models.TextField(default="", verbose_name=_("Scopes"))
auth_time = models.DateTimeField(verbose_name="Authentication time")
session_id = models.CharField(default="", blank=True)
@property
def scope(self) -> list[str]:

View File

@ -85,25 +85,6 @@ class TestAuthorize(OAuthTestCase):
)
OAuthAuthorizationParams.from_request(request)
def test_blocked_redirect_uri(self):
"""test missing/invalid redirect URI"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="data:local.invalid",
)
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "data:localhost",
},
)
OAuthAuthorizationParams.from_request(request)
def test_invalid_redirect_uri_empty(self):
"""test missing/invalid redirect URI"""
provider = OAuth2Provider.objects.create(

View File

@ -1,187 +0,0 @@
"""Test token view"""
from base64 import b64encode, urlsafe_b64encode
from hashlib import sha256
from django.test import RequestFactory
from django.urls import reverse
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.constants import GRANT_TYPE_AUTHORIZATION_CODE
from authentik.providers.oauth2.models import AuthorizationCode, OAuth2Provider
from authentik.providers.oauth2.tests.utils import OAuthTestCase
class TestTokenPKCE(OAuthTestCase):
"""Test token view"""
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.app = Application.objects.create(name=generate_id(), slug="test")
def test_pkce_missing_in_token(self):
"""Test full with pkce"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
challenge = generate_id()
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"code_challenge": challenge,
"code_challenge_method": "S256",
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
code: AuthorizationCode = AuthorizationCode.objects.filter(user=user).first()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": f"foo://localhost?code={code.code}&state={state}",
},
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
# Missing the code_verifier here
"redirect_uri": "foo://localhost",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertJSONEqual(
response.content,
{"error": "invalid_request", "error_description": "The request is otherwise malformed"},
)
self.assertEqual(response.status_code, 400)
def test_pkce_correct_s256(self):
"""Test full with pkce"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
verifier = generate_id()
challenge = (
urlsafe_b64encode(sha256(verifier.encode("ascii")).digest())
.decode("utf-8")
.replace("=", "")
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"code_challenge": challenge,
"code_challenge_method": "S256",
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
code: AuthorizationCode = AuthorizationCode.objects.filter(user=user).first()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": f"foo://localhost?code={code.code}&state={state}",
},
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
"code_verifier": verifier,
"redirect_uri": "foo://localhost",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertEqual(response.status_code, 200)
def test_pkce_correct_plain(self):
"""Test full with pkce"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
verifier = generate_id()
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"code_challenge": verifier,
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
code: AuthorizationCode = AuthorizationCode.objects.filter(user=user).first()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": f"foo://localhost?code={code.code}&state={state}",
},
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
"code_verifier": verifier,
"redirect_uri": "foo://localhost",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertEqual(response.status_code, 200)

View File

@ -188,7 +188,6 @@ def authenticate_provider(request: HttpRequest) -> Optional[OAuth2Provider]:
if client_id != provider.client_id or client_secret != provider.client_secret:
LOGGER.debug("(basic) Provider for basic auth does not exist")
return None
CTX_AUTH_VIA.set("oauth_client_secret")
return provider

View File

@ -1,7 +1,6 @@
"""authentik OAuth2 Authorization views"""
from dataclasses import dataclass, field
from datetime import timedelta
from hashlib import sha256
from json import dumps
from re import error as RegexError
from re import fullmatch
@ -75,7 +74,6 @@ PLAN_CONTEXT_PARAMS = "goauthentik.io/providers/oauth2/params"
SESSION_KEY_LAST_LOGIN_UID = "authentik/providers/oauth2/last_login_uid"
ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN}
FORBIDDEN_URI_SCHEMES = {"javascript", "data", "vbscript"}
@dataclass(slots=True)
@ -176,10 +174,6 @@ class OAuthAuthorizationParams:
self.check_scope()
self.check_nonce()
self.check_code_challenge()
if self.request:
raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
def check_redirect_uri(self):
"""Redirect URI validation."""
@ -217,9 +211,10 @@ class OAuthAuthorizationParams:
expected=allowed_redirect_urls,
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
# Check against forbidden schemes
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
if self.request:
raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
def check_scope(self):
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
@ -287,7 +282,6 @@ class OAuthAuthorizationParams:
expires=now + timedelta_from_string(self.provider.access_code_validity),
scope=self.scope,
nonce=self.nonce,
session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
)
if self.code_challenge and self.code_challenge_method:
@ -575,7 +569,6 @@ class OAuthFulfillmentStage(StageView):
expires=access_token_expiry,
provider=self.provider,
auth_time=auth_event.created if auth_event else now,
session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
)
id_token = IDToken.new(self.provider, token, self.request)

View File

@ -6,7 +6,6 @@ from hashlib import sha256
from re import error as RegexError
from re import fullmatch
from typing import Any, Optional
from urllib.parse import urlparse
from django.http import HttpRequest, HttpResponse
from django.utils import timezone
@ -18,7 +17,6 @@ from jwt import PyJWK, PyJWT, PyJWTError, decode
from sentry_sdk.hub import Hub
from structlog.stdlib import get_logger
from authentik.core.middleware import CTX_AUTH_VIA
from authentik.core.models import (
USER_ATTRIBUTE_EXPIRES,
USER_ATTRIBUTE_GENERATED,
@ -55,7 +53,6 @@ from authentik.providers.oauth2.models import (
RefreshToken,
)
from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth
from authentik.providers.oauth2.views.authorize import FORBIDDEN_URI_SCHEMES
from authentik.sources.oauth.models import OAuthSource
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
@ -207,10 +204,6 @@ class TokenParams:
).from_http(request)
raise TokenError("invalid_client")
# Check against forbidden schemes
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
raise TokenError("invalid_request")
self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first()
if not self.authorization_code:
LOGGER.warning("Code does not exist", code=raw_code)
@ -228,10 +221,7 @@ class TokenParams:
raise TokenError("invalid_grant")
# Validate PKCE parameters.
if self.authorization_code.code_challenge:
# Authorization code had PKCE but we didn't get one
if not self.code_verifier:
raise TokenError("invalid_request")
if self.code_verifier:
if self.authorization_code.code_challenge_method == PKCE_METHOD_S256:
new_code_challenge = (
urlsafe_b64encode(sha256(self.code_verifier.encode("ascii")).digest())
@ -458,7 +448,6 @@ class TokenView(View):
if not self.provider:
LOGGER.warning("OAuth2Provider does not exist", client_id=client_id)
raise TokenError("invalid_client")
CTX_AUTH_VIA.set("oauth_client_secret")
self.params = TokenParams.parse(request, self.provider, client_id, client_secret)
with Hub.current.start_span(
@ -493,7 +482,6 @@ class TokenView(View):
# Keep same scopes as previous token
scope=self.params.authorization_code.scope,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
access_token.id_token = IDToken.new(
self.provider,
@ -509,7 +497,6 @@ class TokenView(View):
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
id_token = IDToken.new(
self.provider,
@ -547,7 +534,6 @@ class TokenView(View):
# Keep same scopes as previous token
scope=self.params.refresh_token.scope,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
)
access_token.id_token = IDToken.new(
self.provider,
@ -563,7 +549,6 @@ class TokenView(View):
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
)
id_token = IDToken.new(
self.provider,

View File

@ -1,6 +1,4 @@
"""proxy provider tasks"""
from hashlib import sha256
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError
@ -25,7 +23,6 @@ def proxy_set_defaults():
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
@ -33,6 +30,6 @@ def proxy_on_logout(session_id: str):
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": hashed_session_id,
"session_id": session_id,
},
)

View File

@ -21,7 +21,6 @@ class RadiusProviderSerializer(ProviderSerializer):
# an admin might have to view it
"shared_secret",
"outpost_set",
"mfa_support",
]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
@ -56,7 +55,6 @@ class RadiusOutpostConfigSerializer(ModelSerializer):
"auth_flow_slug",
"client_networks",
"shared_secret",
"mfa_support",
]

View File

@ -1,21 +0,0 @@
# Generated by Django 4.2.6 on 2023-10-18 15:09
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_radius", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="radiusprovider",
name="mfa_support",
field=models.BooleanField(
default=True,
help_text="When enabled, code-based multi-factor authentication can be used by appending a semicolon and the TOTP code to the password. This should only be enabled if all users that will bind to this provider have a TOTP device configured, as otherwise a password may incorrectly be rejected if it contains a semicolon.",
verbose_name="MFA Support",
),
),
]

View File

@ -27,17 +27,6 @@ class RadiusProvider(OutpostModel, Provider):
),
)
mfa_support = models.BooleanField(
default=True,
verbose_name="MFA Support",
help_text=_(
"When enabled, code-based multi-factor authentication can be used by appending a "
"semicolon and the TOTP code to the password. This should only be enabled if all "
"users that will bind to this provider have a TOTP device configured, as otherwise "
"a password may incorrectly be rejected if it contains a semicolon."
),
)
@property
def launch_url(self) -> Optional[str]:
"""Radius never has a launch URL"""

View File

@ -46,9 +46,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
def to_scim(self, obj: Group) -> SCIMGroupSchema:
"""Convert authentik user into SCIM"""
raw_scim_group = {
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:Group",),
}
raw_scim_group = {}
for mapping in (
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
):

View File

@ -15,14 +15,12 @@ from pydanticscim.user import User as BaseUser
class User(BaseUser):
"""Modified User schema with added externalId field"""
schemas: tuple[str] = ("urn:ietf:params:scim:schemas:core:2.0:User",)
externalId: Optional[str] = None
class Group(BaseGroup):
"""Modified Group schema with added externalId field"""
schemas: tuple[str] = ("urn:ietf:params:scim:schemas:core:2.0:Group",)
externalId: Optional[str] = None

View File

@ -39,9 +39,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
def to_scim(self, obj: User) -> SCIMUserSchema:
"""Convert authentik user into SCIM"""
raw_scim_user = {
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:User",),
}
raw_scim_user = {}
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
if not isinstance(mapping, SCIMMapping):
continue

View File

@ -1,12 +1,13 @@
"""SCIM task Settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"providers_scim_sync": {
"task": "authentik.providers.scim.tasks.scim_sync_all",
"schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
}

View File

@ -61,11 +61,7 @@ class SCIMGroupTests(TestCase):
self.assertEqual(mock.request_history[1].method, "POST")
self.assertJSONEqual(
mock.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
{"externalId": str(group.pk), "displayName": group.name},
)
@Mocker()
@ -100,11 +96,7 @@ class SCIMGroupTests(TestCase):
validate(body, loads(schema.read()))
self.assertEqual(
body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
{"externalId": str(group.pk), "displayName": group.name},
)
group.save()
self.assertEqual(mock.call_count, 4)
@ -137,11 +129,7 @@ class SCIMGroupTests(TestCase):
self.assertEqual(mock.request_history[1].method, "POST")
self.assertJSONEqual(
mock.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
{"externalId": str(group.pk), "displayName": group.name},
)
group.delete()
self.assertEqual(mock.call_count, 4)

View File

@ -89,7 +89,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual(
mocker.request_history[3].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [],
"active": True,
"externalId": user.uid,
@ -100,11 +99,7 @@ class SCIMMembershipTests(TestCase):
)
self.assertJSONEqual(
mocker.request_history[5].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
{"externalId": str(group.pk), "displayName": group.name},
)
with Mocker() as mocker:
@ -123,7 +118,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
@ -131,6 +125,7 @@ class SCIMMembershipTests(TestCase):
"value": [{"value": user_scim_id}],
}
],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
},
)
@ -179,7 +174,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual(
mocker.request_history[3].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"displayName": "",
"emails": [],
@ -190,11 +184,7 @@ class SCIMMembershipTests(TestCase):
)
self.assertJSONEqual(
mocker.request_history[5].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
{"externalId": str(group.pk), "displayName": group.name},
)
with Mocker() as mocker:
@ -213,7 +203,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
@ -221,6 +210,7 @@ class SCIMMembershipTests(TestCase):
"value": [{"value": user_scim_id}],
}
],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
},
)
@ -240,7 +230,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "remove",
@ -248,5 +237,6 @@ class SCIMMembershipTests(TestCase):
"value": [{"value": user_scim_id}],
}
],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
},
)

View File

@ -57,7 +57,7 @@ class SCIMUserTests(TestCase):
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
name=uid,
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 2)
@ -66,7 +66,6 @@ class SCIMUserTests(TestCase):
self.assertJSONEqual(
mock.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
@ -77,11 +76,11 @@ class SCIMUserTests(TestCase):
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"familyName": "",
"formatted": uid,
"givenName": uid,
},
"displayName": f"{uid} {uid}",
"displayName": uid,
"userName": uid,
},
)
@ -110,7 +109,7 @@ class SCIMUserTests(TestCase):
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
name=uid,
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 2)
@ -122,7 +121,6 @@ class SCIMUserTests(TestCase):
self.assertEqual(
body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
@ -131,11 +129,11 @@ class SCIMUserTests(TestCase):
"value": f"{uid}@goauthentik.io",
}
],
"displayName": f"{uid} {uid}",
"displayName": uid,
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"familyName": "",
"formatted": uid,
"givenName": uid,
},
"userName": uid,
@ -166,7 +164,7 @@ class SCIMUserTests(TestCase):
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
name=uid,
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 2)
@ -175,7 +173,6 @@ class SCIMUserTests(TestCase):
self.assertJSONEqual(
mock.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
@ -186,11 +183,11 @@ class SCIMUserTests(TestCase):
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"familyName": "",
"formatted": uid,
"givenName": uid,
},
"displayName": f"{uid} {uid}",
"displayName": uid,
"userName": uid,
},
)
@ -230,7 +227,7 @@ class SCIMUserTests(TestCase):
)
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
name=uid,
email=f"{uid}@goauthentik.io",
)
@ -243,7 +240,6 @@ class SCIMUserTests(TestCase):
self.assertJSONEqual(
mock.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
@ -254,11 +250,11 @@ class SCIMUserTests(TestCase):
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"familyName": "",
"formatted": uid,
"givenName": uid,
},
"displayName": f"{uid} {uid}",
"displayName": uid,
"userName": uid,
},
)

View File

@ -32,19 +32,13 @@ class PermissionSerializer(ModelSerializer):
def get_app_label_verbose(self, instance: Permission) -> str:
"""Human-readable app label"""
try:
return apps.get_app_config(instance.content_type.app_label).verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
def get_model_verbose(self, instance: Permission) -> str:
"""Human-readable model name"""
try:
return apps.get_model(
instance.content_type.app_label, instance.content_type.model
)._meta.verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
class Meta:
model = Permission

View File

@ -24,19 +24,13 @@ class ExtraRoleObjectPermissionSerializer(RoleObjectPermissionSerializer):
def get_app_label_verbose(self, instance: GroupObjectPermission) -> str:
"""Get app label from permission's model"""
try:
return apps.get_app_config(instance.content_type.app_label).verbose_name
except LookupError:
return instance.content_type.app_label
def get_model_verbose(self, instance: GroupObjectPermission) -> str:
"""Get model label from permission's model"""
try:
return apps.get_model(
instance.content_type.app_label, instance.content_type.model
)._meta.verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
def get_object_description(self, instance: GroupObjectPermission) -> Optional[str]:
"""Get model description from attached model. This operation takes at least
@ -44,10 +38,7 @@ class ExtraRoleObjectPermissionSerializer(RoleObjectPermissionSerializer):
view_ permission on the object"""
app_label = instance.content_type.app_label
model = instance.content_type.model
try:
model_class = apps.get_model(app_label, model)
except LookupError:
return None
objects = get_objects_for_group(instance.group, f"{app_label}.view_{model}", model_class)
obj = objects.first()
if not obj:

View File

@ -24,19 +24,13 @@ class ExtraUserObjectPermissionSerializer(UserObjectPermissionSerializer):
def get_app_label_verbose(self, instance: UserObjectPermission) -> str:
"""Get app label from permission's model"""
try:
return apps.get_app_config(instance.content_type.app_label).verbose_name
except LookupError:
return instance.content_type.app_label
def get_model_verbose(self, instance: UserObjectPermission) -> str:
"""Get model label from permission's model"""
try:
return apps.get_model(
instance.content_type.app_label, instance.content_type.model
)._meta.verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
def get_object_description(self, instance: UserObjectPermission) -> Optional[str]:
"""Get model description from attached model. This operation takes at least
@ -44,10 +38,7 @@ class ExtraUserObjectPermissionSerializer(UserObjectPermissionSerializer):
view_ permission on the object"""
app_label = instance.content_type.app_label
model = instance.content_type.model
try:
model_class = apps.get_model(app_label, model)
except LookupError:
return None
objects = get_objects_for_user(instance.user, f"{app_label}.view_{model}", model_class)
obj = objects.first()
if not obj:

View File

@ -339,17 +339,23 @@ CELERY = {
"clean_expired_models": {
"task": "authentik.core.tasks.clean_expired_models",
"schedule": crontab(minute="2-59/5"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
"user_cleanup": {
"task": "authentik.core.tasks.clean_temporary_users",
"schedule": crontab(minute="9-59/5"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
},
"task_create_missing_queues": True,
"task_default_queue": "authentik",
"task_default_priority": 4,
"task_inherit_parent_priority": True,
"broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
"broker_transport_options": {
"queue_order_strategy": "priority",
"priority_steps": list(range(10)),
},
"result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
}

View File

@ -1,12 +1,13 @@
"""LDAP Settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"sources_ldap_sync": {
"task": "authentik.sources.ldap.tasks.ldap_sync_all",
"schedule": crontab(minute=fqdn_rand("sources_ldap_sync"), hour="*/2"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}
}

View File

@ -58,10 +58,12 @@ def ldap_sync_single(source_pk: str):
group(
ldap_sync_paginator(source, UserLDAPSynchronizer)
+ ldap_sync_paginator(source, GroupLDAPSynchronizer),
priority=CONFIG.get_int("worker.priority.sync"),
),
# Membership sync needs to run afterwards
group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer),
priority=CONFIG.get_int("worker.priority.sync"),
),
)
task()

View File

@ -30,8 +30,6 @@ class SourceTypeSerializer(PassiveSerializer):
authorization_url = CharField(read_only=True, allow_null=True)
access_token_url = CharField(read_only=True, allow_null=True)
profile_url = CharField(read_only=True, allow_null=True)
oidc_well_known_url = CharField(read_only=True, allow_null=True)
oidc_jwks_url = CharField(read_only=True, allow_null=True)
class OAuthSourceSerializer(SourceSerializer):
@ -54,15 +52,11 @@ class OAuthSourceSerializer(SourceSerializer):
@extend_schema_field(SourceTypeSerializer)
def get_type(self, instance: OAuthSource) -> SourceTypeSerializer:
"""Get source's type configuration"""
return SourceTypeSerializer(instance.source_type).data
return SourceTypeSerializer(instance.type).data
def validate(self, attrs: dict) -> dict:
session = get_http_session()
source_type = registry.find_type(attrs["provider_type"])
well_known = attrs.get("oidc_well_known_url") or source_type.oidc_well_known_url
inferred_oidc_jwks_url = None
well_known = attrs.get("oidc_well_known_url")
if well_known and well_known != "":
try:
well_known_config = session.get(well_known)
@ -71,23 +65,24 @@ class OAuthSourceSerializer(SourceSerializer):
text = exc.response.text if exc.response else str(exc)
raise ValidationError({"oidc_well_known_url": text})
config = well_known_config.json()
if "issuer" not in config:
raise ValidationError({"oidc_well_known_url": "Invalid well-known configuration"})
attrs["authorization_url"] = config.get("authorization_endpoint", "")
attrs["access_token_url"] = config.get("token_endpoint", "")
attrs["profile_url"] = config.get("userinfo_endpoint", "")
inferred_oidc_jwks_url = config.get("jwks_uri", "")
try:
attrs["authorization_url"] = config["authorization_endpoint"]
attrs["access_token_url"] = config["token_endpoint"]
attrs["profile_url"] = config["userinfo_endpoint"]
attrs["oidc_jwks_url"] = config["jwks_uri"]
except (IndexError, KeyError) as exc:
raise ValidationError(
{"oidc_well_known_url": f"Invalid well-known configuration: {exc}"}
)
# Prefer user-entered URL to inferred URL to default URL
jwks_url = attrs.get("oidc_jwks_url") or inferred_oidc_jwks_url or source_type.oidc_jwks_url
jwks_url = attrs.get("oidc_jwks_url")
if jwks_url and jwks_url != "":
attrs["oidc_jwks_url"] = jwks_url
try:
jwks_config = session.get(jwks_url)
jwks_config.raise_for_status()
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
raise ValidationError({"oidc_jwks_url": text})
raise ValidationError({"jwks_url": text})
config = jwks_config.json()
attrs["oidc_jwks"] = config

View File

@ -36,8 +36,8 @@ class BaseOAuthClient:
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"""Fetch user profile information."""
profile_url = self.source.source_type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url:
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
response = self.do_request("get", profile_url, token=token)
try:
@ -57,8 +57,8 @@ class BaseOAuthClient:
def get_redirect_url(self, parameters=None):
"""Build authentication redirect url."""
authorization_url = self.source.source_type.authorization_url or ""
if self.source.source_type.urls_customizable and self.source.authorization_url:
authorization_url = self.source.type.authorization_url or ""
if self.source.type.urls_customizable and self.source.authorization_url:
authorization_url = self.source.authorization_url
if authorization_url == "":
Event.new(

View File

@ -28,8 +28,8 @@ class OAuthClient(BaseOAuthClient):
if raw_token is not None and verifier is not None:
token = self.parse_raw_token(raw_token)
try:
access_token_url = self.source.source_type.access_token_url or ""
if self.source.source_type.urls_customizable and self.source.access_token_url:
access_token_url = self.source.type.access_token_url or ""
if self.source.type.urls_customizable and self.source.access_token_url:
access_token_url = self.source.access_token_url
response = self.do_request(
"post",
@ -54,8 +54,8 @@ class OAuthClient(BaseOAuthClient):
"""Fetch the OAuth request token. Only required for OAuth 1.0."""
callback = self.request.build_absolute_uri(self.callback)
try:
request_token_url = self.source.source_type.request_token_url or ""
if self.source.source_type.urls_customizable and self.source.request_token_url:
request_token_url = self.source.type.request_token_url or ""
if self.source.type.urls_customizable and self.source.request_token_url:
request_token_url = self.source.request_token_url
response = self.do_request(
"post",

View File

@ -76,8 +76,8 @@ class OAuth2Client(BaseOAuthClient):
if SESSION_KEY_OAUTH_PKCE in self.request.session:
args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
try:
access_token_url = self.source.source_type.access_token_url or ""
if self.source.source_type.urls_customizable and self.source.access_token_url:
access_token_url = self.source.type.access_token_url or ""
if self.source.type.urls_customizable and self.source.access_token_url:
access_token_url = self.source.access_token_url
response = self.session.request(
"post", access_token_url, data=args, headers=self._default_headers, **request_kwargs
@ -140,8 +140,8 @@ class UserprofileHeaderAuthClient(OAuth2Client):
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
profile_url = self.source.source_type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url:
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
response = self.session.request(
"get",

View File

@ -1,5 +1,5 @@
"""OAuth Client models"""
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Type
from django.db import models
from django.http.request import HttpRequest
@ -55,7 +55,7 @@ class OAuthSource(Source):
oidc_jwks = models.JSONField(default=dict, blank=True)
@property
def source_type(self) -> type["SourceType"]:
def type(self) -> type["SourceType"]:
"""Return the provider instance for this source"""
from authentik.sources.oauth.types.registry import registry
@ -65,14 +65,15 @@ class OAuthSource(Source):
def component(self) -> str:
return "ak-source-oauth-form"
# we're using Type[] instead of type[] here since type[] interferes with the property above
@property
def serializer(self) -> type[Serializer]:
def serializer(self) -> Type[Serializer]:
from authentik.sources.oauth.api.source import OAuthSourceSerializer
return OAuthSourceSerializer
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
provider_type = self.source_type
provider_type = self.type
provider = provider_type()
icon = self.get_icon
if not icon:
@ -84,7 +85,7 @@ class OAuthSource(Source):
)
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
provider_type = self.source_type
provider_type = self.type
icon = self.get_icon
if not icon:
icon = provider_type().icon_url()

View File

@ -1,12 +0,0 @@
"""OAuth source settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"update_oauth_source_oidc_well_known": {
"task": "authentik.sources.oauth.tasks.update_well_known_jwks",
"schedule": crontab(minute=fqdn_rand("update_well_known_jwks"), hour="*/3"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,70 +0,0 @@
"""OAuth Source tasks"""
from json import dumps
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP
from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger()
@CELERY_APP.task(bind=True, base=MonitoredTask)
def update_well_known_jwks(self: MonitoredTask):
"""Update OAuth sources' config from well_known, and JWKS info from the configured URL"""
session = get_http_session()
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
try:
well_known_config = session.get(source.oidc_well_known_url)
well_known_config.raise_for_status()
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
LOGGER.warning("Failed to update well_known", source=source, exc=exc, text=text)
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
continue
config = well_known_config.json()
try:
dirty = False
source_attr_key = (
("authorization_url", "authorization_endpoint"),
("access_token_url", "token_endpoint"),
("profile_url", "userinfo_endpoint"),
("oidc_jwks_url", "jwks_uri"),
)
for source_attr, config_key in source_attr_key:
# Check if we're actually changing anything to only
# save when something has changed
if getattr(source, source_attr, "") != config[config_key]:
dirty = True
setattr(source, source_attr, config[config_key])
except (IndexError, KeyError) as exc:
LOGGER.warning(
"Failed to update well_known",
source=source,
exc=exc,
)
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
continue
if dirty:
LOGGER.info("Updating sources' OpenID Configuration", source=source)
source.save()
for source in OAuthSource.objects.all().exclude(oidc_jwks_url=""):
try:
jwks_config = session.get(source.oidc_jwks_url)
jwks_config.raise_for_status()
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
LOGGER.warning("Failed to update JWKS", source=source, exc=exc, text=text)
result.messages.append(f"Failed to update JWKS for {source.slug}")
continue
config = jwks_config.json()
if dumps(source.oidc_jwks, sort_keys=True) != dumps(config, sort_keys=True):
source.oidc_jwks = config
LOGGER.info("Updating sources' JWKS", source=source)
source.save()
self.set_status(result)

View File

@ -1,48 +0,0 @@
"""Test OAuth Source tasks"""
from django.test import TestCase
from requests_mock import Mocker
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.tasks import update_well_known_jwks
class TestOAuthSourceTasks(TestCase):
"""Test OAuth Source tasks"""
def setUp(self) -> None:
self.source = OAuthSource.objects.create(
name="test",
slug="test",
provider_type="openidconnect",
authorization_url="",
profile_url="",
consumer_key="",
)
@Mocker()
def test_well_known_jwks(self, mock: Mocker):
"""Test well_known update"""
self.source.oidc_well_known_url = "http://foo/.well-known/openid-configuration"
self.source.save()
mock.get(
self.source.oidc_well_known_url,
json={
"authorization_endpoint": "foo",
"token_endpoint": "foo",
"userinfo_endpoint": "foo",
"jwks_uri": "http://foo/jwks",
},
)
mock.get("http://foo/jwks", json={"foo": "bar"})
update_well_known_jwks() # pylint: disable=no-value-for-parameter
self.source.refresh_from_db()
self.assertEqual(self.source.authorization_url, "foo")
self.assertEqual(self.source.access_token_url, "foo")
self.assertEqual(self.source.profile_url, "foo")
self.assertEqual(self.source.oidc_jwks_url, "http://foo/jwks")
self.assertEqual(
self.source.oidc_jwks,
{
"foo": "bar",
},
)

View File

@ -50,7 +50,6 @@ class TestOAuthSource(TestCase):
def test_api_validate_openid_connect(self):
"""Test API validation (with OIDC endpoints)"""
openid_config = {
"issuer": "foo",
"authorization_endpoint": "http://mock/oauth/authorize",
"token_endpoint": "http://mock/oauth/token",
"userinfo_endpoint": "http://mock/oauth/userinfo",

View File

@ -4,8 +4,8 @@ from typing import Any
from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger()
@ -20,7 +20,7 @@ class AzureADOAuthRedirect(OAuthRedirect):
}
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback"""
client_class = UserprofileHeaderAuthClient
@ -50,8 +50,4 @@ class AzureADType(SourceType):
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
profile_url = "https://login.microsoftonline.com/common/openid/userinfo"
oidc_well_known_url = (
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
)
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
profile_url = "https://graph.microsoft.com/v1.0/me"

View File

@ -23,8 +23,8 @@ class GitHubOAuth2Client(OAuth2Client):
def get_github_emails(self, token: dict[str, str]) -> list[dict[str, Any]]:
"""Get Emails from the GitHub API"""
profile_url = self.source.source_type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url:
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
profile_url += "/emails"
response = self.do_request("get", profile_url, token=token)
@ -76,7 +76,3 @@ class GitHubType(SourceType):
authorization_url = "https://github.com/login/oauth/authorize"
access_token_url = "https://github.com/login/oauth/access_token" # nosec
profile_url = "https://api.github.com/user"
oidc_well_known_url = (
"https://token.actions.githubusercontent.com/.well-known/openid-configuration"
)
oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks"

View File

@ -40,5 +40,3 @@ class GoogleType(SourceType):
authorization_url = "https://accounts.google.com/o/oauth2/auth"
access_token_url = "https://oauth2.googleapis.com/token" # nosec
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"
oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration"
oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs"

View File

@ -26,8 +26,8 @@ class MailcowOAuth2Client(OAuth2Client):
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
profile_url = self.source.source_type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url:
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
response = self.session.request(
"get",

View File

@ -23,7 +23,7 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", None)
return info.get("sub", "")
def get_user_enroll_context(
self,

View File

@ -3,8 +3,8 @@ from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -17,7 +17,7 @@ class OktaOAuthRedirect(OAuthRedirect):
}
class OktaOAuth2Callback(OpenIDConnectOAuth2Callback):
class OktaOAuth2Callback(OAuthCallback):
"""Okta OAuth2 Callback"""
# Okta has the same quirk as azure and throws an error if the access token
@ -25,6 +25,9 @@ class OktaOAuth2Callback(OpenIDConnectOAuth2Callback):
# see https://github.com/goauthentik/authentik/issues/1910
client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context(
self,
info: dict[str, Any],

View File

@ -12,9 +12,8 @@ class PatreonOAuthRedirect(OAuthRedirect):
"""Patreon OAuth2 Redirect"""
def get_additional_parameters(self, source: OAuthSource): # pragma: no cover
# https://docs.patreon.com/#scopes
return {
"scope": ["identity", "identity[email]"],
"scope": ["openid", "email", "profile"],
}

View File

@ -36,8 +36,6 @@ class SourceType:
authorization_url: Optional[str] = None
access_token_url: Optional[str] = None
profile_url: Optional[str] = None
oidc_well_known_url: Optional[str] = None
oidc_jwks_url: Optional[str] = None
def icon_url(self) -> str:
"""Get Icon URL for login"""

View File

@ -3,8 +3,8 @@ from json import dumps
from typing import Any, Optional
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -27,11 +27,14 @@ class TwitchOAuthRedirect(OAuthRedirect):
}
class TwitchOAuth2Callback(OpenIDConnectOAuth2Callback):
class TwitchOAuth2Callback(OAuthCallback):
"""Twitch OAuth2 Callback"""
client_class = TwitchClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context(
self,
info: dict[str, Any],

View File

@ -25,7 +25,7 @@ class OAuthClientMixin:
if self.client_class is not None:
# pylint: disable=not-callable
return self.client_class(source, self.request, **kwargs)
if source.source_type.request_token_url or source.request_token_url:
if source.type.request_token_url or source.request_token_url:
client = OAuthClient(source, self.request, **kwargs)
else:
client = OAuth2Client(source, self.request, **kwargs)

View File

@ -1,12 +1,13 @@
"""Plex source settings"""
from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"check_plex_token": {
"task": "authentik.sources.plex.tasks.check_plex_token_all",
"schedule": crontab(minute=fqdn_rand("check_plex_token"), hour="*/3"),
"options": {"queue": "authentik_scheduled"},
"options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
},
}

View File

@ -12,6 +12,7 @@ from authentik.flows.challenge import (
Challenge,
ChallengeResponse,
ChallengeTypes,
ErrorDetailSerializer,
WithUserInfoChallenge,
)
from authentik.flows.stage import ChallengeStageView
@ -23,7 +24,6 @@ from authentik.stages.authenticator_sms.models import (
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
SESSION_KEY_SMS_DEVICE = "authentik/stages/authenticator_sms/sms_device"
PLAN_CONTEXT_PHONE = "phone"
class AuthenticatorSMSChallenge(WithUserInfoChallenge):
@ -48,8 +48,6 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse):
def validate(self, attrs: dict) -> dict:
"""Check"""
if "code" not in attrs:
if "phone_number" not in attrs:
raise ValidationError("phone_number required")
self.device.phone_number = attrs["phone_number"]
self.stage.validate_and_send(attrs["phone_number"])
return super().validate(attrs)
@ -77,9 +75,9 @@ class AuthenticatorSMSStageView(ChallengeStageView):
def _has_phone_number(self) -> Optional[str]:
context = self.executor.plan.context
if PLAN_CONTEXT_PHONE in context.get(PLAN_CONTEXT_PROMPT, {}):
if "phone" in context.get(PLAN_CONTEXT_PROMPT, {}):
self.logger.debug("got phone number from plan context")
return context.get(PLAN_CONTEXT_PROMPT, {}).get(PLAN_CONTEXT_PHONE)
return context.get(PLAN_CONTEXT_PROMPT, {}).get("phone")
if SESSION_KEY_SMS_DEVICE in self.request.session:
self.logger.debug("got phone number from device in session")
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
@ -115,17 +113,10 @@ class AuthenticatorSMSStageView(ChallengeStageView):
try:
self.validate_and_send(phone_number)
except ValidationError as exc:
# We had a phone number given already (at this point only possible from flow
# context), but an error occurred while sending a number (most likely)
# due to a duplicate device, so delete the number we got given, reset the state
# (ish) and retry
device.phone_number = ""
self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}).pop(
PLAN_CONTEXT_PHONE, None
)
self.request.session.pop(SESSION_KEY_SMS_DEVICE, None)
self.logger.warning("failed to send SMS message to pre-set number", exc=exc)
return self.get(request, *args, **kwargs)
response = AuthenticatorSMSChallengeResponse()
response._errors.setdefault("phone_number", [])
response._errors["phone_number"].append(ErrorDetailSerializer(exc.detail))
return self.challenge_invalid(response)
return super().get(request, *args, **kwargs)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:

Some files were not shown because too many files have changed in this diff Show More