core: rework recursive group membership (#6017)

* rework checking group membership and add `user.all_groups` to get full list of groups

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* refactor some more for better performance

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* migrate things to use all_groups

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* update release notes

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix for django 4.2

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2023-08-18 17:31:39 +02:00
committed by GitHub
parent 9e17b1bad3
commit 8bba3c0a9b
5 changed files with 58 additions and 34 deletions

View File

@ -207,7 +207,7 @@ class UserSelfSerializer(ModelSerializer):
)
def get_groups(self, _: User):
"""Return only the group names a user is member of"""
for group in self.instance.ak_groups.all():
for group in self.instance.all_groups().order_by("name"):
yield {
"name": group.name,
"pk": group.pk,

View File

@ -113,27 +113,7 @@ class Group(SerializerModel):
def is_member(self, user: "User") -> bool:
"""Recursively check if `user` is member of us, or any parent."""
query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid = %s
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth - 1
FROM authentik_core_group,parents
WHERE (
authentik_core_group.parent_id = parents.group_uuid and
parents.relative_depth > -20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid;
"""
groups = Group.objects.raw(query, [self.group_uuid])
return user.ak_groups.filter(pk__in=[group.pk for group in groups]).exists()
return user.all_groups().filter(group_uuid=self.group_uuid).exists()
def __str__(self):
return f"Group {self.name}"
@ -176,13 +156,45 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""Get the default user path"""
return User._meta.get_field("path").default
def all_groups(self) -> QuerySet[Group]:
"""Recursively get all groups this user is a member of.
At least one query is done to get the direct groups of the user, with groups
there are at most 3 queries done"""
direct_groups = tuple(
str(x) for x in self.ak_groups.all().values_list("pk", flat=True).iterator()
)
if len(direct_groups) < 1:
return Group.objects.none()
query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid IN (%s)
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth + 1
FROM authentik_core_group, parents
WHERE (
authentik_core_group.group_uuid = parents.parent_id and
parents.relative_depth < 20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid, name
ORDER BY name;
"""
group_pks = [group.pk for group in Group.objects.raw(query, direct_groups).iterator()]
return Group.objects.filter(pk__in=group_pks)
def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]:
"""Get a dictionary containing the attributes from all groups the user belongs to,
including the users attributes"""
final_attributes = {}
if request and hasattr(request, "tenant"):
always_merger.merge(final_attributes, request.tenant.attributes)
for group in self.ak_groups.all().order_by("name"):
for group in self.all_groups().order_by("name"):
always_merger.merge(final_attributes, group.attributes)
always_merger.merge(final_attributes, self.attributes)
return final_attributes
@ -196,7 +208,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
@cached_property
def is_superuser(self) -> bool:
"""Get supseruser status based on membership in a group with superuser status"""
return self.ak_groups.filter(is_superuser=True).exists()
return self.all_groups().filter(is_superuser=True).exists()
@property
def is_staff(self) -> bool:

View File

@ -21,22 +21,26 @@ class TestGroups(TestCase):
"""Test parent membership"""
user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id())
first = Group.objects.create(name=generate_id())
second = Group.objects.create(name=generate_id(), parent=first)
second.users.add(user)
self.assertTrue(first.is_member(user))
self.assertFalse(first.is_member(user2))
parent = Group.objects.create(name=generate_id())
child = Group.objects.create(name=generate_id(), parent=parent)
child.users.add(user)
self.assertTrue(child.is_member(user))
self.assertTrue(parent.is_member(user))
self.assertFalse(child.is_member(user2))
self.assertFalse(parent.is_member(user2))
def test_group_membership_parent_extra(self):
"""Test parent membership"""
user = User.objects.create(username=generate_id())
user2 = User.objects.create(username=generate_id())
first = Group.objects.create(name=generate_id())
second = Group.objects.create(name=generate_id(), parent=first)
parent = Group.objects.create(name=generate_id())
second = Group.objects.create(name=generate_id(), parent=parent)
third = Group.objects.create(name=generate_id(), parent=second)
second.users.add(user)
self.assertTrue(first.is_member(user))
self.assertFalse(first.is_member(user2))
self.assertTrue(parent.is_member(user))
self.assertFalse(parent.is_member(user2))
self.assertTrue(second.is_member(user))
self.assertFalse(second.is_member(user2))
self.assertFalse(third.is_member(user))
self.assertFalse(third.is_member(user2))