add tenant api tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt 2023-12-05 17:42:28 +01:00
parent 5c56dab82f
commit cdfbb48cb6
No known key found for this signature in database
GPG Key ID: 9C3FA22FABF1AA8D
6 changed files with 155 additions and 10 deletions

View File

@ -1,7 +1,7 @@
"""Serializer for tenants models""" """Serializer for tenants models"""
from hmac import compare_digest from hmac import compare_digest
from django.http import Http404 from django.http import HttpResponseNotFound
from django_tenants.utils import get_tenant from django_tenants.utils import get_tenant
from rest_framework import permissions from rest_framework import permissions
from rest_framework.authentication import get_authorization_header from rest_framework.authentication import get_authorization_header
@ -19,13 +19,15 @@ from authentik.lib.config import CONFIG
from authentik.tenants.models import Domain, Tenant from authentik.tenants.models import Domain, Tenant
class TenantManagementKeyPermission(permissions.BasePermission): class TenantApiKeyPermission(permissions.BasePermission):
"""Authentication based on tenant_management_key""" """Authentication based on tenants.api_key"""
def has_permission(self, request: Request, view: View) -> bool: def has_permission(self, request: Request, view: View) -> bool:
key = CONFIG.get("tenants.api_key", "")
if not key:
return False
token = validate_auth(get_authorization_header(request)) token = validate_auth(get_authorization_header(request))
key = CONFIG.get("tenants.api_key") if token is None:
if compare_digest("", key):
return False return False
return compare_digest(token, key) return compare_digest(token, key)
@ -53,12 +55,13 @@ class TenantViewSet(ModelViewSet):
"domains__domain", "domains__domain",
] ]
ordering = ["schema_name"] ordering = ["schema_name"]
permission_classes = [TenantManagementKeyPermission] authentication_classes = []
permission_classes = [TenantApiKeyPermission]
filter_backends = [OrderingFilter, SearchFilter] filter_backends = [OrderingFilter, SearchFilter]
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
if not CONFIG.get_bool("tenants.enabled", True): if not CONFIG.get_bool("tenants.enabled", True):
return Http404() return HttpResponseNotFound()
return super().dispatch(request, *args, **kwargs) return super().dispatch(request, *args, **kwargs)
@ -81,9 +84,15 @@ class DomainViewSet(ModelViewSet):
"tenant__schema_name", "tenant__schema_name",
] ]
ordering = ["domain"] ordering = ["domain"]
permission_classes = [TenantManagementKeyPermission] authentication_classes = []
permission_classes = [TenantApiKeyPermission]
filter_backends = [OrderingFilter, SearchFilter] filter_backends = [OrderingFilter, SearchFilter]
def dispatch(self, request, *args, **kwargs):
if not CONFIG.get_bool("tenants.enabled", True):
return HttpResponseNotFound()
return super().dispatch(request, *args, **kwargs)
class SettingsSerializer(ModelSerializer): class SettingsSerializer(ModelSerializer):
"""Settings Serializer""" """Settings Serializer"""

View File

