*: fix multiple tests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-11-16 10:34:51 +01:00
parent 425b87a6d0
commit 638e8d741f
10 changed files with 50 additions and 25 deletions

View File

@ -86,7 +86,7 @@ class SystemSerializer(PassiveSerializer):
def get_embedded_outpost_host(self, request: Request) -> str: def get_embedded_outpost_host(self, request: Request) -> str:
"""Get the FQDN configured on the embedded outpost""" """Get the FQDN configured on the embedded outpost"""
outposts = Outpost.objects.filter(managed=MANAGED_OUTPOST) outposts = Outpost.objects.filter(managed=MANAGED_OUTPOST)
if not outposts.exists(): if not outposts.exists(): # pragma: no cover
return "" return ""
return outposts.first().config.authentik_host return outposts.first().config.authentik_host

View File

@ -74,6 +74,7 @@ class TestAdminTasks(TestCase):
action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"} action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"}
) )
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"})
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={})
clear_update_notifications() clear_update_notifications()
self.assertFalse( self.assertFalse(
Event.objects.filter( Event.objects.filter(

View File

@ -1,18 +0,0 @@
"""Throttling classes"""
from typing import Type
from django.views import View
from rest_framework.request import Request
from rest_framework.throttling import ScopedRateThrottle
class SessionThrottle(ScopedRateThrottle):
"""Throttle based on session key"""
def allow_request(self, request: Request, view):
if request._request.user.is_superuser:
return True
return super().allow_request(request, view)
def get_cache_key(self, request: Request, view: Type[View]) -> str:
return f"authentik-throttle-session-{request._request.session.session_key}"

View File

@ -53,8 +53,7 @@ class RequestIDMiddleware:
response = self.get_response(request) response = self.get_response(request)
response[RESPONSE_HEADER_ID] = request.request_id response[RESPONSE_HEADER_ID] = request.request_id
setattr(response, "ak_context", {}) setattr(response, "ak_context", {})
if auth_via := LOCAL.authentik.get(KEY_AUTH_VIA, None): response.ak_context.update(LOCAL.authentik)
response.ak_context[KEY_AUTH_VIA] = auth_via
response.ak_context[KEY_USER] = request.user.username response.ak_context[KEY_USER] = request.user.username
for key in list(LOCAL.authentik.keys()): for key in list(LOCAL.authentik.keys()):
del LOCAL.authentik[key] del LOCAL.authentik[key]

View File

@ -23,7 +23,7 @@ def model_tester_factory(test_model: Type[Stage]) -> Callable:
model_class = test_model() model_class = test_model()
self.assertTrue(issubclass(model_class.type, StageView)) self.assertTrue(issubclass(model_class.type, StageView))
self.assertIsNotNone(test_model.component) self.assertIsNotNone(test_model.component)
_ = test_model.ui_user_settings _ = model_class.ui_user_settings
return tester return tester

View File

@ -3,6 +3,7 @@ from django.test import RequestFactory, TestCase
from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents, User from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents, User
from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip
from authentik.lib.views import bad_request_message
class TestHTTP(TestCase): class TestHTTP(TestCase):
@ -12,6 +13,11 @@ class TestHTTP(TestCase):
self.user = User.objects.get(username="akadmin") self.user = User.objects.get(username="akadmin")
self.factory = RequestFactory() self.factory = RequestFactory()
def test_bad_request_message(self):
"""test bad_request_message"""
request = self.factory.get("/")
self.assertEqual(bad_request_message(request, "foo").status_code, 400)
def test_normal(self): def test_normal(self):
"""Test normal request""" """Test normal request"""
request = self.factory.get("/") request = self.factory.get("/")

View File

@ -37,7 +37,8 @@ class PytestTestRunner: # pragma: no cover
argv.append("-vv") argv.append("-vv")
if self.failfast: if self.failfast:
argv.append("--exitfirst") argv.append("--exitfirst")
argv.append("--reuse-db") if self.keepdb:
argv.append("--reuse-db")
argv.extend(test_labels) argv.extend(test_labels)
return pytest.main(argv) return pytest.main(argv)

View File

@ -0,0 +1,36 @@
"""OpenID Type tests"""
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
# https://connect2id.com/products/server/docs/api/userinfo
OPENID_USER = {
"sub": "83692",
"name": "Alice Adams",
"email": "alice@example.com",
"department": "Engineering",
"birthdate": "1975-12-31",
"nickname": "foo",
}
class TestTypeOpenID(TestCase):
"""OAuth Source tests"""
def setUp(self):
self.source = OAuthSource.objects.create(
name="test",
slug="test",
provider_type="openidconnect",
authorization_url="",
profile_url="",
consumer_key="",
)
def test_enroll_context(self):
"""Test OpenID Enrollment context"""
ak_context = OpenIDConnectOAuth2Callback().get_user_enroll_context(OPENID_USER)
self.assertEqual(ak_context["username"], OPENID_USER["nickname"])
self.assertEqual(ak_context["email"], OPENID_USER["email"])
self.assertEqual(ak_context["name"], OPENID_USER["name"])

View File

@ -29,7 +29,7 @@ class UserWriteStageView(StageView):
"""Allow use of attributes.foo.bar when writing to a user, with full """Allow use of attributes.foo.bar when writing to a user, with full
recursion""" recursion"""
parts = key.replace("_", ".").split(".") parts = key.replace("_", ".").split(".")
if len(parts) < 1: if len(parts) < 1: # pragma: no cover
return return
# Function will always be called with a key like attribute. # Function will always be called with a key like attribute.
# this is just a sanity check to ensure that is removed # this is just a sanity check to ensure that is removed

View File

@ -92,7 +92,7 @@ class TestUserWriteStage(APITestCase):
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session.save() session.save()
response = self.client.get( response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
) )