core: rework recursive group membership (#6017)
* rework checking group membership and add `user.all_groups` to get full list of groups Signed-off-by: Jens Langhammer <jens@goauthentik.io> * refactor some more for better performance Signed-off-by: Jens Langhammer <jens@goauthentik.io> * migrate things to use all_groups Signed-off-by: Jens Langhammer <jens@goauthentik.io> * update release notes Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix for django 4.2 Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		@ -207,7 +207,7 @@ class UserSelfSerializer(ModelSerializer):
 | 
			
		||||
    )
 | 
			
		||||
    def get_groups(self, _: User):
 | 
			
		||||
        """Return only the group names a user is member of"""
 | 
			
		||||
        for group in self.instance.ak_groups.all():
 | 
			
		||||
        for group in self.instance.all_groups().order_by("name"):
 | 
			
		||||
            yield {
 | 
			
		||||
                "name": group.name,
 | 
			
		||||
                "pk": group.pk,
 | 
			
		||||
 | 
			
		||||
@ -113,27 +113,7 @@ class Group(SerializerModel):
 | 
			
		||||
 | 
			
		||||
    def is_member(self, user: "User") -> bool:
 | 
			
		||||
        """Recursively check if `user` is member of us, or any parent."""
 | 
			
		||||
        query = """
 | 
			
		||||
        WITH RECURSIVE parents AS (
 | 
			
		||||
            SELECT authentik_core_group.*, 0 AS relative_depth
 | 
			
		||||
            FROM authentik_core_group
 | 
			
		||||
            WHERE authentik_core_group.group_uuid = %s
 | 
			
		||||
 | 
			
		||||
            UNION ALL
 | 
			
		||||
 | 
			
		||||
            SELECT authentik_core_group.*, parents.relative_depth - 1
 | 
			
		||||
            FROM authentik_core_group,parents
 | 
			
		||||
            WHERE (
 | 
			
		||||
                authentik_core_group.parent_id = parents.group_uuid and
 | 
			
		||||
                parents.relative_depth > -20
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        SELECT group_uuid
 | 
			
		||||
        FROM parents
 | 
			
		||||
        GROUP BY group_uuid;
 | 
			
		||||
        """
 | 
			
		||||
        groups = Group.objects.raw(query, [self.group_uuid])
 | 
			
		||||
        return user.ak_groups.filter(pk__in=[group.pk for group in groups]).exists()
 | 
			
		||||
        return user.all_groups().filter(group_uuid=self.group_uuid).exists()
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"Group {self.name}"
 | 
			
		||||
@ -176,13 +156,45 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
 | 
			
		||||
        """Get the default user path"""
 | 
			
		||||
        return User._meta.get_field("path").default
 | 
			
		||||
 | 
			
		||||
    def all_groups(self) -> QuerySet[Group]:
 | 
			
		||||
        """Recursively get all groups this user is a member of.
 | 
			
		||||
        At least one query is done to get the direct groups of the user, with groups
 | 
			
		||||
        there are at most 3 queries done"""
 | 
			
		||||
        direct_groups = tuple(
 | 
			
		||||
            str(x) for x in self.ak_groups.all().values_list("pk", flat=True).iterator()
 | 
			
		||||
        )
 | 
			
		||||
        if len(direct_groups) < 1:
 | 
			
		||||
            return Group.objects.none()
 | 
			
		||||
        query = """
 | 
			
		||||
        WITH RECURSIVE parents AS (
 | 
			
		||||
            SELECT authentik_core_group.*, 0 AS relative_depth
 | 
			
		||||
            FROM authentik_core_group
 | 
			
		||||
            WHERE authentik_core_group.group_uuid IN (%s)
 | 
			
		||||
 | 
			
		||||
            UNION ALL
 | 
			
		||||
 | 
			
		||||
            SELECT authentik_core_group.*, parents.relative_depth + 1
 | 
			
		||||
            FROM authentik_core_group, parents
 | 
			
		||||
            WHERE (
 | 
			
		||||
                authentik_core_group.group_uuid = parents.parent_id and
 | 
			
		||||
                parents.relative_depth < 20
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        SELECT group_uuid
 | 
			
		||||
        FROM parents
 | 
			
		||||
        GROUP BY group_uuid, name
 | 
			
		||||
        ORDER BY name;
 | 
			
		||||
        """
 | 
			
		||||
        group_pks = [group.pk for group in Group.objects.raw(query, direct_groups).iterator()]
 | 
			
		||||
        return Group.objects.filter(pk__in=group_pks)
 | 
			
		||||
 | 
			
		||||
    def group_attributes(self, request: Optional[HttpRequest] = None) -> dict[str, Any]:
 | 
			
		||||
        """Get a dictionary containing the attributes from all groups the user belongs to,
 | 
			
		||||
        including the users attributes"""
 | 
			
		||||
        final_attributes = {}
 | 
			
		||||
        if request and hasattr(request, "tenant"):
 | 
			
		||||
            always_merger.merge(final_attributes, request.tenant.attributes)
 | 
			
		||||
        for group in self.ak_groups.all().order_by("name"):
 | 
			
		||||
        for group in self.all_groups().order_by("name"):
 | 
			
		||||
            always_merger.merge(final_attributes, group.attributes)
 | 
			
		||||
        always_merger.merge(final_attributes, self.attributes)
 | 
			
		||||
        return final_attributes
 | 
			
		||||
@ -196,7 +208,7 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
 | 
			
		||||
    @cached_property
 | 
			
		||||
    def is_superuser(self) -> bool:
 | 
			
		||||
        """Get supseruser status based on membership in a group with superuser status"""
 | 
			
		||||
        return self.ak_groups.filter(is_superuser=True).exists()
 | 
			
		||||
        return self.all_groups().filter(is_superuser=True).exists()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_staff(self) -> bool:
 | 
			
		||||
 | 
			
		||||
@ -21,22 +21,26 @@ class TestGroups(TestCase):
 | 
			
		||||
        """Test parent membership"""
 | 
			
		||||
        user = User.objects.create(username=generate_id())
 | 
			
		||||
        user2 = User.objects.create(username=generate_id())
 | 
			
		||||
        first = Group.objects.create(name=generate_id())
 | 
			
		||||
        second = Group.objects.create(name=generate_id(), parent=first)
 | 
			
		||||
        second.users.add(user)
 | 
			
		||||
        self.assertTrue(first.is_member(user))
 | 
			
		||||
        self.assertFalse(first.is_member(user2))
 | 
			
		||||
        parent = Group.objects.create(name=generate_id())
 | 
			
		||||
        child = Group.objects.create(name=generate_id(), parent=parent)
 | 
			
		||||
        child.users.add(user)
 | 
			
		||||
        self.assertTrue(child.is_member(user))
 | 
			
		||||
        self.assertTrue(parent.is_member(user))
 | 
			
		||||
        self.assertFalse(child.is_member(user2))
 | 
			
		||||
        self.assertFalse(parent.is_member(user2))
 | 
			
		||||
 | 
			
		||||
    def test_group_membership_parent_extra(self):
 | 
			
		||||
        """Test parent membership"""
 | 
			
		||||
        user = User.objects.create(username=generate_id())
 | 
			
		||||
        user2 = User.objects.create(username=generate_id())
 | 
			
		||||
        first = Group.objects.create(name=generate_id())
 | 
			
		||||
        second = Group.objects.create(name=generate_id(), parent=first)
 | 
			
		||||
        parent = Group.objects.create(name=generate_id())
 | 
			
		||||
        second = Group.objects.create(name=generate_id(), parent=parent)
 | 
			
		||||
        third = Group.objects.create(name=generate_id(), parent=second)
 | 
			
		||||
        second.users.add(user)
 | 
			
		||||
        self.assertTrue(first.is_member(user))
 | 
			
		||||
        self.assertFalse(first.is_member(user2))
 | 
			
		||||
        self.assertTrue(parent.is_member(user))
 | 
			
		||||
        self.assertFalse(parent.is_member(user2))
 | 
			
		||||
        self.assertTrue(second.is_member(user))
 | 
			
		||||
        self.assertFalse(second.is_member(user2))
 | 
			
		||||
        self.assertFalse(third.is_member(user))
 | 
			
		||||
        self.assertFalse(third.is_member(user2))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -112,7 +112,7 @@ class BaseEvaluator:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def expr_is_group_member(user: User, **group_filters) -> bool:
 | 
			
		||||
        """Check if `user` is member of group with name `group_name`"""
 | 
			
		||||
        return user.ak_groups.filter(**group_filters).exists()
 | 
			
		||||
        return user.all_groups().filter(**group_filters).exists()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def expr_user_by(**filters) -> Optional[User]:
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,14 @@ slug: "/releases/2023.7"
 | 
			
		||||
 | 
			
		||||
    For Kubernetes install, a manual one-time migration has to be done: [Upgrading PostgreSQL on Kubernetes](../../troubleshooting/postgres/upgrade_kubernetes.md)
 | 
			
		||||
 | 
			
		||||
-   Changed nested Group membership behaviour
 | 
			
		||||
 | 
			
		||||
    In previous versions, nested groups were handled very inconsistently. Binding a group to an application/etc would check the membership recursively, however when using `user.ak_groups.all()` would only return direct memberships. Additionally, using `user.group_attributes()` would do the same and only merge all group attributes for direct memberships.
 | 
			
		||||
 | 
			
		||||
    This has been changed to always use the same logic as when checking for access, which means dealing with complex group structures is a lot more consistent.
 | 
			
		||||
 | 
			
		||||
    Policies that do use `user.ak_groups.all()` will retain the current behaviour, to use the new behaviour replace the call with `user.all_groups()`.
 | 
			
		||||
 | 
			
		||||
## New features
 | 
			
		||||
 | 
			
		||||
## Upgrading
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user