diff --git a/authentik/flows/views.py b/authentik/flows/views.py index 830f46382..8a79d20bd 100644 --- a/authentik/flows/views.py +++ b/authentik/flows/views.py @@ -14,12 +14,7 @@ from django.utils.decorators import method_decorator from django.views.decorators.clickjacking import xframe_options_sameorigin from django.views.generic import View from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import ( - OpenApiParameter, - OpenApiResponse, - PolymorphicProxySerializer, - extend_schema, -) +from drf_spectacular.utils import OpenApiParameter, PolymorphicProxySerializer, extend_schema from rest_framework.permissions import AllowAny from rest_framework.views import APIView from sentry_sdk import capture_exception diff --git a/authentik/policies/expression/evaluator.py b/authentik/policies/expression/evaluator.py index b8302b00a..25b70292e 100644 --- a/authentik/policies/expression/evaluator.py +++ b/authentik/policies/expression/evaluator.py @@ -3,8 +3,10 @@ from ipaddress import ip_address, ip_network from typing import TYPE_CHECKING, Optional from django.http import HttpRequest +from django_otp import devices_for_user from structlog.stdlib import get_logger +from authentik.core.models import User from authentik.flows.planner import PLAN_CONTEXT_SSO from authentik.lib.expression.evaluator import BaseEvaluator from authentik.lib.utils.http import get_client_ip @@ -28,6 +30,7 @@ class PolicyEvaluator(BaseEvaluator): self._messages = [] self._context["ak_logger"] = get_logger(policy_name) self._context["ak_message"] = self.expr_func_message + self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator self._context["ip_address"] = ip_address self._context["ip_network"] = ip_network self._filename = policy_name or "PolicyEvaluator" @@ -36,6 +39,19 @@ class PolicyEvaluator(BaseEvaluator): """Wrapper to append to messages list, which is returned with PolicyResult""" self._messages.append(message) + def expr_func_user_has_authenticator( + self, user: User, device_type: Optional[str] = None + ) -> bool: + """Check if a user has any authenticator devices, optionally matching *device_type*""" + user_devices = devices_for_user(user) + if device_type: + for device in user_devices: + device_class = device.__class__.__name__.lower().replace("device", "") + if device_class == device_type: + return True + return False + return len(user_devices) > 0 + def set_policy_request(self, request: PolicyRequest): """Update context based on policy request (if http request is given, update that too)""" # update website/docs/expressions/_objects.md diff --git a/website/docs/policies/expression.mdx b/website/docs/policies/expression.mdx index 4017a2c96..7403be683 100644 --- a/website/docs/policies/expression.mdx +++ b/website/docs/policies/expression.mdx @@ -25,6 +25,23 @@ ak_message("Access denied") return False ``` +### `ak_user_has_authenticator(user: User, device_type: Optional[str] = None)` (2021.9+) + +Check if a user has any authenticator devices. Only fully validated devices are counted. + +Optionally, you can filter a specific device type. The following options are valid: + +- `totp` +- `duo` +- `static` +- `webauthn` + +Example: + +```python +return ak_user_has_authenticator(request.user) +``` + import Functions from '../expressions/_functions.md'