ci: update pyright (#3546)
This commit is contained in:
parent
03a3f1bd6f
commit
62f93c83d4
|
@ -27,7 +27,7 @@ runs:
|
|||
docker-compose -f .github/actions/setup/docker-compose.yml up -d
|
||||
poetry env use python3.10
|
||||
poetry install
|
||||
npm install -g pyright@1.1.136
|
||||
cd web && npm ci
|
||||
- name: Generate config
|
||||
shell: poetry run python {0}
|
||||
run: |
|
||||
|
|
12
Makefile
12
Makefile
|
@ -148,25 +148,25 @@ website-watch:
|
|||
|
||||
# These targets are use by GitHub actions to allow usage of matrix
|
||||
# which makes the YAML File a lot smaller
|
||||
|
||||
PY_SOURCES=authentik tests lifecycle
|
||||
ci--meta-debug:
|
||||
python -V
|
||||
node --version
|
||||
|
||||
ci-pylint: ci--meta-debug
|
||||
pylint authentik tests lifecycle
|
||||
pylint $(PY_SOURCES)
|
||||
|
||||
ci-black: ci--meta-debug
|
||||
black --check authentik tests lifecycle
|
||||
black --check $(PY_SOURCES)
|
||||
|
||||
ci-isort: ci--meta-debug
|
||||
isort --check authentik tests lifecycle
|
||||
isort --check $(PY_SOURCES)
|
||||
|
||||
ci-bandit: ci--meta-debug
|
||||
bandit -r authentik tests lifecycle
|
||||
bandit -r $(PY_SOURCES)
|
||||
|
||||
ci-pyright: ci--meta-debug
|
||||
pyright e2e lifecycle
|
||||
./web/node_modules/.bin/pyright $(PY_SOURCES)
|
||||
|
||||
ci-pending-migrations: ci--meta-debug
|
||||
ak makemigrations --check
|
||||
|
|
|
@ -16,7 +16,7 @@ from authentik.providers.oauth2.models import RefreshToken
|
|||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def validate_auth(header: bytes) -> str:
|
||||
def validate_auth(header: bytes) -> Optional[str]:
|
||||
"""Validate that the header is in a correct format,
|
||||
returns type and credentials"""
|
||||
auth_credentials = header.decode().strip()
|
||||
|
|
|
@ -4,7 +4,7 @@ from glob import glob
|
|||
from pathlib import Path
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
|
|
@ -105,9 +105,9 @@ class Blueprint:
|
|||
|
||||
version: int = field(default=1)
|
||||
entries: list[BlueprintEntry] = field(default_factory=list)
|
||||
context: dict = field(default_factory=dict)
|
||||
|
||||
metadata: Optional[BlueprintMetadata] = field(default=None)
|
||||
context: Optional[dict] = field(default_factory=dict)
|
||||
|
||||
|
||||
class YAMLTag:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Blueprint exporter"""
|
||||
from typing import Iterator
|
||||
from typing import Iterable
|
||||
from uuid import UUID
|
||||
|
||||
from django.apps import apps
|
||||
|
@ -34,7 +34,7 @@ class Exporter:
|
|||
Event,
|
||||
]
|
||||
|
||||
def get_entries(self) -> Iterator[BlueprintEntry]:
|
||||
def get_entries(self) -> Iterable[BlueprintEntry]:
|
||||
"""Get blueprint entries"""
|
||||
for model in apps.get_models():
|
||||
if not is_model_allowed(model):
|
||||
|
@ -96,7 +96,7 @@ class FlowExporter(Exporter):
|
|||
"pbm_uuid", flat=True
|
||||
)
|
||||
|
||||
def walk_stages(self) -> Iterator[BlueprintEntry]:
|
||||
def walk_stages(self) -> Iterable[BlueprintEntry]:
|
||||
"""Convert all stages attached to self.flow into BlueprintEntry objects"""
|
||||
stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
|
||||
for stage in stages:
|
||||
|
@ -104,13 +104,13 @@ class FlowExporter(Exporter):
|
|||
pass
|
||||
yield BlueprintEntry.from_model(stage, "name")
|
||||
|
||||
def walk_stage_bindings(self) -> Iterator[BlueprintEntry]:
|
||||
def walk_stage_bindings(self) -> Iterable[BlueprintEntry]:
|
||||
"""Convert all bindings attached to self.flow into BlueprintEntry objects"""
|
||||
bindings = FlowStageBinding.objects.filter(target=self.flow).select_related()
|
||||
for binding in bindings:
|
||||
yield BlueprintEntry.from_model(binding, "target", "stage", "order")
|
||||
|
||||
def walk_policies(self) -> Iterator[BlueprintEntry]:
|
||||
def walk_policies(self) -> Iterable[BlueprintEntry]:
|
||||
"""Walk over all policies. This is done at the beginning of the export for stages that have
|
||||
a direct foreign key to a policy."""
|
||||
# Special case for PromptStage as that has a direct M2M to policy, we have to ensure
|
||||
|
@ -121,21 +121,21 @@ class FlowExporter(Exporter):
|
|||
for policy in policies:
|
||||
yield BlueprintEntry.from_model(policy)
|
||||
|
||||
def walk_policy_bindings(self) -> Iterator[BlueprintEntry]:
|
||||
def walk_policy_bindings(self) -> Iterable[BlueprintEntry]:
|
||||
"""Walk over all policybindings relative to us. This is run at the end of the export, as
|
||||
we are sure all objects exist now."""
|
||||
bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related()
|
||||
for binding in bindings:
|
||||
yield BlueprintEntry.from_model(binding, "policy", "target", "order")
|
||||
|
||||
def walk_stage_prompts(self) -> Iterator[BlueprintEntry]:
|
||||
def walk_stage_prompts(self) -> Iterable[BlueprintEntry]:
|
||||
"""Walk over all prompts associated with any PromptStages"""
|
||||
prompt_stages = PromptStage.objects.filter(flow=self.flow)
|
||||
for stage in prompt_stages:
|
||||
for prompt in stage.fields.all():
|
||||
yield BlueprintEntry.from_model(prompt)
|
||||
|
||||
def get_entries(self) -> Iterator[BlueprintEntry]:
|
||||
def get_entries(self) -> Iterable[BlueprintEntry]:
|
||||
entries = []
|
||||
entries.append(BlueprintEntry.from_model(self.flow, "slug"))
|
||||
if self.with_stage_prompts:
|
||||
|
|
|
@ -3,7 +3,7 @@ from contextlib import contextmanager
|
|||
from copy import deepcopy
|
||||
from typing import Any, Optional
|
||||
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from dacite.exceptions import DaciteError
|
||||
from deepmerge import always_merger
|
||||
from django.db import transaction
|
||||
|
@ -143,7 +143,8 @@ class Importer:
|
|||
if not is_model_allowed(model):
|
||||
raise EntryInvalidError(f"Model {model} not allowed")
|
||||
if issubclass(model, BaseMetaModel):
|
||||
serializer = model.serializer()(data=entry.get_attrs(self.__import))
|
||||
serializer_class: type[Serializer] = model.serializer()
|
||||
serializer = serializer_class(data=entry.get_attrs(self.__import))
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
except ValidationError as exc:
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
"""Base models"""
|
||||
from typing import Optional
|
||||
|
||||
from django.apps import apps
|
||||
from django.db.models import Model
|
||||
from rest_framework.serializers import Serializer
|
||||
|
@ -51,7 +49,7 @@ class MetaModelRegistry:
|
|||
models.append(value)
|
||||
return models
|
||||
|
||||
def get_model(self, app_label: str, model_id: str) -> Optional[type[Model]]:
|
||||
def get_model(self, app_label: str, model_id: str) -> type[Model]:
|
||||
"""Get model checks if any virtual models are registered, and falls back
|
||||
to actual django models"""
|
||||
if app_label.lower() == self.virtual_prefix:
|
||||
|
|
|
@ -4,7 +4,7 @@ from hashlib import sha512
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
|
@ -77,7 +77,9 @@ def blueprints_find():
|
|||
LOGGER.warning("invalid blueprint version", version=version, path=str(path))
|
||||
continue
|
||||
file_hash = sha512(path.read_bytes()).hexdigest()
|
||||
blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime)
|
||||
blueprint = BlueprintFile(
|
||||
str(path.relative_to(root)), version, file_hash, int(path.stat().st_mtime)
|
||||
)
|
||||
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
|
||||
blueprints.append(blueprint)
|
||||
LOGGER.info(
|
||||
|
@ -136,6 +138,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
|
|||
def apply_blueprint(self: MonitoredTask, instance_pk: str):
|
||||
"""Apply single blueprint"""
|
||||
self.save_on_success = False
|
||||
instance: Optional[BlueprintInstance] = None
|
||||
try:
|
||||
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
|
||||
self.set_uid(slugify(instance.name))
|
||||
|
@ -170,7 +173,9 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
|
|||
BlueprintRetrievalFailed,
|
||||
EntryInvalidError,
|
||||
) as exc:
|
||||
if instance:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||
finally:
|
||||
if instance:
|
||||
instance.save()
|
||||
|
|
|
@ -9,7 +9,7 @@ from django.db.models.signals import post_save, pre_delete
|
|||
|
||||
from authentik import __version__
|
||||
from authentik.core.models import User
|
||||
from authentik.events.middleware import IGNORED_MODELS
|
||||
from authentik.events.middleware import should_log_model
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.utils import model_to_dict
|
||||
|
||||
|
@ -50,7 +50,7 @@ class Command(BaseCommand):
|
|||
# pylint: disable=unused-argument
|
||||
def post_save_handler(sender, instance: Model, created: bool, **_):
|
||||
"""Signal handler for all object's post_save"""
|
||||
if isinstance(instance, IGNORED_MODELS):
|
||||
if not should_log_model(instance):
|
||||
return
|
||||
|
||||
action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED
|
||||
|
@ -66,7 +66,7 @@ class Command(BaseCommand):
|
|||
# pylint: disable=unused-argument
|
||||
def pre_delete_handler(sender, instance: Model, **_):
|
||||
"""Signal handler for all object's pre_delete"""
|
||||
if isinstance(instance, IGNORED_MODELS): # pragma: no cover
|
||||
if not should_log_model(instance): # pragma: no cover
|
||||
return
|
||||
|
||||
Event.new(EventAction.MODEL_DELETED, model=model_to_dict(instance)).set_user(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""authentik admin Middleware to impersonate users"""
|
||||
from contextvars import ContextVar
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
|
@ -13,9 +13,9 @@ RESPONSE_HEADER_ID = "X-authentik-id"
|
|||
KEY_AUTH_VIA = "auth_via"
|
||||
KEY_USER = "user"
|
||||
|
||||
CTX_REQUEST_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "request_id", default=None)
|
||||
CTX_HOST = ContextVar(STRUCTLOG_KEY_PREFIX + "host", default=None)
|
||||
CTX_AUTH_VIA = ContextVar(STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None)
|
||||
CTX_REQUEST_ID = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "request_id", default=None)
|
||||
CTX_HOST = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "host", default=None)
|
||||
CTX_AUTH_VIA = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None)
|
||||
|
||||
|
||||
class ImpersonateMiddleware:
|
||||
|
|
|
@ -52,5 +52,5 @@ def create_test_cert() -> CertificateKeyPair:
|
|||
subject_alt_names=["goauthentik.io"],
|
||||
validity_days=360,
|
||||
)
|
||||
builder.name = generate_id()
|
||||
builder.common_name = generate_id()
|
||||
return builder.save()
|
||||
|
|
|
@ -26,7 +26,7 @@ class CertificateBuilder:
|
|||
self.common_name = "authentik Self-signed Certificate"
|
||||
self.cert = CertificateKeyPair()
|
||||
|
||||
def save(self) -> Optional[CertificateKeyPair]:
|
||||
def save(self) -> CertificateKeyPair:
|
||||
"""Save generated certificate as model"""
|
||||
if not self.__certificate:
|
||||
raise ValueError("Certificated hasn't been built yet")
|
||||
|
|
|
@ -6,12 +6,7 @@ from uuid import uuid4
|
|||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES, PUBLIC_KEY_TYPES
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
from cryptography.x509 import Certificate, load_pem_x509_certificate
|
||||
from django.db import models
|
||||
|
@ -42,8 +37,8 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
|||
)
|
||||
|
||||
_cert: Optional[Certificate] = None
|
||||
_private_key: Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey] = None
|
||||
_public_key: Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey] = None
|
||||
_private_key: Optional[PRIVATE_KEY_TYPES] = None
|
||||
_public_key: Optional[PUBLIC_KEY_TYPES] = None
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
|
@ -61,7 +56,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
|||
return self._cert
|
||||
|
||||
@property
|
||||
def public_key(self) -> Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey]:
|
||||
def public_key(self) -> Optional[PUBLIC_KEY_TYPES]:
|
||||
"""Get public key of the private key"""
|
||||
if not self._public_key:
|
||||
self._public_key = self.private_key.public_key()
|
||||
|
@ -70,7 +65,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
|||
@property
|
||||
def private_key(
|
||||
self,
|
||||
) -> Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey]:
|
||||
) -> Optional[PRIVATE_KEY_TYPES]:
|
||||
"""Get python cryptography PrivateKey instance"""
|
||||
if not self._private_key and self.key_data != "":
|
||||
try:
|
||||
|
|
|
@ -19,7 +19,7 @@ from authentik.flows.models import FlowToken
|
|||
from authentik.lib.sentry import before_send
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
IGNORED_MODELS = [
|
||||
IGNORED_MODELS = (
|
||||
Event,
|
||||
Notification,
|
||||
UserObjectPermission,
|
||||
|
@ -27,12 +27,14 @@ IGNORED_MODELS = [
|
|||
StaticToken,
|
||||
Session,
|
||||
FlowToken,
|
||||
]
|
||||
if settings.DEBUG:
|
||||
from silk.models import Request, Response, SQLQuery
|
||||
)
|
||||
|
||||
IGNORED_MODELS += [Request, Response, SQLQuery]
|
||||
IGNORED_MODELS = tuple(IGNORED_MODELS)
|
||||
|
||||
def should_log_model(model: Model) -> bool:
|
||||
"""Return true if operation on `model` should be logged"""
|
||||
if model.__module__.startswith("silk"):
|
||||
return False
|
||||
return not isinstance(model, IGNORED_MODELS)
|
||||
|
||||
|
||||
class AuditMiddleware:
|
||||
|
@ -109,7 +111,7 @@ class AuditMiddleware:
|
|||
user: User, request: HttpRequest, sender, instance: Model, created: bool, **_
|
||||
):
|
||||
"""Signal handler for all object's post_save"""
|
||||
if isinstance(instance, IGNORED_MODELS):
|
||||
if not should_log_model(instance):
|
||||
return
|
||||
|
||||
action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED
|
||||
|
@ -119,7 +121,7 @@ class AuditMiddleware:
|
|||
# pylint: disable=unused-argument
|
||||
def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_):
|
||||
"""Signal handler for all object's pre_delete"""
|
||||
if isinstance(instance, IGNORED_MODELS): # pragma: no cover
|
||||
if not should_log_model(instance): # pragma: no cover
|
||||
return
|
||||
|
||||
EventNewThread(
|
||||
|
|
|
@ -152,6 +152,7 @@ class FlowExecutorView(APIView):
|
|||
token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first()
|
||||
if not token:
|
||||
return None
|
||||
plan = None
|
||||
try:
|
||||
plan = token.plan
|
||||
except (AttributeError, EOFError, ImportError, IndexError) as exc:
|
||||
|
|
|
@ -20,7 +20,7 @@ ENV_PREFIX = "AUTHENTIK"
|
|||
ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local")
|
||||
|
||||
|
||||
def get_path_from_dict(root: dict, path: str, sep=".", default=None):
|
||||
def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
|
||||
"""Recursively walk through `root`, checking each part of `path` split by `sep`.
|
||||
If at any point a dict does not exist, return default"""
|
||||
for comp in path.split(sep):
|
||||
|
@ -180,7 +180,7 @@ class ConfigLoader:
|
|||
# pyright: reportGeneralTypeIssues=false
|
||||
if comp not in root:
|
||||
root[comp] = {}
|
||||
root = root.get(comp)
|
||||
root = root.get(comp, {})
|
||||
root[path_parts[-1]] = value
|
||||
|
||||
def y_bool(self, path: str, default=False) -> bool:
|
||||
|
|
|
@ -12,5 +12,4 @@ class TestReflectionUtils(TestCase):
|
|||
|
||||
def test_path_to_class(self):
|
||||
"""Test path_to_class"""
|
||||
self.assertIsNone(path_to_class(None))
|
||||
self.assertEqual(path_to_class("datetime.datetime"), datetime)
|
||||
|
|
|
@ -29,10 +29,8 @@ def class_to_path(cls: type) -> str:
|
|||
return f"{cls.__module__}.{cls.__name__}"
|
||||
|
||||
|
||||
def path_to_class(path: str | None) -> type | None:
|
||||
def path_to_class(path: str = "") -> type:
|
||||
"""Import module and return class"""
|
||||
if not path:
|
||||
return None
|
||||
parts = path.split(".")
|
||||
package = ".".join(parts[:-1])
|
||||
_class = getattr(import_module(package), parts[-1])
|
||||
|
|
|
@ -5,7 +5,7 @@ from enum import IntEnum
|
|||
from typing import Any, Optional
|
||||
|
||||
from channels.exceptions import DenyConnection
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from dacite.data import Data
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi
|
||||
|
||||
from authentik.outposts.controllers.base import FIELD_MANAGER
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
|||
from typing import Iterable, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.core.cache import cache
|
||||
from django.db import IntegrityError, models, transaction
|
||||
|
@ -74,7 +74,7 @@ class OutpostConfig:
|
|||
kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls")
|
||||
kubernetes_service_type: str = field(default="ClusterIP")
|
||||
kubernetes_disabled_components: list[str] = field(default_factory=list)
|
||||
kubernetes_image_pull_secrets: Optional[list[str]] = field(default_factory=list)
|
||||
kubernetes_image_pull_secrets: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class OutpostModel(Model):
|
||||
|
|
|
@ -74,10 +74,14 @@ def outpost_service_connection_state(connection_pk: Any):
|
|||
)
|
||||
if not connection:
|
||||
return
|
||||
cls = None
|
||||
if isinstance(connection, DockerServiceConnection):
|
||||
cls = DockerClient
|
||||
if isinstance(connection, KubernetesServiceConnection):
|
||||
cls = KubernetesClient
|
||||
if not cls:
|
||||
LOGGER.warning("No class found for service connection", connection=connection)
|
||||
return
|
||||
try:
|
||||
with cls(connection) as client:
|
||||
state = client.fetch_state()
|
||||
|
|
|
@ -11,7 +11,7 @@ from urllib.parse import urlparse, urlunparse
|
|||
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
from django.utils import dateformat, timezone
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from dacite import from_dict
|
||||
from dacite.core import from_dict
|
||||
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi
|
||||
|
||||
from authentik.outposts.controllers.base import FIELD_MANAGER
|
||||
|
|
|
@ -39,8 +39,8 @@ class BaseOAuthClient:
|
|||
profile_url = self.source.type.profile_url or ""
|
||||
if self.source.type.urls_customizable and self.source.profile_url:
|
||||
profile_url = self.source.profile_url
|
||||
try:
|
||||
response = self.do_request("get", profile_url, token=token)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text)
|
||||
|
|
|
@ -138,12 +138,12 @@ class UserprofileHeaderAuthClient(OAuth2Client):
|
|||
profile_url = self.source.type.profile_url or ""
|
||||
if self.source.type.urls_customizable and self.source.profile_url:
|
||||
profile_url = self.source.profile_url
|
||||
try:
|
||||
response = self.session.request(
|
||||
"get",
|
||||
profile_url,
|
||||
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""GitHub OAuth Views"""
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
|
@ -21,14 +21,14 @@ class GitHubOAuthRedirect(OAuthRedirect):
|
|||
class GitHubOAuth2Client(OAuth2Client):
|
||||
"""GitHub OAuth2 Client"""
|
||||
|
||||
def get_github_emails(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||
def get_github_emails(self, token: dict[str, str]) -> list[dict[str, Any]]:
|
||||
"""Get Emails from the GitHub API"""
|
||||
profile_url = self.source.type.profile_url or ""
|
||||
if self.source.type.urls_customizable and self.source.profile_url:
|
||||
profile_url = self.source.profile_url
|
||||
profile_url += "/emails"
|
||||
try:
|
||||
response = self.do_request("get", profile_url, token=token)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
self.logger.warning("Unable to fetch github emails", exc=exc)
|
||||
|
|
|
@ -29,11 +29,11 @@ class MailcowOAuth2Client(OAuth2Client):
|
|||
profile_url = self.source.type.profile_url or ""
|
||||
if self.source.type.urls_customizable and self.source.profile_url:
|
||||
profile_url = self.source.profile_url
|
||||
try:
|
||||
response = self.session.request(
|
||||
"get",
|
||||
f"{profile_url}?access_token={token['access_token']}",
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)
|
||||
|
|
|
@ -13,9 +13,11 @@ from django_otp.models import Device
|
|||
from rest_framework.fields import CharField, JSONField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from structlog.stdlib import get_logger
|
||||
from webauthn import generate_authentication_options, verify_authentication_response
|
||||
from webauthn.helpers import base64url_to_bytes, options_to_json
|
||||
from webauthn.authentication.generate_authentication_options import generate_authentication_options
|
||||
from webauthn.authentication.verify_authentication_response import verify_authentication_response
|
||||
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
|
||||
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
||||
from webauthn.helpers.options_to_json import options_to_json
|
||||
from webauthn.helpers.structs import AuthenticationCredential
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
|
|
|
@ -4,7 +4,8 @@ from time import sleep
|
|||
from django.test.client import RequestFactory
|
||||
from django.urls.base import reverse
|
||||
from rest_framework.serializers import ValidationError
|
||||
from webauthn.helpers import base64url_to_bytes, bytes_to_base64url
|
||||
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
|
||||
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.flows.models import Flow, FlowStageBinding, NotConfiguredAction
|
||||
|
|
|
@ -5,15 +5,19 @@ from django.http import HttpRequest, HttpResponse
|
|||
from django.http.request import QueryDict
|
||||
from rest_framework.fields import CharField, JSONField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from webauthn import generate_registration_options, options_to_json, verify_registration_response
|
||||
from webauthn.helpers import bytes_to_base64url
|
||||
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
|
||||
from webauthn.helpers.exceptions import InvalidRegistrationResponse
|
||||
from webauthn.helpers.options_to_json import options_to_json
|
||||
from webauthn.helpers.structs import (
|
||||
AuthenticatorSelectionCriteria,
|
||||
PublicKeyCredentialCreationOptions,
|
||||
RegistrationCredential,
|
||||
)
|
||||
from webauthn.registration.verify_registration_response import VerifiedRegistration
|
||||
from webauthn.registration.generate_registration_options import generate_registration_options
|
||||
from webauthn.registration.verify_registration_response import (
|
||||
VerifiedRegistration,
|
||||
verify_registration_response,
|
||||
)
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.flows.challenge import (
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from base64 import b64decode
|
||||
|
||||
from django.urls import reverse
|
||||
from webauthn.helpers import bytes_to_base64url
|
||||
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.flows.markers import StageMarker
|
||||
|
|
|
@ -62,6 +62,8 @@ if __name__ == "__main__":
|
|||
try:
|
||||
for migration in Path(__file__).parent.absolute().glob("system_migrations/*.py"):
|
||||
spec = spec_from_file_location("lifecycle.system_migrations", migration)
|
||||
if not spec:
|
||||
continue
|
||||
mod = module_from_spec(spec)
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
spec.loader.exec_module(mod)
|
||||
|
|
|
@ -3,14 +3,17 @@ ignore = [
|
|||
"**/migrations/**",
|
||||
"**/node_modules/**"
|
||||
]
|
||||
|
||||
reportMissingTypeStubs = false
|
||||
strictParameterNoneValue = true
|
||||
strictDictionaryInference = true
|
||||
strictListInference = true
|
||||
reportOptionalMemberAccess = false
|
||||
# Sadly pyright still has issues with enums, and they fall under general type issues
|
||||
# so we have to disable those for now
|
||||
reportGeneralTypeIssues = false
|
||||
verboseOutput = false
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "Linux"
|
||||
pythonVersion = "3.10"
|
||||
pythonPlatform = "All"
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
|
|
@ -198,7 +198,7 @@ class TestProviderLDAP(SeleniumTestCase):
|
|||
search_scope=SUBTREE,
|
||||
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
|
||||
)
|
||||
response = _connection.response
|
||||
response: dict = _connection.response
|
||||
# Remove raw_attributes to make checking easier
|
||||
for obj in response:
|
||||
del obj["raw_attributes"]
|
||||
|
|
|
@ -26,7 +26,7 @@ from tests.e2e.utils import SeleniumTestCase, retry
|
|||
CONFIG_PATH = "/tmp/dex.yml" # nosec
|
||||
|
||||
|
||||
class OAUth1Callback(OAuthCallback):
|
||||
class OAuth1Callback(OAuthCallback):
|
||||
"""OAuth1 Callback with custom getters"""
|
||||
|
||||
def get_user_id(self, info: dict[str, str]) -> str:
|
||||
|
@ -47,7 +47,7 @@ class OAUth1Callback(OAuthCallback):
|
|||
class OAUth1Type(SourceType):
|
||||
"""OAuth1 Type definition"""
|
||||
|
||||
callback_view = OAUth1Callback
|
||||
callback_view = OAuth1Callback
|
||||
name = "OAuth1"
|
||||
slug = "oauth1"
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from selenium.webdriver.common.by import By
|
|||
from selenium.webdriver.common.keys import Keys
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from selenium.webdriver.remote.webelement import WebElement
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.users import UserSerializer
|
||||
|
@ -143,7 +143,9 @@ class SeleniumTestCase(StaticLiveServerTestCase):
|
|||
"""same as self.url() but show URL in shell"""
|
||||
return f"{self.live_server_url}/if/user/#{view}"
|
||||
|
||||
def get_shadow_root(self, selector: str, container: Optional[WebElement] = None) -> WebElement:
|
||||
def get_shadow_root(
|
||||
self, selector: str, container: Optional[WebElement | WebDriver] = None
|
||||
) -> WebElement:
|
||||
"""Get shadow root element's inner shadowRoot"""
|
||||
if not container:
|
||||
container = self.driver
|
||||
|
|
|
@ -62,6 +62,7 @@
|
|||
"lit": "^2.3.1",
|
||||
"moment": "^2.29.4",
|
||||
"prettier": "^2.7.1",
|
||||
"pyright": "^1.1.269",
|
||||
"rapidoc": "^9.3.3",
|
||||
"rollup": "^2.79.0",
|
||||
"rollup-plugin-copy": "^3.4.0",
|
||||
|
@ -7361,6 +7362,18 @@
|
|||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/pyright": {
|
||||
"version": "1.1.269",
|
||||
"resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.269.tgz",
|
||||
"integrity": "sha512-n3Q1ccQ4nzMmFGC8B6WUmuoylrkxrknlvpt1ODDbmXUFJlMQSNGLIoZYFZlnP0lt0b4tpO+nDaK1q0lI0nQaxA==",
|
||||
"bin": {
|
||||
"pyright": "index.js",
|
||||
"pyright-langserver": "langserver.index.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/qrjs": {
|
||||
"version": "0.1.2",
|
||||
"resolved": "https://registry.npmjs.org/qrjs/-/qrjs-0.1.2.tgz",
|
||||
|
@ -14573,6 +14586,11 @@
|
|||
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.1.1.tgz",
|
||||
"integrity": "sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A=="
|
||||
},
|
||||
"pyright": {
|
||||
"version": "1.1.269",
|
||||
"resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.269.tgz",
|
||||
"integrity": "sha512-n3Q1ccQ4nzMmFGC8B6WUmuoylrkxrknlvpt1ODDbmXUFJlMQSNGLIoZYFZlnP0lt0b4tpO+nDaK1q0lI0nQaxA=="
|
||||
},
|
||||
"qrjs": {
|
||||
"version": "0.1.2",
|
||||
"resolved": "https://registry.npmjs.org/qrjs/-/qrjs-0.1.2.tgz",
|
||||
|
|
|
@ -105,6 +105,7 @@
|
|||
"lit": "^2.3.1",
|
||||
"moment": "^2.29.4",
|
||||
"prettier": "^2.7.1",
|
||||
"pyright": "^1.1.269",
|
||||
"rapidoc": "^9.3.3",
|
||||
"rollup": "^2.79.0",
|
||||
"rollup-plugin-copy": "^3.4.0",
|
||||
|
|
|
@ -31,7 +31,7 @@ Generally speaking, authentik is a Django application, ran by gunicorn, proxied
|
|||
|
||||
Most functions and classes have type-hints and docstrings, so it is recommended to install a Python Type-checking Extension in your IDE to navigate around the code.
|
||||
|
||||
Before committing code, run `make lint` to ensure your code is formatted well. This also requires `pyright@1.1.136`, which can be installed with npm.
|
||||
Before committing code, run `make lint` to ensure your code is formatted well. This also requires `pyright`, which is installed in the `web/` folder to make dependency management easier.
|
||||
|
||||
Run `make gen` to generate an updated OpenAPI document for any changes you made.
|
||||
|
||||
|
|
Reference in New Issue