core: revert backchannel only filtering (#10455)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-07-11 16:57:19 +02:00
committed by GitHub
parent 96f04d32ea
commit a5e45ba78e
3 changed files with 22 additions and 8 deletions

View File

@ -2,6 +2,7 @@ from collections.abc import Callable
from django.core.paginator import Paginator
from django.db.models import Model
from django.db.models.query import Q
from django.db.models.signals import m2m_changed, post_save, pre_delete
from authentik.core.models import Group, User
@ -34,7 +35,9 @@ def register_signals(
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
if not provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)
@ -43,7 +46,9 @@ def register_signals(
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
if not provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
task_sync_direct.delay(
class_to_path(instance.__class__), instance.pk, Direction.remove.value
@ -58,7 +63,9 @@ def register_signals(
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
if not provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups

View File

@ -5,6 +5,7 @@ from celery.exceptions import Retry
from celery.result import allow_join_result
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.db.models.query import Q
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import BoundLogger, get_logger
@ -37,7 +38,9 @@ class SyncTasks:
self._provider_model = provider_model
def sync_all(self, single_sync: Callable[[int], None]):
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
self.trigger_single_task(provider, single_sync)
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
@ -62,7 +65,8 @@ class SyncTasks:
provider_pk=provider_pk,
)
provider = self._provider_model.objects.filter(
pk=provider_pk, backchannel_application__isnull=False
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
@ -204,7 +208,9 @@ class SyncTasks:
if not instance:
return
operation = Direction(raw_op)
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
@ -233,7 +239,9 @@ class SyncTasks:
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given

View File

@ -15,7 +15,6 @@ const doGroupBy = (items: Provider[]) => groupBy(items, (item) => item.verboseNa
async function fetch(query?: string) {
const args: ProvidersAllListRequest = {
ordering: "name",
backchannel: false,
};
if (query !== undefined) {
args.search = query;