@ -6,6 +6,7 @@ import django.db.models.deletion
import django_tenants.postgresql_backend.base import django_tenants.postgresql_backend.base
from django.db import migrations, models from django.db import migrations, models
import authentik.tenants.models
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -42,7 +43,7 @@ class Migration(migrations.Migration):
db_index=True, db_index=True,
max_length=63, max_length=63,
unique=True, unique=True,
validators=[django_tenants.postgresql_backend.base._check_schema_name], validators=[authentik.tenants.models._validate_schema_name],
), ),
), ),
( (

View File

@ -1,7 +1,9 @@
"""Tenant models""" """Tenant models"""
import re
from uuid import uuid4 from uuid import uuid4
from django.apps import apps from django.apps import apps
from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from django.dispatch import receiver from django.dispatch import receiver
@ -16,10 +18,25 @@ from authentik.lib.models import SerializerModel
LOGGER = get_logger() LOGGER = get_logger()
VALID_SCHEMA_NAME = re.compile(r"^t_[a-z0-9]{1,61}$")
def _validate_schema_name(name):
if not VALID_SCHEMA_NAME.match(name):
raise ValidationError(
_(
"Schema name must start with t_, only contain lowercase letters and numbers and be less than 63 characters."
)
)
class Tenant(TenantMixin, SerializerModel): class Tenant(TenantMixin, SerializerModel):
"""Tenant""" """Tenant"""
tenant_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) tenant_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
schema_name = models.CharField(
max_length=63, unique=True, db_index=True, validators=[_validate_schema_name]
)
name = models.TextField() name = models.TextField()
auto_create_schema = True auto_create_schema = True

View File

View File

@ -0,0 +1,118 @@
"""Test Tenant API"""
from json import loads
from django.core.management import call_command
from django.db import connection
from django.urls import reverse
from rest_framework.test import APILiveServerTestCase, APITestCase, APITransactionTestCase
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id
TENANTS_API_KEY = generate_id()
HEADERS = {"Authorization": f"Bearer {TENANTS_API_KEY}"}
class TestAPI(APITransactionTestCase):
"""Test api view"""
def _fixture_teardown(self):
for db_name in self._databases_names(include_mirrors=False):
call_command(
"flush",
verbosity=0,
interactive=False,
database=db_name,
reset_sequences=False,
allow_cascade=True,
inhibit_post_migrate=False,
)
def setUp(self):
call_command("migrate_schemas", schema="template", tenant=True)
def assertSchemaExists(self, schema_name):
with connection.cursor() as cursor:
cursor.execute(
f"SELECT * FROM information_schema.schemata WHERE schema_name = '{schema_name}';"
)
self.assertEqual(cursor.rowcount, 1)
cursor.execute(
"SELECT * FROM information_schema.tables WHERE table_schema = 'template';"
)
expected_tables = cursor.rowcount
cursor.execute(
f"SELECT * FROM information_schema.tables WHERE table_schema = '{schema_name}';"
)
self.assertEqual(cursor.rowcount, expected_tables)
def assertSchemaDoesntExist(self, schema_name):
with connection.cursor() as cursor:
cursor.execute(
f"SELECT * FROM information_schema.schemata WHERE schema_name = '{schema_name}';"
)
self.assertEqual(cursor.rowcount, 0)
@CONFIG.patch("outposts.disable_embedded_outpost", True)
@CONFIG.patch("tenants.enabled", True)
@CONFIG.patch("tenants.api_key", TENANTS_API_KEY)
def test_tenant_create_delete(self):
"""Test Tenant creation API Endpoint"""
response = self.client.post(
reverse(
"authentik_api:tenant-list",
),
data={"name": generate_id(), "schema_name": "t_" + generate_id(length=63 - 2).lower()},
headers=HEADERS,
)
self.assertEqual(response.status_code, 201)
body = loads(response.content.decode())
self.assertSchemaExists(body["schema_name"])
response = self.client.delete(
reverse(
"authentik_api:tenant-detail",
kwargs={"pk": body["tenant_uuid"]},
),
headers=HEADERS,
)
self.assertEqual(response.status_code, 204)
self.assertSchemaDoesntExist(body["schema_name"])
@CONFIG.patch("outposts.disable_embedded_outpost", True)
@CONFIG.patch("tenants.enabled", True)
@CONFIG.patch("tenants.api_key", TENANTS_API_KEY)
def test_unauthenticated(self):
"""Test Tenant creation API Endpoint"""
response = self.client.get(
reverse(
"authentik_api:tenant-list",
),
)
self.assertEqual(response.status_code, 403)
@CONFIG.patch("outposts.disable_embedded_outpost", True)
@CONFIG.patch("tenants.enabled", True)
@CONFIG.patch("tenants.api_key", "")
def test_no_api_key_configured(self):
"""Test Tenant creation API Endpoint"""
response = self.client.get(
reverse(
"authentik_api:tenant-list",
),
)
self.assertEqual(response.status_code, 403)
@CONFIG.patch("tenants.enabled", False)
@CONFIG.patch("tenants.api_key", TENANTS_API_KEY)
def test_api_disabled(self):
"""Test Tenant creation API Endpoint"""
response = self.client.get(
reverse(
"authentik_api:tenant-list",
),
headers=HEADERS,
)
self.assertEqual(response.status_code, 404)

View File

@ -13,7 +13,7 @@ COMMIT;
class Migration(BaseMigration): class Migration(BaseMigration):
def needs_migration(self) -> bool: def needs_migration(self) -> bool:
self.cur.execute( self.cur.execute(
"select * from information_schema.tables where table_name =" " 'django_migrations';" "select * from information_schema.tables where table_name = 'django_migrations';"
) )
# No migration table, assume new installation # No migration table, assume new installation
if not bool(self.cur.rowcount): if not bool(self.cur.rowcount):