core: hotfix group membership check (#6584)

This commit is contained in:
Jens L 2023-08-20 23:47:13 +02:00 committed by GitHub
parent cecf7a0200
commit 0472ef583c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 4 deletions

View File

@ -160,8 +160,8 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""Recursively get all groups this user is a member of. """Recursively get all groups this user is a member of.
At least one query is done to get the direct groups of the user, with groups At least one query is done to get the direct groups of the user, with groups
there are at most 3 queries done""" there are at most 3 queries done"""
direct_groups = tuple( direct_groups = list(
str(x) for x in self.ak_groups.all().values_list("pk", flat=True).iterator() x for x in self.ak_groups.all().values_list("pk", flat=True).iterator()
) )
if len(direct_groups) < 1: if len(direct_groups) < 1:
return Group.objects.none() return Group.objects.none()
@ -169,7 +169,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
WITH RECURSIVE parents AS ( WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group FROM authentik_core_group
WHERE authentik_core_group.group_uuid IN (%s) WHERE authentik_core_group.group_uuid = ANY(%s)
UNION ALL UNION ALL
@ -185,7 +185,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
GROUP BY group_uuid, name GROUP BY group_uuid, name
ORDER BY name; ORDER BY name;
""" """
group_pks = [group.pk for group in Group.objects.raw(query, direct_groups).iterator()] group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()]
return Group.objects.filter(pk__in=group_pks) return Group.objects.filter(pk__in=group_pks)
def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]: def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]:

View File

@ -13,7 +13,9 @@ class TestGroups(TestCase):
user = User.objects.create(username=generate_id()) user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id()) user2 = User.objects.create(username=generate_id())
group = Group.objects.create(name=generate_id()) group = Group.objects.create(name=generate_id())
other_group = Group.objects.create(name=generate_id())
group.users.add(user) group.users.add(user)
other_group.users.add(user)
self.assertTrue(group.is_member(user)) self.assertTrue(group.is_member(user))
self.assertFalse(group.is_member(user2)) self.assertFalse(group.is_member(user2))