Merge branch 'main' into application-wizard-2-with-api-and-tests

* main: (184 commits)
  web/admin: user details few tooltip buttons (#6899)
  website/blogs: added emm dashes (#6927)
  core: bump goauthentik.io/api/v3 from 3.2023083.2 to 3.2023083.3 (#6925)
  core: bump ruff from 0.0.289 to 0.0.290 (#6923)
  web: bump the babel group in /web with 2 updates (#6919)
  web: bump the storybook group in /web with 5 updates (#6920)
  web: bump rollup from 3.29.1 to 3.29.2 in /web (#6921)
  core: bump pycryptodome from 3.18.0 to 3.19.0 (#6922)
  core: bump django-filter from 23.2 to 23.3 (#6924)
  core: bump github.com/go-ldap/ldap/v3 from 3.4.5 to 3.4.6 (#6926)
  web: bump API Client version (#6918)
  core: create app transactional api (#6446)
  sources/ldap: add warning when a property mapping returns None or bytes (#6913)
  website: replace login card with png (#6911)
  core: compile backend translations (#6912)
  translate: Updates for file locale/en/LC_MESSAGES/django.po in zh-Hans on branch main (#6910)
  translate: Updates for file locale/en/LC_MESSAGES/django.po in zh_CN on branch main (#6907)
  translate: Updates for file web/xliff/en.xlf in zh_CN on branch main (#6908)
  translate: Updates for file web/xliff/en.xlf in zh-Hans on branch main (#6909)
  web/admin: fix webauthn label order, add raw value (#6905)
  ...
This commit is contained in:
Ken Sternberg 2023-09-19 07:29:25 -07:00
commit 53f89ef2f8
246 changed files with 8723 additions and 6788 deletions

View file

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2023.8.1 current_version = 2023.8.3
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View file

@ -5,6 +5,9 @@ htmlcov
dist/** dist/**
build/** build/**
build_docs/** build_docs/**
Dockerfile *Dockerfile
authentik/enterprise authentik/enterprise
blueprints/local blueprints/local
.git
!gen-ts-api/node_modules
!gen-ts-api/dist/**

View file

@ -23,7 +23,7 @@ runs:
- name: Setup node - name: Setup node
uses: actions/setup-node@v3 uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Setup dependencies - name: Setup dependencies

View file

@ -33,7 +33,7 @@ jobs:
- ruff - ruff
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- name: run job - name: run job
@ -41,7 +41,7 @@ jobs:
test-migrations: test-migrations:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- name: run migrations - name: run migrations
@ -50,7 +50,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
continue-on-error: true continue-on-error: true
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup authentik env - name: Setup authentik env
@ -91,7 +91,7 @@ jobs:
- 12-alpine - 12-alpine
- 15-alpine - 15-alpine
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
with: with:
@ -108,7 +108,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- name: Create k8s Kind Cluster - name: Create k8s Kind Cluster
@ -144,7 +144,7 @@ jobs:
- name: flows - name: flows
glob: tests/e2e/test_flows* glob: tests/e2e/test_flows*
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- name: Setup e2e env (chrome, etc) - name: Setup e2e env (chrome, etc)
@ -186,28 +186,31 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
uses: ./.github/actions/docker-push-variables uses: ./.github/actions/docker-push-variables
id: ev id: ev
env: env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
- name: Login to Container Registry - name: Login to Container Registry
uses: docker/login-action@v2 uses: docker/login-action@v3
if: ${{ steps.ev.outputs.shouldBuild == 'true' }} if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: generate ts client
run: make gen-client-ts
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v4 uses: docker/build-push-action@v5
with: with:
context: .
secrets: | secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
@ -220,6 +223,8 @@ jobs:
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
VERSION=${{ steps.ev.outputs.version }} VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Comment on PR - name: Comment on PR
if: github.event_name == 'pull_request' if: github.event_name == 'pull_request'
continue-on-error: true continue-on-error: true
@ -231,28 +236,31 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
uses: ./.github/actions/docker-push-variables uses: ./.github/actions/docker-push-variables
id: ev id: ev
env: env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
- name: Login to Container Registry - name: Login to Container Registry
uses: docker/login-action@v2 uses: docker/login-action@v3
if: ${{ steps.ev.outputs.shouldBuild == 'true' }} if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: generate ts client
run: make gen-client-ts
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v4 uses: docker/build-push-action@v5
with: with:
context: .
secrets: | secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
@ -266,3 +274,5 @@ jobs:
VERSION=${{ steps.ev.outputs.version }} VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
platforms: linux/arm64 platforms: linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=max

View file

@ -14,7 +14,7 @@ jobs:
lint-golint: lint-golint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
@ -31,11 +31,11 @@ jobs:
with: with:
version: v1.52.2 version: v1.52.2
args: --timeout 5000s --verbose args: --timeout 5000s --verbose
skip-pkg-cache: true skip-cache: true
test-unittest: test-unittest:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
@ -64,20 +64,20 @@ jobs:
- radius - radius
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
uses: ./.github/actions/docker-push-variables uses: ./.github/actions/docker-push-variables
id: ev id: ev
env: env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
- name: Login to Container Registry - name: Login to Container Registry
uses: docker/login-action@v2 uses: docker/login-action@v3
if: ${{ steps.ev.outputs.shouldBuild == 'true' }} if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with: with:
registry: ghcr.io registry: ghcr.io
@ -86,7 +86,7 @@ jobs:
- name: Generate API - name: Generate API
run: make gen-client-go run: make gen-client-go
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v4 uses: docker/build-push-action@v5
with: with:
push: ${{ steps.ev.outputs.shouldBuild == 'true' }} push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
tags: | tags: |
@ -99,6 +99,8 @@ jobs:
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
context: . context: .
cache-from: type=gha
cache-to: type=gha,mode=max
build-binary: build-binary:
timeout-minutes: 120 timeout-minutes: 120
needs: needs:
@ -114,15 +116,15 @@ jobs:
goos: [linux] goos: [linux]
goarch: [amd64, arm64] goarch: [amd64, arm64]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Generate API - name: Generate API

View file

@ -14,10 +14,10 @@ jobs:
lint-eslint: lint-eslint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -30,10 +30,10 @@ jobs:
lint-build: lint-build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -46,10 +46,10 @@ jobs:
lint-prettier: lint-prettier:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -62,10 +62,10 @@ jobs:
lint-lit-analyse: lint-lit-analyse:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -94,10 +94,10 @@ jobs:
- ci-web-mark - ci-web-mark
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/

View file

@ -14,10 +14,10 @@ jobs:
lint-prettier: lint-prettier:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/
@ -28,10 +28,10 @@ jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/
@ -49,10 +49,10 @@ jobs:
- build - build
- build-docs-only - build-docs-only
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/

View file

@ -23,7 +23,7 @@ jobs:
language: ["go", "javascript", "python"] language: ["go", "javascript", "python"]
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- name: Initialize CodeQL - name: Initialize CodeQL

View file

@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out code - name: Check out code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Cleanup - name: Cleanup
run: | run: |

View file

@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- id: generate_token - id: generate_token
uses: tibdex/github-app-token@v1 uses: tibdex/github-app-token@v2
with: with:
app_id: ${{ secrets.GH_APP_ID }} app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}

View file

@ -29,11 +29,11 @@ jobs:
github.event.pull_request.head.repo.full_name == github.repository) github.event.pull_request.head.repo.full_name == github.repository)
steps: steps:
- id: generate_token - id: generate_token
uses: tibdex/github-app-token@v1 uses: tibdex/github-app-token@v2
with: with:
app_id: ${{ secrets.GH_APP_ID }} app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
- name: Compress images - name: Compress images

View file

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- name: generate docs - name: generate docs

View file

@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
environment: internal-production environment: internal-production
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
ref: main ref: main
- run: | - run: |

View file

@ -8,28 +8,31 @@ jobs:
build-server: build-server:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
uses: ./.github/actions/docker-push-variables uses: ./.github/actions/docker-push-variables
id: ev id: ev
- name: Docker Login Registry - name: Docker Login Registry
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKER_USERNAME }} username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }} password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: make empty ts client
run: mkdir -p ./gen-ts-client
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v4 uses: docker/build-push-action@v5
with: with:
context: .
push: ${{ github.event_name == 'release' }} push: ${{ github.event_name == 'release' }}
secrets: | secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
@ -55,30 +58,30 @@ jobs:
- ldap - ldap
- radius - radius
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
uses: ./.github/actions/docker-push-variables uses: ./.github/actions/docker-push-variables
id: ev id: ev
- name: Docker Login Registry - name: Docker Login Registry
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKER_USERNAME }} username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }} password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@v2 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v4 uses: docker/build-push-action@v5
with: with:
push: ${{ github.event_name == 'release' }} push: ${{ github.event_name == 'release' }}
tags: | tags: |
@ -106,13 +109,13 @@ jobs:
goos: [linux, darwin] goos: [linux, darwin]
goarch: [amd64, arm64] goarch: [amd64, arm64]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Build web - name: Build web
@ -141,7 +144,7 @@ jobs:
- build-outpost-binary - build-outpost-binary
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Run test suite in final docker images - name: Run test suite in final docker images
run: | run: |
echo "PG_PASS=$(openssl rand -base64 32)" >> .env echo "PG_PASS=$(openssl rand -base64 32)" >> .env
@ -157,7 +160,7 @@ jobs:
- build-outpost-binary - build-outpost-binary
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: prepare variables - name: prepare variables
uses: ./.github/actions/docker-push-variables uses: ./.github/actions/docker-push-variables
id: ev id: ev

View file

@ -10,7 +10,7 @@ jobs:
name: Create Release from Tag name: Create Release from Tag
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Pre-release test - name: Pre-release test
run: | run: |
echo "PG_PASS=$(openssl rand -base64 32)" >> .env echo "PG_PASS=$(openssl rand -base64 32)" >> .env
@ -23,7 +23,7 @@ jobs:
docker-compose start postgresql redis docker-compose start postgresql redis
docker-compose run -u root server test-all docker-compose run -u root server test-all
- id: generate_token - id: generate_token
uses: tibdex/github-app-token@v1 uses: tibdex/github-app-token@v2
with: with:
app_id: ${{ secrets.GH_APP_ID }} app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- id: generate_token - id: generate_token
uses: tibdex/github-app-token@v1 uses: tibdex/github-app-token@v2
with: with:
app_id: ${{ secrets.GH_APP_ID }} app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}

View file

@ -16,11 +16,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- id: generate_token - id: generate_token
uses: tibdex/github-app-token@v1 uses: tibdex/github-app-token@v2
with: with:
app_id: ${{ secrets.GH_APP_ID }} app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
- name: Setup authentik env - name: Setup authentik env

View file

@ -1,4 +1,5 @@
# Rename transifex pull requests to have a correct naming # Rename transifex pull requests to have a correct naming
# Also enables auto squash-merge
name: authentik-translation-transifex-rename name: authentik-translation-transifex-rename
on: on:
@ -11,7 +12,7 @@ jobs:
if: ${{ github.event.pull_request.user.login == 'transifex-integration[bot]'}} if: ${{ github.event.pull_request.user.login == 'transifex-integration[bot]'}}
steps: steps:
- id: generate_token - id: generate_token
uses: tibdex/github-app-token@v1 uses: tibdex/github-app-token@v2
with: with:
app_id: ${{ secrets.GH_APP_ID }} app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
@ -37,3 +38,8 @@ jobs:
-H "X-GitHub-Api-Version: 2022-11-28" \ -H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} \ https://api.github.com/repos/${GITHUB_REPOSITORY}/pulls/${{ github.event.pull_request.number }} \
-d "{\"title\":\"translate: ${{ steps.title.outputs.title }}\"}" -d "{\"title\":\"translate: ${{ steps.title.outputs.title }}\"}"
- uses: peter-evans/enable-pull-request-automerge@v3
with:
token: ${{ steps.generate_token.outputs.token }}
pull-request-number: ${{ github.event.pull_request.number }}
merge-method: squash

View file

@ -10,16 +10,16 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- id: generate_token - id: generate_token
uses: tibdex/github-app-token@v1 uses: tibdex/github-app-token@v2
with: with:
app_id: ${{ secrets.GH_APP_ID }} app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
- uses: actions/setup-node@v3.8.1 - uses: actions/setup-node@v3
with: with:
node-version: "20" node-version: "20.5"
registry-url: "https://registry.npmjs.org" registry-url: "https://registry.npmjs.org"
- name: Generate API Client - name: Generate API Client
run: make gen-client-ts run: make gen-client-ts

View file

@ -1,2 +1,23 @@
* @goauthentik/core # Fallback
website/docs/security/** @goauthentik/security * @goauthentik/backend @goauthentik/frontend
# Backend
authentik/ @goauthentik/backend
blueprints/ @goauthentik/backend
cmd/ @goauthentik/backend
internal/ @goauthentik/backend
lifecycle/ @goauthentik/backend
schemas/ @goauthentik/backend
scripts/ @goauthentik/backend
tests/ @goauthentik/backend
# Infrastructure
.github/ @goauthentik/infrastructure
Dockerfile @goauthentik/infrastructure
*Dockerfile @goauthentik/infrastructure
.dockerignore @goauthentik/infrastructure
docker-compose.yml @goauthentik/infrastructure
# Web
web/ @goauthentik/frontend
# Docs & Website
website/ @goauthentik/docs
# Security
website/docs/security/ @goauthentik/security

View file

@ -1,53 +1,65 @@
# Stage 1: Build website # Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/node:20 as website-builder FROM --platform=${BUILDPLATFORM} docker.io/node:20.5 as website-builder
ENV NODE_ENV=production
WORKDIR /work/website
RUN --mount=type=bind,target=/work/website/package.json,src=./website/package.json \
--mount=type=bind,target=/work/website/package-lock.json,src=./website/package-lock.json \
--mount=type=cache,target=/root/.npm \
npm ci --include=dev
COPY ./website /work/website/ COPY ./website /work/website/
COPY ./blueprints /work/blueprints/ COPY ./blueprints /work/blueprints/
COPY ./SECURITY.md /work/ COPY ./SECURITY.md /work/
ENV NODE_ENV=production RUN npm run build-docs-only
WORKDIR /work/website
RUN npm ci --include=dev && npm run build-docs-only
# Stage 2: Build webui # Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/node:20 as web-builder FROM --platform=${BUILDPLATFORM} docker.io/node:20.5 as web-builder
ENV NODE_ENV=production
WORKDIR /work/web
RUN --mount=type=bind,target=/work/web/package.json,src=./web/package.json \
--mount=type=bind,target=/work/web/package-lock.json,src=./web/package-lock.json \
--mount=type=cache,target=/root/.npm \
npm ci --include=dev
COPY ./web /work/web/ COPY ./web /work/web/
COPY ./website /work/website/ COPY ./website /work/website/
COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
ENV NODE_ENV=production RUN npm run build
WORKDIR /work/web
RUN npm ci --include=dev && npm run build
# Stage 3: Poetry to requirements.txt export # Stage 3: Build go proxy
FROM docker.io/python:3.11.5-slim-bookworm AS poetry-locker FROM docker.io/golang:1.21.1-bookworm AS go-builder
WORKDIR /work WORKDIR /go/src/goauthentik.io
COPY ./pyproject.toml /work
COPY ./poetry.lock /work
RUN pip install --no-cache-dir poetry && \ RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
poetry export -f requirements.txt --output requirements.txt && \ --mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
poetry export -f requirements.txt --dev --output requirements-dev.txt --mount=type=cache,target=/go/pkg/mod \
go mod download
# Stage 4: Build go proxy COPY ./cmd /go/src/goauthentik.io/cmd
FROM docker.io/golang:1.21.0-bookworm AS go-builder COPY ./authentik/lib /go/src/goauthentik.io/authentik/lib
COPY ./web/static.go /go/src/goauthentik.io/web/static.go
COPY --from=web-builder /work/web/robots.txt /go/src/goauthentik.io/web/robots.txt
COPY --from=web-builder /work/web/security.txt /go/src/goauthentik.io/web/security.txt
COPY ./internal /go/src/goauthentik.io/internal
COPY ./go.mod /go/src/goauthentik.io/go.mod
COPY ./go.sum /go/src/goauthentik.io/go.sum
WORKDIR /work ENV CGO_ENABLED=0
COPY --from=web-builder /work/web/robots.txt /work/web/robots.txt RUN --mount=type=cache,target=/go/pkg/mod \
COPY --from=web-builder /work/web/security.txt /work/web/security.txt --mount=type=cache,target=/root/.cache/go-build \
go build -o /go/authentik ./cmd/server
COPY ./cmd /work/cmd # Stage 4: MaxMind GeoIP
COPY ./authentik/lib /work/authentik/lib
COPY ./web/static.go /work/web/static.go
COPY ./internal /work/internal
COPY ./go.mod /work/go.mod
COPY ./go.sum /work/go.sum
RUN go build -o /work/bin/authentik ./cmd/server/
# Stage 5: MaxMind GeoIP
FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City" ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City"
@ -61,6 +73,29 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
mkdir -p /usr/share/GeoIP && \ mkdir -p /usr/share/GeoIP && \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 5: Python dependencies
FROM docker.io/python:3.11.5-bookworm AS python-deps
WORKDIR /ak-root/poetry
ENV VENV_PATH="/ak-root/venv" \
POETRY_VIRTUALENVS_CREATE=false \
PATH="/ak-root/venv/bin:$PATH"
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && \
# Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev
RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
--mount=type=bind,target=./poetry.lock,src=./poetry.lock \
--mount=type=cache,target=/root/.cache/pip \
--mount=type=cache,target=/root/.cache/pypoetry \
python -m venv /ak-root/venv/ && \
pip3 install --upgrade pip && \
pip3 install poetry && \
poetry install --only=main --no-ansi --no-interaction
# Stage 6: Run # Stage 6: Run
FROM docker.io/python:3.11.5-slim-bookworm AS final-image FROM docker.io/python:3.11.5-slim-bookworm AS final-image
@ -76,46 +111,45 @@ LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
WORKDIR / WORKDIR /
COPY --from=poetry-locker /work/requirements.txt / # We cannot cache this layer otherwise we'll end up with a bigger image
COPY --from=poetry-locker /work/requirements-dev.txt /
COPY --from=geoip /usr/share/GeoIP /geoip
RUN apt-get update && \ RUN apt-get update && \
# Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev python3-dev && \
# Required for runtime # Required for runtime
apt-get install -y --no-install-recommends libpq5 openssl libxmlsec1-openssl libmaxminddb0 && \ apt-get install -y --no-install-recommends libpq5 openssl libxmlsec1-openssl libmaxminddb0 && \
# Required for bootstrap & healtcheck # Required for bootstrap & healtcheck
apt-get install -y --no-install-recommends runit && \ apt-get install -y --no-install-recommends runit && \
pip install --no-cache-dir -r /requirements.txt && \
apt-get remove --purge -y build-essential pkg-config libxmlsec1-dev libpq-dev python3-dev && \
apt-get autoremove --purge -y && \
apt-get clean && \ apt-get clean && \
rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \ rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \
adduser --system --no-create-home --uid 1000 --group --home /authentik authentik && \ adduser --system --no-create-home --uid 1000 --group --home /authentik authentik && \
mkdir -p /certs /media /blueprints && \ mkdir -p /certs /media /blueprints && \
mkdir -p /authentik/.ssh && \ mkdir -p /authentik/.ssh && \
chown authentik:authentik /certs /media /authentik/.ssh mkdir -p /ak-root && \
chown authentik:authentik /certs /media /authentik/.ssh /ak-root
COPY ./authentik/ /authentik COPY ./authentik/ /authentik
COPY ./pyproject.toml / COPY ./pyproject.toml /
COPY ./poetry.lock /
COPY ./schemas /schemas COPY ./schemas /schemas
COPY ./locale /locale COPY ./locale /locale
COPY ./tests /tests COPY ./tests /tests
COPY ./manage.py / COPY ./manage.py /
COPY ./blueprints /blueprints COPY ./blueprints /blueprints
COPY ./lifecycle/ /lifecycle COPY ./lifecycle/ /lifecycle
COPY --from=go-builder /work/bin/authentik /bin/authentik COPY --from=go-builder /go/authentik /bin/authentik
COPY --from=python-deps /ak-root/venv /ak-root/venv
COPY --from=web-builder /work/web/dist/ /web/dist/ COPY --from=web-builder /work/web/dist/ /web/dist/
COPY --from=web-builder /work/web/authentik/ /web/authentik/ COPY --from=web-builder /work/web/authentik/ /web/authentik/
COPY --from=website-builder /work/website/help/ /website/help/ COPY --from=website-builder /work/website/help/ /website/help/
COPY --from=geoip /usr/share/GeoIP /geoip
USER 1000 USER 1000
ENV TMPDIR /dev/shm/ ENV TMPDIR=/dev/shm/ \
ENV PYTHONUNBUFFERED 1 PYTHONDONTWRITEBYTECODE=1 \
ENV PATH "/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/lifecycle" PYTHONUNBUFFERED=1 \
PATH="/ak-root/venv/bin:$PATH" \
VENV_PATH="/ak-root/venv" \
POETRY_VIRTUALENVS_CREATE=false
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 CMD [ "/lifecycle/ak", "healthcheck" ] HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 CMD [ "/lifecycle/ak", "healthcheck" ]
ENTRYPOINT [ "/usr/local/bin/dumb-init", "--", "/lifecycle/ak" ] ENTRYPOINT [ "dumb-init", "--", "/lifecycle/ak" ]

View file

@ -174,8 +174,19 @@ website-build:
website-watch: website-watch:
cd website && npm run watch cd website && npm run watch
#########################
## Docker
#########################
docker:
DOCKER_BUILDKIT=1 docker build . --progress plain --tag authentik:test
#########################
## CI
#########################
# These targets are use by GitHub actions to allow usage of matrix # These targets are use by GitHub actions to allow usage of matrix
# which makes the YAML File a lot smaller # which makes the YAML File a lot smaller
ci--meta-debug: ci--meta-debug:
python -V python -V
node --version node --version

View file

@ -2,7 +2,7 @@
from os import environ from os import environ
from typing import Optional from typing import Optional
__version__ = "2023.8.1" __version__ = "2023.8.3"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View file

@ -12,7 +12,7 @@ from rest_framework.viewsets import ModelViewSet
from authentik.api.decorators import permission_required from authentik.api.decorators import permission_required
from authentik.blueprints.models import BlueprintInstance from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.oci import OCI_PREFIX from authentik.blueprints.v1.oci import OCI_PREFIX
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
@ -49,7 +49,7 @@ class BlueprintInstanceSerializer(ModelSerializer):
if content == "": if content == "":
return content return content
context = self.instance.context if self.instance else {} context = self.instance.context if self.instance else {}
valid, logs = StringImporter(content, context).validate() valid, logs = Importer.from_string(content, context).validate()
if not valid: if not valid:
text_logs = "\n".join([x["event"] for x in logs]) text_logs = "\n".join([x["event"] for x in logs])
raise ValidationError(_("Failed to validate blueprint: %(logs)s" % {"logs": text_logs})) raise ValidationError(_("Failed to validate blueprint: %(logs)s" % {"logs": text_logs}))

View file

@ -5,7 +5,7 @@ from django.core.management.base import BaseCommand, no_translations
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.blueprints.models import BlueprintInstance from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
LOGGER = get_logger() LOGGER = get_logger()
@ -18,7 +18,7 @@ class Command(BaseCommand):
"""Apply all blueprints in order, abort when one fails to import""" """Apply all blueprints in order, abort when one fails to import"""
for blueprint_path in options.get("blueprints", []): for blueprint_path in options.get("blueprints", []):
content = BlueprintInstance(path=blueprint_path).retrieve() content = BlueprintInstance(path=blueprint_path).retrieve()
importer = StringImporter(content) importer = Importer.from_string(content)
valid, _ = importer.validate() valid, _ = importer.validate()
if not valid: if not valid:
self.stderr.write("blueprint invalid") self.stderr.write("blueprint invalid")

View file

@ -11,7 +11,7 @@ from authentik.blueprints.models import BlueprintInstance
def apply_blueprint(*files: str): def apply_blueprint(*files: str):
"""Apply blueprint before test""" """Apply blueprint before test"""
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
def wrapper_outer(func: Callable): def wrapper_outer(func: Callable):
"""Apply blueprint before test""" """Apply blueprint before test"""
@ -20,7 +20,7 @@ def apply_blueprint(*files: str):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for file in files: for file in files:
content = BlueprintInstance(path=file).retrieve() content = BlueprintInstance(path=file).retrieve()
StringImporter(content).apply() Importer.from_string(content).apply()
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper

View file

@ -6,7 +6,7 @@ from django.test import TransactionTestCase
from authentik.blueprints.models import BlueprintInstance from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -25,7 +25,7 @@ def blueprint_tester(file_name: Path) -> Callable:
def tester(self: TestPackaged): def tester(self: TestPackaged):
base = Path("blueprints/") base = Path("blueprints/")
rel_path = Path(file_name).relative_to(base) rel_path = Path(file_name).relative_to(base)
importer = StringImporter(BlueprintInstance(path=str(rel_path)).retrieve()) importer = Importer.from_string(BlueprintInstance(path=str(rel_path)).retrieve())
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())

View file

@ -4,7 +4,7 @@ from os import environ
from django.test import TransactionTestCase from django.test import TransactionTestCase
from authentik.blueprints.v1.exporter import FlowExporter from authentik.blueprints.v1.exporter import FlowExporter
from authentik.blueprints.v1.importer import StringImporter, transaction_rollback from authentik.blueprints.v1.importer import Importer, transaction_rollback
from authentik.core.models import Group from authentik.core.models import Group
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
@ -21,14 +21,14 @@ class TestBlueprintsV1(TransactionTestCase):
def test_blueprint_invalid_format(self): def test_blueprint_invalid_format(self):
"""Test blueprint with invalid format""" """Test blueprint with invalid format"""
importer = StringImporter('{"version": 3}') importer = Importer.from_string('{"version": 3}')
self.assertFalse(importer.validate()[0]) self.assertFalse(importer.validate()[0])
importer = StringImporter( importer = Importer.from_string(
'{"version": 1,"entries":[{"identifiers":{},"attrs":{},' '{"version": 1,"entries":[{"identifiers":{},"attrs":{},'
'"model": "authentik_core.User"}]}' '"model": "authentik_core.User"}]}'
) )
self.assertFalse(importer.validate()[0]) self.assertFalse(importer.validate()[0])
importer = StringImporter( importer = Importer.from_string(
'{"version": 1, "entries": [{"attrs": {"name": "test"}, ' '{"version": 1, "entries": [{"attrs": {"name": "test"}, '
'"identifiers": {}, ' '"identifiers": {}, '
'"model": "authentik_core.Group"}]}' '"model": "authentik_core.Group"}]}'
@ -54,7 +54,7 @@ class TestBlueprintsV1(TransactionTestCase):
}, },
) )
importer = StringImporter( importer = Importer.from_string(
'{"version": 1, "entries": [{"attrs": {"name": "test999", "attributes": ' '{"version": 1, "entries": [{"attrs": {"name": "test999", "attributes": '
'{"key": ["updated_value"]}}, "identifiers": {"attributes": {"other_key": ' '{"key": ["updated_value"]}}, "identifiers": {"attributes": {"other_key": '
'["other_value"]}}, "model": "authentik_core.Group"}]}' '["other_value"]}}, "model": "authentik_core.Group"}]}'
@ -103,7 +103,7 @@ class TestBlueprintsV1(TransactionTestCase):
self.assertEqual(len(export.entries), 3) self.assertEqual(len(export.entries), 3)
export_yaml = exporter.export_to_string() export_yaml = exporter.export_to_string()
importer = StringImporter(export_yaml) importer = Importer.from_string(export_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
@ -113,14 +113,14 @@ class TestBlueprintsV1(TransactionTestCase):
"""Test export and import it twice""" """Test export and import it twice"""
count_initial = Prompt.objects.filter(field_key="username").count() count_initial = Prompt.objects.filter(field_key="username").count()
importer = StringImporter(load_fixture("fixtures/static_prompt_export.yaml")) importer = Importer.from_string(load_fixture("fixtures/static_prompt_export.yaml"))
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
count_before = Prompt.objects.filter(field_key="username").count() count_before = Prompt.objects.filter(field_key="username").count()
self.assertEqual(count_initial + 1, count_before) self.assertEqual(count_initial + 1, count_before)
importer = StringImporter(load_fixture("fixtures/static_prompt_export.yaml")) importer = Importer.from_string(load_fixture("fixtures/static_prompt_export.yaml"))
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
self.assertEqual(Prompt.objects.filter(field_key="username").count(), count_before) self.assertEqual(Prompt.objects.filter(field_key="username").count(), count_before)
@ -130,7 +130,7 @@ class TestBlueprintsV1(TransactionTestCase):
ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete() ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete()
Group.objects.filter(name="test").delete() Group.objects.filter(name="test").delete()
environ["foo"] = generate_id() environ["foo"] = generate_id()
importer = StringImporter(load_fixture("fixtures/tags.yaml"), {"bar": "baz"}) importer = Importer.from_string(load_fixture("fixtures/tags.yaml"), {"bar": "baz"})
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first() policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first()
@ -248,7 +248,7 @@ class TestBlueprintsV1(TransactionTestCase):
exporter = FlowExporter(flow) exporter = FlowExporter(flow)
export_yaml = exporter.export_to_string() export_yaml = exporter.export_to_string()
importer = StringImporter(export_yaml) importer = Importer.from_string(export_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
self.assertTrue(UserLoginStage.objects.filter(name=stage_name).exists()) self.assertTrue(UserLoginStage.objects.filter(name=stage_name).exists())
@ -297,7 +297,7 @@ class TestBlueprintsV1(TransactionTestCase):
exporter = FlowExporter(flow) exporter = FlowExporter(flow)
export_yaml = exporter.export_to_string() export_yaml = exporter.export_to_string()
importer = StringImporter(export_yaml) importer = Importer.from_string(export_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())

View file

@ -1,7 +1,7 @@
"""Test blueprints v1""" """Test blueprints v1"""
from django.test import TransactionTestCase from django.test import TransactionTestCase
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
from authentik.core.models import Application, Token, User from authentik.core.models import Application, Token, User
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.flows.models import Flow from authentik.flows.models import Flow
@ -18,7 +18,7 @@ class TestBlueprintsV1ConditionalFields(TransactionTestCase):
self.uid = generate_id() self.uid = generate_id()
import_yaml = load_fixture("fixtures/conditional_fields.yaml", uid=self.uid, user=user.pk) import_yaml = load_fixture("fixtures/conditional_fields.yaml", uid=self.uid, user=user.pk)
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())

View file

@ -1,7 +1,7 @@
"""Test blueprints v1""" """Test blueprints v1"""
from django.test import TransactionTestCase from django.test import TransactionTestCase
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture from authentik.lib.tests.utils import load_fixture
@ -18,7 +18,7 @@ class TestBlueprintsV1Conditions(TransactionTestCase):
"fixtures/conditions_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2 "fixtures/conditions_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2
) )
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
# Ensure objects exist # Ensure objects exist
@ -35,7 +35,7 @@ class TestBlueprintsV1Conditions(TransactionTestCase):
"fixtures/conditions_not_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2 "fixtures/conditions_not_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2
) )
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
# Ensure objects do not exist # Ensure objects do not exist

View file

@ -1,7 +1,7 @@
"""Test blueprints v1""" """Test blueprints v1"""
from django.test import TransactionTestCase from django.test import TransactionTestCase
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture from authentik.lib.tests.utils import load_fixture
@ -15,7 +15,7 @@ class TestBlueprintsV1State(TransactionTestCase):
flow_slug = generate_id() flow_slug = generate_id()
import_yaml = load_fixture("fixtures/state_present.yaml", id=flow_slug) import_yaml = load_fixture("fixtures/state_present.yaml", id=flow_slug)
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
# Ensure object exists # Ensure object exists
@ -30,7 +30,7 @@ class TestBlueprintsV1State(TransactionTestCase):
self.assertEqual(flow.title, "bar") self.assertEqual(flow.title, "bar")
# Ensure importer updates it # Ensure importer updates it
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
flow: Flow = Flow.objects.filter(slug=flow_slug).first() flow: Flow = Flow.objects.filter(slug=flow_slug).first()
@ -41,7 +41,7 @@ class TestBlueprintsV1State(TransactionTestCase):
flow_slug = generate_id() flow_slug = generate_id()
import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug) import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug)
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
# Ensure object exists # Ensure object exists
@ -56,7 +56,7 @@ class TestBlueprintsV1State(TransactionTestCase):
self.assertEqual(flow.title, "bar") self.assertEqual(flow.title, "bar")
# Ensure importer doesn't update it # Ensure importer doesn't update it
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
flow: Flow = Flow.objects.filter(slug=flow_slug).first() flow: Flow = Flow.objects.filter(slug=flow_slug).first()
@ -67,7 +67,7 @@ class TestBlueprintsV1State(TransactionTestCase):
flow_slug = generate_id() flow_slug = generate_id()
import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug) import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug)
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
# Ensure object exists # Ensure object exists
@ -75,7 +75,7 @@ class TestBlueprintsV1State(TransactionTestCase):
self.assertEqual(flow.slug, flow_slug) self.assertEqual(flow.slug, flow_slug)
import_yaml = load_fixture("fixtures/state_absent.yaml", id=flow_slug) import_yaml = load_fixture("fixtures/state_absent.yaml", id=flow_slug)
importer = StringImporter(import_yaml) importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0]) self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply()) self.assertTrue(importer.apply())
flow: Flow = Flow.objects.filter(slug=flow_slug).first() flow: Flow = Flow.objects.filter(slug=flow_slug).first()

View file

@ -12,6 +12,7 @@ from uuid import UUID
from deepmerge import always_merger from deepmerge import always_merger
from django.apps import apps from django.apps import apps
from django.db.models import Model, Q from django.db.models import Model, Q
from rest_framework.exceptions import ValidationError
from rest_framework.fields import Field from rest_framework.fields import Field
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from yaml import SafeDumper, SafeLoader, ScalarNode, SequenceNode from yaml import SafeDumper, SafeLoader, ScalarNode, SequenceNode
@ -207,8 +208,8 @@ class KeyOf(YAMLTag):
): ):
return _entry._state.instance.pbm_uuid return _entry._state.instance.pbm_uuid
return _entry._state.instance.pk return _entry._state.instance.pk
raise EntryInvalidError( raise EntryInvalidError.from_entry(
f"KeyOf: failed to find entry with `id` of `{self.id_from}` and a model instance" f"KeyOf: failed to find entry with `id` of `{self.id_from}` and a model instance", entry
) )
@ -279,7 +280,7 @@ class Format(YAMLTag):
try: try:
return self.format_string % tuple(args) return self.format_string % tuple(args)
except TypeError as exc: except TypeError as exc:
raise EntryInvalidError(exc) raise EntryInvalidError.from_entry(exc, entry)
class Find(YAMLTag): class Find(YAMLTag):
@ -356,13 +357,15 @@ class Condition(YAMLTag):
args.append(arg) args.append(arg)
if not args: if not args:
raise EntryInvalidError("At least one value is required after mode selection.") raise EntryInvalidError.from_entry(
"At least one value is required after mode selection.", entry
)
try: try:
comparator = self._COMPARATORS[self.mode.upper()] comparator = self._COMPARATORS[self.mode.upper()]
return comparator(tuple(bool(x) for x in args)) return comparator(tuple(bool(x) for x in args))
except (TypeError, KeyError) as exc: except (TypeError, KeyError) as exc:
raise EntryInvalidError(exc) raise EntryInvalidError.from_entry(exc, entry)
class If(YAMLTag): class If(YAMLTag):
@ -394,7 +397,7 @@ class If(YAMLTag):
blueprint, blueprint,
) )
except TypeError as exc: except TypeError as exc:
raise EntryInvalidError(exc) raise EntryInvalidError.from_entry(exc, entry)
class Enumerate(YAMLTag, YAMLTagContext): class Enumerate(YAMLTag, YAMLTagContext):
@ -426,9 +429,10 @@ class Enumerate(YAMLTag, YAMLTagContext):
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
if isinstance(self.iterable, EnumeratedItem) and self.iterable.depth == 0: if isinstance(self.iterable, EnumeratedItem) and self.iterable.depth == 0:
raise EntryInvalidError( raise EntryInvalidError.from_entry(
f"{self.__class__.__name__} tag's iterable references this tag's context. " f"{self.__class__.__name__} tag's iterable references this tag's context. "
"This is a noop. Check you are setting depth bigger than 0." "This is a noop. Check you are setting depth bigger than 0.",
entry,
) )
if isinstance(self.iterable, YAMLTag): if isinstance(self.iterable, YAMLTag):
@ -437,9 +441,10 @@ class Enumerate(YAMLTag, YAMLTagContext):
iterable = self.iterable iterable = self.iterable
if not isinstance(iterable, Iterable): if not isinstance(iterable, Iterable):
raise EntryInvalidError( raise EntryInvalidError.from_entry(
f"{self.__class__.__name__}'s iterable must be an iterable " f"{self.__class__.__name__}'s iterable must be an iterable "
"such as a sequence or a mapping" "such as a sequence or a mapping",
entry,
) )
if isinstance(iterable, Mapping): if isinstance(iterable, Mapping):
@ -450,7 +455,7 @@ class Enumerate(YAMLTag, YAMLTagContext):
try: try:
output_class, add_fn = self._OUTPUT_BODIES[self.output_body.upper()] output_class, add_fn = self._OUTPUT_BODIES[self.output_body.upper()]
except KeyError as exc: except KeyError as exc:
raise EntryInvalidError(exc) raise EntryInvalidError.from_entry(exc, entry)
result = output_class() result = output_class()
@ -462,8 +467,8 @@ class Enumerate(YAMLTag, YAMLTagContext):
resolved_body = entry.tag_resolver(self.item_body, blueprint) resolved_body = entry.tag_resolver(self.item_body, blueprint)
result = add_fn(result, resolved_body) result = add_fn(result, resolved_body)
if not isinstance(result, output_class): if not isinstance(result, output_class):
raise EntryInvalidError( raise EntryInvalidError.from_entry(
f"Invalid {self.__class__.__name__} item found: {resolved_body}" f"Invalid {self.__class__.__name__} item found: {resolved_body}", entry
) )
finally: finally:
self.__current_context = tuple() self.__current_context = tuple()
@ -490,12 +495,13 @@ class EnumeratedItem(YAMLTag):
) )
except ValueError as exc: except ValueError as exc:
if self.depth == 0: if self.depth == 0:
raise EntryInvalidError( raise EntryInvalidError.from_entry(
f"{self.__class__.__name__} tags are only usable " f"{self.__class__.__name__} tags are only usable "
f"inside an {Enumerate.__name__} tag" f"inside an {Enumerate.__name__} tag",
entry,
) )
raise EntryInvalidError(f"{self.__class__.__name__} tag: {exc}") raise EntryInvalidError.from_entry(f"{self.__class__.__name__} tag: {exc}", entry)
return context_tag.get_context(entry, blueprint) return context_tag.get_context(entry, blueprint)
@ -509,7 +515,7 @@ class Index(EnumeratedItem):
try: try:
return context[0] return context[0]
except IndexError: # pragma: no cover except IndexError: # pragma: no cover
raise EntryInvalidError(f"Empty/invalid context: {context}") raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry)
class Value(EnumeratedItem): class Value(EnumeratedItem):
@ -521,7 +527,7 @@ class Value(EnumeratedItem):
try: try:
return context[1] return context[1]
except IndexError: # pragma: no cover except IndexError: # pragma: no cover
raise EntryInvalidError(f"Empty/invalid context: {context}") raise EntryInvalidError.from_entry(f"Empty/invalid context: {context}", entry)
class BlueprintDumper(SafeDumper): class BlueprintDumper(SafeDumper):
@ -575,8 +581,26 @@ class BlueprintLoader(SafeLoader):
class EntryInvalidError(SentryIgnoredException): class EntryInvalidError(SentryIgnoredException):
"""Error raised when an entry is invalid""" """Error raised when an entry is invalid"""
serializer_errors: Optional[dict] entry_model: Optional[str]
entry_id: Optional[str]
validation_error: Optional[ValidationError]
def __init__(self, *args: object, serializer_errors: Optional[dict] = None) -> None: def __init__(self, *args: object, validation_error: Optional[ValidationError] = None) -> None:
super().__init__(*args) super().__init__(*args)
self.serializer_errors = serializer_errors self.entry_model = None
self.entry_id = None
self.validation_error = validation_error
@staticmethod
def from_entry(
msg_or_exc: str | Exception, entry: BlueprintEntry, *args, **kwargs
) -> "EntryInvalidError":
"""Create EntryInvalidError with the context of an entry"""
error = EntryInvalidError(msg_or_exc, *args, **kwargs)
if isinstance(msg_or_exc, ValidationError):
error.validation_error = msg_or_exc
# Make sure the model and id are strings, depending where the error happens
# they might still be YAMLTag instances
error.entry_model = str(entry.model)
error.entry_id = str(entry.id)
return error

View file

@ -8,9 +8,9 @@ from dacite.core import from_dict
from dacite.exceptions import DaciteError from dacite.exceptions import DaciteError
from deepmerge import always_merger from deepmerge import always_merger
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import transaction
from django.db.models import Model from django.db.models import Model
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.transaction import atomic
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer, Serializer from rest_framework.serializers import BaseSerializer, Serializer
@ -38,6 +38,7 @@ from authentik.core.models import (
from authentik.events.utils import cleanse_dict from authentik.events.utils import cleanse_dict
from authentik.flows.models import FlowToken, Stage from authentik.flows.models import FlowToken, Stage
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException
from authentik.outposts.models import OutpostServiceConnection from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel from authentik.policies.models import Policy, PolicyBindingModel
@ -72,18 +73,23 @@ def is_model_allowed(model: type[Model]) -> bool:
return model not in excluded_models and issubclass(model, (SerializerModel, BaseMetaModel)) return model not in excluded_models and issubclass(model, (SerializerModel, BaseMetaModel))
class DoRollback(SentryIgnoredException):
"""Exception to trigger a rollback"""
@contextmanager @contextmanager
def transaction_rollback(): def transaction_rollback():
"""Enters an atomic transaction and always triggers a rollback at the end of the block.""" """Enters an atomic transaction and always triggers a rollback at the end of the block."""
atomic = transaction.atomic() try:
# pylint: disable=unnecessary-dunder-call with atomic():
atomic.__enter__() yield
yield raise DoRollback()
atomic.__exit__(IntegrityError, None, None) except DoRollback:
pass
class Importer: class Importer:
"""Import Blueprint from YAML""" """Import Blueprint from raw dict or YAML/JSON"""
logger: BoundLogger logger: BoundLogger
_import: Blueprint _import: Blueprint
@ -98,6 +104,18 @@ class Importer:
always_merger.merge(ctx, context) always_merger.merge(ctx, context)
self._import.context = ctx self._import.context = ctx
@staticmethod
def from_string(yaml_input: str, context: dict | None = None) -> "Importer":
"""Parse YAML string and create blueprint importer from it"""
import_dict = load(yaml_input, BlueprintLoader)
try:
_import = from_dict(
Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState])
)
except DaciteError as exc:
raise EntryInvalidError from exc
return Importer(_import, context)
@property @property
def blueprint(self) -> Blueprint: def blueprint(self) -> Blueprint:
"""Get imported blueprint""" """Get imported blueprint"""
@ -155,7 +173,7 @@ class Importer:
model: type[SerializerModel] = registry.get_model(model_app_label, model_name) model: type[SerializerModel] = registry.get_model(model_app_label, model_name)
# Don't use isinstance since we don't want to check for inheritance # Don't use isinstance since we don't want to check for inheritance
if not is_model_allowed(model): if not is_model_allowed(model):
raise EntryInvalidError(f"Model {model} not allowed") raise EntryInvalidError.from_entry(f"Model {model} not allowed", entry)
if issubclass(model, BaseMetaModel): if issubclass(model, BaseMetaModel):
serializer_class: type[Serializer] = model.serializer() serializer_class: type[Serializer] = model.serializer()
serializer = serializer_class( serializer = serializer_class(
@ -167,8 +185,10 @@ class Importer:
try: try:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
except ValidationError as exc: except ValidationError as exc:
raise EntryInvalidError( raise EntryInvalidError.from_entry(
f"Serializer errors {serializer.errors}", serializer_errors=serializer.errors f"Serializer errors {serializer.errors}",
validation_error=exc,
entry=entry,
) from exc ) from exc
return serializer return serializer
@ -185,12 +205,12 @@ class Importer:
query = self.__query_from_identifier(updated_identifiers) query = self.__query_from_identifier(updated_identifiers)
if not query: if not query:
raise EntryInvalidError("No or invalid identifiers") raise EntryInvalidError.from_entry("No or invalid identifiers", entry)
try: try:
existing_models = model.objects.filter(query) existing_models = model.objects.filter(query)
except FieldError as exc: except FieldError as exc:
raise EntryInvalidError(f"Invalid identifier field: {exc}") from exc raise EntryInvalidError.from_entry(f"Invalid identifier field: {exc}", entry) from exc
serializer_kwargs = {} serializer_kwargs = {}
model_instance = existing_models.first() model_instance = existing_models.first()
@ -204,11 +224,12 @@ class Importer:
serializer_kwargs["instance"] = model_instance serializer_kwargs["instance"] = model_instance
serializer_kwargs["partial"] = True serializer_kwargs["partial"] = True
elif model_instance and entry.state == BlueprintEntryDesiredState.MUST_CREATED: elif model_instance and entry.state == BlueprintEntryDesiredState.MUST_CREATED:
raise EntryInvalidError( raise EntryInvalidError.from_entry(
( (
f"state is set to {BlueprintEntryDesiredState.MUST_CREATED}" f"state is set to {BlueprintEntryDesiredState.MUST_CREATED} "
" and object exists already" "and object exists already",
) ),
entry,
) )
else: else:
self.logger.debug( self.logger.debug(
@ -224,7 +245,7 @@ class Importer:
try: try:
full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import)) full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import))
except ValueError as exc: except ValueError as exc:
raise EntryInvalidError(exc) from exc raise EntryInvalidError.from_entry(exc, entry) from exc
always_merger.merge(full_data, updated_identifiers) always_merger.merge(full_data, updated_identifiers)
serializer_kwargs["data"] = full_data serializer_kwargs["data"] = full_data
@ -237,15 +258,17 @@ class Importer:
try: try:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
except ValidationError as exc: except ValidationError as exc:
raise EntryInvalidError( raise EntryInvalidError.from_entry(
f"Serializer errors {serializer.errors}", serializer_errors=serializer.errors f"Serializer errors {serializer.errors}",
validation_error=exc,
entry=entry,
) from exc ) from exc
return serializer return serializer
def apply(self) -> bool: def apply(self) -> bool:
"""Apply (create/update) models yaml, in database transaction""" """Apply (create/update) models yaml, in database transaction"""
try: try:
with transaction.atomic(): with atomic():
if not self._apply_models(): if not self._apply_models():
self.logger.debug("Reverting changes due to error") self.logger.debug("Reverting changes due to error")
raise IntegrityError raise IntegrityError
@ -254,7 +277,7 @@ class Importer:
self.logger.debug("Committing changes") self.logger.debug("Committing changes")
return True return True
def _apply_models(self) -> bool: def _apply_models(self, raise_errors=False) -> bool:
"""Apply (create/update) models yaml""" """Apply (create/update) models yaml"""
self.__pk_map = {} self.__pk_map = {}
for entry in self._import.entries: for entry in self._import.entries:
@ -274,6 +297,8 @@ class Importer:
if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT: if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT:
continue continue
self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc)
if raise_errors:
raise exc
return False return False
if not serializer: if not serializer:
continue continue
@ -311,7 +336,7 @@ class Importer:
self.logger.debug("entry to delete with no instance, skipping") self.logger.debug("entry to delete with no instance, skipping")
return True return True
def validate(self) -> tuple[bool, list[EventDict]]: def validate(self, raise_validation_errors=False) -> tuple[bool, list[EventDict]]:
"""Validate loaded blueprint export, ensure all models are allowed """Validate loaded blueprint export, ensure all models are allowed
and serializers have no errors""" and serializers have no errors"""
self.logger.debug("Starting blueprint import validation") self.logger.debug("Starting blueprint import validation")
@ -323,7 +348,7 @@ class Importer:
transaction_rollback(), transaction_rollback(),
capture_logs() as logs, capture_logs() as logs,
): ):
successful = self._apply_models() successful = self._apply_models(raise_errors=raise_validation_errors)
if not successful: if not successful:
self.logger.debug("Blueprint validation failed") self.logger.debug("Blueprint validation failed")
for log in logs: for log in logs:
@ -331,17 +356,3 @@ class Importer:
self.logger.debug("Finished blueprint import validation") self.logger.debug("Finished blueprint import validation")
self._import = orig_import self._import = orig_import
return successful, logs return successful, logs
class StringImporter(Importer):
"""Importer that also parses from string"""
def __init__(self, yaml_input: str, context: dict | None = None):
import_dict = load(yaml_input, BlueprintLoader)
try:
_import = from_dict(
Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState])
)
except DaciteError as exc:
raise EntryInvalidError from exc
super().__init__(_import, context)

View file

@ -26,7 +26,7 @@ from authentik.blueprints.models import (
BlueprintRetrievalFailed, BlueprintRetrievalFailed,
) )
from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, EntryInvalidError from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, EntryInvalidError
from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
from authentik.blueprints.v1.oci import OCI_PREFIX from authentik.blueprints.v1.oci import OCI_PREFIX
from authentik.events.monitored_tasks import ( from authentik.events.monitored_tasks import (
@ -190,7 +190,7 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
self.set_uid(slugify(instance.name)) self.set_uid(slugify(instance.name))
blueprint_content = instance.retrieve() blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest() file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = StringImporter(blueprint_content, instance.context) importer = Importer.from_string(blueprint_content, instance.context)
if importer.blueprint.metadata: if importer.blueprint.metadata:
instance.metadata = asdict(importer.blueprint.metadata) instance.metadata = asdict(importer.blueprint.metadata)
valid, logs = importer.validate() valid, logs = importer.validate()

View file

@ -1,6 +1,4 @@
"""Authenticator Devices API Views""" """Authenticator Devices API Views"""
from django_otp import device_classes, devices_for_user
from django_otp.models import Device
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework.fields import BooleanField, CharField, IntegerField, SerializerMethodField from rest_framework.fields import BooleanField, CharField, IntegerField, SerializerMethodField
@ -10,6 +8,8 @@ from rest_framework.response import Response
from rest_framework.viewsets import ViewSet from rest_framework.viewsets import ViewSet
from authentik.core.api.utils import MetaNameSerializer from authentik.core.api.utils import MetaNameSerializer
from authentik.stages.authenticator import device_classes, devices_for_user
from authentik.stages.authenticator.models import Device
class DeviceSerializer(MetaNameSerializer): class DeviceSerializer(MetaNameSerializer):

View file

@ -13,6 +13,7 @@ from authentik.blueprints.v1.common import (
Blueprint, Blueprint,
BlueprintEntry, BlueprintEntry,
BlueprintEntryDesiredState, BlueprintEntryDesiredState,
EntryInvalidError,
KeyOf, KeyOf,
) )
from authentik.blueprints.v1.importer import Importer from authentik.blueprints.v1.importer import Importer
@ -66,11 +67,51 @@ class TransactionApplicationSerializer(PassiveSerializer):
raise ValidationError("Invalid provider model") raise ValidationError("Invalid provider model")
return fq_model_name return fq_model_name
def validate(self, attrs: dict) -> dict:
blueprint = Blueprint()
blueprint.entries.append(
BlueprintEntry(
model=attrs["provider_model"],
state=BlueprintEntryDesiredState.MUST_CREATED,
identifiers={
"name": attrs["provider"]["name"],
},
# Must match the name of the field on `self`
id="provider",
attrs=attrs["provider"],
)
)
app_data = attrs["app"]
app_data["provider"] = KeyOf(None, ScalarNode(tag="", value="provider"))
blueprint.entries.append(
BlueprintEntry(
model="authentik_core.application",
state=BlueprintEntryDesiredState.MUST_CREATED,
identifiers={
"slug": attrs["app"]["slug"],
},
attrs=app_data,
# Must match the name of the field on `self`
id="app",
)
)
importer = Importer(blueprint, {})
try:
valid, _ = importer.validate(raise_validation_errors=True)
if not valid:
raise ValidationError("Invalid blueprint")
except EntryInvalidError as exc:
raise ValidationError(
{
exc.entry_id: exc.validation_error.detail,
}
)
return blueprint
class TransactionApplicationResponseSerializer(PassiveSerializer): class TransactionApplicationResponseSerializer(PassiveSerializer):
"""Transactional creation response""" """Transactional creation response"""
valid = BooleanField()
applied = BooleanField() applied = BooleanField()
logs = ListField(child=CharField()) logs = ListField(child=CharField())
@ -90,38 +131,9 @@ class TransactionalApplicationView(APIView):
"""Convert data into a blueprint, validate it and apply it""" """Convert data into a blueprint, validate it and apply it"""
data = TransactionApplicationSerializer(data=request.data) data = TransactionApplicationSerializer(data=request.data)
data.is_valid(raise_exception=True) data.is_valid(raise_exception=True)
print(data.validated_data)
blueprint = Blueprint() importer = Importer(data.validated_data, {})
blueprint.entries.append( applied = importer.apply()
BlueprintEntry( response = {"applied": False, "logs": []}
model=data.validated_data["provider_model"], response["applied"] = applied
state=BlueprintEntryDesiredState.MUST_CREATED,
identifiers={
"name": data.validated_data["provider"]["name"],
},
id="provider",
attrs=data.validated_data["provider"],
)
)
app_data = data.validated_data["app"]
app_data["provider"] = KeyOf(None, ScalarNode(tag="", value="provider"))
blueprint.entries.append(
BlueprintEntry(
model="authentik_core.application",
state=BlueprintEntryDesiredState.MUST_CREATED,
identifiers={
"slug": data.validated_data["app"]["slug"],
},
attrs=app_data,
)
)
importer = Importer(blueprint, {})
response = {"valid": False, "applied": False, "logs": []}
valid, logs = importer.validate()
response["logs"] = [x["event"] for x in logs]
response["valid"] = valid
if valid:
applied = importer.apply()
response["applied"] = applied
return Response(response, status=200) return Response(response, status=200)

View file

@ -0,0 +1,9 @@
"""custom runserver command"""
from daphne.management.commands.runserver import Command as RunServer
class Command(RunServer):
"""custom runserver command, which doesn't show the misleading django startup message"""
def on_bind(self, server_port):
pass

View file

@ -16,6 +16,9 @@ LOGGER = get_logger()
class Command(BaseCommand): class Command(BaseCommand):
"""Run worker""" """Run worker"""
def add_arguments(self, parser):
parser.add_argument("-b", "--beat", action="store_true")
def handle(self, **options): def handle(self, **options):
close_old_connections() close_old_connections()
if CONFIG.get_bool("remote_debug"): if CONFIG.get_bool("remote_debug"):
@ -26,10 +29,9 @@ class Command(BaseCommand):
no_color=False, no_color=False,
quiet=True, quiet=True,
optimization="fair", optimization="fair",
max_tasks_per_child=1,
autoscale=(3, 1), autoscale=(3, 1),
task_events=True, task_events=True,
beat=True, beat=options.get("beat", True),
schedule_filename=f"{tempdir}/celerybeat-schedule", schedule_filename=f"{tempdir}/celerybeat-schedule",
queues=["authentik", "authentik_scheduled", "authentik_events"], queues=["authentik", "authentik_scheduled", "authentik_events"],
) )

View file

@ -48,7 +48,7 @@ class Action(Enum):
class MessageStage(StageView): class MessageStage(StageView):
"""Show a pre-configured message after the flow is done""" """Show a pre-configured message after the flow is done"""
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Show a pre-configured message after the flow is done""" """Show a pre-configured message after the flow is done"""
message = getattr(self.executor.current_stage, "message", "") message = getattr(self.executor.current_stage, "message", "")
level = getattr(self.executor.current_stage, "level", messages.SUCCESS) level = getattr(self.executor.current_stage, "level", messages.SUCCESS)
@ -59,10 +59,6 @@ class MessageStage(StageView):
) )
return self.executor.stage_ok() return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
class SourceFlowManager: class SourceFlowManager:
"""Help sources decide what they should do after authorization. Based on source settings and """Help sources decide what they should do after authorization. Based on source settings and

View file

@ -13,7 +13,7 @@ class PostUserEnrollmentStage(StageView):
"""Dynamically injected stage which saves the Connection after """Dynamically injected stage which saves the Connection after
the user has been enrolled.""" the user has been enrolled."""
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Stage used after the user has been enrolled""" """Stage used after the user has been enrolled"""
connection: UserSourceConnection = self.executor.plan.context[ connection: UserSourceConnection = self.executor.plan.context[
PLAN_CONTEXT_SOURCES_CONNECTION PLAN_CONTEXT_SOURCES_CONNECTION
@ -27,7 +27,3 @@ class PostUserEnrollmentStage(StageView):
source=connection.source, source=connection.source,
).from_http(self.request) ).from_http(self.request)
return self.executor.stage_ok() return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View file

@ -1,6 +1,4 @@
"""Test Transactional API""" """Test Transactional API"""
from json import loads
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
@ -35,11 +33,32 @@ class TestTransactionalApplicationsAPI(APITestCase):
}, },
}, },
) )
response_body = loads(response.content.decode()) self.assertJSONEqual(response.content.decode(), {"applied": True, "logs": []})
self.assertTrue(response_body["valid"])
self.assertTrue(response_body["applied"])
provider = OAuth2Provider.objects.filter(name=uid).first() provider = OAuth2Provider.objects.filter(name=uid).first()
self.assertIsNotNone(provider) self.assertIsNotNone(provider)
app = Application.objects.filter(slug=uid).first() app = Application.objects.filter(slug=uid).first()
self.assertIsNotNone(app) self.assertIsNotNone(app)
self.assertEqual(app.provider.pk, provider.pk) self.assertEqual(app.provider.pk, provider.pk)
def test_create_transactional_invalid(self):
"""Test transactional Application + provider creation"""
self.client.force_login(self.user)
uid = generate_id()
response = self.client.put(
reverse("authentik_api:core-transactional-application"),
data={
"app": {
"name": uid,
"slug": uid,
},
"provider_model": "authentik_providers_oauth2.oauth2provider",
"provider": {
"name": uid,
"authorization_flow": "",
},
},
)
self.assertJSONEqual(
response.content.decode(),
{"provider": {"authorization_flow": ["This field may not be null."]}},
)

View file

@ -25,10 +25,10 @@ def create_test_admin_user(name: Optional[str] = None, **kwargs) -> User:
"""Generate a test-admin user""" """Generate a test-admin user"""
uid = generate_id(20) if not name else name uid = generate_id(20) if not name else name
group = Group.objects.create(name=uid, is_superuser=True) group = Group.objects.create(name=uid, is_superuser=True)
kwargs.setdefault("email", f"{uid}@goauthentik.io")
kwargs.setdefault("username", uid)
user: User = User.objects.create( user: User = User.objects.create(
username=uid,
name=uid, name=uid,
email=f"{uid}@goauthentik.io",
**kwargs, **kwargs,
) )
user.set_password(uid) user.set_password(uid)

View file

@ -189,6 +189,8 @@ class CertificateKeyPairFilter(FilterSet):
def filter_has_key(self, queryset, name, value): # pragma: no cover def filter_has_key(self, queryset, name, value): # pragma: no cover
"""Only return certificate-key pairs with keys""" """Only return certificate-key pairs with keys"""
if not value:
return queryset
return queryset.exclude(key_data__exact="") return queryset.exclude(key_data__exact="")
class Meta: class Meta:

View file

@ -128,8 +128,26 @@ class TestCrypto(APITestCase):
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_api:certificatekeypair-list", "authentik_api:certificatekeypair-list",
) ),
+ f"?name={cert.name}" data={"name": cert.name},
)
self.assertEqual(200, response.status_code)
body = loads(response.content.decode())
api_cert = [x for x in body["results"] if x["name"] == cert.name][0]
self.assertEqual(api_cert["fingerprint_sha1"], cert.fingerprint_sha1)
self.assertEqual(api_cert["fingerprint_sha256"], cert.fingerprint_sha256)
def test_list_has_key_false(self):
"""Test API List with has_key set to false"""
cert = create_test_cert()
cert.key_data = ""
cert.save()
self.client.force_login(create_test_admin_user())
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-list",
),
data={"name": cert.name, "has_key": False},
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
body = loads(response.content.decode()) body = loads(response.content.decode())
@ -144,8 +162,8 @@ class TestCrypto(APITestCase):
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_api:certificatekeypair-list", "authentik_api:certificatekeypair-list",
) ),
+ f"?name={cert.name}&include_details=false" data={"name": cert.name, "include_details": False},
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
body = loads(response.content.decode()) body = loads(response.content.decode())
@ -168,8 +186,8 @@ class TestCrypto(APITestCase):
reverse( reverse(
"authentik_api:certificatekeypair-view-certificate", "authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk}, kwargs={"pk": keypair.pk},
) ),
+ "?download", data={"download": True},
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response) self.assertIn("Content-Disposition", response)
@ -189,8 +207,8 @@ class TestCrypto(APITestCase):
reverse( reverse(
"authentik_api:certificatekeypair-view-private-key", "authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk}, kwargs={"pk": keypair.pk},
) ),
+ "?download", data={"download": True},
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response) self.assertIn("Content-Disposition", response)
@ -200,7 +218,7 @@ class TestCrypto(APITestCase):
self.client.force_login(create_test_admin_user()) self.client.force_login(create_test_admin_user())
keypair = create_test_cert() keypair = create_test_cert()
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(
name="test", name=generate_id(),
client_id="test", client_id="test",
client_secret=generate_key(), client_secret=generate_key(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),

View file

@ -26,4 +26,11 @@ class Migration(migrations.Migration):
fields=["key"], name="authentik_e_key_523e13_hash" fields=["key"], name="authentik_e_key_523e13_hash"
), ),
), ),
migrations.AlterModelOptions(
name="licenseusage",
options={
"verbose_name": "License Usage",
"verbose_name_plural": "License Usage Records",
},
),
] ]

View file

@ -15,6 +15,7 @@ from django.contrib.postgres.indexes import HashIndex
from django.db import models from django.db import models
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from jwt import PyJWTError, decode, get_unverified_header from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
@ -187,3 +188,7 @@ class LicenseUsage(ExpiringModel):
within_limits = models.BooleanField() within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True) record_date = models.DateTimeField(auto_now_add=True)
class Meta:
verbose_name = _("License Usage")
verbose_name_plural = _("License Usage Records")

View file

@ -1,44 +1,30 @@
"""Enterprise license policies""" """Enterprise license policies"""
from typing import Optional from typing import Optional
from rest_framework.serializers import BaseSerializer
from authentik.core.models import User, UserTypes from authentik.core.models import User, UserTypes
from authentik.enterprise.models import LicenseKey from authentik.enterprise.models import LicenseKey
from authentik.policies.models import Policy
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
from authentik.policies.views import PolicyAccessView from authentik.policies.views import PolicyAccessView
class EnterprisePolicy(Policy):
"""Check that a user is correctly licensed for the request"""
@property
def component(self) -> str:
return ""
@property
def serializer(self) -> type[BaseSerializer]:
raise NotImplementedError
def passes(self, request: PolicyRequest) -> PolicyResult:
if not LicenseKey.get_total().is_valid():
return PolicyResult(False)
if request.user.type != UserTypes.INTERNAL:
return PolicyResult(False)
return PolicyResult(True)
class EnterprisePolicyAccessView(PolicyAccessView): class EnterprisePolicyAccessView(PolicyAccessView):
"""PolicyAccessView which also checks enterprise licensing""" """PolicyAccessView which also checks enterprise licensing"""
def check_license(self):
"""Check license"""
if not LicenseKey.get_total().is_valid():
return False
if self.request.user.type != UserTypes.INTERNAL:
return False
return True
def user_has_access(self, user: Optional[User] = None) -> PolicyResult: def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
user = user or self.request.user user = user or self.request.user
request = PolicyRequest(user) request = PolicyRequest(user)
request.http_request = self.request request.http_request = self.request
result = super().user_has_access(user) result = super().user_has_access(user)
enterprise_result = EnterprisePolicy().passes(request) enterprise_result = self.check_license()
if not enterprise_result.passing: if not enterprise_result:
return enterprise_result return enterprise_result
return result return result

View file

@ -9,7 +9,6 @@ from django.core.exceptions import SuspiciousOperation
from django.db.models import Model from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete from django.db.models.signals import m2m_changed, post_save, pre_delete
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django_otp.plugins.otp_static.models import StaticToken
from guardian.models import UserObjectPermission from guardian.models import UserObjectPermission
from authentik.core.models import ( from authentik.core.models import (
@ -30,6 +29,7 @@ from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel from authentik.policies.models import Policy, PolicyBindingModel
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.scim.models import SCIMGroup, SCIMUser from authentik.providers.scim.models import SCIMGroup, SCIMUser
from authentik.stages.authenticator_static.models import StaticToken
IGNORED_MODELS = ( IGNORED_MODELS = (
Event, Event,

View file

@ -16,7 +16,7 @@ from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required from authentik.api.decorators import permission_required
from authentik.blueprints.v1.exporter import FlowExporter from authentik.blueprints.v1.exporter import FlowExporter
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, StringImporter from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import CacheSerializer, LinkSerializer, PassiveSerializer from authentik.core.api.utils import CacheSerializer, LinkSerializer, PassiveSerializer
from authentik.events.utils import sanitize_dict from authentik.events.utils import sanitize_dict
@ -181,7 +181,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
if not file: if not file:
return Response(data=import_response.initial_data, status=400) return Response(data=import_response.initial_data, status=400)
importer = StringImporter(file.read().decode()) importer = Importer.from_string(file.read().decode())
valid, logs = importer.validate() valid, logs = importer.validate()
import_response.initial_data["logs"] = [sanitize_dict(log) for log in logs] import_response.initial_data["logs"] = [sanitize_dict(log) for log in logs]
import_response.initial_data["success"] = valid import_response.initial_data["success"] = valid

View file

@ -26,3 +26,8 @@ class EmptyFlowException(SentryIgnoredException):
class FlowSkipStageException(SentryIgnoredException): class FlowSkipStageException(SentryIgnoredException):
"""Exception to skip a stage""" """Exception to skip a stage"""
class StageInvalidException(SentryIgnoredException):
"""Exception can be thrown in a `Challenge` or `ChallengeResponse` serializer's
validation to trigger a `executor.stage_invalid()` response"""

View file

@ -23,6 +23,7 @@ from authentik.flows.challenge import (
RedirectChallenge, RedirectChallenge,
WithUserInfoChallenge, WithUserInfoChallenge,
) )
from authentik.flows.exceptions import StageInvalidException
from authentik.flows.models import InvalidResponseAction from authentik.flows.models import InvalidResponseAction
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_PENDING_USER
from authentik.lib.avatars import DEFAULT_AVATAR from authentik.lib.avatars import DEFAULT_AVATAR
@ -100,8 +101,14 @@ class ChallengeStageView(StageView):
def post(self, request: Request, *args, **kwargs) -> HttpResponse: def post(self, request: Request, *args, **kwargs) -> HttpResponse:
"""Handle challenge response""" """Handle challenge response"""
challenge: ChallengeResponse = self.get_response_instance(data=request.data) valid = False
if not challenge.is_valid(): try:
challenge: ChallengeResponse = self.get_response_instance(data=request.data)
valid = challenge.is_valid()
except StageInvalidException as exc:
self.logger.debug("Got StageInvalidException", exc=exc)
return self.executor.stage_invalid()
if not valid:
if self.executor.current_binding.invalid_response_action in [ if self.executor.current_binding.invalid_response_action in [
InvalidResponseAction.RESTART, InvalidResponseAction.RESTART,
InvalidResponseAction.RESTART_WITH_CONTEXT, InvalidResponseAction.RESTART_WITH_CONTEXT,

View file

@ -21,8 +21,9 @@ def view_tester_factory(view_class: type[StageView]) -> Callable:
def tester(self: TestViews): def tester(self: TestViews):
model_class = view_class(self.exec) model_class = view_class(self.exec)
self.assertIsNotNone(model_class.post) if not hasattr(model_class, "dispatch"):
self.assertIsNotNone(model_class.get) self.assertIsNotNone(model_class.post)
self.assertIsNotNone(model_class.get)
return tester return tester

View file

@ -73,22 +73,21 @@ QS_QUERY = "query"
def challenge_types(): def challenge_types():
"""This function returns a class which is an iterator, which returns the """This function returns a mapping which contains all subclasses of challenges
subclasses of Challenge, and Challenge itself.""" subclasses of Challenge, and Challenge itself."""
mapping = {} mapping = {}
classes = all_subclasses(Challenge) for cls in all_subclasses(Challenge):
classes.remove(WithUserInfoChallenge) if cls == WithUserInfoChallenge:
for cls in classes: continue
mapping[cls().fields["component"].default] = cls mapping[cls().fields["component"].default] = cls
return mapping return mapping
def challenge_response_types(): def challenge_response_types():
"""This function returns a class which is an iterator, which returns the """This function returns a mapping which contains all subclasses of challenges
subclasses of Challenge, and Challenge itself.""" subclasses of Challenge, and Challenge itself."""
mapping = {} mapping = {}
classes = all_subclasses(ChallengeResponse) for cls in all_subclasses(ChallengeResponse):
for cls in classes:
mapping[cls(stage=None).fields["component"].default] = cls mapping[cls(stage=None).fields["component"].default] = cls
return mapping return mapping
@ -279,7 +278,7 @@ class FlowExecutorView(APIView):
span.set_data("Method", "GET") span.set_data("Method", "GET")
span.set_data("authentik Stage", self.current_stage_view) span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug) span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.get(request, *args, **kwargs) stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response) return to_stage_response(request, stage_response)
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
return self.handle_exception(exc) return self.handle_exception(exc)
@ -323,7 +322,7 @@ class FlowExecutorView(APIView):
span.set_data("Method", "POST") span.set_data("Method", "POST")
span.set_data("authentik Stage", self.current_stage_view) span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug) span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.post(request, *args, **kwargs) stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response) return to_stage_response(request, stage_response)
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
return self.handle_exception(exc) return self.handle_exception(exc)
@ -346,10 +345,15 @@ class FlowExecutorView(APIView):
def restart_flow(self, keep_context=False) -> HttpResponse: def restart_flow(self, keep_context=False) -> HttpResponse:
"""Restart the currently active flow, optionally keeping the current context""" """Restart the currently active flow, optionally keeping the current context"""
planner = FlowPlanner(self.flow) planner = FlowPlanner(self.flow)
planner.use_cache = False
default_context = None default_context = None
if keep_context: if keep_context:
default_context = self.plan.context default_context = self.plan.context
plan = planner.plan(self.request, default_context) try:
plan = planner.plan(self.request, default_context)
except FlowNonApplicableException as exc:
self._logger.warning("f(exec): Flow restart not applicable to current user", exc=exc)
return self.handle_invalid_flow(exc)
self.request.session[SESSION_KEY_PLAN] = plan self.request.session[SESSION_KEY_PLAN] = plan
kwargs = self.kwargs kwargs = self.kwargs
kwargs.update({"flow_slug": self.flow.slug}) kwargs.update({"flow_slug": self.flow.slug})

View file

@ -1,4 +1,4 @@
# update website/docs/installation/configuration.md # update website/docs/installation/configuration.mdx
# This is the default configuration file # This is the default configuration file
postgresql: postgresql:
host: localhost host: localhost
@ -84,6 +84,9 @@ ldap:
tls: tls:
ciphers: null ciphers: null
reputation:
expiry: 86400
cookie_domain: null cookie_domain: null
disable_update_check: false disable_update_check: false
disable_startup_analytics: false disable_startup_analytics: false

View file

@ -7,7 +7,6 @@ from typing import Any, Iterable, Optional
from cachetools import TLRUCache, cached from cachetools import TLRUCache, cached
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django_otp import devices_for_user
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from sentry_sdk.hub import Hub from sentry_sdk.hub import Hub
@ -20,6 +19,7 @@ from authentik.lib.utils.http import get_http_session
from authentik.policies.models import Policy, PolicyBinding from authentik.policies.models import Policy, PolicyBinding
from authentik.policies.process import PolicyProcess from authentik.policies.process import PolicyProcess
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
from authentik.stages.authenticator import devices_for_user
LOGGER = get_logger() LOGGER = get_logger()

View file

@ -1,7 +1,112 @@
"""logging helpers""" """logging helpers"""
import logging
from logging import Logger from logging import Logger
from os import getpid from os import getpid
import structlog
from authentik.lib.config import CONFIG
LOG_PRE_CHAIN = [
# Add the log level and a timestamp to the event_dict if the log entry
# is not from structlog.
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(),
structlog.processors.StackInfoRenderer(),
]
def get_log_level():
"""Get log level, clamp trace to debug"""
level = CONFIG.get("log_level").upper()
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
# Additionally, the entire code uses debug as highest level
# so that would have to be re-written too
if level == "TRACE":
level = "DEBUG"
return level
def structlog_configure():
"""Configure structlog itself"""
structlog.configure_once(
processors=[
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.contextvars.merge_contextvars,
add_process_id,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso", utc=False),
structlog.processors.StackInfoRenderer(),
structlog.processors.dict_tracebacks,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, get_log_level(), logging.WARNING)
),
cache_logger_on_first_use=True,
)
def get_logger_config():
"""Configure python stdlib's logging"""
debug = CONFIG.get_bool("debug")
global_level = get_log_level()
base_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"json": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.processors.JSONRenderer(sort_keys=True),
"foreign_pre_chain": LOG_PRE_CHAIN + [structlog.processors.dict_tracebacks],
},
"console": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.dev.ConsoleRenderer(colors=debug),
"foreign_pre_chain": LOG_PRE_CHAIN,
},
},
"handlers": {
"console": {
"level": "DEBUG",
"class": "logging.StreamHandler",
"formatter": "console" if debug else "json",
},
},
"loggers": {},
}
handler_level_map = {
"": global_level,
"authentik": global_level,
"django": "WARNING",
"django.request": "ERROR",
"celery": "WARNING",
"selenium": "WARNING",
"docker": "WARNING",
"urllib3": "WARNING",
"websockets": "WARNING",
"daphne": "WARNING",
"kubernetes": "INFO",
"asyncio": "WARNING",
"redis": "WARNING",
"silk": "INFO",
"fsevents": "WARNING",
"uvicorn": "WARNING",
"gunicorn": "INFO",
}
for handler_name, level in handler_level_map.items():
base_config["loggers"][handler_name] = {
"handlers": ["console"],
"level": level,
"propagate": False,
}
return base_config
def add_process_id(logger: Logger, method_name: str, event_dict): def add_process_id(logger: Logger, method_name: str, event_dict):
"""Add the current process ID""" """Add the current process ID"""

View file

@ -77,6 +77,7 @@ class PolicyBindingSerializer(ModelSerializer):
"enabled", "enabled",
"order", "order",
"timeout", "timeout",
"failure_result",
] ]
def validate(self, attrs: OrderedDict) -> OrderedDict: def validate(self, attrs: OrderedDict) -> OrderedDict:

View file

@ -0,0 +1,26 @@
# Generated by Django 4.2.5 on 2023-09-13 18:07
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies", "0010_alter_policy_name"),
]
operations = [
migrations.AddField(
model_name="policybinding",
name="failure_result",
field=models.BooleanField(
default=False, help_text="Result if the Policy execution fails."
),
),
migrations.AlterField(
model_name="policybinding",
name="timeout",
field=models.PositiveIntegerField(
default=30, help_text="Timeout after which Policy execution is terminated."
),
),
]

View file

@ -85,9 +85,12 @@ class PolicyBinding(SerializerModel):
default=False, default=False,
help_text=_("Negates the outcome of the policy. Messages are unaffected."), help_text=_("Negates the outcome of the policy. Messages are unaffected."),
) )
timeout = models.IntegerField( timeout = models.PositiveIntegerField(
default=30, help_text=_("Timeout after which Policy execution is terminated.") default=30, help_text=_("Timeout after which Policy execution is terminated.")
) )
failure_result = models.BooleanField(
default=False, help_text=_("Result if the Policy execution fails.")
)
order = models.IntegerField() order = models.IntegerField()

View file

@ -98,8 +98,8 @@ class PolicyProcess(PROCESS_CLASS):
# Create policy exception event, only when we're not debugging # Create policy exception event, only when we're not debugging
if not self.request.debug: if not self.request.debug:
self.create_event(EventAction.POLICY_EXCEPTION, message=error_string) self.create_event(EventAction.POLICY_EXCEPTION, message=error_string)
LOGGER.debug("P_ENG(proc): error", exc=src_exc) LOGGER.debug("P_ENG(proc): error, using failure result", exc=src_exc)
policy_result = PolicyResult(False, str(src_exc)) policy_result = PolicyResult(self.binding.failure_result, str(src_exc))
policy_result.source_binding = self.binding policy_result.source_binding = self.binding
should_cache = self.request.should_cache should_cache = self.request.should_cache
if should_cache: if should_cache:

View file

@ -1,5 +1,7 @@
"""Reputation policy API Views""" """Reputation policy API Views"""
from django.utils.translation import gettext_lazy as _
from rest_framework import mixins from rest_framework import mixins
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
@ -11,6 +13,11 @@ from authentik.policies.reputation.models import Reputation, ReputationPolicy
class ReputationPolicySerializer(PolicySerializer): class ReputationPolicySerializer(PolicySerializer):
"""Reputation Policy Serializer""" """Reputation Policy Serializer"""
def validate(self, attrs: dict) -> dict:
if not attrs.get("check_ip", False) and not attrs.get("check_username", False):
raise ValidationError(_("Either IP or Username must be checked"))
return super().validate(attrs)
class Meta: class Meta:
model = ReputationPolicy model = ReputationPolicy
fields = PolicySerializer.Meta.fields + [ fields = PolicySerializer.Meta.fields + [

View file

@ -0,0 +1,33 @@
# Generated by Django 4.2.4 on 2023-08-31 10:42
from django.db import migrations, models
import authentik.policies.reputation.models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_reputation", "0004_reputationpolicy_authentik_p_policy__8f0d70_idx"),
]
operations = [
migrations.AddField(
model_name="reputation",
name="expires",
field=models.DateTimeField(
default=authentik.policies.reputation.models.reputation_expiry
),
),
migrations.AddField(
model_name="reputation",
name="expiring",
field=models.BooleanField(default=True),
),
migrations.AlterModelOptions(
name="reputation",
options={
"verbose_name": "Reputation Score",
"verbose_name_plural": "Reputation Scores",
},
),
]

View file

@ -1,13 +1,17 @@
"""authentik reputation request policy""" """authentik reputation request policy"""
from datetime import timedelta
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.db import models
from django.db.models import Sum from django.db.models import Sum
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from structlog import get_logger from structlog import get_logger
from authentik.core.models import ExpiringModel
from authentik.lib.config import CONFIG
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.http import get_client_ip from authentik.lib.utils.http import get_client_ip
from authentik.policies.models import Policy from authentik.policies.models import Policy
@ -17,6 +21,11 @@ LOGGER = get_logger()
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/" CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
def reputation_expiry():
"""Reputation expiry"""
return now() + timedelta(seconds=CONFIG.get_int("reputation.expiry"))
class ReputationPolicy(Policy): class ReputationPolicy(Policy):
"""Return true if request IP/target username's score is below a certain threshold""" """Return true if request IP/target username's score is below a certain threshold"""
@ -59,7 +68,7 @@ class ReputationPolicy(Policy):
verbose_name_plural = _("Reputation Policies") verbose_name_plural = _("Reputation Policies")
class Reputation(SerializerModel): class Reputation(ExpiringModel, SerializerModel):
"""Reputation for user and or IP.""" """Reputation for user and or IP."""
reputation_uuid = models.UUIDField(primary_key=True, unique=True, default=uuid4) reputation_uuid = models.UUIDField(primary_key=True, unique=True, default=uuid4)
@ -69,6 +78,8 @@ class Reputation(SerializerModel):
ip_geo_data = models.JSONField(default=dict) ip_geo_data = models.JSONField(default=dict)
score = models.BigIntegerField(default=0) score = models.BigIntegerField(default=0)
expires = models.DateTimeField(default=reputation_expiry)
updated = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now_add=True)
@property @property
@ -81,4 +92,6 @@ class Reputation(SerializerModel):
return f"Reputation {self.identifier}/{self.ip} @ {self.score}" return f"Reputation {self.identifier}/{self.ip} @ {self.score}"
class Meta: class Meta:
verbose_name = _("Reputation Score")
verbose_name_plural = _("Reputation Scores")
unique_together = ("identifier", "ip") unique_together = ("identifier", "ip")

View file

@ -3,6 +3,8 @@ from django.core.cache import cache
from django.test import RequestFactory, TestCase from django.test import RequestFactory, TestCase
from authentik.core.models import User from authentik.core.models import User
from authentik.lib.generators import generate_id
from authentik.policies.reputation.api import ReputationPolicySerializer
from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy
from authentik.policies.reputation.tasks import save_reputation from authentik.policies.reputation.tasks import save_reputation
from authentik.policies.types import PolicyRequest from authentik.policies.types import PolicyRequest
@ -61,3 +63,8 @@ class TestReputationPolicy(TestCase):
name="reputation-test", threshold=0 name="reputation-test", threshold=0
) )
self.assertTrue(policy.passes(request).passing) self.assertTrue(policy.passes(request).passing)
def test_api(self):
"""Test API Validation"""
no_toggle = ReputationPolicySerializer(data={"name": generate_id(), "threshold": -5})
self.assertFalse(no_toggle.is_valid())

View file

@ -97,6 +97,17 @@ class TestPolicyEngine(TestCase):
self.assertEqual(result.passing, False) self.assertEqual(result.passing, False)
self.assertEqual(result.messages, ("division by zero",)) self.assertEqual(result.messages, ("division by zero",))
def test_engine_policy_error_failure(self):
"""Test policy raising an error flag"""
pbm = PolicyBindingModel.objects.create()
PolicyBinding.objects.create(
target=pbm, policy=self.policy_raises, order=0, failure_result=True
)
engine = PolicyEngine(pbm, self.user)
result = engine.build().result
self.assertEqual(result.passing, True)
self.assertEqual(result.messages, ("division by zero",))
def test_engine_policy_type(self): def test_engine_policy_type(self):
"""Test invalid policy type""" """Test invalid policy type"""
pbm = PolicyBindingModel.objects.create() pbm = PolicyBindingModel.objects.create()

View file

@ -13,10 +13,9 @@ from rest_framework.decorators import action
from rest_framework.fields import CharField, FileField, SerializerMethodField from rest_framework.fields import CharField, FileField, SerializerMethodField
from rest_framework.parsers import MultiPartParser from rest_framework.parsers import MultiPartParser
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from rest_framework.relations import SlugRelatedField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ValidationError from rest_framework.serializers import PrimaryKeyRelatedField, ValidationError
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -168,10 +167,8 @@ class SAMLProviderImportSerializer(PassiveSerializer):
"""Import saml provider from XML Metadata""" """Import saml provider from XML Metadata"""
name = CharField(required=True) name = CharField(required=True)
# Using SlugField because https://github.com/OpenAPITools/openapi-generator/issues/3278 authorization_flow = PrimaryKeyRelatedField(
authorization_flow = SlugRelatedField(
queryset=Flow.objects.filter(designation=FlowDesignation.AUTHORIZATION), queryset=Flow.objects.filter(designation=FlowDesignation.AUTHORIZATION),
slug_field="slug",
) )
file = FileField() file = FileField()

View file

@ -171,6 +171,8 @@ class MetadataProcessor:
entity_descriptor, f"{{{NS_SAML_METADATA}}}IDPSSODescriptor" entity_descriptor, f"{{{NS_SAML_METADATA}}}IDPSSODescriptor"
) )
idp_sso_descriptor.attrib["protocolSupportEnumeration"] = NS_SAML_PROTOCOL idp_sso_descriptor.attrib["protocolSupportEnumeration"] = NS_SAML_PROTOCOL
if self.provider.verification_kp:
idp_sso_descriptor.attrib["WantAuthnRequestsSigned"] = "true"
signing_descriptor = self.get_signing_key_descriptor() signing_descriptor = self.get_signing_key_descriptor()
if signing_descriptor is not None: if signing_descriptor is not None:

View file

@ -89,7 +89,7 @@ class TestSAMLProviderAPI(APITestCase):
{ {
"file": metadata, "file": metadata,
"name": generate_id(), "name": generate_id(),
"authorization_flow": create_test_flow(FlowDesignation.AUTHORIZATION).slug, "authorization_flow": create_test_flow(FlowDesignation.AUTHORIZATION).pk,
}, },
format="multipart", format="multipart",
) )
@ -106,7 +106,7 @@ class TestSAMLProviderAPI(APITestCase):
{ {
"file": metadata, "file": metadata,
"name": generate_id(), "name": generate_id(),
"authorization_flow": create_test_flow().slug, "authorization_flow": create_test_flow().pk,
}, },
format="multipart", format="multipart",
) )

View file

@ -12,7 +12,7 @@ from authentik.lib.xml import lxml_from_string
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.processors.metadata import MetadataProcessor from authentik.providers.saml.processors.metadata import MetadataProcessor
from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser
from authentik.sources.saml.processors.constants import NS_MAP from authentik.sources.saml.processors.constants import NS_MAP, NS_SAML_METADATA
class TestServiceProviderMetadataParser(TestCase): class TestServiceProviderMetadataParser(TestCase):
@ -55,6 +55,24 @@ class TestServiceProviderMetadataParser(TestCase):
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
self.assertTrue(schema.validate(metadata)) self.assertTrue(schema.validate(metadata))
def test_schema_want_authn_requests_signed(self):
"""Test metadata generation with WantAuthnRequestsSigned"""
cert = create_test_cert()
provider = SAMLProvider.objects.create(
name=generate_id(),
authorization_flow=self.flow,
verification_kp=cert,
)
Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=provider,
)
request = self.factory.get("/")
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
idp_sso_descriptor = metadata.findall(f"{{{NS_SAML_METADATA}}}IDPSSODescriptor")[0]
self.assertEqual(idp_sso_descriptor.attrib["WantAuthnRequestsSigned"], "true")
def test_simple(self): def test_simple(self):
"""Test simple metadata without Signing""" """Test simple metadata without Signing"""
metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/simple.xml")) metadata = ServiceProviderMetadataParser().parse(load_fixture("fixtures/simple.xml"))

View file

@ -23,6 +23,8 @@ def post_save_provider(sender: type[Model], instance, created: bool, **_):
@receiver(post_save, sender=Group) @receiver(post_save, sender=Group)
def post_save_scim(sender: type[Model], instance: User | Group, created: bool, **_): def post_save_scim(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler""" """Post save handler"""
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.add.value) scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.add.value)
@ -30,6 +32,8 @@ def post_save_scim(sender: type[Model], instance: User | Group, created: bool, *
@receiver(pre_delete, sender=Group) @receiver(pre_delete, sender=Group)
def pre_delete_scim(sender: type[Model], instance: User | Group, **_): def pre_delete_scim(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler""" """Pre-delete handler"""
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.remove.value) scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.remove.value)
@ -40,6 +44,8 @@ def m2m_changed_scim(
"""Sync group membership""" """Sync group membership"""
if action not in ["post_add", "post_remove"]: if action not in ["post_add", "post_remove"]:
return return
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks # reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups # non-reverse: instance is a User, pk_set is a list of groups
if reverse: if reverse:

View file

@ -172,7 +172,7 @@ class ChannelsLoggingMiddleware:
LOGGER.info( LOGGER.info(
scope["path"], scope["path"],
scheme="ws", scheme="ws",
remote=scope.get("client", [""])[0], remote=headers.get(b"x-forwarded-for", b"").decode(),
user_agent=headers.get(b"user-agent", b"").decode(), user_agent=headers.get(b"user-agent", b"").decode(),
**kwargs, **kwargs,
) )

View file

@ -1,25 +1,21 @@
"""root settings for authentik""" """root settings for authentik"""
import importlib import importlib
import logging
import os import os
from hashlib import sha512 from hashlib import sha512
from pathlib import Path from pathlib import Path
from urllib.parse import quote_plus from urllib.parse import quote_plus
import structlog
from celery.schedules import crontab from celery.schedules import crontab
from sentry_sdk import set_tag from sentry_sdk import set_tag
from authentik import ENV_GIT_HASH_KEY, __version__ from authentik import ENV_GIT_HASH_KEY, __version__
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.logging import add_process_id from authentik.lib.logging import get_logger_config, structlog_configure
from authentik.lib.sentry import sentry_init from authentik.lib.sentry import sentry_init
from authentik.lib.utils.reflection import get_env from authentik.lib.utils.reflection import get_env
from authentik.stages.password import BACKEND_APP_PASSWORD, BACKEND_INBUILT, BACKEND_LDAP from authentik.stages.password import BACKEND_APP_PASSWORD, BACKEND_INBUILT, BACKEND_LDAP
LOGGER = structlog.get_logger()
BASE_DIR = Path(__file__).absolute().parent.parent.parent BASE_DIR = Path(__file__).absolute().parent.parent.parent
STATICFILES_DIRS = [BASE_DIR / Path("web")] STATICFILES_DIRS = [BASE_DIR / Path("web")]
MEDIA_ROOT = BASE_DIR / Path("media") MEDIA_ROOT = BASE_DIR / Path("media")
@ -85,6 +81,7 @@ INSTALLED_APPS = [
"authentik.sources.oauth", "authentik.sources.oauth",
"authentik.sources.plex", "authentik.sources.plex",
"authentik.sources.saml", "authentik.sources.saml",
"authentik.stages.authenticator",
"authentik.stages.authenticator_duo", "authentik.stages.authenticator_duo",
"authentik.stages.authenticator_sms", "authentik.stages.authenticator_sms",
"authentik.stages.authenticator_static", "authentik.stages.authenticator_static",
@ -368,90 +365,9 @@ MEDIA_URL = "/media/"
TEST = False TEST = False
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
# We can't check TEST here as its set later by the test runner
LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
# Additionally, the entire code uses debug as highest level so that would have to be re-written too
if LOG_LEVEL == "TRACE":
LOG_LEVEL = "DEBUG"
structlog.configure_once( structlog_configure()
processors=[ LOGGING = get_logger_config()
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.contextvars.merge_contextvars,
add_process_id,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso", utc=False),
structlog.processors.StackInfoRenderer(),
structlog.processors.dict_tracebacks,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, LOG_LEVEL, logging.WARNING)
),
cache_logger_on_first_use=True,
)
LOG_PRE_CHAIN = [
# Add the log level and a timestamp to the event_dict if the log entry
# is not from structlog.
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(),
structlog.processors.StackInfoRenderer(),
]
LOGGING = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"json": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.processors.JSONRenderer(sort_keys=True),
"foreign_pre_chain": LOG_PRE_CHAIN,
},
"console": {
"()": structlog.stdlib.ProcessorFormatter,
"processor": structlog.dev.ConsoleRenderer(colors=DEBUG),
"foreign_pre_chain": LOG_PRE_CHAIN,
},
},
"handlers": {
"console": {
"level": "DEBUG",
"class": "logging.StreamHandler",
"formatter": "console" if DEBUG else "json",
},
},
"loggers": {},
}
_LOGGING_HANDLER_MAP = {
"": LOG_LEVEL,
"authentik": LOG_LEVEL,
"django": "WARNING",
"django.request": "ERROR",
"celery": "WARNING",
"selenium": "WARNING",
"docker": "WARNING",
"urllib3": "WARNING",
"websockets": "WARNING",
"daphne": "WARNING",
"kubernetes": "INFO",
"asyncio": "WARNING",
"redis": "WARNING",
"silk": "INFO",
"fsevents": "WARNING",
}
for handler_name, level in _LOGGING_HANDLER_MAP.items():
LOGGING["loggers"][handler_name] = {
"handlers": ["console"],
"level": level,
"propagate": False,
}
_DISALLOWED_ITEMS = [ _DISALLOWED_ITEMS = [

View file

@ -20,7 +20,7 @@ class PytestTestRunner: # pragma: no cover
self.failfast = failfast self.failfast = failfast
self.keepdb = keepdb self.keepdb = keepdb
self.args = ["-vv", "--full-trace"] self.args = []
if self.failfast: if self.failfast:
self.args.append("--exitfirst") self.args.append("--exitfirst")
if self.keepdb: if self.keepdb:

View file

@ -153,8 +153,10 @@ class BaseLDAPSynchronizer:
try: try:
value = mapping.evaluate(user=None, request=None, ldap=kwargs, dn=object_dn) value = mapping.evaluate(user=None, request=None, ldap=kwargs, dn=object_dn)
if value is None: if value is None:
self._logger.warning("property mapping returned None", mapping=mapping)
continue continue
if isinstance(value, (bytes)): if isinstance(value, (bytes)):
self._logger.warning("property mapping returned bytes", mapping=mapping)
continue continue
object_field = mapping.object_field object_field = mapping.object_field
if object_field.startswith("attributes."): if object_field.startswith("attributes."):

View file

@ -18,6 +18,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
return "groups" return "groups"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,

View file

@ -24,6 +24,9 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer):
return "membership" return "membership"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_groups:
self.message("Group syncing is disabled for this Source")
return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,

View file

@ -20,6 +20,9 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
return "users" return "users"
def get_objects(self, **kwargs) -> Generator: def get_objects(self, **kwargs) -> Generator:
if not self._source.sync_users:
self.message("User syncing is disabled for this Source")
return iter(())
return self.search_paginator( return self.search_paginator(
search_base=self.base_dn_users, search_base=self.base_dn_users,
search_filter=self._source.user_object_filter, search_filter=self._source.user_object_filter,

View file

@ -45,7 +45,13 @@ class FreeIPA(BaseLDAPSynchronizer):
# 389-ds and this will trigger regardless # 389-ds and this will trigger regardless
if "nsaccountlock" not in attributes: if "nsaccountlock" not in attributes:
return return
is_active = attributes.get("nsaccountlock", False) # For some reason, nsaccountlock is not defined properly in the schema as bool
# hence we get it as a list of strings
_is_locked = str(self._flatten(attributes.get("nsaccountlock", ["FALSE"])))
# So we have to attempt to convert it to a bool
is_locked = _is_locked.lower() == "true"
# And then invert it since freeipa saves locked and we save active
is_active = not is_locked
if is_active != user.is_active: if is_active != user.is_active:
user.is_active = is_active user.is_active = is_active
user.save() user.save()

View file

@ -33,7 +33,13 @@ def ldap_sync_all():
ldap_sync_single(source.pk) ldap_sync_single(source.pk)
@CELERY_APP.task() @CELERY_APP.task(
# We take the configured hours timeout time by 2.5 as we run user and
# group in parallel and then membership, so 2x is to cover the serial tasks,
# and 0.5x on top of that to give some more leeway
soft_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 2.5,
task_time_limit=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 2.5,
)
def ldap_sync_single(source_pk: str): def ldap_sync_single(source_pk: str):
"""Sync a single source""" """Sync a single source"""
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()

View file

@ -0,0 +1,111 @@
"""ldap testing utils"""
from ldap3 import MOCK_SYNC, OFFLINE_DS389_1_3_3, Connection, Server
def mock_freeipa_connection(password: str) -> Connection:
"""Create mock FreeIPA-ish connection"""
server = Server("my_fake_server", get_info=OFFLINE_DS389_1_3_3)
_pass = "foo" # noqa # nosec
connection = Connection(
server,
user="cn=my_user,dc=goauthentik,dc=io",
password=_pass,
client_strategy=MOCK_SYNC,
)
# Entry for password checking
connection.strategy.add_entry(
"cn=user,ou=users,dc=goauthentik,dc=io",
{
"name": "test-user",
"uid": "unique-test-group",
"objectClass": "person",
"displayName": "Erin M. Hagens",
},
)
connection.strategy.add_entry(
"cn=group1,ou=groups,dc=goauthentik,dc=io",
{
"cn": "group1",
"uid": "unique-test-group",
"objectClass": "groupOfNames",
"member": ["cn=user0,ou=users,dc=goauthentik,dc=io"],
},
)
# Group without SID
connection.strategy.add_entry(
"cn=group2,ou=groups,dc=goauthentik,dc=io",
{
"cn": "group2",
"objectClass": "groupOfNames",
},
)
connection.strategy.add_entry(
"cn=user0,ou=users,dc=goauthentik,dc=io",
{
"userPassword": password,
"name": "user0_sn",
"uid": "user0_sn",
"objectClass": "person",
},
)
# User without SID
connection.strategy.add_entry(
"cn=user1,ou=users,dc=goauthentik,dc=io",
{
"userPassword": "test1111",
"name": "user1_sn",
"objectClass": "person",
},
)
# Duplicate users
connection.strategy.add_entry(
"cn=user2,ou=users,dc=goauthentik,dc=io",
{
"userPassword": "test2222",
"name": "user2_sn",
"uid": "unique-test2222",
"objectClass": "person",
},
)
connection.strategy.add_entry(
"cn=user3,ou=users,dc=goauthentik,dc=io",
{
"userPassword": "test2222",
"name": "user2_sn",
"uid": "unique-test2222",
"objectClass": "person",
},
)
# Group with posixGroup and memberUid
connection.strategy.add_entry(
"cn=group-posix,ou=groups,dc=goauthentik,dc=io",
{
"cn": "group-posix",
"objectClass": "posixGroup",
"memberUid": ["user-posix"],
},
)
# User with posixAccount
connection.strategy.add_entry(
"cn=user-posix,ou=users,dc=goauthentik,dc=io",
{
"userPassword": password,
"uid": "user-posix",
"cn": "user-posix",
"objectClass": "posixAccount",
},
)
# Locked out user
connection.strategy.add_entry(
"cn=user-nsaccountlock,ou=users,dc=goauthentik,dc=io",
{
"userPassword": password,
"uid": "user-nsaccountlock",
"cn": "user-nsaccountlock",
"objectClass": "person",
"nsaccountlock": ["TRUE"],
},
)
connection.bind()
return connection

View file

@ -4,7 +4,7 @@ from ldap3 import MOCK_SYNC, OFFLINE_SLAPD_2_4, Connection, Server
def mock_slapd_connection(password: str) -> Connection: def mock_slapd_connection(password: str) -> Connection:
"""Create mock AD connection""" """Create mock SLAPD connection"""
server = Server("my_fake_server", get_info=OFFLINE_SLAPD_2_4) server = Server("my_fake_server", get_info=OFFLINE_SLAPD_2_4)
_pass = "foo" # noqa # nosec _pass = "foo" # noqa # nosec
connection = Connection( connection = Connection(

View file

@ -9,7 +9,7 @@ from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.events.monitored_tasks import TaskInfo, TaskResultStatus from authentik.events.monitored_tasks import TaskInfo, TaskResultStatus
from authentik.lib.generators import generate_key from authentik.lib.generators import generate_id, generate_key
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
@ -17,6 +17,7 @@ from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all from authentik.sources.ldap.tasks import ldap_sync, ldap_sync_all
from authentik.sources.ldap.tests.mock_ad import mock_ad_connection from authentik.sources.ldap.tests.mock_ad import mock_ad_connection
from authentik.sources.ldap.tests.mock_freeipa import mock_freeipa_connection
from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection from authentik.sources.ldap.tests.mock_slapd import mock_slapd_connection
LDAP_PASSWORD = generate_key() LDAP_PASSWORD = generate_key()
@ -70,6 +71,28 @@ class LDAPSyncTests(TestCase):
) )
self.assertTrue(events.exists()) self.assertTrue(events.exists())
def test_sync_mapping(self):
"""Test property mappings"""
none = LDAPPropertyMapping.objects.create(
name=generate_id(), object_field="none", expression="return None"
)
byte_mapping = LDAPPropertyMapping.objects.create(
name=generate_id(), object_field="bytes", expression="return b''"
)
self.source.property_mappings.set(
LDAPPropertyMapping.objects.filter(
Q(managed__startswith="goauthentik.io/sources/ldap/default")
| Q(managed__startswith="goauthentik.io/sources/ldap/ms")
)
)
self.source.property_mappings.add(none, byte_mapping)
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
# we basically just test that the mappings don't throw errors
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync_full()
def test_sync_users_ad(self): def test_sync_users_ad(self):
"""Test user sync""" """Test user sync"""
self.source.property_mappings.set( self.source.property_mappings.set(
@ -78,7 +101,6 @@ class LDAPSyncTests(TestCase):
| Q(managed__startswith="goauthentik.io/sources/ldap/ms") | Q(managed__startswith="goauthentik.io/sources/ldap/ms")
) )
) )
self.source.save()
connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_ad_connection(LDAP_PASSWORD))
# Create the user beforehand so we can set attributes and check they aren't removed # Create the user beforehand so we can set attributes and check they aren't removed
@ -112,7 +134,6 @@ class LDAPSyncTests(TestCase):
| Q(managed__startswith="goauthentik.io/sources/ldap/openldap") | Q(managed__startswith="goauthentik.io/sources/ldap/openldap")
) )
) )
self.source.save()
connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) connection = MagicMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection): with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source) user_sync = UserLDAPSynchronizer(self.source)
@ -120,6 +141,23 @@ class LDAPSyncTests(TestCase):
self.assertTrue(User.objects.filter(username="user0_sn").exists()) self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists()) self.assertFalse(User.objects.filter(username="user1_sn").exists())
def test_sync_users_freeipa_ish(self):
"""Test user sync (FreeIPA-ish), mainly testing vendor quirks"""
self.source.object_uniqueness_field = "uid"
self.source.property_mappings.set(
LDAPPropertyMapping.objects.filter(
Q(managed__startswith="goauthentik.io/sources/ldap/default")
| Q(managed__startswith="goauthentik.io/sources/ldap/openldap")
)
)
connection = MagicMock(return_value=mock_freeipa_connection(LDAP_PASSWORD))
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
user_sync = UserLDAPSynchronizer(self.source)
user_sync.sync_full()
self.assertTrue(User.objects.filter(username="user0_sn").exists())
self.assertFalse(User.objects.filter(username="user1_sn").exists())
self.assertFalse(User.objects.get(username="user-nsaccountlock").is_active)
def test_sync_groups_ad(self): def test_sync_groups_ad(self):
"""Test group sync""" """Test group sync"""
self.source.property_mappings.set( self.source.property_mappings.set(

View file

@ -30,7 +30,7 @@ def check_plex_token(self: MonitoredTask, source_slug: int):
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."])) self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."]))
except RequestException as exc: except RequestException as exc:
error = exception_to_string(exc) error = exception_to_string(exc)
if len(source.plex_token) < 1: if len(source.plex_token) > 0:
error = error.replace(source.plex_token, "$PLEX_TOKEN") error = error.replace(source.plex_token, "$PLEX_TOKEN")
self.set_status( self.set_status(
TaskResult( TaskResult(

View file

@ -0,0 +1,129 @@
"""Authenticator devices helpers"""
from django.db import transaction
def verify_token(user, device_id, token):
"""
Attempts to verify a :term:`token` against a specific device, identified by
:attr:`~authentik.stages.authenticator.models.Device.persistent_id`.
This wraps the verification process in a transaction to ensure that things
like throttling polices are properly enforced.
:param user: The user supplying the token.
:type user: :class:`~django.contrib.auth.models.User`
:param str device_id: A device's persistent_id value.
:param str token: An OTP token to verify.
:returns: The device that accepted ``token``, if any.
:rtype: :class:`~authentik.stages.authenticator.models.Device` or ``None``
"""
from authentik.stages.authenticator.models import Device
verified = None
with transaction.atomic():
device = Device.from_persistent_id(device_id, for_verify=True)
if (device is not None) and (device.user_id == user.pk) and device.verify_token(token):
verified = device
return verified
def match_token(user, token):
"""
Attempts to verify a :term:`token` on every device attached to the given
user until one of them succeeds.
.. warning::
This originally existed for more convenient integration with the admin
site. Its use is no longer recommended and it is not guaranteed to
interact well with more recent features (such as throttling). Tokens
should always be verified against specific devices.
:param user: The user supplying the token.
:type user: :class:`~django.contrib.auth.models.User`
:param str token: An OTP token to verify.
:returns: The device that accepted ``token``, if any.
:rtype: :class:`~authentik.stages.authenticator.models.Device` or ``None``
"""
with transaction.atomic():
for device in devices_for_user(user, for_verify=True):
if device.verify_token(token):
break
else:
device = None
return device
def devices_for_user(user, confirmed=True, for_verify=False):
"""
Return an iterable of all devices registered to the given user.
Returns an empty iterable for anonymous users.
:param user: standard or custom user object.
:type user: :class:`~django.contrib.auth.models.User`
:param bool confirmed: If ``None``, all matching devices are returned.
Otherwise, this can be any true or false value to limit the query
to confirmed or unconfirmed devices, respectively.
:param bool for_verify: If ``True``, we'll load the devices with
:meth:`~django.db.models.query.QuerySet.select_for_update` to prevent
concurrent verifications from succeeding. In which case, this must be
called inside a transaction.
:rtype: iterable
"""
if user.is_anonymous:
return
for model in device_classes():
device_set = model.objects.devices_for_user(user, confirmed=confirmed)
if for_verify:
device_set = device_set.select_for_update()
yield from device_set
def user_has_device(user, confirmed=True):
"""
Return ``True`` if the user has at least one device.
Returns ``False`` for anonymous users.
:param user: standard or custom user object.
:type user: :class:`~django.contrib.auth.models.User`
:param confirmed: If ``None``, all matching devices are considered.
Otherwise, this can be any true or false value to limit the query
to confirmed or unconfirmed devices, respectively.
"""
try:
next(devices_for_user(user, confirmed=confirmed))
except StopIteration:
has_device = False
else:
has_device = True
return has_device
def device_classes():
"""
Returns an iterable of all loaded device models.
"""
from django.apps import apps # isort: skip
from authentik.stages.authenticator.models import Device
for config in apps.get_app_configs():
for model in config.get_models():
if issubclass(model, Device):
yield model

View file

@ -0,0 +1,10 @@
"""Authenticator"""
from django.apps import AppConfig
class AuthentikStageAuthenticatorConfig(AppConfig):
"""Authenticator App config"""
name = "authentik.stages.authenticator"
label = "authentik_stages_authenticator"
verbose_name = "authentik Stages.Authenticator"

View file

@ -0,0 +1,401 @@
"""Base authenticator models"""
from datetime import timedelta
from django.apps import apps
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from django.utils import timezone
from django.utils.functional import cached_property
from authentik.core.models import User
from authentik.stages.authenticator.util import random_number_token
class DeviceManager(models.Manager):
"""
The :class:`~django.db.models.Manager` object installed as
``Device.objects``.
"""
def devices_for_user(self, user, confirmed=None):
"""
Returns a queryset for all devices of this class that belong to the
given user.
:param user: The user.
:type user: :class:`~django.contrib.auth.models.User`
:param confirmed: If ``None``, all matching devices are returned.
Otherwise, this can be any true or false value to limit the query
to confirmed or unconfirmed devices, respectively.
"""
devices = self.model.objects.filter(user=user)
if confirmed is not None:
devices = devices.filter(confirmed=bool(confirmed))
return devices
class Device(models.Model):
"""
Abstract base model for a :term:`device` attached to a user. Plugins must
subclass this to define their OTP models.
.. _unsaved_device_warning:
.. warning::
OTP devices are inherently stateful. For example, verifying a token is
logically a mutating operation on the device, which may involve
incrementing a counter or otherwise consuming a token. A device must be
committed to the database before it can be used in any way.
.. attribute:: user
*ForeignKey*: Foreign key to your user model, as configured by
:setting:`AUTH_USER_MODEL` (:class:`~django.contrib.auth.models.User`
by default).
.. attribute:: name
*CharField*: A human-readable name to help the user identify their
devices.
.. attribute:: confirmed
*BooleanField*: A boolean value that tells us whether this device has
been confirmed as valid. It defaults to ``True``, but subclasses or
individual deployments can force it to ``False`` if they wish to create
a device and then ask the user for confirmation. As a rule, built-in
APIs that enumerate devices will only include those that are confirmed.
.. attribute:: objects
A :class:`~authentik.stages.authenticator.models.DeviceManager`.
"""
user = models.ForeignKey(
User,
help_text="The user that this device belongs to.",
on_delete=models.CASCADE,
)
name = models.CharField(max_length=64, help_text="The human-readable name of this device.")
confirmed = models.BooleanField(default=True, help_text="Is this device ready for use?")
objects = DeviceManager()
class Meta:
abstract = True
def __str__(self):
try:
user = self.user
except ObjectDoesNotExist:
user = None
return "{0} ({1})".format(self.name, user)
@property
def persistent_id(self):
"""
A stable device identifier for forms and APIs.
"""
return "{0}/{1}".format(self.model_label(), self.id)
@classmethod
def model_label(cls):
"""
Returns an identifier for this Django model class.
This is just the standard "<app_label>.<model_name>" form.
"""
return "{0}.{1}".format(cls._meta.app_label, cls._meta.model_name)
@classmethod
def from_persistent_id(cls, persistent_id, for_verify=False):
"""
Loads a device from its persistent id::
device == Device.from_persistent_id(device.persistent_id)
:param bool for_verify: If ``True``, we'll load the device with
:meth:`~django.db.models.query.QuerySet.select_for_update` to
prevent concurrent verifications from succeeding. In which case,
this must be called inside a transaction.
"""
device = None
try:
model_label, device_id = persistent_id.rsplit("/", 1)
app_label, model_name = model_label.split(".")
device_cls = apps.get_model(app_label, model_name)
if issubclass(device_cls, Device):
device_set = device_cls.objects.filter(id=int(device_id))
if for_verify:
device_set = device_set.select_for_update()
device = device_set.first()
except (ValueError, LookupError):
pass
return device
def is_interactive(self):
"""
Returns ``True`` if this is an interactive device. The default
implementation returns ``True`` if
:meth:`~authentik.stages.authenticator.models.Device.generate_challenge` has been
overridden, but subclasses are welcome to provide smarter
implementations.
:rtype: bool
"""
return not hasattr(self.generate_challenge, "stub")
def generate_challenge(self):
"""
Generates a challenge value that the user will need to produce a token.
This method is permitted to have side effects, such as transmitting
information to the user through some other channel (email or SMS,
perhaps). And, of course, some devices may need to commit the
challenge to the database.
:returns: A message to the user. This should be a string that fits
comfortably in the template ``'OTP Challenge: {0}'``. This may
return ``None`` if this device is not interactive.
:rtype: string or ``None``
:raises: Any :exc:`~exceptions.Exception` is permitted. Callers should
trap ``Exception`` and report it to the user.
"""
return None
generate_challenge.stub = True
def verify_is_allowed(self):
"""
Checks whether it is permissible to call :meth:`verify_token`. If it is
allowed, returns ``(True, None)``. Otherwise returns ``(False,
data_dict)``, where ``data_dict`` contains extra information, defined
by the implementation.
This method can be used to implement throttling or locking, for
example. Client code should check this method before calling
:meth:`verify_token` and report problems to the user.
To report specific problems, the data dictionary can return include a
``'reason'`` member with a value from the constants in
:class:`VerifyNotAllowed`. Otherwise, an ``'error_message'`` member
should be provided with an error message.
:meth:`verify_token` should also call this method and return False if
verification is not allowed.
:rtype: (bool, dict or ``None``)
"""
return (True, None)
def verify_token(self, token):
"""
Verifies a token. As a rule, the token should no longer be valid if
this returns ``True``.
:param str token: The OTP token provided by the user.
:rtype: bool
"""
return False
class SideChannelDevice(Device):
"""
Abstract base model for a side-channel :term:`device` attached to a user.
This model implements token generation, verification and expiration, so the
concrete devices only have to implement delivery.
"""
token = models.CharField(max_length=16, blank=True, null=True)
valid_until = models.DateTimeField(
default=timezone.now,
help_text="The timestamp of the moment of expiry of the saved token.",
)
class Meta:
abstract = True
def generate_token(self, length=6, valid_secs=300, commit=True):
"""
Generates a token of the specified length, then sets it on the model
and sets the expiration of the token on the model.
Pass 'commit=False' to avoid calling self.save().
:param int length: Number of decimal digits in the generated token.
:param int valid_secs: Amount of seconds the token should be valid.
:param bool commit: Whether to autosave the generated token.
"""
self.token = random_number_token(length)
self.valid_until = timezone.now() + timedelta(seconds=valid_secs)
if commit:
self.save()
def verify_token(self, token):
"""
Verifies a token by content and expiry.
On success, the token is cleared and the device saved.
:param str token: The OTP token provided by the user.
:rtype: bool
"""
_now = timezone.now()
if (self.token is not None) and (token == self.token) and (_now < self.valid_until):
self.token = None
self.valid_until = _now
self.save()
return True
return False
class VerifyNotAllowed:
"""
Constants that may be returned in the ``reason`` member of the extra
information dictionary returned by
:meth:`~authentik.stages.authenticator.models.Device.verify_is_allowed`
.. data:: N_FAILED_ATTEMPTS
Indicates that verification is disallowed because of ``n`` successive
failed attempts. The data dictionary should include the value of ``n``
in member ``failure_count``
"""
N_FAILED_ATTEMPTS = "N_FAILED_ATTEMPTS"
class ThrottlingMixin(models.Model):
"""
Mixin class for models that want throttling behaviour.
This implements exponential back-off for verifying tokens. Subclasses must
implement :meth:`get_throttle_factor`, and must use the
:meth:`verify_is_allowed`, :meth:`throttle_reset` and
:meth:`throttle_increment` methods from within their verify_token() method.
See the implementation of
:class:`~authentik.stages.authenticator.plugins.otp_email.models.EmailDevice` for an example.
"""
throttling_failure_timestamp = models.DateTimeField(
null=True,
blank=True,
default=None,
help_text=(
"A timestamp of the last failed verification attempt. "
"Null if last attempt succeeded."
),
)
throttling_failure_count = models.PositiveIntegerField(
default=0, help_text="Number of successive failed attempts."
)
def verify_is_allowed(self):
"""
If verification is allowed, returns ``(True, None)``.
Otherwise, returns ``(False, data_dict)``.
``data_dict`` contains further information. Currently it can be::
{
'reason': VerifyNotAllowed.N_FAILED_ATTEMPTS,
'failure_count': n
}
where ``n`` is the number of successive failures. See
:class:`~authentik.stages.authenticator.models.VerifyNotAllowed`.
"""
if (
self.throttling_enabled
and self.throttling_failure_count > 0
and self.throttling_failure_timestamp is not None
):
now = timezone.now()
delay = (now - self.throttling_failure_timestamp).total_seconds()
# Required delays should be 1, 2, 4, 8 ...
delay_required = self.get_throttle_factor() * (2 ** (self.throttling_failure_count - 1))
if delay < delay_required:
return (
False,
{
"reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
"failure_count": self.throttling_failure_count,
"locked_until": self.throttling_failure_timestamp
+ timedelta(seconds=delay_required),
},
)
return super().verify_is_allowed()
def throttle_reset(self, commit=True):
"""
Call this method to reset throttling (normally when a verify attempt
succeeded).
Pass 'commit=False' to avoid calling self.save().
"""
self.throttling_failure_timestamp = None
self.throttling_failure_count = 0
if commit:
self.save()
def throttle_increment(self, commit=True):
"""
Call this method to increase throttling (normally when a verify attempt
failed).
Pass 'commit=False' to avoid calling self.save().
"""
self.throttling_failure_timestamp = timezone.now()
self.throttling_failure_count += 1
if commit:
self.save()
@cached_property
def throttling_enabled(self) -> bool:
"""Check if throttling is enabled"""
return self.get_throttle_factor() > 0
def get_throttle_factor(self): # pragma: no cover
"""
This must be implemented to return the throttle factor.
The number of seconds required between verification attempts will be
:math:`c2^{n-1}` where `c` is this factor and `n` is the number of
previous failures. A factor of 1 translates to delays of 1, 2, 4, 8,
etc. seconds. A factor of 0 disables the throttling.
Normally this is just a wrapper for a plugin-specific setting like
:setting:`OTP_EMAIL_THROTTLE_FACTOR`.
"""
raise NotImplementedError()
class Meta:
abstract = True

View file

@ -0,0 +1,199 @@
"""OATH helpers"""
import hmac
from hashlib import sha1
from struct import pack
from time import time
# pylint: disable=invalid-name
def hotp(key: bytes, counter: int, digits=6) -> int:
"""
Implementation of the HOTP algorithm from `RFC 4226
<http://tools.ietf.org/html/rfc4226#section-5>`_.
:param bytes key: The shared secret. A 20-byte string is recommended.
:param int counter: The password counter.
:param int digits: The number of decimal digits to generate.
:returns: The HOTP token.
:rtype: int
>>> key = b'12345678901234567890'
>>> for c in range(10):
... hotp(key, c)
755224
287082
359152
969429
338314
254676
287922
162583
399871
520489
"""
msg = pack(b">Q", counter)
hs = hmac.new(key, msg, sha1).digest()
hs = list(iter(hs))
offset = hs[19] & 0x0F
bin_code = (
(hs[offset] & 0x7F) << 24 | hs[offset + 1] << 16 | hs[offset + 2] << 8 | hs[offset + 3]
)
return bin_code % pow(10, digits)
def totp(key: bytes, step=30, t0=0, digits=6, drift=0) -> int:
"""
Implementation of the TOTP algorithm from `RFC 6238
<http://tools.ietf.org/html/rfc6238#section-4>`_.
:param bytes key: The shared secret. A 20-byte string is recommended.
:param int step: The time step in seconds. The time-based code changes
every ``step`` seconds.
:param int t0: The Unix time at which to start counting time steps.
:param int digits: The number of decimal digits to generate.
:param int drift: The number of time steps to add or remove. Delays and
clock differences might mean that you have to look back or forward a
step or two in order to match a token.
:returns: The TOTP token.
:rtype: int
>>> key = b'12345678901234567890'
>>> now = int(time())
>>> for delta in range(0, 200, 20):
... totp(key, t0=(now-delta))
755224
755224
287082
359152
359152
969429
338314
338314
254676
287922
"""
return TOTP(key, step, t0, digits, drift).token()
class TOTP:
"""
An alternate TOTP interface.
This provides access to intermediate steps of the computation. This is a
living object: the return values of ``t`` and ``token`` will change along
with other properties and with the passage of time.
:param bytes key: The shared secret. A 20-byte string is recommended.
:param int step: The time step in seconds. The time-based code changes
every ``step`` seconds.
:param int t0: The Unix time at which to start counting time steps.
:param int digits: The number of decimal digits to generate.
:param int drift: The number of time steps to add or remove. Delays and
clock differences might mean that you have to look back or forward a
step or two in order to match a token.
>>> key = b'12345678901234567890'
>>> totp = TOTP(key)
>>> totp.time = 0
>>> totp.t()
0
>>> totp.token()
755224
>>> totp.time = 30
>>> totp.t()
1
>>> totp.token()
287082
>>> totp.verify(287082)
True
>>> totp.verify(359152)
False
>>> totp.verify(359152, tolerance=1)
True
>>> totp.drift
1
>>> totp.drift = 0
>>> totp.verify(359152, tolerance=1, min_t=3)
False
>>> totp.drift
0
>>> del totp.time
>>> totp.t0 = int(time()) - 60
>>> totp.t()
2
>>> totp.token()
359152
"""
# pylint: disable=too-many-arguments
def __init__(self, key: bytes, step=30, t0=0, digits=6, drift=0):
self.key = key
self.step = step
self.t0 = t0
self.digits = digits
self.drift = drift
self._time = None
def token(self):
"""The computed TOTP token."""
return hotp(self.key, self.t(), digits=self.digits)
def t(self):
"""The computed time step."""
return ((int(self.time) - self.t0) // self.step) + self.drift
@property
def time(self):
"""
The current time.
By default, this returns time.time() each time it is accessed. If you
want to generate a token at a specific time, you can set this property
to a fixed value instead. Deleting the value returns it to its 'live'
state.
"""
return self._time if (self._time is not None) else time()
@time.setter
def time(self, value):
self._time = value
@time.deleter
def time(self):
self._time = None
def verify(self, token, tolerance=0, min_t=None):
"""
A high-level verification helper.
:param int token: The provided token.
:param int tolerance: The amount of clock drift you're willing to
accommodate, in steps. We'll look for the token at t values in
[t - tolerance, t + tolerance].
:param int min_t: The minimum t value we'll accept. As a rule, this
should be one larger than the largest t value of any previously
accepted token.
:rtype: bool
Iff this returns True, `self.drift` will be updated to reflect the
drift value that was necessary to match the token.
"""
drift_orig = self.drift
verified = False
for offset in range(-tolerance, tolerance + 1):
self.drift = drift_orig + offset
if (min_t is not None) and (self.t() < min_t):
continue
if self.token() == token:
verified = True
break
else:
self.drift = drift_orig
return verified

View file

@ -0,0 +1,220 @@
"""Base authenticator tests"""
from datetime import timedelta
from threading import Thread
from django.contrib.auth.models import AnonymousUser
from django.db import connection
from django.test import TestCase, TransactionTestCase
from django.test.utils import override_settings
from django.utils import timezone
from freezegun import freeze_time
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
from authentik.stages.authenticator import match_token, user_has_device, verify_token
from authentik.stages.authenticator.models import Device, VerifyNotAllowed
class TestThread(Thread):
"Django testing quirk: threads have to close their DB connections."
__test__ = False
def run(self):
super().run()
connection.close()
class ThrottlingTestMixin:
"""
Generic tests for throttled devices.
Any concrete device implementation that uses throttling should define a
TestCase subclass that includes this as a base class. This will help verify
a correct integration of ThrottlingMixin.
Subclasses are responsible for populating self.device with a device to test
as well as implementing methods to generate tokens to test with.
"""
device: Device
def valid_token(self):
"""Returns a valid token to pass to our device under test."""
raise NotImplementedError()
def invalid_token(self):
"""Returns an invalid token to pass to our device under test."""
raise NotImplementedError()
#
# Tests
#
def test_delay_imposed_after_fail(self):
"""Test delay imposed after fail"""
verified1 = self.device.verify_token(self.invalid_token())
self.assertFalse(verified1)
verified2 = self.device.verify_token(self.valid_token())
self.assertFalse(verified2)
def test_delay_after_fail_expires(self):
"""Test delay after fail expires"""
verified1 = self.device.verify_token(self.invalid_token())
self.assertFalse(verified1)
with freeze_time() as frozen_time:
# With default settings initial delay is 1 second
frozen_time.tick(delta=timedelta(seconds=1.1))
verified2 = self.device.verify_token(self.valid_token())
self.assertTrue(verified2)
def test_throttling_failure_count(self):
"""Test throttling failure count"""
self.assertEqual(self.device.throttling_failure_count, 0)
for _ in range(0, 5):
self.device.verify_token(self.invalid_token())
# Only the first attempt will increase throttling_failure_count,
# the others will all be within 1 second of first
# and therefore not count as attempts.
self.assertEqual(self.device.throttling_failure_count, 1)
def test_verify_is_allowed(self):
"""Test verify allowed"""
# Initially should be allowed
verify_is_allowed1, data1 = self.device.verify_is_allowed()
self.assertEqual(verify_is_allowed1, True)
self.assertEqual(data1, None)
# After failure, verify is not allowed
with freeze_time():
self.device.verify_token(self.invalid_token())
verify_is_allowed2, data2 = self.device.verify_is_allowed()
self.assertEqual(verify_is_allowed2, False)
self.assertEqual(
data2,
{
"reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
"failure_count": 1,
"locked_until": timezone.now() + timezone.timedelta(seconds=1),
},
)
# After a successful attempt, should be allowed again
with freeze_time() as frozen_time:
frozen_time.tick(delta=timedelta(seconds=1.1))
self.device.verify_token(self.valid_token())
verify_is_allowed3, data3 = self.device.verify_is_allowed()
self.assertEqual(verify_is_allowed3, True)
self.assertEqual(data3, None)
@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
class APITestCase(TestCase):
"""Test API"""
def setUp(self):
self.alice = create_test_admin_user("alice")
self.bob = create_test_admin_user("bob")
device = self.alice.staticdevice_set.create()
self.valid = generate_id(length=16)
device.token_set.create(token=self.valid)
def test_user_has_device(self):
"""Test user_has_device"""
with self.subTest(user="anonymous"):
self.assertFalse(user_has_device(AnonymousUser()))
with self.subTest(user="alice"):
self.assertTrue(user_has_device(self.alice))
with self.subTest(user="bob"):
self.assertFalse(user_has_device(self.bob))
def test_verify_token(self):
"""Test verify_token"""
device = self.alice.staticdevice_set.first()
verified = verify_token(self.alice, device.persistent_id, "bogus")
self.assertIsNone(verified)
verified = verify_token(self.alice, device.persistent_id, self.valid)
self.assertIsNotNone(verified)
def test_match_token(self):
"""Test match_token"""
verified = match_token(self.alice, "bogus")
self.assertIsNone(verified)
verified = match_token(self.alice, self.valid)
self.assertEqual(verified, self.alice.staticdevice_set.first())
@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
class ConcurrencyTestCase(TransactionTestCase):
"""Test concurrent verifications"""
def setUp(self):
self.alice = create_test_admin_user("alice")
self.bob = create_test_admin_user("bob")
self.valid = generate_id(length=16)
for user in [self.alice, self.bob]:
device = user.staticdevice_set.create()
device.token_set.create(token=self.valid)
def test_verify_token(self):
"""Test verify_token in a thread"""
class VerifyThread(Thread):
"""Verifier thread"""
__test__ = False
def __init__(self, user, device_id, token):
super().__init__()
self.user = user
self.device_id = device_id
self.token = token
self.verified = None
def run(self):
self.verified = verify_token(self.user, self.device_id, self.token)
connection.close()
device = self.alice.staticdevice_set.get()
threads = [VerifyThread(device.user, device.persistent_id, self.valid) for _ in range(10)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)
def test_match_token(self):
"""Test match_token in a thread"""
class VerifyThread(Thread):
"""Verifier thread"""
__test__ = False
def __init__(self, user, token):
super().__init__()
self.user = user
self.token = token
self.verified = None
def run(self):
self.verified = match_token(self.user, self.token)
connection.close()
threads = [VerifyThread(self.alice, self.valid) for _ in range(10)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)

View file

@ -0,0 +1,86 @@
"""Authenticator utils"""
import random
import string
from binascii import unhexlify
from os import urandom
from django.core.exceptions import ValidationError
def hex_validator(length=0):
"""
Returns a function to be used as a model validator for a hex-encoded
CharField. This is useful for secret keys of all kinds::
def key_validator(value):
return hex_validator(20)(value)
key = models.CharField(max_length=40,
validators=[key_validator], help_text='A hex-encoded 20-byte secret key')
:param int length: If greater than 0, validation will fail unless the
decoded value is exactly this number of bytes.
:rtype: function
>>> hex_validator()('0123456789abcdef')
>>> hex_validator(8)(b'0123456789abcdef')
>>> hex_validator()('phlebotinum') # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValidationError: ['phlebotinum is not valid hex-encoded data.']
>>> hex_validator(9)('0123456789abcdef') # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValidationError: ['0123456789abcdef does not represent exactly 9 bytes.']
"""
def _validator(value):
try:
if isinstance(value, str):
value = value.encode()
unhexlify(value)
except Exception:
raise ValidationError("{0} is not valid hex-encoded data.".format(value))
if (length > 0) and (len(value) != length * 2):
raise ValidationError("{0} does not represent exactly {1} bytes.".format(value, length))
return _validator
def random_hex(length=20):
"""
Returns a string of random bytes encoded as hex.
This uses :func:`os.urandom`, so it should be suitable for generating
cryptographic keys.
:param int length: The number of (decoded) bytes to return.
:returns: A string of hex digits.
:rtype: str
"""
return urandom(length).hex()
def random_number_token(length=6):
"""
Returns a string of random digits encoded as string.
:param int length: The number of digits to return.
:returns: A string of decimal digits.
:rtype: str
"""
rand = random.SystemRandom()
if hasattr(rand, "choices"):
digits = rand.choices(string.digits, k=length)
else:
digits = (rand.choice(string.digits) for i in range(length))
return "".join(digits)

View file

@ -5,7 +5,6 @@ from django.contrib.auth import get_user_model
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.views import View from django.views import View
from django_otp.models import Device
from duo_client.admin import Admin from duo_client.admin import Admin
from duo_client.auth import Auth from duo_client.auth import Auth
from rest_framework.serializers import BaseSerializer, Serializer from rest_framework.serializers import BaseSerializer, Serializer
@ -14,6 +13,7 @@ from authentik.core.types import UserSettingSerializer
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.http import authentik_user_agent from authentik.lib.utils.http import authentik_user_agent
from authentik.stages.authenticator.models import Device
class AuthenticatorDuoStage(ConfigurableStage, FriendlyNamedStage, Stage): class AuthenticatorDuoStage(ConfigurableStage, FriendlyNamedStage, Stage):

View file

@ -6,7 +6,6 @@ from django.contrib.auth import get_user_model
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.views import View from django.views import View
from django_otp.models import SideChannelDevice
from requests.exceptions import RequestException from requests.exceptions import RequestException
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
@ -21,6 +20,7 @@ from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.stages.authenticator.models import SideChannelDevice
LOGGER = get_logger() LOGGER = get_logger()

View file

@ -1,6 +1,5 @@
"""AuthenticatorStaticStage API Views""" """AuthenticatorStaticStage API Views"""
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from django_otp.plugins.otp_static.models import StaticDevice, StaticToken
from rest_framework import mixins from rest_framework import mixins
from rest_framework.filters import OrderingFilter, SearchFilter from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser
@ -10,7 +9,11 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.api.authorization import OwnerFilter, OwnerPermissions from authentik.api.authorization import OwnerFilter, OwnerPermissions
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.flows.api.stages import StageSerializer from authentik.flows.api.stages import StageSerializer
from authentik.stages.authenticator_static.models import AuthenticatorStaticStage from authentik.stages.authenticator_static.models import (
AuthenticatorStaticStage,
StaticDevice,
StaticToken,
)
class AuthenticatorStaticStageSerializer(StageSerializer): class AuthenticatorStaticStageSerializer(StageSerializer):

View file

@ -0,0 +1,70 @@
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
(
"authentik_stages_authenticator_static",
"0007_authenticatorstaticstage_token_length_and_more",
),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="StaticDevice",
fields=[
(
"id",
models.AutoField(
verbose_name="ID", serialize=False, auto_created=True, primary_key=True
),
),
(
"name",
models.CharField(
help_text="The human-readable name of this device.", max_length=64
),
),
(
"confirmed",
models.BooleanField(default=True, help_text="Is this device ready for use?"),
),
(
"user",
models.ForeignKey(
help_text="The user that this device belongs to.",
to=settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
),
),
],
options={
"abstract": False,
},
bases=(models.Model,),
),
migrations.CreateModel(
name="StaticToken",
fields=[
(
"id",
models.AutoField(
verbose_name="ID", serialize=False, auto_created=True, primary_key=True
),
),
("token", models.CharField(max_length=16, db_index=True)),
(
"device",
models.ForeignKey(
related_name="token_set",
to="authentik_stages_authenticator_static.staticdevice",
on_delete=models.CASCADE,
),
),
],
options={},
bases=(models.Model,),
),
]

View file

@ -0,0 +1,33 @@
# Generated by Django 3.0.5 on 2020-04-16 13:41
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_stages_authenticator_static", "0008_initial"),
]
operations = [
migrations.AddField(
model_name="staticdevice",
name="throttling_failure_count",
field=models.PositiveIntegerField(
default=0, help_text="Number of successive failed attempts."
),
),
migrations.AddField(
model_name="staticdevice",
name="throttling_failure_timestamp",
field=models.DateTimeField(
blank=True,
default=None,
help_text="A timestamp of the last failed verification attempt. Null if last attempt succeeded.",
null=True,
),
),
migrations.AlterModelOptions(
name="staticdevice",
options={"verbose_name": "Static device", "verbose_name_plural": "Static devices"},
),
]

View file

@ -1,6 +1,9 @@
"""Static Authenticator models""" """Static Authenticator models"""
from base64 import b32encode
from os import urandom
from typing import Optional from typing import Optional
from django.conf import settings
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.views import View from django.views import View
@ -8,6 +11,8 @@ from rest_framework.serializers import BaseSerializer
from authentik.core.types import UserSettingSerializer from authentik.core.types import UserSettingSerializer
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
from authentik.lib.models import SerializerModel
from authentik.stages.authenticator.models import Device, ThrottlingMixin
class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage): class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage):
@ -46,3 +51,76 @@ class AuthenticatorStaticStage(ConfigurableStage, FriendlyNamedStage, Stage):
class Meta: class Meta:
verbose_name = _("Static Authenticator Stage") verbose_name = _("Static Authenticator Stage")
verbose_name_plural = _("Static Authenticator Stages") verbose_name_plural = _("Static Authenticator Stages")
class StaticDevice(SerializerModel, ThrottlingMixin, Device):
"""
A static :class:`~authentik.stages.authenticator.models.Device` simply consists of random
tokens shared by the database and the user.
These are frequently used as emergency tokens in case a user's normal
device is lost or unavailable. They can be consumed in any order; each
token will be removed from the database as soon as it is used.
This model has no fields of its own, but serves as a container for
:class:`StaticToken` objects.
.. attribute:: token_set
The RelatedManager for our tokens.
"""
@property
def serializer(self) -> type[BaseSerializer]:
from authentik.stages.authenticator_static.api import StaticDeviceSerializer
return StaticDeviceSerializer
def get_throttle_factor(self):
return getattr(settings, "OTP_STATIC_THROTTLE_FACTOR", 1)
def verify_token(self, token):
verify_allowed, _ = self.verify_is_allowed()
if verify_allowed:
match = self.token_set.filter(token=token).first()
if match is not None:
match.delete()
self.throttle_reset()
else:
self.throttle_increment()
else:
match = None
return match is not None
class Meta(Device.Meta):
verbose_name = _("Static device")
verbose_name_plural = _("Static devices")
class StaticToken(models.Model):
"""
A single token belonging to a :class:`StaticDevice`.
.. attribute:: device
*ForeignKey*: A foreign key to :class:`StaticDevice`.
.. attribute:: token
*CharField*: A random string up to 16 characters.
"""
device = models.ForeignKey(StaticDevice, related_name="token_set", on_delete=models.CASCADE)
token = models.CharField(max_length=16, db_index=True)
@staticmethod
def random_token():
"""
Returns a new random string that can be used as a static token.
:rtype: bytes
"""
return b32encode(urandom(5)).decode("utf-8").lower()

View file

@ -1,5 +0,0 @@
"""Static Authenticator settings"""
INSTALLED_APPS = [
"django_otp.plugins.otp_static",
]

View file

@ -1,9 +1,9 @@
"""totp authenticator signals""" """totp authenticator signals"""
from django.db.models.signals import pre_delete from django.db.models.signals import pre_delete
from django.dispatch import receiver from django.dispatch import receiver
from django_otp.plugins.otp_static.models import StaticDevice
from authentik.events.models import Event from authentik.events.models import Event
from authentik.stages.authenticator_static.models import StaticDevice
@receiver(pre_delete, sender=StaticDevice) @receiver(pre_delete, sender=StaticDevice)

View file

@ -1,12 +1,15 @@
"""Static OTP Setup stage""" """Static OTP Setup stage"""
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django_otp.plugins.otp_static.models import StaticDevice, StaticToken
from rest_framework.fields import CharField, ListField from rest_framework.fields import CharField, ListField
from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge
from authentik.flows.stage import ChallengeStageView from authentik.flows.stage import ChallengeStageView
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.stages.authenticator_static.models import AuthenticatorStaticStage from authentik.stages.authenticator_static.models import (
AuthenticatorStaticStage,
StaticDevice,
StaticToken,
)
SESSION_STATIC_DEVICE = "static_device" SESSION_STATIC_DEVICE = "static_device"
SESSION_STATIC_TOKENS = "static_device_tokens" SESSION_STATIC_TOKENS = "static_device_tokens"

View file

@ -1,9 +1,13 @@
"""Test Static API""" """Test Static API"""
from django.test.utils import override_settings
from django.urls import reverse from django.urls import reverse
from django_otp.plugins.otp_static.models import StaticDevice
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.core.models import User from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
from authentik.stages.authenticator.tests import TestCase, ThrottlingTestMixin
from authentik.stages.authenticator_static.models import StaticDevice
class AuthenticatorStaticStageTests(APITestCase): class AuthenticatorStaticStageTests(APITestCase):
@ -18,3 +22,42 @@ class AuthenticatorStaticStageTests(APITestCase):
reverse("authentik_api:staticdevice-detail", kwargs={"pk": dev.pk}) reverse("authentik_api:staticdevice-detail", kwargs={"pk": dev.pk})
) )
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
class DeviceTest(TestCase):
"""A few generic tests to get us started."""
def setUp(self):
self.user = create_test_admin_user("alice")
def test_str(self):
"""Test __str__ of model"""
device = StaticDevice.objects.create(user=self.user, name="Device")
str(device)
def test_str_unpopulated(self):
"""Test __str__ of model"""
device = StaticDevice()
str(device)
@override_settings(
OTP_STATIC_THROTTLE_FACTOR=1,
)
class ThrottlingTestCase(ThrottlingTestMixin, TestCase):
"""Test static device throttling"""
def setUp(self):
user = create_test_admin_user("alice")
self.device = user.staticdevice_set.create()
self.device.token_set.create(token=generate_id(length=16))
self.device.token_set.create(token=generate_id(length=16))
self.device.token_set.create(token=generate_id(length=16))
def valid_token(self):
return self.device.token_set.first().token
def invalid_token(self):
return "bogus"

View file

@ -1,6 +1,5 @@
"""AuthenticatorTOTPStage API Views""" """AuthenticatorTOTPStage API Views"""
from django_filters.rest_framework.backends import DjangoFilterBackend from django_filters.rest_framework.backends import DjangoFilterBackend
from django_otp.plugins.otp_totp.models import TOTPDevice
from rest_framework import mixins from rest_framework import mixins
from rest_framework.filters import OrderingFilter, SearchFilter from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser
@ -10,7 +9,7 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.api.authorization import OwnerFilter, OwnerPermissions from authentik.api.authorization import OwnerFilter, OwnerPermissions
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.flows.api.stages import StageSerializer from authentik.flows.api.stages import StageSerializer
from authentik.stages.authenticator_totp.models import AuthenticatorTOTPStage from authentik.stages.authenticator_totp.models import AuthenticatorTOTPStage, TOTPDevice
class AuthenticatorTOTPStageSerializer(StageSerializer): class AuthenticatorTOTPStageSerializer(StageSerializer):

Some files were not shown because too many files have changed in this diff Show more