Compare commits
2 Commits
web/legibi
...
web/reques
Author | SHA1 | Date | |
---|---|---|---|
c8be337414 | |||
5c85c2c9e6 |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2024.4.2
|
||||
current_version = 2024.4.1
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||
|
@ -12,7 +12,7 @@ should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower()
|
||||
branch_name = os.environ["GITHUB_REF"]
|
||||
if os.environ.get("GITHUB_HEAD_REF", "") != "":
|
||||
branch_name = os.environ["GITHUB_HEAD_REF"]
|
||||
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-").replace("'", "-")
|
||||
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
|
||||
|
||||
image_names = os.getenv("IMAGE_NAME").split(",")
|
||||
image_arch = os.getenv("IMAGE_ARCH") or None
|
||||
|
2
.github/workflows/ci-outpost.yml
vendored
2
.github/workflows/ci-outpost.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
- name: Generate API
|
||||
run: make gen-client-go
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
uses: golangci/golangci-lint-action@v5
|
||||
with:
|
||||
version: v1.54.2
|
||||
args: --timeout 5000s --verbose
|
||||
|
@ -38,7 +38,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
|
||||
RUN npm run build
|
||||
|
||||
# Stage 3: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.3-bookworm AS go-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.2-bookworm AS go-builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
1
Makefile
1
Makefile
@ -19,7 +19,6 @@ pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null)
|
||||
CODESPELL_ARGS = -D - -D .github/codespell-dictionary.txt \
|
||||
-I .github/codespell-words.txt \
|
||||
-S 'web/src/locales/**' \
|
||||
-S 'website/developer-docs/api/reference/**' \
|
||||
authentik \
|
||||
internal \
|
||||
cmd \
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from os import environ
|
||||
|
||||
__version__ = "2024.4.2"
|
||||
__version__ = "2024.4.1"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from rest_framework.fields import CharField, IntegerField, SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError
|
||||
from rest_framework.validators import UniqueValidator
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
@ -101,10 +100,7 @@ class GroupSerializer(ModelSerializer):
|
||||
extra_kwargs = {
|
||||
"users": {
|
||||
"default": list,
|
||||
},
|
||||
# TODO: This field isn't unique on the database which is hard to backport
|
||||
# hence we just validate the uniqueness here
|
||||
"name": {"validators": [UniqueValidator(Group.objects.all())]},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
7
authentik/core/exceptions.py
Normal file
7
authentik/core/exceptions.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""authentik core exceptions"""
|
||||
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
|
||||
|
||||
class PropertyMappingExpressionException(SentryIgnoredException):
|
||||
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated."""
|
@ -6,7 +6,6 @@ from django.db.models import Model
|
||||
from django.http import HttpRequest
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
@ -48,7 +47,6 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
||||
self._context["request"] = req
|
||||
req.context.update(**kwargs)
|
||||
self._context.update(**kwargs)
|
||||
self._globals["SkipObject"] = SkipObjectException
|
||||
self.dry_run = dry_run
|
||||
|
||||
def handle_error(self, exc: Exception, expression_source: str):
|
||||
|
@ -1,13 +0,0 @@
|
||||
"""authentik core exceptions"""
|
||||
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
|
||||
|
||||
class PropertyMappingExpressionException(SentryIgnoredException):
|
||||
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated."""
|
||||
|
||||
|
||||
class SkipObjectException(PropertyMappingExpressionException):
|
||||
"""Exception which can be raised in a property mapping to skip syncing an object.
|
||||
Only applies to Property mappings which sync objects, and not on mappings which transitively
|
||||
apply to a single user"""
|
@ -7,10 +7,9 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def backport_is_backchannel(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
from authentik.providers.ldap.models import LDAPProvider
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.core.models import BackchannelProvider
|
||||
|
||||
for model in [LDAPProvider, SCIMProvider]:
|
||||
for model in BackchannelProvider.__subclasses__():
|
||||
try:
|
||||
for obj in model.objects.only("is_backchannel"):
|
||||
obj.is_backchannel = True
|
||||
|
@ -22,7 +22,7 @@ from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.blueprints.models import ManagedModel
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.types import UILoginButton, UserSettingSerializer
|
||||
from authentik.lib.avatars import get_avatar
|
||||
from authentik.lib.generators import generate_id
|
||||
|
@ -100,6 +100,8 @@ class SourceFlowManager:
|
||||
if self.request.user.is_authenticated:
|
||||
new_connection.user = self.request.user
|
||||
new_connection = self.update_connection(new_connection, **kwargs)
|
||||
|
||||
new_connection.save()
|
||||
return Action.LINK, new_connection
|
||||
|
||||
existing_connections = self.connection_type.objects.filter(
|
||||
@ -146,6 +148,7 @@ class SourceFlowManager:
|
||||
]:
|
||||
new_connection.user = user
|
||||
new_connection = self.update_connection(new_connection, **kwargs)
|
||||
new_connection.save()
|
||||
return Action.LINK, new_connection
|
||||
if self.source.user_matching_mode in [
|
||||
SourceUserMatchingModes.EMAIL_DENY,
|
||||
|
@ -2,9 +2,7 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.conf import ImproperlyConfigured
|
||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
|
||||
from django.contrib.sessions.backends.db import SessionStore as DBSessionStore
|
||||
from django.core.cache import cache
|
||||
from django.utils.timezone import now
|
||||
from structlog.stdlib import get_logger
|
||||
@ -17,7 +15,6 @@ from authentik.core.models import (
|
||||
User,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -42,31 +39,16 @@ def clean_expired_models(self: SystemTask):
|
||||
amount = 0
|
||||
|
||||
for session in AuthenticatedSession.objects.all():
|
||||
match CONFIG.get("session_storage", "cache"):
|
||||
case "cache":
|
||||
cache_key = f"{KEY_PREFIX}{session.session_key}"
|
||||
value = None
|
||||
try:
|
||||
value = cache.get(cache_key)
|
||||
cache_key = f"{KEY_PREFIX}{session.session_key}"
|
||||
value = None
|
||||
try:
|
||||
value = cache.get(cache_key)
|
||||
|
||||
except Exception as exc:
|
||||
LOGGER.debug("Failed to get session from cache", exc=exc)
|
||||
if not value:
|
||||
session.delete()
|
||||
amount += 1
|
||||
case "db":
|
||||
if not (
|
||||
DBSessionStore.get_model_class()
|
||||
.objects.filter(session_key=session.session_key, expire_date__gt=now())
|
||||
.exists()
|
||||
):
|
||||
session.delete()
|
||||
amount += 1
|
||||
case _:
|
||||
# Should never happen, as we check for other values in authentik/root/settings.py
|
||||
raise ImproperlyConfigured(
|
||||
"Invalid session_storage setting, allowed values are db and cache"
|
||||
)
|
||||
except Exception as exc:
|
||||
LOGGER.debug("Failed to get session from cache", exc=exc)
|
||||
if not value:
|
||||
session.delete()
|
||||
amount += 1
|
||||
LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
|
||||
|
||||
messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")
|
||||
|
@ -3,7 +3,7 @@
|
||||
from django.test import RequestFactory, TestCase
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import PropertyMapping
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
|
@ -48,21 +48,15 @@ class TestSourceFlowManager(TestCase):
|
||||
|
||||
def test_authenticated_link(self):
|
||||
"""Test authenticated user linking"""
|
||||
UserOAuthSourceConnection.objects.create(
|
||||
user=get_anonymous_user(), source=self.source, identifier=self.identifier
|
||||
)
|
||||
user = User.objects.create(username="foo", email="foo@bar.baz")
|
||||
flow_manager = OAuthSourceFlowManager(
|
||||
self.source, get_request("/", user=user), self.identifier, {}
|
||||
)
|
||||
action, connection = flow_manager.get_action()
|
||||
action, _ = flow_manager.get_action()
|
||||
self.assertEqual(action, Action.LINK)
|
||||
self.assertIsNone(connection.pk)
|
||||
flow_manager.get_flow()
|
||||
|
||||
def test_unauthenticated_link(self):
|
||||
"""Test un-authenticated user linking"""
|
||||
flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {})
|
||||
action, connection = flow_manager.get_action()
|
||||
self.assertEqual(action, Action.LINK)
|
||||
self.assertIsNone(connection.pk)
|
||||
flow_manager.get_flow()
|
||||
|
||||
def test_unauthenticated_enroll_email(self):
|
||||
|
@ -2,12 +2,11 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from django.apps.registry import apps
|
||||
from django.core.files import File
|
||||
from django.db import connection
|
||||
from django.db.models import ManyToManyRel, Model
|
||||
from django.db.models import Model
|
||||
from django.db.models.expressions import BaseExpression, Combinable
|
||||
from django.db.models.signals import post_init
|
||||
from django.http import HttpRequest
|
||||
@ -45,7 +44,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
post_init.disconnect(dispatch_uid=request.request_id)
|
||||
|
||||
def serialize_simple(self, model: Model) -> dict:
|
||||
"""Serialize a model in a very simple way. No ForeignKeys or other relationships are
|
||||
"""Serialize a model in a very simple way. No ForeginKeys or other relationships are
|
||||
resolved"""
|
||||
data = {}
|
||||
deferred_fields = model.get_deferred_fields()
|
||||
@ -71,9 +70,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
for key, value in before.items():
|
||||
if after.get(key) != value:
|
||||
diff[key] = {"previous_value": value, "new_value": after.get(key)}
|
||||
for key, value in after.items():
|
||||
if key not in before and key not in diff and before.get(key) != value:
|
||||
diff[key] = {"previous_value": before.get(key), "new_value": value}
|
||||
return sanitize_item(diff)
|
||||
|
||||
def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
|
||||
@ -102,37 +98,8 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
thread_kwargs = {}
|
||||
if hasattr(instance, "_previous_state") or created:
|
||||
prev_state = getattr(instance, "_previous_state", {})
|
||||
if created:
|
||||
prev_state = {}
|
||||
# Get current state
|
||||
new_state = self.serialize_simple(instance)
|
||||
diff = self.diff(prev_state, new_state)
|
||||
thread_kwargs["diff"] = diff
|
||||
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
|
||||
|
||||
def m2m_changed_handler( # noqa: PLR0913
|
||||
self,
|
||||
request: HttpRequest,
|
||||
sender,
|
||||
instance: Model,
|
||||
action: str,
|
||||
pk_set: set[Any],
|
||||
thread_kwargs: dict | None = None,
|
||||
**_,
|
||||
):
|
||||
thread_kwargs = {}
|
||||
m2m_field = None
|
||||
# For the audit log we don't care about `pre_` or `post_` so we trim that part off
|
||||
_, _, action_direction = action.partition("_")
|
||||
# resolve the "through" model to an actual field
|
||||
for field in instance._meta.get_fields():
|
||||
if not isinstance(field, ManyToManyRel):
|
||||
continue
|
||||
if field.through == sender:
|
||||
m2m_field = field
|
||||
if m2m_field:
|
||||
# If we're clearing we just set the "flag" to True
|
||||
if action_direction == "clear":
|
||||
pk_set = True
|
||||
thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
|
||||
return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
|
||||
|
@ -1,22 +1,9 @@
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.generators import generate_id
|
||||
from django.test import TestCase
|
||||
|
||||
|
||||
class TestEnterpriseAudit(APITestCase):
|
||||
"""Test audit middleware"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
class TestEnterpriseAudit(TestCase):
|
||||
|
||||
def test_import(self):
|
||||
"""Ensure middleware is imported when app.ready is called"""
|
||||
@ -29,182 +16,3 @@ class TestEnterpriseAudit(APITestCase):
|
||||
self.assertIn(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_create(self):
|
||||
"""Test create audit log"""
|
||||
self.client.force_login(self.user)
|
||||
username = generate_id()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:user-list"),
|
||||
data={"name": generate_id(), "username": username, "groups": [], "path": "foo"},
|
||||
)
|
||||
user = User.objects.get(username=username)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED,
|
||||
context__model__model_name="user",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=user.pk,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{
|
||||
"name": {
|
||||
"new_value": user.name,
|
||||
"previous_value": None,
|
||||
},
|
||||
"path": {"new_value": "foo", "previous_value": None},
|
||||
"type": {"new_value": "internal", "previous_value": None},
|
||||
"uuid": {
|
||||
"new_value": user.uuid.hex,
|
||||
"previous_value": None,
|
||||
},
|
||||
"email": {"new_value": "", "previous_value": None},
|
||||
"username": {
|
||||
"new_value": user.username,
|
||||
"previous_value": None,
|
||||
},
|
||||
"is_active": {"new_value": True, "previous_value": None},
|
||||
"attributes": {"new_value": {}, "previous_value": None},
|
||||
"date_joined": {
|
||||
"new_value": sanitize_item(user.date_joined),
|
||||
"previous_value": None,
|
||||
},
|
||||
"first_name": {"new_value": "", "previous_value": None},
|
||||
"id": {"new_value": user.pk, "previous_value": None},
|
||||
"last_name": {"new_value": "", "previous_value": None},
|
||||
"password": {"new_value": "********************", "previous_value": None},
|
||||
"password_change_date": {
|
||||
"new_value": sanitize_item(user.password_change_date),
|
||||
"previous_value": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_update(self):
|
||||
"""Test update audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
current_name = user.name
|
||||
new_name = generate_id()
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
|
||||
data={"name": new_name},
|
||||
)
|
||||
user.refresh_from_db()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_UPDATED,
|
||||
context__model__model_name="user",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=user.pk,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{
|
||||
"name": {
|
||||
"new_value": new_name,
|
||||
"previous_value": current_name,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_delete(self):
|
||||
"""Test delete audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
response = self.client.delete(
|
||||
reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_DELETED,
|
||||
context__model__model_name="user",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=user.pk,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertNotIn("diff", event.context)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_m2m_add(self):
|
||||
"""Test m2m add audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:group-add-user", kwargs={"pk": group.group_uuid}),
|
||||
data={
|
||||
"pk": user.pk,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_UPDATED,
|
||||
context__model__model_name="group",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=group.pk.hex,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{"users": {"add": [user.pk]}},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_m2m_remove(self):
|
||||
"""Test m2m remove audit log"""
|
||||
self.client.force_login(self.user)
|
||||
user = create_test_admin_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:group-remove-user", kwargs={"pk": group.group_uuid}),
|
||||
data={
|
||||
"pk": user.pk,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
events = Event.objects.filter(
|
||||
action=EventAction.MODEL_UPDATED,
|
||||
context__model__model_name="group",
|
||||
context__model__app="authentik_core",
|
||||
context__model__pk=group.pk.hex,
|
||||
)
|
||||
event = events.first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIsNotNone(event.context["diff"])
|
||||
diff = event.context["diff"]
|
||||
self.assertEqual(
|
||||
diff,
|
||||
{"users": {"remove": [user.pk]}},
|
||||
)
|
||||
|
@ -1,39 +0,0 @@
|
||||
"""google Property mappings API Views"""
|
||||
|
||||
from django_filters.filters import AllValuesMultipleFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderMapping
|
||||
|
||||
|
||||
class GoogleProviderMappingSerializer(PropertyMappingSerializer):
|
||||
"""GoogleProviderMapping Serializer"""
|
||||
|
||||
class Meta:
|
||||
model = GoogleWorkspaceProviderMapping
|
||||
fields = PropertyMappingSerializer.Meta.fields
|
||||
|
||||
|
||||
class GoogleProviderMappingFilter(FilterSet):
|
||||
"""Filter for GoogleProviderMapping"""
|
||||
|
||||
managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))
|
||||
|
||||
class Meta:
|
||||
model = GoogleWorkspaceProviderMapping
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class GoogleProviderMappingViewSet(UsedByMixin, ModelViewSet):
|
||||
"""GoogleProviderMapping Viewset"""
|
||||
|
||||
queryset = GoogleWorkspaceProviderMapping.objects.all()
|
||||
serializer_class = GoogleProviderMappingSerializer
|
||||
filterset_class = GoogleProviderMappingFilter
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
@ -1,54 +0,0 @@
|
||||
"""Google Provider API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
|
||||
|
||||
class GoogleProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
|
||||
"""GoogleProvider Serializer"""
|
||||
|
||||
class Meta:
|
||||
model = GoogleWorkspaceProvider
|
||||
fields = [
|
||||
"pk",
|
||||
"name",
|
||||
"property_mappings",
|
||||
"property_mappings_group",
|
||||
"component",
|
||||
"assigned_backchannel_application_slug",
|
||||
"assigned_backchannel_application_name",
|
||||
"verbose_name",
|
||||
"verbose_name_plural",
|
||||
"meta_model_name",
|
||||
"delegated_subject",
|
||||
"credentials",
|
||||
"scopes",
|
||||
"exclude_users_service_account",
|
||||
"filter_group",
|
||||
"user_delete_action",
|
||||
"group_delete_action",
|
||||
"default_group_email_domain",
|
||||
]
|
||||
extra_kwargs = {}
|
||||
|
||||
|
||||
class GoogleProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
|
||||
"""GoogleProvider Viewset"""
|
||||
|
||||
queryset = GoogleWorkspaceProvider.objects.all()
|
||||
serializer_class = GoogleProviderSerializer
|
||||
filterset_fields = [
|
||||
"name",
|
||||
"exclude_users_service_account",
|
||||
"delegated_subject",
|
||||
"filter_group",
|
||||
]
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_single_task = google_workspace_sync
|
@ -1,9 +0,0 @@
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
|
||||
|
||||
class AuthentikEnterpriseProviderGoogleConfig(EnterpriseConfig):
|
||||
|
||||
name = "authentik.enterprise.providers.google_workspace"
|
||||
label = "authentik_providers_google_workspace"
|
||||
verbose_name = "authentik Enterprise.Providers.Google Workspace"
|
||||
default = True
|
@ -1,71 +0,0 @@
|
||||
from django.db.models import Model
|
||||
from django.http import HttpResponseNotFound
|
||||
from google.auth.exceptions import GoogleAuthError, TransportError
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import Error, HttpError
|
||||
from googleapiclient.http import HttpRequest
|
||||
from httplib2 import HttpLib2Error, HttpLib2ErrorWithResponse
|
||||
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.lib.sync.outgoing import HTTP_CONFLICT
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
NotFoundSyncException,
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
TransientSyncException,
|
||||
)
|
||||
|
||||
|
||||
class GoogleWorkspaceSyncClient[TModel: Model, TConnection: Model, TSchema: dict](
|
||||
BaseOutgoingSyncClient[TModel, TConnection, TSchema, GoogleWorkspaceProvider]
|
||||
):
|
||||
"""Base client for syncing to google workspace"""
|
||||
|
||||
domains: list
|
||||
|
||||
def __init__(self, provider: GoogleWorkspaceProvider) -> None:
|
||||
super().__init__(provider)
|
||||
self.directory_service = build(
|
||||
"admin",
|
||||
"directory_v1",
|
||||
cache_discovery=False,
|
||||
**provider.google_credentials(),
|
||||
)
|
||||
self.__prefetch_domains()
|
||||
|
||||
def __prefetch_domains(self):
|
||||
self.domains = []
|
||||
domains = self._request(self.directory_service.domains().list(customer="my_customer"))
|
||||
for domain in domains.get("domains", []):
|
||||
domain_name = domain.get("domainName")
|
||||
self.domains.append(domain_name)
|
||||
|
||||
def _request(self, request: HttpRequest):
|
||||
try:
|
||||
response = request.execute()
|
||||
except GoogleAuthError as exc:
|
||||
if isinstance(exc, TransportError):
|
||||
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
|
||||
raise StopSync(exc) from exc
|
||||
except HttpLib2Error as exc:
|
||||
if isinstance(exc, HttpLib2ErrorWithResponse):
|
||||
self._response_handle_status_code(exc.response.status, exc)
|
||||
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
|
||||
except HttpError as exc:
|
||||
self._response_handle_status_code(exc.status_code, exc)
|
||||
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
|
||||
except Error as exc:
|
||||
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
|
||||
return response
|
||||
|
||||
def _response_handle_status_code(self, status_code: int, root_exc: Exception):
|
||||
if status_code == HttpResponseNotFound.status_code:
|
||||
raise NotFoundSyncException("Object not found") from root_exc
|
||||
if status_code == HTTP_CONFLICT:
|
||||
raise ObjectExistsSyncException("Object exists") from root_exc
|
||||
|
||||
def check_email_valid(self, *emails: str):
|
||||
for email in emails:
|
||||
if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
|
||||
raise TransientSyncException(f"Invalid email domain: {email}")
|
@ -1,245 +0,0 @@
|
||||
from deepmerge import always_merger
|
||||
from django.db import transaction
|
||||
from django.utils.text import slugify
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import Group
|
||||
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
||||
from authentik.enterprise.providers.google_workspace.models import (
|
||||
GoogleWorkspaceDeleteAction,
|
||||
GoogleWorkspaceProviderGroup,
|
||||
GoogleWorkspaceProviderMapping,
|
||||
GoogleWorkspaceProviderUser,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
NotFoundSyncException,
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
TransientSyncException,
|
||||
)
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
|
||||
class GoogleWorkspaceGroupClient(
|
||||
GoogleWorkspaceSyncClient[Group, GoogleWorkspaceProviderGroup, dict]
|
||||
):
|
||||
"""Google client for groups"""
|
||||
|
||||
connection_type = GoogleWorkspaceProviderGroup
|
||||
connection_type_query = "group"
|
||||
can_discover = True
|
||||
|
||||
def to_schema(self, obj: Group) -> dict:
|
||||
"""Convert authentik group"""
|
||||
raw_google_group = {
|
||||
"email": f"{slugify(obj.name)}@{self.provider.default_group_email_domain}"
|
||||
}
|
||||
for mapping in (
|
||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
|
||||
):
|
||||
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
|
||||
continue
|
||||
try:
|
||||
mapping: GoogleWorkspaceProviderMapping
|
||||
value = mapping.evaluate(
|
||||
user=None,
|
||||
request=None,
|
||||
group=obj,
|
||||
provider=self.provider,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_google_group, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_google_group:
|
||||
raise StopSync(ValueError("No group mappings configured"), obj)
|
||||
|
||||
return raw_google_group
|
||||
|
||||
def delete(self, obj: Group):
|
||||
"""Delete group"""
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=obj
|
||||
).first()
|
||||
if not google_group:
|
||||
self.logger.debug("Group does not exist in Google, skipping")
|
||||
return None
|
||||
with transaction.atomic():
|
||||
if self.provider.group_delete_action == GoogleWorkspaceDeleteAction.DELETE:
|
||||
self._request(
|
||||
self.directory_service.groups().delete(groupKey=google_group.google_id)
|
||||
)
|
||||
google_group.delete()
|
||||
|
||||
def create(self, group: Group):
|
||||
"""Create group from scratch and create a connection object"""
|
||||
google_group = self.to_schema(group)
|
||||
self.check_email_valid(google_group["email"])
|
||||
with transaction.atomic():
|
||||
try:
|
||||
response = self._request(self.directory_service.groups().insert(body=google_group))
|
||||
except ObjectExistsSyncException:
|
||||
# group already exists in google workspace, so we can connect them manually
|
||||
# for groups we need to fetch the group from google as we connect on
|
||||
# ID and not group email
|
||||
group_data = self._request(
|
||||
self.directory_service.groups().get(groupKey=google_group["email"])
|
||||
)
|
||||
GoogleWorkspaceProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, google_id=group_data["id"]
|
||||
)
|
||||
else:
|
||||
GoogleWorkspaceProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, google_id=response["id"]
|
||||
)
|
||||
|
||||
def update(self, group: Group, connection: GoogleWorkspaceProviderGroup):
|
||||
"""Update existing group"""
|
||||
google_group = self.to_schema(group)
|
||||
self.check_email_valid(google_group["email"])
|
||||
try:
|
||||
return self._request(
|
||||
self.directory_service.groups().update(
|
||||
groupKey=connection.google_id,
|
||||
body=google_group,
|
||||
)
|
||||
)
|
||||
except NotFoundSyncException:
|
||||
# Resource missing is handled by self.write, which will re-create the group
|
||||
raise
|
||||
|
||||
def write(self, obj: Group):
|
||||
google_group, created = super().write(obj)
|
||||
if created:
|
||||
self.create_sync_members(obj, google_group)
|
||||
return google_group
|
||||
|
||||
def create_sync_members(self, obj: Group, google_group: dict):
|
||||
"""Sync all members after a group was created"""
|
||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||
connections = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user__pk__in=users
|
||||
)
|
||||
for user in connections:
|
||||
try:
|
||||
self._request(
|
||||
self.directory_service.members().insert(
|
||||
groupKey=google_group["id"],
|
||||
body={
|
||||
"email": user.google_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
except TransientSyncException:
|
||||
continue
|
||||
|
||||
def update_group(self, group: Group, action: Direction, users_set: set[int]):
|
||||
"""Update a groups members"""
|
||||
if action == Direction.add:
|
||||
return self._patch_add_users(group, users_set)
|
||||
if action == Direction.remove:
|
||||
return self._patch_remove_users(group, users_set)
|
||||
|
||||
def _patch(self, google_group_id: str, direction: Direction, members: list[str]):
|
||||
for user in members:
|
||||
try:
|
||||
if direction == Direction.add:
|
||||
self._request(
|
||||
self.directory_service.members().insert(
|
||||
groupKey=google_group_id, body={"email": user}
|
||||
)
|
||||
)
|
||||
if direction == Direction.remove:
|
||||
self._request(
|
||||
self.directory_service.members().delete(
|
||||
groupKey=google_group_id, memberKey=user
|
||||
)
|
||||
)
|
||||
except ObjectExistsSyncException:
|
||||
pass
|
||||
except TransientSyncException:
|
||||
raise
|
||||
|
||||
def _patch_add_users(self, group: Group, users_set: set[int]):
|
||||
"""Add users in users_set to group"""
|
||||
if len(users_set) < 1:
|
||||
return
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
if not google_group:
|
||||
self.logger.warning(
|
||||
"could not sync group membership, group does not exist", group=group
|
||||
)
|
||||
return
|
||||
user_ids = list(
|
||||
GoogleWorkspaceProviderUser.objects.filter(
|
||||
user__pk__in=users_set, provider=self.provider
|
||||
).values_list("google_id", flat=True)
|
||||
)
|
||||
if len(user_ids) < 1:
|
||||
return
|
||||
self._patch(google_group.google_id, Direction.add, user_ids)
|
||||
|
||||
def _patch_remove_users(self, group: Group, users_set: set[int]):
|
||||
"""Remove users in users_set from group"""
|
||||
if len(users_set) < 1:
|
||||
return
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
if not google_group:
|
||||
self.logger.warning(
|
||||
"could not sync group membership, group does not exist", group=group
|
||||
)
|
||||
return
|
||||
user_ids = list(
|
||||
GoogleWorkspaceProviderUser.objects.filter(
|
||||
user__pk__in=users_set, provider=self.provider
|
||||
).values_list("google_id", flat=True)
|
||||
)
|
||||
if len(user_ids) < 1:
|
||||
return
|
||||
self._patch(google_group.google_id, Direction.remove, user_ids)
|
||||
|
||||
def discover(self):
|
||||
"""Iterate through all groups and connect them with authentik groups if possible"""
|
||||
request = self.directory_service.groups().list(
|
||||
customer="my_customer", maxResults=500, orderBy="email"
|
||||
)
|
||||
while request:
|
||||
response = request.execute()
|
||||
for group in response.get("groups", []):
|
||||
self._discover_single_group(group)
|
||||
request = self.directory_service.groups().list_next(
|
||||
previous_request=request, previous_response=response
|
||||
)
|
||||
|
||||
def _discover_single_group(self, group: dict):
|
||||
"""handle discovery of a single group"""
|
||||
google_name = group["name"]
|
||||
google_id = group["id"]
|
||||
matching_authentik_group = (
|
||||
self.provider.get_object_qs(Group).filter(name=google_name).first()
|
||||
)
|
||||
if not matching_authentik_group:
|
||||
return
|
||||
GoogleWorkspaceProviderGroup.objects.get_or_create(
|
||||
provider=self.provider,
|
||||
group=matching_authentik_group,
|
||||
google_id=google_id,
|
||||
)
|
@ -1,41 +0,0 @@
|
||||
from json import dumps
|
||||
|
||||
from httplib2 import Response
|
||||
|
||||
|
||||
class MockHTTP:
|
||||
|
||||
_recorded_requests = []
|
||||
_responses = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
raise_on_unrecorded=True,
|
||||
) -> None:
|
||||
self._recorded_requests = []
|
||||
self._responses = {}
|
||||
self.raise_on_unrecorded = raise_on_unrecorded
|
||||
|
||||
def add_response(self, uri: str, body: str | dict = "", meta: dict | None = None, method="GET"):
|
||||
if isinstance(body, dict):
|
||||
body = dumps(body)
|
||||
self._responses[(uri, method.upper())] = (body, meta or {"status": "200"})
|
||||
|
||||
def requests(self):
|
||||
return self._recorded_requests
|
||||
|
||||
def request(
|
||||
self,
|
||||
uri,
|
||||
method="GET",
|
||||
body=None,
|
||||
headers=None,
|
||||
redirections=1,
|
||||
connection_type=None,
|
||||
):
|
||||
key = (uri, method.upper())
|
||||
self._recorded_requests.append((uri, method, body, headers))
|
||||
if key not in self._responses and self.raise_on_unrecorded:
|
||||
raise AssertionError(key)
|
||||
body, meta = self._responses[key]
|
||||
return Response(meta), body.encode("utf-8")
|
@ -1,141 +0,0 @@
|
||||
from deepmerge import always_merger
|
||||
from django.db import transaction
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
||||
from authentik.enterprise.providers.google_workspace.models import (
|
||||
GoogleWorkspaceDeleteAction,
|
||||
GoogleWorkspaceProviderMapping,
|
||||
GoogleWorkspaceProviderUser,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
TransientSyncException,
|
||||
)
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.utils import delete_none_values
|
||||
|
||||
|
||||
class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceProviderUser, dict]):
|
||||
"""Sync authentik users into google workspace"""
|
||||
|
||||
connection_type = GoogleWorkspaceProviderUser
|
||||
connection_type_query = "user"
|
||||
can_discover = True
|
||||
|
||||
def to_schema(self, obj: User) -> dict:
|
||||
"""Convert authentik user"""
|
||||
raw_google_user = {}
|
||||
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
|
||||
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
|
||||
continue
|
||||
try:
|
||||
mapping: GoogleWorkspaceProviderMapping
|
||||
value = mapping.evaluate(
|
||||
user=obj,
|
||||
request=None,
|
||||
provider=self.provider,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_google_user, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_google_user:
|
||||
raise StopSync(ValueError("No user mappings configured"), obj)
|
||||
if "primaryEmail" not in raw_google_user:
|
||||
raw_google_user["primaryEmail"] = str(obj.email)
|
||||
return delete_none_values(raw_google_user)
|
||||
|
||||
def delete(self, obj: User):
|
||||
"""Delete user"""
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=obj
|
||||
).first()
|
||||
if not google_user:
|
||||
self.logger.debug("User does not exist in Google, skipping")
|
||||
return None
|
||||
with transaction.atomic():
|
||||
response = None
|
||||
if self.provider.user_delete_action == GoogleWorkspaceDeleteAction.DELETE:
|
||||
response = self._request(
|
||||
self.directory_service.users().delete(userKey=google_user.google_id)
|
||||
)
|
||||
elif self.provider.user_delete_action == GoogleWorkspaceDeleteAction.SUSPEND:
|
||||
response = self._request(
|
||||
self.directory_service.users().update(
|
||||
userKey=google_user.google_id, body={"suspended": True}
|
||||
)
|
||||
)
|
||||
google_user.delete()
|
||||
return response
|
||||
|
||||
def create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
google_user = self.to_schema(user)
|
||||
self.check_email_valid(
|
||||
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
||||
)
|
||||
with transaction.atomic():
|
||||
try:
|
||||
response = self._request(self.directory_service.users().insert(body=google_user))
|
||||
except ObjectExistsSyncException:
|
||||
# user already exists in google workspace, so we can connect them manually
|
||||
GoogleWorkspaceProviderUser.objects.create(
|
||||
provider=self.provider, user=user, google_id=user.email
|
||||
)
|
||||
except TransientSyncException as exc:
|
||||
raise exc
|
||||
else:
|
||||
GoogleWorkspaceProviderUser.objects.create(
|
||||
provider=self.provider, user=user, google_id=response["primaryEmail"]
|
||||
)
|
||||
|
||||
def update(self, user: User, connection: GoogleWorkspaceProviderUser):
|
||||
"""Update existing user"""
|
||||
google_user = self.to_schema(user)
|
||||
self.check_email_valid(
|
||||
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
||||
)
|
||||
self._request(
|
||||
self.directory_service.users().update(userKey=connection.google_id, body=google_user)
|
||||
)
|
||||
|
||||
def discover(self):
|
||||
"""Iterate through all users and connect them with authentik users if possible"""
|
||||
request = self.directory_service.users().list(
|
||||
customer="my_customer", maxResults=500, orderBy="email"
|
||||
)
|
||||
while request:
|
||||
response = request.execute()
|
||||
for user in response.get("users", []):
|
||||
self._discover_single_user(user)
|
||||
request = self.directory_service.users().list_next(
|
||||
previous_request=request, previous_response=response
|
||||
)
|
||||
|
||||
def _discover_single_user(self, user: dict):
|
||||
"""handle discovery of a single user"""
|
||||
email = user["primaryEmail"]
|
||||
matching_authentik_user = self.provider.get_object_qs(User).filter(email=email).first()
|
||||
if not matching_authentik_user:
|
||||
return
|
||||
GoogleWorkspaceProviderUser.objects.get_or_create(
|
||||
provider=self.provider,
|
||||
user=matching_authentik_user,
|
||||
google_id=email,
|
||||
)
|
@ -1,167 +0,0 @@
|
||||
# Generated by Django 5.0.4 on 2024-05-07 16:03
|
||||
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0035_alter_group_options_and_more"),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="GoogleWorkspaceProviderMapping",
|
||||
fields=[
|
||||
(
|
||||
"propertymapping_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.propertymapping",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Google Workspace Provider Mapping",
|
||||
"verbose_name_plural": "Google Workspace Provider Mappings",
|
||||
},
|
||||
bases=("authentik_core.propertymapping",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GoogleWorkspaceProvider",
|
||||
fields=[
|
||||
(
|
||||
"provider_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.provider",
|
||||
),
|
||||
),
|
||||
("delegated_subject", models.EmailField(max_length=254)),
|
||||
("credentials", models.JSONField()),
|
||||
(
|
||||
"scopes",
|
||||
models.TextField(
|
||||
default="https://www.googleapis.com/auth/admin.directory.user,https://www.googleapis.com/auth/admin.directory.group,https://www.googleapis.com/auth/admin.directory.group.member,https://www.googleapis.com/auth/admin.directory.domain.readonly"
|
||||
),
|
||||
),
|
||||
("default_group_email_domain", models.TextField()),
|
||||
("exclude_users_service_account", models.BooleanField(default=False)),
|
||||
(
|
||||
"user_delete_action",
|
||||
models.TextField(
|
||||
choices=[
|
||||
("do_nothing", "Do Nothing"),
|
||||
("delete", "Delete"),
|
||||
("suspend", "Suspend"),
|
||||
],
|
||||
default="delete",
|
||||
),
|
||||
),
|
||||
(
|
||||
"group_delete_action",
|
||||
models.TextField(
|
||||
choices=[
|
||||
("do_nothing", "Do Nothing"),
|
||||
("delete", "Delete"),
|
||||
("suspend", "Suspend"),
|
||||
],
|
||||
default="delete",
|
||||
),
|
||||
),
|
||||
(
|
||||
"filter_group",
|
||||
models.ForeignKey(
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_DEFAULT,
|
||||
to="authentik_core.group",
|
||||
),
|
||||
),
|
||||
(
|
||||
"property_mappings_group",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
default=None,
|
||||
help_text="Property mappings used for group creation/updating.",
|
||||
to="authentik_core.propertymapping",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Google Workspace Provider",
|
||||
"verbose_name_plural": "Google Workspace Providers",
|
||||
},
|
||||
bases=("authentik_core.provider", models.Model),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GoogleWorkspaceProviderGroup",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("google_id", models.TextField()),
|
||||
(
|
||||
"group",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="authentik_core.group"
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_google_workspace.googleworkspaceprovider",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"unique_together": {("google_id", "group", "provider")},
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GoogleWorkspaceProviderUser",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("google_id", models.TextField()),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_google_workspace.googleworkspaceprovider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"unique_together": {("google_id", "user", "provider")},
|
||||
},
|
||||
),
|
||||
]
|
@ -1,179 +0,0 @@
|
||||
"""Google workspace sync provider"""
|
||||
|
||||
from typing import Any, Self
|
||||
from uuid import uuid4
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from google.oauth2.service_account import Credentials
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import (
|
||||
BackchannelProvider,
|
||||
Group,
|
||||
PropertyMapping,
|
||||
User,
|
||||
UserTypes,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
|
||||
|
||||
def default_scopes() -> list[str]:
|
||||
return [
|
||||
"https://www.googleapis.com/auth/admin.directory.user",
|
||||
"https://www.googleapis.com/auth/admin.directory.group",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.member",
|
||||
"https://www.googleapis.com/auth/admin.directory.domain.readonly",
|
||||
]
|
||||
|
||||
|
||||
class GoogleWorkspaceDeleteAction(models.TextChoices):
|
||||
"""Action taken when a user/group is deleted in authentik. Suspend is not available for groups,
|
||||
and will be treated as `do_nothing`"""
|
||||
|
||||
DO_NOTHING = "do_nothing"
|
||||
DELETE = "delete"
|
||||
SUSPEND = "suspend"
|
||||
|
||||
|
||||
class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
"""Sync users from authentik into Google Workspace."""
|
||||
|
||||
delegated_subject = models.EmailField()
|
||||
credentials = models.JSONField()
|
||||
scopes = models.TextField(default=",".join(default_scopes()))
|
||||
|
||||
default_group_email_domain = models.TextField()
|
||||
exclude_users_service_account = models.BooleanField(default=False)
|
||||
user_delete_action = models.TextField(
|
||||
choices=GoogleWorkspaceDeleteAction.choices, default=GoogleWorkspaceDeleteAction.DELETE
|
||||
)
|
||||
group_delete_action = models.TextField(
|
||||
choices=GoogleWorkspaceDeleteAction.choices, default=GoogleWorkspaceDeleteAction.DELETE
|
||||
)
|
||||
|
||||
filter_group = models.ForeignKey(
|
||||
"authentik_core.group", on_delete=models.SET_DEFAULT, default=None, null=True
|
||||
)
|
||||
|
||||
property_mappings_group = models.ManyToManyField(
|
||||
PropertyMapping,
|
||||
default=None,
|
||||
blank=True,
|
||||
help_text=_("Property mappings used for group creation/updating."),
|
||||
)
|
||||
|
||||
def client_for_model(
|
||||
self, model: type[User | Group]
|
||||
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
|
||||
if issubclass(model, User):
|
||||
from authentik.enterprise.providers.google_workspace.clients.users import (
|
||||
GoogleWorkspaceUserClient,
|
||||
)
|
||||
|
||||
return GoogleWorkspaceUserClient(self)
|
||||
if issubclass(model, Group):
|
||||
from authentik.enterprise.providers.google_workspace.clients.groups import (
|
||||
GoogleWorkspaceGroupClient,
|
||||
)
|
||||
|
||||
return GoogleWorkspaceGroupClient(self)
|
||||
raise ValueError(f"Invalid model {model}")
|
||||
|
||||
def get_object_qs(self, type: type[User | Group]) -> QuerySet[User | Group]:
|
||||
if type == User:
|
||||
# Get queryset of all users with consistent ordering
|
||||
# according to the provider's settings
|
||||
base = User.objects.all().exclude_anonymous()
|
||||
if self.exclude_users_service_account:
|
||||
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
|
||||
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
)
|
||||
if self.filter_group:
|
||||
base = base.filter(ak_groups__in=[self.filter_group])
|
||||
return base.order_by("pk")
|
||||
if type == Group:
|
||||
# Get queryset of all groups with consistent ordering
|
||||
return Group.objects.all().order_by("pk")
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
|
||||
def google_credentials(self):
|
||||
return {
|
||||
"credentials": Credentials.from_service_account_info(
|
||||
self.credentials, scopes=self.scopes.split(",")
|
||||
).with_subject(self.delegated_subject),
|
||||
}
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-provider-google-workspace-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.google_workspace.api.providers import (
|
||||
GoogleProviderSerializer,
|
||||
)
|
||||
|
||||
return GoogleProviderSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"Google Workspace Provider {self.name}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Google Workspace Provider")
|
||||
verbose_name_plural = _("Google Workspace Providers")
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderMapping(PropertyMapping):
|
||||
"""Map authentik data to outgoing Google requests"""
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-google-workspace-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.google_workspace.api.property_mappings import (
|
||||
GoogleProviderMappingSerializer,
|
||||
)
|
||||
|
||||
return GoogleProviderMappingSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"Google Workspace Provider Mapping {self.name}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Google Workspace Provider Mapping")
|
||||
verbose_name_plural = _("Google Workspace Provider Mappings")
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderUser(models.Model):
|
||||
"""Mapping of a user and provider to a Google user ID"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
google_id = models.TextField()
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
unique_together = (("google_id", "user", "provider"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Google Workspace User {self.user_id} to {self.provider_id}"
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderGroup(models.Model):
|
||||
"""Mapping of a group and provider to a Google group ID"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
google_id = models.TextField()
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
unique_together = (("google_id", "group", "provider"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Google Workspace Group {self.group_id} to {self.provider_id}"
|
@ -1,13 +0,0 @@
|
||||
"""Google workspace provider task Settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"providers_google_workspace_sync": {
|
||||
"task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all",
|
||||
"schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
"""Google provider signals"""
|
||||
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import (
|
||||
google_workspace_sync,
|
||||
google_workspace_sync_direct,
|
||||
google_workspace_sync_m2m,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.signals import register_signals
|
||||
|
||||
register_signals(
|
||||
GoogleWorkspaceProvider,
|
||||
task_sync_single=google_workspace_sync,
|
||||
task_sync_direct=google_workspace_sync_direct,
|
||||
task_sync_m2m=google_workspace_sync_m2m,
|
||||
)
|
@ -1,34 +0,0 @@
|
||||
"""Google Provider tasks"""
|
||||
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.sync.outgoing.tasks import SyncTasks
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
sync_tasks = SyncTasks(GoogleWorkspaceProvider)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def google_workspace_sync_objects(*args, **kwargs):
|
||||
return sync_tasks.sync_objects(*args, **kwargs)
|
||||
|
||||
|
||||
@CELERY_APP.task(base=SystemTask, bind=True)
|
||||
def google_workspace_sync(self, provider_pk: int, *args, **kwargs):
|
||||
"""Run full sync for Google Workspace provider"""
|
||||
return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def google_workspace_sync_all():
|
||||
return sync_tasks.sync_all(google_workspace_sync)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def google_workspace_sync_direct(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct(*args, **kwargs)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def google_workspace_sync_m2m(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_m2m(*args, **kwargs)
|
@ -1,14 +0,0 @@
|
||||
{
|
||||
"kind": "admin#directory#domains",
|
||||
"etag": "\"a1kA7zE2sFLsHiFwgXN9G3effoc9grR2OwUu8_95xD4/uvC5HsKHylhnUtnRV6ZxINODtV0\"",
|
||||
"domains": [
|
||||
{
|
||||
"kind": "admin#directory#domain",
|
||||
"etag": "\"a1kA7zE2sFLsHiFwgXN9G3effoc9grR2OwUu8_95xD4/V4koSPWBFIWuIpAmUamO96QhTLo\"",
|
||||
"domainName": "goauthentik.io",
|
||||
"isPrimary": true,
|
||||
"verified": true,
|
||||
"creationTime": "1543048869840"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,313 +0,0 @@
|
||||
"""Google Workspace Group tests"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.enterprise.providers.google_workspace.clients.test_http import MockHTTP
|
||||
from authentik.enterprise.providers.google_workspace.models import (
|
||||
GoogleWorkspaceDeleteAction,
|
||||
GoogleWorkspaceProvider,
|
||||
GoogleWorkspaceProviderGroup,
|
||||
GoogleWorkspaceProviderMapping,
|
||||
)
|
||||
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
domains_list_v1_mock = load_fixture("fixtures/domains_list_v1.json")
|
||||
|
||||
|
||||
class GoogleWorkspaceGroupTests(TestCase):
|
||||
"""Google workspace Group tests"""
|
||||
|
||||
@apply_blueprint("system/providers-google-workspace.yaml")
|
||||
def setUp(self) -> None:
|
||||
# Delete all groups and groups as the mocked HTTP responses only return one ID
|
||||
# which will cause errors with multiple groups
|
||||
Tenant.objects.update(avatars="none")
|
||||
User.objects.all().exclude_anonymous().delete()
|
||||
Group.objects.all().delete()
|
||||
self.provider: GoogleWorkspaceProvider = GoogleWorkspaceProvider.objects.create(
|
||||
name=generate_id(),
|
||||
credentials={},
|
||||
delegated_subject="",
|
||||
exclude_users_service_account=True,
|
||||
default_group_email_domain="goauthentik.io",
|
||||
)
|
||||
self.app: Application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
)
|
||||
self.app.backchannel_providers.add(self.provider)
|
||||
self.provider.property_mappings.add(
|
||||
GoogleWorkspaceProviderMapping.objects.get(
|
||||
managed="goauthentik.io/providers/google_workspace/user"
|
||||
)
|
||||
)
|
||||
self.provider.property_mappings_group.add(
|
||||
GoogleWorkspaceProviderMapping.objects.get(
|
||||
managed="goauthentik.io/providers/google_workspace/group"
|
||||
)
|
||||
)
|
||||
self.api_key = generate_id()
|
||||
|
||||
def test_group_create(self):
|
||||
"""Test group creation"""
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"id": generate_id()},
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
group = Group.objects.create(name=uid)
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNotNone(google_group)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 2)
|
||||
|
||||
def test_group_create_update(self):
|
||||
"""Test group updating"""
|
||||
uid = generate_id()
|
||||
ext_id = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"id": ext_id},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}?key={self.api_key}&alt=json",
|
||||
method="PUT",
|
||||
body={"id": ext_id},
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
group = Group.objects.create(name=uid)
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNotNone(google_group)
|
||||
|
||||
group.name = "new name"
|
||||
group.save()
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 4)
|
||||
|
||||
def test_group_create_delete(self):
|
||||
"""Test group deletion"""
|
||||
uid = generate_id()
|
||||
ext_id = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"id": ext_id},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}?key={self.api_key}",
|
||||
method="DELETE",
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
group = Group.objects.create(name=uid)
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNotNone(google_group)
|
||||
|
||||
group.delete()
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 4)
|
||||
|
||||
def test_group_create_member_add(self):
|
||||
"""Test group creation"""
|
||||
uid = generate_id()
|
||||
ext_id = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"id": ext_id},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
|
||||
method="PUT",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}/members?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = create_test_user(uid)
|
||||
group = Group.objects.create(name=uid)
|
||||
group.users.add(user)
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNotNone(google_group)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 8)
|
||||
|
||||
def test_group_create_member_remove(self):
|
||||
"""Test group creation"""
|
||||
uid = generate_id()
|
||||
ext_id = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"id": ext_id},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
|
||||
method="PUT",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}/members/{uid}%40goauthentik.io?key={self.api_key}",
|
||||
method="DELETE",
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}/members?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = create_test_user(uid)
|
||||
group = Group.objects.create(name=uid)
|
||||
group.users.add(user)
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNotNone(google_group)
|
||||
group.users.remove(user)
|
||||
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 10)
|
||||
|
||||
def test_group_create_delete_do_nothing(self):
|
||||
"""Test group deletion (delete action = do nothing)"""
|
||||
self.provider.group_delete_action = GoogleWorkspaceDeleteAction.DO_NOTHING
|
||||
self.provider.save()
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"id": uid},
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
group = Group.objects.create(name=uid)
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNotNone(google_group)
|
||||
|
||||
group.delete()
|
||||
self.assertEqual(len(http.requests()), 3)
|
||||
self.assertFalse(
|
||||
GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group__name=uid
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_sync_task(self):
|
||||
"""Test group discovery"""
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
method="GET",
|
||||
body={"users": []},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
method="GET",
|
||||
body={"groups": [{"id": uid, "name": uid}]},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups/{uid}?key={self.api_key}&alt=json",
|
||||
method="PUT",
|
||||
body={"id": uid},
|
||||
)
|
||||
self.app.backchannel_providers.remove(self.provider)
|
||||
different_group = Group.objects.create(
|
||||
name=uid,
|
||||
)
|
||||
self.app.backchannel_providers.add(self.provider)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
google_workspace_sync.delay(self.provider.pk).get()
|
||||
self.assertTrue(
|
||||
GoogleWorkspaceProviderGroup.objects.filter(
|
||||
group=different_group, provider=self.provider
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
@ -1,287 +0,0 @@
|
||||
"""Google Workspace User tests"""
|
||||
|
||||
from json import loads
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.enterprise.providers.google_workspace.clients.test_http import MockHTTP
|
||||
from authentik.enterprise.providers.google_workspace.models import (
|
||||
GoogleWorkspaceDeleteAction,
|
||||
GoogleWorkspaceProvider,
|
||||
GoogleWorkspaceProviderMapping,
|
||||
GoogleWorkspaceProviderUser,
|
||||
)
|
||||
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
domains_list_v1_mock = load_fixture("fixtures/domains_list_v1.json")
|
||||
|
||||
|
||||
class GoogleWorkspaceUserTests(TestCase):
|
||||
"""Google workspace User tests"""
|
||||
|
||||
@apply_blueprint("system/providers-google-workspace.yaml")
|
||||
def setUp(self) -> None:
|
||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||
# which will cause errors with multiple users
|
||||
Tenant.objects.update(avatars="none")
|
||||
User.objects.all().exclude_anonymous().delete()
|
||||
Group.objects.all().delete()
|
||||
self.provider: GoogleWorkspaceProvider = GoogleWorkspaceProvider.objects.create(
|
||||
name=generate_id(),
|
||||
credentials={},
|
||||
delegated_subject="",
|
||||
exclude_users_service_account=True,
|
||||
default_group_email_domain="goauthentik.io",
|
||||
)
|
||||
self.app: Application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
)
|
||||
self.app.backchannel_providers.add(self.provider)
|
||||
self.provider.property_mappings.add(
|
||||
GoogleWorkspaceProviderMapping.objects.get(
|
||||
managed="goauthentik.io/providers/google_workspace/user"
|
||||
)
|
||||
)
|
||||
self.provider.property_mappings_group.add(
|
||||
GoogleWorkspaceProviderMapping.objects.get(
|
||||
managed="goauthentik.io/providers/google_workspace/group"
|
||||
)
|
||||
)
|
||||
self.api_key = generate_id()
|
||||
|
||||
def test_user_create(self):
|
||||
"""Test user creation"""
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=user
|
||||
).first()
|
||||
self.assertIsNotNone(google_user)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 2)
|
||||
|
||||
def test_user_create_update(self):
|
||||
"""Test user updating"""
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
|
||||
method="PUT",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=user
|
||||
).first()
|
||||
self.assertIsNotNone(google_user)
|
||||
|
||||
user.name = "new name"
|
||||
user.save()
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 4)
|
||||
|
||||
def test_user_create_delete(self):
|
||||
"""Test user deletion"""
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}",
|
||||
method="DELETE",
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=user
|
||||
).first()
|
||||
self.assertIsNotNone(google_user)
|
||||
|
||||
user.delete()
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 4)
|
||||
|
||||
def test_user_create_delete_suspend(self):
|
||||
"""Test user deletion (delete action = Suspend)"""
|
||||
self.provider.user_delete_action = GoogleWorkspaceDeleteAction.SUSPEND
|
||||
self.provider.save()
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
|
||||
method="PUT",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=user
|
||||
).first()
|
||||
self.assertIsNotNone(google_user)
|
||||
|
||||
user.delete()
|
||||
self.assertEqual(len(http.requests()), 4)
|
||||
_, _, body, _ = http.requests()[3]
|
||||
self.assertEqual(
|
||||
loads(body),
|
||||
{
|
||||
"suspended": True,
|
||||
},
|
||||
)
|
||||
self.assertFalse(
|
||||
GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user__username=uid
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_user_create_delete_do_nothing(self):
|
||||
"""Test user deletion (delete action = do nothing)"""
|
||||
self.provider.user_delete_action = GoogleWorkspaceDeleteAction.DO_NOTHING
|
||||
self.provider.save()
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
|
||||
method="POST",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=user
|
||||
).first()
|
||||
self.assertIsNotNone(google_user)
|
||||
|
||||
user.delete()
|
||||
self.assertEqual(len(http.requests()), 3)
|
||||
self.assertFalse(
|
||||
GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user__username=uid
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_sync_task(self):
|
||||
"""Test user discovery"""
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
method="GET",
|
||||
body={"users": [{"primaryEmail": f"{uid}@goauthentik.io"}]},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/groups?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
|
||||
method="GET",
|
||||
body={"groups": []},
|
||||
)
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
|
||||
method="PUT",
|
||||
body={"primaryEmail": f"{uid}@goauthentik.io"},
|
||||
)
|
||||
self.app.backchannel_providers.remove(self.provider)
|
||||
different_user = User.objects.create(
|
||||
username=uid,
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
self.app.backchannel_providers.add(self.provider)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
google_workspace_sync.delay(self.provider.pk).get()
|
||||
self.assertTrue(
|
||||
GoogleWorkspaceProviderUser.objects.filter(
|
||||
user=different_user, provider=self.provider
|
||||
).exists()
|
||||
)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 5)
|
@ -1,11 +0,0 @@
|
||||
"""google provider urls"""
|
||||
|
||||
from authentik.enterprise.providers.google_workspace.api.property_mappings import (
|
||||
GoogleProviderMappingViewSet,
|
||||
)
|
||||
from authentik.enterprise.providers.google_workspace.api.providers import GoogleProviderViewSet
|
||||
|
||||
api_urlpatterns = [
|
||||
("providers/google_workspace", GoogleProviderViewSet),
|
||||
("propertymappings/provider/google_workspace", GoogleProviderMappingViewSet),
|
||||
]
|
@ -11,7 +11,7 @@ from django.utils.translation import gettext as _
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User, default_token_key
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.models import SerializerModel
|
||||
|
@ -14,7 +14,6 @@ CELERY_BEAT_SCHEDULE = {
|
||||
|
||||
TENANT_APPS = [
|
||||
"authentik.enterprise.audit",
|
||||
"authentik.enterprise.providers.google_workspace",
|
||||
"authentik.enterprise.providers.rac",
|
||||
"authentik.enterprise.stages.source",
|
||||
]
|
||||
|
@ -60,8 +60,6 @@ class SystemTaskSerializer(ModelSerializer):
|
||||
"duration",
|
||||
"status",
|
||||
"messages",
|
||||
"expires",
|
||||
"expiring",
|
||||
]
|
||||
|
||||
|
||||
|
@ -214,15 +214,7 @@ class AuditMiddleware:
|
||||
model=model_to_dict(instance),
|
||||
).run()
|
||||
|
||||
def m2m_changed_handler(
|
||||
self,
|
||||
request: HttpRequest,
|
||||
sender,
|
||||
instance: Model,
|
||||
action: str,
|
||||
thread_kwargs: dict | None = None,
|
||||
**_,
|
||||
):
|
||||
def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_):
|
||||
"""Signal handler for all object's m2m_changed"""
|
||||
if action not in ["pre_add", "pre_remove", "post_clear"]:
|
||||
return
|
||||
@ -237,5 +229,4 @@ class AuditMiddleware:
|
||||
request,
|
||||
user=user,
|
||||
model=model_to_dict(instance),
|
||||
**thread_kwargs,
|
||||
).run()
|
||||
|
@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
from structlog.stdlib import get_logger
|
||||
from tenant_schemas_celery.task import TenantTask
|
||||
|
||||
from authentik.events.logs import LogEvent
|
||||
@ -15,12 +15,12 @@ from authentik.events.models import SystemTask as DBSystemTask
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class SystemTask(TenantTask):
|
||||
"""Task which can save its state to the cache"""
|
||||
|
||||
logger: BoundLogger
|
||||
|
||||
# For tasks that should only be listed if they failed, set this to False
|
||||
save_on_success: bool
|
||||
|
||||
@ -63,7 +63,6 @@ class SystemTask(TenantTask):
|
||||
def before_start(self, task_id, args, kwargs):
|
||||
self._start_precise = perf_counter()
|
||||
self._start = now()
|
||||
self.logger = get_logger().bind(task_id=task_id)
|
||||
return super().before_start(task_id, args, kwargs)
|
||||
|
||||
def db(self) -> DBSystemTask | None:
|
||||
@ -120,7 +119,7 @@ class SystemTask(TenantTask):
|
||||
"task_call_kwargs": sanitize_item(kwargs),
|
||||
"status": self._status,
|
||||
"messages": sanitize_item(self._messages),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours + 3),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours),
|
||||
"expiring": True,
|
||||
},
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from django.db.models.query_utils import Q
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import (
|
||||
Event,
|
||||
|
@ -53,7 +53,6 @@ cache:
|
||||
|
||||
# result_backend:
|
||||
# url: ""
|
||||
# transport_options: ""
|
||||
|
||||
debug: false
|
||||
remote_debug: false
|
||||
|
@ -9,7 +9,6 @@ from typing import Any
|
||||
|
||||
from cachetools import TLRUCache, cached
|
||||
from django.core.exceptions import FieldError
|
||||
from django.utils.text import slugify
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from rest_framework.serializers import ValidationError
|
||||
from sentry_sdk.hub import Hub
|
||||
@ -57,7 +56,6 @@ class BaseEvaluator:
|
||||
"requests": get_http_session(),
|
||||
"resolve_dns": BaseEvaluator.expr_resolve_dns,
|
||||
"reverse_dns": BaseEvaluator.expr_reverse_dns,
|
||||
"slugify": slugify,
|
||||
}
|
||||
self._context = {}
|
||||
|
||||
|
@ -100,7 +100,6 @@ def get_logger_config():
|
||||
"fsevents": "WARNING",
|
||||
"uvicorn": "WARNING",
|
||||
"gunicorn": "INFO",
|
||||
"requests_mock": "WARNING",
|
||||
}
|
||||
for handler_name, level in handler_level_map.items():
|
||||
base_config["loggers"][handler_name] = {
|
||||
|
@ -1,5 +0,0 @@
|
||||
"""Sync constants"""
|
||||
|
||||
PAGE_SIZE = 100
|
||||
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
|
||||
HTTP_CONFLICT = 409
|
@ -1,54 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.utils.text import slugify
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
|
||||
|
||||
class SyncStatusSerializer(PassiveSerializer):
|
||||
"""Provider sync status"""
|
||||
|
||||
is_running = BooleanField(read_only=True)
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class OutgoingSyncProviderStatusMixin:
|
||||
"""Common API Endpoints for Outgoing sync providers"""
|
||||
|
||||
sync_single_task: Callable = None
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: SyncStatusSerializer(),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
}
|
||||
)
|
||||
@action(
|
||||
methods=["GET"],
|
||||
detail=True,
|
||||
pagination_class=None,
|
||||
url_path="sync/status",
|
||||
filter_backends=[],
|
||||
)
|
||||
def sync_status(self, request: Request, pk: int) -> Response:
|
||||
"""Get provider's sync status"""
|
||||
provider: OutgoingSyncProvider = self.get_object()
|
||||
tasks = list(
|
||||
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
|
||||
name=self.sync_single_task.__name__,
|
||||
uid=slugify(provider.name),
|
||||
)
|
||||
)
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
"is_running": provider.sync_lock.locked(),
|
||||
}
|
||||
return Response(SyncStatusSerializer(status).data)
|
@ -1,83 +0,0 @@
|
||||
"""Basic outgoing sync Client"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.db import DatabaseError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models import Model
|
||||
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
|
||||
|
||||
class Direction(StrEnum):
|
||||
|
||||
add = "add"
|
||||
remove = "remove"
|
||||
|
||||
|
||||
class BaseOutgoingSyncClient[
|
||||
TModel: "Model", TConnection: "Model", TSchema: dict, TProvider: "OutgoingSyncProvider"
|
||||
]:
|
||||
"""Basic Outgoing sync client Client"""
|
||||
|
||||
provider: TProvider
|
||||
connection_type: type[TConnection]
|
||||
connection_type_query: str
|
||||
|
||||
can_discover = False
|
||||
|
||||
def __init__(self, provider: TProvider):
|
||||
self.logger = get_logger().bind(provider=provider.name)
|
||||
self.provider = provider
|
||||
|
||||
def create(self, obj: TModel) -> TConnection:
|
||||
"""Create object in remote destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, obj: TModel, connection: object):
|
||||
"""Update object in remote destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, obj: TModel) -> tuple[TConnection, bool]:
|
||||
"""Write object to destination. Uses self.create and self.update, but
|
||||
can be overwritten for further logic"""
|
||||
remote_obj = self.connection_type.objects.filter(
|
||||
provider=self.provider, **{self.connection_type_query: obj}
|
||||
).first()
|
||||
connection: TConnection | None = None
|
||||
try:
|
||||
if not remote_obj:
|
||||
connection = self.create(obj)
|
||||
return connection, True
|
||||
try:
|
||||
self.update(obj, remote_obj)
|
||||
return remote_obj, False
|
||||
except NotFoundSyncException:
|
||||
remote_obj.delete()
|
||||
connection = self.create(obj)
|
||||
return connection, True
|
||||
except DatabaseError as exc:
|
||||
self.logger.warning("Failed to write object", obj=obj, exc=exc)
|
||||
if connection:
|
||||
connection.delete()
|
||||
return None, False
|
||||
|
||||
def delete(self, obj: TModel):
|
||||
"""Delete object from destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_schema(self, obj: TModel) -> TSchema:
|
||||
"""Convert object to destination schema"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def discover(self):
|
||||
"""Optional method. Can be used to implement a "discovery" where
|
||||
upon creation of this provider, this function will be called and can
|
||||
pre-link any users/groups in the remote system with the respective
|
||||
object in authentik based on a common identifier"""
|
||||
raise NotImplementedError()
|
@ -1,37 +0,0 @@
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
|
||||
|
||||
class BaseSyncException(SentryIgnoredException):
|
||||
"""Base class for all sync exceptions"""
|
||||
|
||||
|
||||
class TransientSyncException(BaseSyncException):
|
||||
"""Transient sync exception which may be caused by network blips, etc"""
|
||||
|
||||
|
||||
class NotFoundSyncException(BaseSyncException):
|
||||
"""Exception when an object was not found in the remote system"""
|
||||
|
||||
|
||||
class ObjectExistsSyncException(BaseSyncException):
|
||||
"""Exception when an object already exists in the remote system"""
|
||||
|
||||
|
||||
class StopSync(BaseSyncException):
|
||||
"""Exception raised when a configuration error should stop the sync process"""
|
||||
|
||||
def __init__(
|
||||
self, exc: Exception, obj: object | None = None, mapping: object | None = None
|
||||
) -> None:
|
||||
self.exc = exc
|
||||
self.obj = obj
|
||||
self.mapping = mapping
|
||||
|
||||
def detail(self) -> str:
|
||||
"""Get human readable details of this error"""
|
||||
msg = f"Error {str(self.exc)}"
|
||||
if self.obj:
|
||||
msg += f", caused by {self.obj}"
|
||||
if self.mapping:
|
||||
msg += f" (mapping {self.mapping})"
|
||||
return msg
|
@ -1,32 +0,0 @@
|
||||
from typing import Any, Self
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Model, QuerySet
|
||||
from redis.lock import Lock
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.sync.outgoing import PAGE_TIMEOUT
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
|
||||
|
||||
class OutgoingSyncProvider(Model):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def client_for_model[
|
||||
T: User | Group
|
||||
](self, model: type[T]) -> BaseOutgoingSyncClient[T, Any, Any, Self]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def sync_lock(self) -> Lock:
|
||||
"""Redis lock to prevent multiple parallel syncs happening"""
|
||||
return Lock(
|
||||
cache.client.get_client(),
|
||||
name=f"goauthentik.io/providers/outgoing-sync/{str(self.pk)}",
|
||||
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
|
||||
)
|
@ -1,71 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import m2m_changed, post_save, pre_delete
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
|
||||
def register_signals(
|
||||
provider_type: type[OutgoingSyncProvider],
|
||||
task_sync_single: Callable[[int], None],
|
||||
task_sync_direct: Callable[[int], None],
|
||||
task_sync_m2m: Callable[[int], None],
|
||||
):
|
||||
"""Register sync signals"""
|
||||
uid = class_to_path(provider_type)
|
||||
|
||||
def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_):
|
||||
"""Trigger sync when Provider is saved"""
|
||||
users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE)
|
||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
||||
time_limit = soft_time_limit * 1.5
|
||||
task_sync_single.apply_async(
|
||||
(instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
||||
)
|
||||
|
||||
post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
|
||||
"""Post save handler"""
|
||||
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
|
||||
return
|
||||
task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)
|
||||
|
||||
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
|
||||
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
|
||||
"""Pre-delete handler"""
|
||||
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
|
||||
return
|
||||
task_sync_direct.delay(
|
||||
class_to_path(instance.__class__), instance.pk, Direction.remove.value
|
||||
)
|
||||
|
||||
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
|
||||
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_m2m_changed(
|
||||
sender: type[Model], instance, action: str, pk_set: set, reverse: bool, **kwargs
|
||||
):
|
||||
"""Sync group membership"""
|
||||
if action not in ["post_add", "post_remove"]:
|
||||
return
|
||||
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
|
||||
return
|
||||
# reverse: instance is a Group, pk_set is a list of user pks
|
||||
# non-reverse: instance is a User, pk_set is a list of groups
|
||||
if reverse:
|
||||
task_sync_m2m.delay(str(instance.pk), action, list(pk_set))
|
||||
else:
|
||||
for group_pk in pk_set:
|
||||
task_sync_m2m.delay(group_pk, action, [instance.pk])
|
||||
|
||||
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)
|
@ -1,215 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from celery.result import allow_join_result
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.logs import LogEvent
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync, TransientSyncException
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
|
||||
|
||||
class SyncTasks:
|
||||
"""Container for all sync 'tasks' (this class doesn't actually contain celery
|
||||
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
|
||||
|
||||
logger: BoundLogger
|
||||
|
||||
def __init__(self, provider_model: type[OutgoingSyncProvider]) -> None:
|
||||
super().__init__()
|
||||
self._provider_model = provider_model
|
||||
|
||||
def sync_all(self, single_sync: Callable[[int], None]):
|
||||
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
|
||||
self.trigger_single_task(provider, single_sync)
|
||||
|
||||
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
|
||||
"""Wrapper single sync task that correctly sets time limits based
|
||||
on the amount of objects that will be synced"""
|
||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
||||
time_limit = soft_time_limit * 1.5
|
||||
return sync_task.apply_async(
|
||||
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
||||
)
|
||||
|
||||
def sync_single(
|
||||
self,
|
||||
task: SystemTask,
|
||||
provider_pk: int,
|
||||
sync_objects: Callable[[int, int], list[str]],
|
||||
):
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
)
|
||||
provider = self._provider_model.objects.filter(
|
||||
pk=provider_pk, backchannel_application__isnull=False
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
lock = provider.sync_lock
|
||||
if lock.locked():
|
||||
self.logger.debug("Sync locked, skipping task", source=provider.name)
|
||||
return
|
||||
task.set_uid(slugify(provider.name))
|
||||
messages = []
|
||||
messages.append(_("Starting full provider sync"))
|
||||
self.logger.debug("Starting provider sync")
|
||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||
with allow_join_result(), lock:
|
||||
try:
|
||||
for page in users_paginator.page_range:
|
||||
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
|
||||
for msg in sync_objects.apply_async(
|
||||
args=(class_to_path(User), page, provider_pk),
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
).get():
|
||||
messages.append(msg)
|
||||
for page in groups_paginator.page_range:
|
||||
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
|
||||
for msg in sync_objects.apply_async(
|
||||
args=(class_to_path(Group), page, provider_pk),
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
).get():
|
||||
messages.append(msg)
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
raise task.retry(exc=exc) from exc
|
||||
except StopSync as exc:
|
||||
task.set_error(exc)
|
||||
return
|
||||
task.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
def sync_objects(self, object_type: str, page: int, provider_pk: int):
|
||||
_object_type = path_to_class(object_type)
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
object_type=object_type,
|
||||
)
|
||||
messages = []
|
||||
provider = self._provider_model.objects.filter(pk=provider_pk).first()
|
||||
if not provider:
|
||||
return messages
|
||||
try:
|
||||
client = provider.client_for_model(_object_type)
|
||||
except TransientSyncException:
|
||||
return messages
|
||||
paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
|
||||
if client.can_discover:
|
||||
self.logger.debug("starting discover")
|
||||
client.discover()
|
||||
self.logger.debug("starting sync for page", page=page)
|
||||
for obj in paginator.page(page).object_list:
|
||||
obj: Model
|
||||
try:
|
||||
client.write(obj)
|
||||
except SkipObjectException:
|
||||
continue
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
||||
messages.append(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to transient error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
)
|
||||
)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc)
|
||||
messages.append(
|
||||
LogEvent(
|
||||
_(
|
||||
"Stopping sync due to error: {error}".format_map(
|
||||
{
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
)
|
||||
)
|
||||
break
|
||||
return messages
|
||||
|
||||
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
model_class: type[Model] = path_to_class(model)
|
||||
instance = model_class.objects.filter(pk=pk).first()
|
||||
if not instance:
|
||||
return
|
||||
operation = Direction(raw_op)
|
||||
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
|
||||
client = provider.client_for_model(instance.__class__)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset = provider.get_object_qs(instance.__class__)
|
||||
if not queryset:
|
||||
continue
|
||||
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=instance.pk).exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
if operation == Direction.add:
|
||||
client.write(instance)
|
||||
if operation == Direction.remove:
|
||||
client.delete(instance)
|
||||
except (StopSync, TransientSyncException) as exc:
|
||||
self.logger.warning(exc, provider_pk=provider.pk)
|
||||
|
||||
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
group = Group.objects.filter(pk=group_pk).first()
|
||||
if not group:
|
||||
return
|
||||
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet = provider.get_object_qs(Group)
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=group_pk).exists():
|
||||
continue
|
||||
|
||||
client = provider.client_for_model(Group)
|
||||
try:
|
||||
operation = None
|
||||
if action == "post_add":
|
||||
operation = Direction.add
|
||||
if action == "post_remove":
|
||||
operation = Direction.remove
|
||||
client.update_group(group, operation, pk_set)
|
||||
except (StopSync, TransientSyncException) as exc:
|
||||
self.logger.warning(exc, provider_pk=provider.pk)
|
@ -24,7 +24,7 @@ def load_fixture(path: str, **kwargs) -> str:
|
||||
fixture = _fixture.read()
|
||||
try:
|
||||
return fixture % kwargs
|
||||
except (TypeError, ValueError):
|
||||
except TypeError:
|
||||
return fixture
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@ from django.views import View
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import Application
|
||||
from authentik.providers.oauth2.constants import (
|
||||
ACR_AUTHENTIK_DEFAULT,
|
||||
|
@ -11,7 +11,7 @@ from django.views import View
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.challenge import PermissionDict
|
||||
from authentik.providers.oauth2.constants import (
|
||||
|
@ -9,7 +9,7 @@ from lxml import etree # nosec
|
||||
from lxml.etree import Element, SubElement # nosec
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.signals import get_login_event
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
|
@ -1,12 +1,19 @@
|
||||
"""SCIM Provider API Views"""
|
||||
|
||||
from django.utils.text import slugify
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
|
||||
|
||||
class SCIMProviderSerializer(ProviderSerializer):
|
||||
@ -33,7 +40,14 @@ class SCIMProviderSerializer(ProviderSerializer):
|
||||
extra_kwargs = {}
|
||||
|
||||
|
||||
class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
|
||||
class SCIMSyncStatusSerializer(PassiveSerializer):
|
||||
"""SCIM Provider sync status"""
|
||||
|
||||
is_running = BooleanField(read_only=True)
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
"""SCIMProvider Viewset"""
|
||||
|
||||
queryset = SCIMProvider.objects.all()
|
||||
@ -41,4 +55,25 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
|
||||
filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"]
|
||||
search_fields = ["name", "url"]
|
||||
ordering = ["name", "url"]
|
||||
sync_single_task = scim_sync
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: SCIMSyncStatusSerializer(),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
}
|
||||
)
|
||||
@action(methods=["GET"], detail=True, pagination_class=None, filter_backends=[])
|
||||
def sync_status(self, request: Request, pk: int) -> Response:
|
||||
"""Get provider's sync status"""
|
||||
provider: SCIMProvider = self.get_object()
|
||||
tasks = list(
|
||||
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
|
||||
name="scim_sync",
|
||||
uid=slugify(provider.name),
|
||||
)
|
||||
)
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
"is_running": provider.sync_lock.locked(),
|
||||
}
|
||||
return Response(SCIMSyncStatusSerializer(status).data)
|
||||
|
@ -0,0 +1,4 @@
|
||||
"""SCIM constants"""
|
||||
|
||||
PAGE_SIZE = 100
|
||||
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
|
||||
|
@ -1,37 +1,33 @@
|
||||
"""SCIM Client"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from django.http import HttpResponseBadRequest, HttpResponseNotFound
|
||||
from pydantic import ValidationError
|
||||
from requests import RequestException, Session
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.sync.outgoing import HTTP_CONFLICT
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, ObjectExistsSyncException
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.providers.scim.clients.exceptions import SCIMRequestException
|
||||
from authentik.providers.scim.clients.exceptions import ResourceMissing, SCIMRequestException
|
||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models import Model
|
||||
from pydantic import BaseModel
|
||||
T = TypeVar("T")
|
||||
|
||||
SchemaType = TypeVar("SchemaType")
|
||||
|
||||
|
||||
class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
||||
BaseOutgoingSyncClient[TModel, TConnection, TSchema, SCIMProvider]
|
||||
):
|
||||
class SCIMClient(Generic[T, SchemaType]):
|
||||
"""SCIM Client"""
|
||||
|
||||
base_url: str
|
||||
token: str
|
||||
provider: SCIMProvider
|
||||
|
||||
_session: Session
|
||||
_config: ServiceProviderConfiguration
|
||||
|
||||
def __init__(self, provider: SCIMProvider):
|
||||
super().__init__(provider)
|
||||
self._session = get_http_session()
|
||||
self.provider = provider
|
||||
# Remove trailing slashes as we assume the URL doesn't have any
|
||||
@ -40,6 +36,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
||||
base_url = base_url[:-1]
|
||||
self.base_url = base_url
|
||||
self.token = provider.token
|
||||
self.logger = get_logger().bind(provider=provider.name)
|
||||
self._config = self.get_service_provider_config()
|
||||
|
||||
def _request(self, method: str, path: str, **kwargs) -> dict:
|
||||
@ -60,9 +57,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
||||
self.logger.debug("scim request", path=path, method=method, **kwargs)
|
||||
if response.status_code >= HttpResponseBadRequest.status_code:
|
||||
if response.status_code == HttpResponseNotFound.status_code:
|
||||
raise NotFoundSyncException(response)
|
||||
if response.status_code == HTTP_CONFLICT:
|
||||
raise ObjectExistsSyncException(response)
|
||||
raise ResourceMissing(response)
|
||||
self.logger.warning(
|
||||
"Failed to send SCIM request", path=path, method=method, response=response.text
|
||||
)
|
||||
@ -81,3 +76,15 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
||||
except (ValidationError, SCIMRequestException) as exc:
|
||||
self.logger.warning("failed to get ServiceProviderConfig", exc=exc)
|
||||
return default_config
|
||||
|
||||
def write(self, obj: T):
|
||||
"""Write object to SCIM"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete(self, obj: T):
|
||||
"""Delete object from SCIM"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_scim(self, obj: T) -> SchemaType:
|
||||
"""Convert object to scim"""
|
||||
raise NotImplementedError()
|
||||
|
@ -3,11 +3,28 @@
|
||||
from pydantic import ValidationError
|
||||
from requests import Response
|
||||
|
||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.providers.scim.clients.schema import SCIMError
|
||||
|
||||
|
||||
class SCIMRequestException(TransientSyncException):
|
||||
class StopSync(SentryIgnoredException):
|
||||
"""Exception raised when a configuration error should stop the sync process"""
|
||||
|
||||
def __init__(self, exc: Exception, obj: object, mapping: object | None = None) -> None:
|
||||
self.exc = exc
|
||||
self.obj = obj
|
||||
self.mapping = mapping
|
||||
|
||||
def detail(self) -> str:
|
||||
"""Get human readable details of this error"""
|
||||
msg = f"Error {str(self.exc)}, caused by {self.obj}"
|
||||
|
||||
if self.mapping:
|
||||
msg += f" (mapping {self.mapping})"
|
||||
return msg
|
||||
|
||||
|
||||
class SCIMRequestException(SentryIgnoredException):
|
||||
"""Exception raised when an SCIM request fails"""
|
||||
|
||||
_response: Response | None
|
||||
@ -22,8 +39,13 @@ class SCIMRequestException(TransientSyncException):
|
||||
if not self._response:
|
||||
return self._message
|
||||
try:
|
||||
error = SCIMError.model_validate_json(self._response.text)
|
||||
error = SCIMError.parse_raw(self._response.text)
|
||||
return error.detail
|
||||
except ValidationError:
|
||||
pass
|
||||
return self._message
|
||||
|
||||
|
||||
class ResourceMissing(SCIMRequestException):
|
||||
"""Error raised when the provider raises a 404, meaning that we
|
||||
should delete our internal ID and re-create the object"""
|
||||
|
@ -5,36 +5,47 @@ from pydantic import ValidationError
|
||||
from pydanticscim.group import GroupMember
|
||||
from pydanticscim.responses import PatchOp, PatchOperation
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import Group
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
NotFoundSyncException,
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
)
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.utils import delete_none_values
|
||||
from authentik.providers.scim.clients.base import SCIMClient
|
||||
from authentik.providers.scim.clients.exceptions import (
|
||||
ResourceMissing,
|
||||
SCIMRequestException,
|
||||
StopSync,
|
||||
)
|
||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
|
||||
from authentik.providers.scim.clients.schema import PatchRequest
|
||||
from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMUser
|
||||
|
||||
|
||||
class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
|
||||
"""SCIM client for groups"""
|
||||
|
||||
connection_type = SCIMGroup
|
||||
connection_type_query = "group"
|
||||
def write(self, obj: Group):
|
||||
"""Write a group"""
|
||||
scim_group = SCIMGroup.objects.filter(provider=self.provider, group=obj).first()
|
||||
if not scim_group:
|
||||
return self._create(obj)
|
||||
try:
|
||||
return self._update(obj, scim_group)
|
||||
except ResourceMissing:
|
||||
scim_group.delete()
|
||||
return self._create(obj)
|
||||
|
||||
def to_schema(self, obj: Group) -> SCIMGroupSchema:
|
||||
def delete(self, obj: Group):
|
||||
"""Delete group"""
|
||||
scim_group = SCIMGroup.objects.filter(provider=self.provider, group=obj).first()
|
||||
if not scim_group:
|
||||
self.logger.debug("Group does not exist in SCIM, skipping")
|
||||
return None
|
||||
response = self._request("DELETE", f"/Groups/{scim_group.id}")
|
||||
scim_group.delete()
|
||||
return response
|
||||
|
||||
def to_scim(self, obj: Group) -> SCIMGroupSchema:
|
||||
"""Convert authentik user into SCIM"""
|
||||
raw_scim_group = {
|
||||
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:Group",),
|
||||
@ -55,8 +66,6 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_scim_group, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
@ -80,26 +89,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.scim_id,
|
||||
value=user.id,
|
||||
)
|
||||
)
|
||||
if members:
|
||||
scim_group.members = members
|
||||
return scim_group
|
||||
|
||||
def delete(self, obj: Group):
|
||||
"""Delete group"""
|
||||
scim_group = SCIMGroup.objects.filter(provider=self.provider, group=obj).first()
|
||||
if not scim_group:
|
||||
self.logger.debug("Group does not exist in SCIM, skipping")
|
||||
return None
|
||||
response = self._request("DELETE", f"/Groups/{scim_group.scim_id}")
|
||||
scim_group.delete()
|
||||
return response
|
||||
|
||||
def create(self, group: Group):
|
||||
def _create(self, group: Group):
|
||||
"""Create group from scratch and create a connection object"""
|
||||
scim_group = self.to_schema(group)
|
||||
scim_group = self.to_scim(group)
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/Groups",
|
||||
@ -108,28 +107,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
scim_id = response.get("id")
|
||||
if not scim_id or scim_id == "":
|
||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||
SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id)
|
||||
SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"])
|
||||
|
||||
def update(self, group: Group, connection: SCIMGroup):
|
||||
def _update(self, group: Group, connection: SCIMGroup):
|
||||
"""Update existing group"""
|
||||
scim_group = self.to_schema(group)
|
||||
scim_group.id = connection.scim_id
|
||||
scim_group = self.to_scim(group)
|
||||
scim_group.id = connection.id
|
||||
try:
|
||||
return self._request(
|
||||
"PUT",
|
||||
f"/Groups/{connection.scim_id}",
|
||||
f"/Groups/{scim_group.id}",
|
||||
json=scim_group.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
except NotFoundSyncException:
|
||||
except ResourceMissing:
|
||||
# Resource missing is handled by self.write, which will re-create the group
|
||||
raise
|
||||
except (SCIMRequestException, ObjectExistsSyncException):
|
||||
except SCIMRequestException:
|
||||
# Some providers don't support PUT on groups, so this is mainly a fix for the initial
|
||||
# sync, send patch add requests for all the users the group currently has
|
||||
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||
@ -144,12 +140,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
),
|
||||
)
|
||||
|
||||
def update_group(self, group: Group, action: Direction, users_set: set[int]):
|
||||
def update_group(self, group: Group, action: PatchOp, users_set: set[int]):
|
||||
"""Update a group, either using PUT to replace it or PATCH if supported"""
|
||||
if self._config.patch.supported:
|
||||
if action == Direction.add:
|
||||
if action == PatchOp.add:
|
||||
return self._patch_add_users(group, users_set)
|
||||
if action == Direction.remove:
|
||||
if action == PatchOp.remove:
|
||||
return self._patch_remove_users(group, users_set)
|
||||
try:
|
||||
return self.write(group)
|
||||
@ -157,9 +153,9 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
if self._config.is_fallback:
|
||||
# Assume that provider does not support PUT and also doesn't support
|
||||
# ServiceProviderConfig, so try PATCH as a fallback
|
||||
if action == Direction.add:
|
||||
if action == PatchOp.add:
|
||||
return self._patch_add_users(group, users_set)
|
||||
if action == Direction.remove:
|
||||
if action == PatchOp.remove:
|
||||
return self._patch_remove_users(group, users_set)
|
||||
raise exc
|
||||
|
||||
@ -189,13 +185,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
return
|
||||
user_ids = list(
|
||||
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
|
||||
"scim_id", flat=True
|
||||
"id", flat=True
|
||||
)
|
||||
)
|
||||
if len(user_ids) < 1:
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
scim_group.id,
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
@ -215,13 +211,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
return
|
||||
user_ids = list(
|
||||
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
|
||||
"scim_id", flat=True
|
||||
"id", flat=True
|
||||
)
|
||||
)
|
||||
if len(user_ids) < 1:
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
scim_group.id,
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
@ -9,14 +9,13 @@ from pydanticscim.service_provider import (
|
||||
)
|
||||
from pydanticscim.user import User as BaseUser
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
|
||||
|
||||
class User(BaseUser):
|
||||
"""Modified User schema with added externalId field"""
|
||||
|
||||
schemas: list[str] = [SCIM_USER_SCHEMA]
|
||||
schemas: list[str] = [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
]
|
||||
externalId: str | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
@ -24,7 +23,9 @@ class User(BaseUser):
|
||||
class Group(BaseGroup):
|
||||
"""Modified Group schema with added externalId field"""
|
||||
|
||||
schemas: list[str] = [SCIM_GROUP_SCHEMA]
|
||||
schemas: list[str] = [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:Group",
|
||||
]
|
||||
externalId: str | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
|
@ -3,27 +3,42 @@
|
||||
from deepmerge import always_merger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.utils import delete_none_values
|
||||
from authentik.providers.scim.clients.base import SCIMClient
|
||||
from authentik.providers.scim.clients.exceptions import ResourceMissing, StopSync
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMUser
|
||||
|
||||
|
||||
class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
|
||||
"""SCIM client for users"""
|
||||
|
||||
connection_type = SCIMUser
|
||||
connection_type_query = "user"
|
||||
def write(self, obj: User):
|
||||
"""Write a user"""
|
||||
scim_user = SCIMUser.objects.filter(provider=self.provider, user=obj).first()
|
||||
if not scim_user:
|
||||
return self._create(obj)
|
||||
try:
|
||||
return self._update(obj, scim_user)
|
||||
except ResourceMissing:
|
||||
scim_user.delete()
|
||||
return self._create(obj)
|
||||
|
||||
def to_schema(self, obj: User) -> SCIMUserSchema:
|
||||
def delete(self, obj: User):
|
||||
"""Delete user"""
|
||||
scim_user = SCIMUser.objects.filter(provider=self.provider, user=obj).first()
|
||||
if not scim_user:
|
||||
self.logger.debug("User does not exist in SCIM, skipping")
|
||||
return None
|
||||
response = self._request("DELETE", f"/Users/{scim_user.id}")
|
||||
scim_user.delete()
|
||||
return response
|
||||
|
||||
def to_scim(self, obj: User) -> SCIMUserSchema:
|
||||
"""Convert authentik user into SCIM"""
|
||||
raw_scim_user = {
|
||||
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:User",),
|
||||
@ -41,8 +56,6 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_scim_user, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
@ -61,19 +74,9 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
scim_user.externalId = str(obj.uid)
|
||||
return scim_user
|
||||
|
||||
def delete(self, obj: User):
|
||||
"""Delete user"""
|
||||
scim_user = SCIMUser.objects.filter(provider=self.provider, user=obj).first()
|
||||
if not scim_user:
|
||||
self.logger.debug("User does not exist in SCIM, skipping")
|
||||
return None
|
||||
response = self._request("DELETE", f"/Users/{scim_user.scim_id}")
|
||||
scim_user.delete()
|
||||
return response
|
||||
|
||||
def create(self, user: User):
|
||||
def _create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
scim_user = self.to_schema(user)
|
||||
scim_user = self.to_scim(user)
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/Users",
|
||||
@ -82,18 +85,15 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
scim_id = response.get("id")
|
||||
if not scim_id or scim_id == "":
|
||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||
SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
|
||||
SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"])
|
||||
|
||||
def update(self, user: User, connection: SCIMUser):
|
||||
def _update(self, user: User, connection: SCIMUser):
|
||||
"""Update existing user"""
|
||||
scim_user = self.to_schema(user)
|
||||
scim_user.id = connection.scim_id
|
||||
scim_user = self.to_scim(user)
|
||||
scim_user.id = connection.id
|
||||
self._request(
|
||||
"PUT",
|
||||
f"/Users/{connection.scim_id}",
|
||||
f"/Users/{connection.id}",
|
||||
json=scim_user.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
@ -3,7 +3,7 @@
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync, sync_tasks
|
||||
from authentik.providers.scim.tasks import scim_task_wrapper
|
||||
from authentik.tenants.management import TenantCommand
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -21,4 +21,4 @@ class Command(TenantCommand):
|
||||
if not provider:
|
||||
LOGGER.warning("Provider does not exist", name=provider_name)
|
||||
continue
|
||||
sync_tasks.trigger_single_task(provider, scim_sync).get()
|
||||
scim_task_wrapper(provider.pk).get()
|
||||
|
@ -1,76 +0,0 @@
|
||||
# Generated by Django 5.0.4 on 2024-05-03 12:38
|
||||
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
from django.apps.registry import Apps
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
from authentik.lib.migrations import progress_bar
|
||||
|
||||
|
||||
def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser")
|
||||
SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup")
|
||||
db_alias = schema_editor.connection.alias
|
||||
print("\nFixing primary key for SCIM users, this might take a couple of minutes...")
|
||||
for user in progress_bar(SCIMUser.objects.using(db_alias).all()):
|
||||
SCIMUser.objects.using(db_alias).filter(
|
||||
pk=user.pk, user=user.user_id, provider=user.provider_id
|
||||
).update(scim_id=user.pk, id=uuid.uuid4())
|
||||
|
||||
print("\nFixing primary key for SCIM groups, this might take a couple of minutes...")
|
||||
for group in progress_bar(SCIMGroup.objects.using(db_alias).all()):
|
||||
SCIMGroup.objects.using(db_alias).filter(
|
||||
pk=group.pk, group=group.group_id, provider=group.provider_id
|
||||
).update(scim_id=group.pk, id=uuid.uuid4())
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_providers_scim",
|
||||
"0001_squashed_0006_rename_parent_group_scimprovider_filter_group",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="scimgroup",
|
||||
name="scim_id",
|
||||
field=models.TextField(default="temp"),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimuser",
|
||||
name="scim_id",
|
||||
field=models.TextField(default="temp"),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.RunPython(fix_scim_user_group_pk),
|
||||
migrations.AlterField(
|
||||
model_name="scimgroup",
|
||||
name="id",
|
||||
field=models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimuser",
|
||||
name="id",
|
||||
field=models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()),
|
||||
migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimgroup",
|
||||
unique_together={("scim_id", "group", "provider")},
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimuser",
|
||||
unique_together={("scim_id", "user", "provider")},
|
||||
),
|
||||
]
|
@ -1,19 +1,17 @@
|
||||
"""SCIM Provider models"""
|
||||
|
||||
from typing import Any, Self
|
||||
from uuid import uuid4
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from redis.lock import Lock
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.providers.scim.clients import PAGE_TIMEOUT
|
||||
|
||||
|
||||
class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
class SCIMProvider(BackchannelProvider):
|
||||
"""SCIM 2.0 provider to create users and groups in external applications"""
|
||||
|
||||
exclude_users_service_account = models.BooleanField(default=False)
|
||||
@ -32,35 +30,30 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
help_text=_("Property mappings used for group creation/updating."),
|
||||
)
|
||||
|
||||
def client_for_model(
|
||||
self, model: type[User | Group]
|
||||
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
|
||||
if issubclass(model, User):
|
||||
from authentik.providers.scim.clients.users import SCIMUserClient
|
||||
@property
|
||||
def sync_lock(self) -> Lock:
|
||||
"""Redis lock for syncing SCIM to prevent multiple parallel syncs happening"""
|
||||
return Lock(
|
||||
cache.client.get_client(),
|
||||
name=f"goauthentik.io/providers/scim/sync-{str(self.pk)}",
|
||||
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
|
||||
)
|
||||
|
||||
return SCIMUserClient(self)
|
||||
if issubclass(model, Group):
|
||||
from authentik.providers.scim.clients.groups import SCIMGroupClient
|
||||
def get_user_qs(self) -> QuerySet[User]:
|
||||
"""Get queryset of all users with consistent ordering
|
||||
according to the provider's settings"""
|
||||
base = User.objects.all().exclude_anonymous()
|
||||
if self.exclude_users_service_account:
|
||||
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
|
||||
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
)
|
||||
if self.filter_group:
|
||||
base = base.filter(ak_groups__in=[self.filter_group])
|
||||
return base.order_by("pk")
|
||||
|
||||
return SCIMGroupClient(self)
|
||||
raise ValueError(f"Invalid model {model}")
|
||||
|
||||
def get_object_qs(self, type: type[User | Group]) -> QuerySet[User | Group]:
|
||||
if type == User:
|
||||
# Get queryset of all users with consistent ordering
|
||||
# according to the provider's settings
|
||||
base = User.objects.all().exclude_anonymous()
|
||||
if self.exclude_users_service_account:
|
||||
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
|
||||
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
)
|
||||
if self.filter_group:
|
||||
base = base.filter(ak_groups__in=[self.filter_group])
|
||||
return base.order_by("pk")
|
||||
if type == Group:
|
||||
# Get queryset of all groups with consistent ordering
|
||||
return Group.objects.all().order_by("pk")
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
def get_group_qs(self) -> QuerySet[Group]:
|
||||
"""Get queryset of all groups with consistent ordering"""
|
||||
return Group.objects.all().order_by("pk")
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
@ -89,7 +82,7 @@ class SCIMMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.providers.scim.api.property_mappings import SCIMMappingSerializer
|
||||
from authentik.providers.scim.api.property_mapping import SCIMMappingSerializer
|
||||
|
||||
return SCIMMappingSerializer
|
||||
|
||||
@ -104,13 +97,12 @@ class SCIMMapping(PropertyMapping):
|
||||
class SCIMUser(models.Model):
|
||||
"""Mapping of a user and provider to a SCIM user ID"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
scim_id = models.TextField()
|
||||
id = models.TextField(primary_key=True)
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
unique_together = (("scim_id", "user", "provider"),)
|
||||
unique_together = (("id", "user", "provider"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM User {self.user_id} to {self.provider_id}"
|
||||
@ -119,13 +111,12 @@ class SCIMUser(models.Model):
|
||||
class SCIMGroup(models.Model):
|
||||
"""Mapping of a group and provider to a SCIM user ID"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
scim_id = models.TextField()
|
||||
id = models.TextField(primary_key=True)
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
unique_together = (("scim_id", "group", "provider"),)
|
||||
unique_together = (("id", "group", "provider"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM Group {self.group_id} to {self.provider_id}"
|
||||
|
@ -7,7 +7,7 @@ from authentik.lib.utils.time import fqdn_rand
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"providers_scim_sync": {
|
||||
"task": "authentik.providers.scim.tasks.scim_sync_all",
|
||||
"schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*/4"),
|
||||
"schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
||||
|
@ -1,12 +1,56 @@
|
||||
"""SCIM provider signals"""
|
||||
|
||||
from authentik.lib.sync.outgoing.signals import register_signals
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_direct, scim_sync_m2m
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import m2m_changed, post_save, pre_delete
|
||||
from django.dispatch import receiver
|
||||
from pydanticscim.responses import PatchOp
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
register_signals(
|
||||
SCIMProvider,
|
||||
task_sync_single=scim_sync,
|
||||
task_sync_direct=scim_sync_direct,
|
||||
task_sync_m2m=scim_sync_m2m,
|
||||
)
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_task_wrapper
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@receiver(post_save, sender=SCIMProvider)
|
||||
def post_save_provider(sender: type[Model], instance, created: bool, **_):
|
||||
"""Trigger sync when SCIM provider is saved"""
|
||||
scim_task_wrapper(instance.pk)
|
||||
|
||||
|
||||
@receiver(post_save, sender=User)
|
||||
@receiver(post_save, sender=Group)
|
||||
def post_save_scim(sender: type[Model], instance: User | Group, created: bool, **_):
|
||||
"""Post save handler"""
|
||||
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
|
||||
return
|
||||
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.add.value)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=User)
|
||||
@receiver(pre_delete, sender=Group)
|
||||
def pre_delete_scim(sender: type[Model], instance: User | Group, **_):
|
||||
"""Pre-delete handler"""
|
||||
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
|
||||
return
|
||||
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.remove.value)
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=User.ak_groups.through)
|
||||
def m2m_changed_scim(
|
||||
sender: type[Model], instance, action: str, pk_set: set, reverse: bool, **kwargs
|
||||
):
|
||||
"""Sync group membership"""
|
||||
if action not in ["post_add", "post_remove"]:
|
||||
return
|
||||
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
|
||||
return
|
||||
# reverse: instance is a Group, pk_set is a list of user pks
|
||||
# non-reverse: instance is a User, pk_set is a list of groups
|
||||
if reverse:
|
||||
scim_signal_m2m.delay(str(instance.pk), action, list(pk_set))
|
||||
else:
|
||||
for group_pk in pk_set:
|
||||
scim_signal_m2m.delay(group_pk, action, [instance.pk])
|
||||
|
@ -1,34 +1,242 @@
|
||||
"""SCIM Provider tasks"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from celery.result import allow_join_result
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from pydanticscim.responses import PatchOp
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.sync.outgoing.tasks import SyncTasks
|
||||
from authentik.lib.utils.reflection import path_to_class
|
||||
from authentik.providers.scim.clients import PAGE_SIZE, PAGE_TIMEOUT
|
||||
from authentik.providers.scim.clients.base import SCIMClient
|
||||
from authentik.providers.scim.clients.exceptions import SCIMRequestException, StopSync
|
||||
from authentik.providers.scim.clients.group import SCIMGroupClient
|
||||
from authentik.providers.scim.clients.user import SCIMUserClient
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
sync_tasks = SyncTasks(SCIMProvider)
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def scim_sync_objects(*args, **kwargs):
|
||||
return sync_tasks.sync_objects(*args, **kwargs)
|
||||
|
||||
|
||||
@CELERY_APP.task(base=SystemTask, bind=True)
|
||||
def scim_sync(self, provider_pk: int, *args, **kwargs):
|
||||
"""Run full sync for SCIM provider"""
|
||||
return sync_tasks.sync_single(self, provider_pk, scim_sync_objects)
|
||||
def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient:
|
||||
"""Get SCIM client for model"""
|
||||
if isinstance(model, User):
|
||||
return SCIMUserClient(provider)
|
||||
if isinstance(model, Group):
|
||||
return SCIMGroupClient(provider)
|
||||
raise ValueError(f"Invalid model {model}")
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def scim_sync_all():
|
||||
return sync_tasks.sync_all(scim_sync)
|
||||
"""Run sync for all providers"""
|
||||
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
|
||||
scim_task_wrapper(provider.pk)
|
||||
|
||||
|
||||
def scim_task_wrapper(provider_pk: int):
|
||||
"""Wrap scim_sync to set the correct timeouts"""
|
||||
provider: SCIMProvider = SCIMProvider.objects.filter(
|
||||
pk=provider_pk, backchannel_application__isnull=False
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
|
||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
||||
time_limit = soft_time_limit * 1.5
|
||||
return scim_sync.apply_async(
|
||||
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
||||
)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def scim_sync(self: SystemTask, provider_pk: int) -> None:
|
||||
"""Run SCIM full sync for provider"""
|
||||
provider: SCIMProvider = SCIMProvider.objects.filter(
|
||||
pk=provider_pk, backchannel_application__isnull=False
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
lock = provider.sync_lock
|
||||
if lock.locked():
|
||||
LOGGER.debug("SCIM sync locked, skipping task", source=provider.name)
|
||||
return
|
||||
self.set_uid(slugify(provider.name))
|
||||
messages = []
|
||||
messages.append(_("Starting full SCIM sync"))
|
||||
LOGGER.debug("Starting SCIM sync")
|
||||
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
|
||||
self.soft_time_limit = self.time_limit = (
|
||||
users_paginator.num_pages + groups_paginator.num_pages
|
||||
) * PAGE_TIMEOUT
|
||||
with allow_join_result():
|
||||
try:
|
||||
for page in users_paginator.page_range:
|
||||
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
|
||||
for msg in scim_sync_users.delay(page, provider_pk).get():
|
||||
messages.append(msg)
|
||||
for page in groups_paginator.page_range:
|
||||
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
|
||||
for msg in scim_sync_group.delay(page, provider_pk).get():
|
||||
messages.append(msg)
|
||||
except StopSync as exc:
|
||||
self.set_error(exc)
|
||||
return
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
task_time_limit=PAGE_TIMEOUT,
|
||||
)
|
||||
def scim_sync_users(page: int, provider_pk: int):
|
||||
"""Sync single or multiple users to SCIM"""
|
||||
messages = []
|
||||
provider: SCIMProvider = SCIMProvider.objects.filter(pk=provider_pk).first()
|
||||
if not provider:
|
||||
return messages
|
||||
try:
|
||||
client = SCIMUserClient(provider)
|
||||
except SCIMRequestException:
|
||||
return messages
|
||||
paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
|
||||
LOGGER.debug("starting user sync for page", page=page)
|
||||
for user in paginator.page(page).object_list:
|
||||
try:
|
||||
client.write(user)
|
||||
except SCIMRequestException as exc:
|
||||
LOGGER.warning("failed to sync user", exc=exc, user=user)
|
||||
messages.append(
|
||||
_(
|
||||
"Failed to sync user {user_name} due to remote error: {error}".format_map(
|
||||
{
|
||||
"user_name": user.username,
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
except StopSync as exc:
|
||||
LOGGER.warning("Stopping sync", exc=exc)
|
||||
messages.append(
|
||||
_(
|
||||
"Stopping sync due to error: {error}".format_map(
|
||||
{
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def scim_sync_direct(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct(*args, **kwargs)
|
||||
def scim_sync_group(page: int, provider_pk: int):
|
||||
"""Sync single or multiple groups to SCIM"""
|
||||
messages = []
|
||||
provider: SCIMProvider = SCIMProvider.objects.filter(pk=provider_pk).first()
|
||||
if not provider:
|
||||
return messages
|
||||
try:
|
||||
client = SCIMGroupClient(provider)
|
||||
except SCIMRequestException:
|
||||
return messages
|
||||
paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
|
||||
LOGGER.debug("starting group sync for page", page=page)
|
||||
for group in paginator.page(page).object_list:
|
||||
try:
|
||||
client.write(group)
|
||||
except SCIMRequestException as exc:
|
||||
LOGGER.warning("failed to sync group", exc=exc, group=group)
|
||||
messages.append(
|
||||
_(
|
||||
"Failed to sync group {group_name} due to remote error: {error}".format_map(
|
||||
{
|
||||
"group_name": group.name,
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
except StopSync as exc:
|
||||
LOGGER.warning("Stopping sync", exc=exc)
|
||||
messages.append(
|
||||
_(
|
||||
"Stopping sync due to error: {error}".format_map(
|
||||
{
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def scim_sync_m2m(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_m2m(*args, **kwargs)
|
||||
def scim_signal_direct(model: str, pk: Any, raw_op: str):
|
||||
"""Handler for post_save and pre_delete signal"""
|
||||
model_class: type[Model] = path_to_class(model)
|
||||
instance = model_class.objects.filter(pk=pk).first()
|
||||
if not instance:
|
||||
return
|
||||
operation = PatchOp(raw_op)
|
||||
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
|
||||
client = client_for_model(provider, instance)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet | None = None
|
||||
if isinstance(instance, User):
|
||||
queryset = provider.get_user_qs()
|
||||
if isinstance(instance, Group):
|
||||
queryset = provider.get_group_qs()
|
||||
if not queryset:
|
||||
continue
|
||||
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=instance.pk).exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
if operation == PatchOp.add:
|
||||
client.write(instance)
|
||||
if operation == PatchOp.remove:
|
||||
client.delete(instance)
|
||||
except (StopSync, SCIMRequestException) as exc:
|
||||
LOGGER.warning(exc)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def scim_signal_m2m(group_pk: str, action: str, pk_set: list[int]):
|
||||
"""Update m2m (group membership)"""
|
||||
group = Group.objects.filter(pk=group_pk).first()
|
||||
if not group:
|
||||
return
|
||||
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet = provider.get_group_qs()
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=group_pk).exists():
|
||||
continue
|
||||
|
||||
client = SCIMGroupClient(provider)
|
||||
try:
|
||||
operation = None
|
||||
if action == "post_add":
|
||||
operation = PatchOp.add
|
||||
if action == "post_remove":
|
||||
operation = PatchOp.remove
|
||||
client.update_group(group, operation, pk_set)
|
||||
except (StopSync, SCIMRequestException) as exc:
|
||||
LOGGER.warning(exc)
|
||||
|
@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync, sync_tasks
|
||||
from authentik.providers.scim.tasks import scim_task_wrapper
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ class SCIMMembershipTests(TestCase):
|
||||
)
|
||||
|
||||
self.configure()
|
||||
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
|
||||
scim_task_wrapper(self.provider.pk).get()
|
||||
|
||||
self.assertEqual(mocker.call_count, 6)
|
||||
self.assertEqual(mocker.request_history[0].method, "GET")
|
||||
@ -169,7 +169,7 @@ class SCIMMembershipTests(TestCase):
|
||||
)
|
||||
|
||||
self.configure()
|
||||
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
|
||||
scim_task_wrapper(self.provider.pk).get()
|
||||
|
||||
self.assertEqual(mocker.call_count, 6)
|
||||
self.assertEqual(mocker.request_history[0].method, "GET")
|
||||
|
@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application, Group, User
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync, sync_tasks
|
||||
from authentik.providers.scim.tasks import scim_task_wrapper
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
|
||||
@ -88,72 +88,6 @@ class SCIMUserTests(TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_user_create_different_provider_same_id(self, mock: Mocker):
|
||||
"""Test user creation with multiple providers that happen
|
||||
to return the same object ID"""
|
||||
# Create duplicate provider
|
||||
provider: SCIMProvider = SCIMProvider.objects.create(
|
||||
name=generate_id(),
|
||||
url="https://localhost",
|
||||
token=generate_id(),
|
||||
exclude_users_service_account=True,
|
||||
)
|
||||
app: Application = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
)
|
||||
app.backchannel_providers.add(provider)
|
||||
provider.property_mappings.add(
|
||||
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
|
||||
)
|
||||
provider.property_mappings_group.add(
|
||||
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
|
||||
)
|
||||
|
||||
scim_id = generate_id()
|
||||
mock.get(
|
||||
"https://localhost/ServiceProviderConfig",
|
||||
json={},
|
||||
)
|
||||
mock.post(
|
||||
"https://localhost/Users",
|
||||
json={
|
||||
"id": scim_id,
|
||||
},
|
||||
)
|
||||
uid = generate_id()
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
self.assertEqual(mock.call_count, 4)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
self.assertEqual(mock.request_history[1].method, "POST")
|
||||
self.assertJSONEqual(
|
||||
mock.request_history[1].body,
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"active": True,
|
||||
"emails": [
|
||||
{
|
||||
"primary": True,
|
||||
"type": "other",
|
||||
"value": f"{uid}@goauthentik.io",
|
||||
}
|
||||
],
|
||||
"externalId": user.uid,
|
||||
"name": {
|
||||
"familyName": uid,
|
||||
"formatted": f"{uid} {uid}",
|
||||
"givenName": uid,
|
||||
},
|
||||
"displayName": f"{uid} {uid}",
|
||||
"userName": uid,
|
||||
},
|
||||
)
|
||||
|
||||
@Mocker()
|
||||
def test_user_create_update(self, mock: Mocker):
|
||||
"""Test user creation and update"""
|
||||
@ -302,7 +236,7 @@ class SCIMUserTests(TestCase):
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
|
||||
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
|
||||
scim_task_wrapper(self.provider.pk).get()
|
||||
|
||||
self.assertEqual(mock.call_count, 5)
|
||||
self.assertEqual(mock.request_history[0].method, "GET")
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""API URLs"""
|
||||
|
||||
from authentik.providers.scim.api.property_mappings import SCIMMappingViewSet
|
||||
from authentik.providers.scim.api.property_mapping import SCIMMappingViewSet
|
||||
from authentik.providers.scim.api.providers import SCIMProviderViewSet
|
||||
|
||||
api_urlpatterns = [
|
||||
|
@ -155,9 +155,6 @@ SPECTACULAR_SETTINGS = {
|
||||
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
|
||||
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
|
||||
"UserTypeEnum": "authentik.core.models.UserTypes",
|
||||
"GoogleWorkspaceDeleteAction": (
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceDeleteAction"
|
||||
),
|
||||
},
|
||||
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
||||
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
|
||||
@ -379,13 +376,7 @@ CELERY = {
|
||||
"task_default_queue": "authentik",
|
||||
"broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")),
|
||||
"result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")),
|
||||
"broker_transport_options": CONFIG.get_dict_from_b64_json(
|
||||
"broker.transport_options", {"retry_policy": {"timeout": 5.0}}
|
||||
),
|
||||
"result_backend_transport_options": CONFIG.get_dict_from_b64_json(
|
||||
"result_backend.transport_options", {"retry_policy": {"timeout": 5.0}}
|
||||
),
|
||||
"redis_retry_on_timeout": True,
|
||||
"broker_transport_options": CONFIG.get_dict_from_b64_json("broker.transport_options"),
|
||||
}
|
||||
|
||||
# Sentry integration
|
||||
|
@ -10,7 +10,7 @@ from drf_spectacular.utils import extend_schema, extend_schema_field, inline_ser
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import DictField, ListField, SerializerMethodField
|
||||
from rest_framework.fields import BooleanField, DictField, ListField, SerializerMethodField
|
||||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
@ -19,8 +19,9 @@ from rest_framework.viewsets import ModelViewSet
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.sync.outgoing.api import SyncStatusSerializer
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
|
||||
from authentik.sources.ldap.tasks import CACHE_KEY_STATUS, SYNC_CLASSES
|
||||
|
||||
@ -88,6 +89,13 @@ class LDAPSourceSerializer(SourceSerializer):
|
||||
extra_kwargs = {"bind_password": {"write_only": True}}
|
||||
|
||||
|
||||
class LDAPSyncStatusSerializer(PassiveSerializer):
|
||||
"""LDAP Source sync status"""
|
||||
|
||||
is_running = BooleanField(read_only=True)
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
"""LDAP Source Viewset"""
|
||||
|
||||
@ -124,16 +132,10 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: SyncStatusSerializer(),
|
||||
200: LDAPSyncStatusSerializer(),
|
||||
}
|
||||
)
|
||||
@action(
|
||||
methods=["GET"],
|
||||
detail=True,
|
||||
pagination_class=None,
|
||||
url_path="sync/status",
|
||||
filter_backends=[],
|
||||
)
|
||||
@action(methods=["GET"], detail=True, pagination_class=None, filter_backends=[])
|
||||
def sync_status(self, request: Request, slug: str) -> Response:
|
||||
"""Get source's sync status"""
|
||||
source: LDAPSource = self.get_object()
|
||||
@ -147,7 +149,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
"tasks": tasks,
|
||||
"is_running": source.sync_lock.locked(),
|
||||
}
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
return Response(LDAPSyncStatusSerializer(status).data)
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
|
@ -9,10 +9,7 @@ from django.db.models.query import QuerySet
|
||||
from ldap3 import DEREF_ALWAYS, SUBTREE, Connection
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG, set_path_in_dict
|
||||
from authentik.lib.merge import MERGE_LIST_UNIQUE
|
||||
@ -174,8 +171,6 @@ class BaseLDAPSynchronizer:
|
||||
set_path_in_dict(properties, object_field, value)
|
||||
else:
|
||||
properties[object_field] = flatten(value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except PropertyMappingExpressionException as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
|
@ -6,7 +6,6 @@ from django.core.exceptions import FieldError
|
||||
from django.db.utils import IntegrityError
|
||||
from ldap3 import ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, SUBTREE
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import Group
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchronizer, flatten
|
||||
@ -66,8 +65,6 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
defaults,
|
||||
)
|
||||
self._logger.debug("Created group with attributes", **defaults)
|
||||
except SkipObjectException:
|
||||
continue
|
||||
except (IntegrityError, FieldError, TypeError, AttributeError) as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
|
@ -6,7 +6,6 @@ from django.core.exceptions import FieldError
|
||||
from django.db.utils import IntegrityError
|
||||
from ldap3 import ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, SUBTREE
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchronizer, flatten
|
||||
@ -60,8 +59,6 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
ak_user, created = self.update_or_create_attributes(
|
||||
User, {f"attributes__{LDAP_UNIQUENESS}": uniq}, defaults
|
||||
)
|
||||
except SkipObjectException:
|
||||
continue
|
||||
except (IntegrityError, FieldError, TypeError, AttributeError) as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
|
@ -13,7 +13,6 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
|
||||
from authentik.sources.scim.models import SCIMSourceGroup
|
||||
from authentik.sources.scim.views.v2.base import SCIMView
|
||||
@ -27,11 +26,9 @@ class GroupsView(SCIMView):
|
||||
def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
|
||||
"""Convert Group to SCIM data"""
|
||||
payload = SCIMGroupModel(
|
||||
schemas=[SCIM_USER_SCHEMA],
|
||||
id=str(scim_group.group.pk),
|
||||
externalId=scim_group.id,
|
||||
displayName=scim_group.group.name,
|
||||
members=[],
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"location": self.request.build_absolute_uri(
|
||||
@ -45,24 +42,28 @@ class GroupsView(SCIMView):
|
||||
),
|
||||
},
|
||||
)
|
||||
for member in scim_group.group.users.order_by("pk"):
|
||||
member: User
|
||||
payload.members.append(GroupMember(value=str(member.uuid)))
|
||||
return payload.model_dump(mode="json", exclude_unset=True)
|
||||
return payload.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
)
|
||||
|
||||
def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
|
||||
"""List Group handler"""
|
||||
base_query = SCIMSourceGroup.objects.select_related("group").prefetch_related(
|
||||
"group__users"
|
||||
)
|
||||
if group_id:
|
||||
connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
|
||||
connection = (
|
||||
SCIMSourceGroup.objects.filter(source=self.source, group__group_uuid=group_id)
|
||||
.select_related("group")
|
||||
.first()
|
||||
)
|
||||
if not connection:
|
||||
raise Http404
|
||||
return Response(self.group_to_scim(connection))
|
||||
connections = (
|
||||
base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
|
||||
SCIMSourceGroup.objects.filter(source=self.source)
|
||||
.select_related("group")
|
||||
.order_by("pk")
|
||||
)
|
||||
connections = connections.filter(self.filter_parse(request))
|
||||
page = self.paginate_query(connections)
|
||||
return Response(
|
||||
{
|
||||
@ -78,8 +79,6 @@ class GroupsView(SCIMView):
|
||||
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict):
|
||||
"""Partial update a group"""
|
||||
group = connection.group if connection else Group()
|
||||
if _group := Group.objects.filter(name=data.get("displayName")).first():
|
||||
group = _group
|
||||
if "displayName" in data:
|
||||
group.name = data.get("displayName")
|
||||
if group.name == "":
|
||||
|
@ -11,7 +11,6 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserModel
|
||||
from authentik.sources.scim.models import SCIMSourceUser
|
||||
from authentik.sources.scim.views.v2.base import SCIMView
|
||||
@ -34,7 +33,6 @@ class UsersView(SCIMView):
|
||||
def user_to_scim(self, scim_user: SCIMSourceUser) -> dict:
|
||||
"""Convert User to SCIM data"""
|
||||
payload = SCIMUserModel(
|
||||
schemas=[SCIM_USER_SCHEMA],
|
||||
id=str(scim_user.user.uuid),
|
||||
externalId=scim_user.id,
|
||||
userName=scim_user.user.username,
|
||||
@ -64,7 +62,10 @@ class UsersView(SCIMView):
|
||||
),
|
||||
},
|
||||
)
|
||||
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||
final_payload = payload.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
)
|
||||
final_payload.update(scim_user.attributes)
|
||||
return final_payload
|
||||
|
||||
@ -98,8 +99,6 @@ class UsersView(SCIMView):
|
||||
def update_user(self, connection: SCIMSourceUser | None, data: QueryDict):
|
||||
"""Partial update a user"""
|
||||
user = connection.user if connection else User()
|
||||
if _user := User.objects.filter(username=data.get("userName")).first():
|
||||
user = _user
|
||||
user.path = self.source.get_user_path()
|
||||
if "userName" in data:
|
||||
user.username = data.get("userName")
|
||||
|
@ -9,7 +9,7 @@ from rest_framework.validators import UniqueValidator
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.flows.api.stages import StageSerializer
|
||||
from authentik.flows.challenge import ChallengeTypes, HttpChallengeResponse
|
||||
from authentik.flows.planner import FlowPlan
|
||||
|
@ -23,8 +23,8 @@ from rest_framework.fields import (
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import User
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.lib.models import SerializerModel
|
||||
|
@ -3,7 +3,6 @@
|
||||
from tenant_schemas_celery.scheduler import (
|
||||
TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler,
|
||||
)
|
||||
from tenant_schemas_celery.scheduler import TenantAwareScheduleEntry
|
||||
|
||||
|
||||
class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler):
|
||||
@ -12,11 +11,3 @@ class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler):
|
||||
@classmethod
|
||||
def get_queryset(cls):
|
||||
return super().get_queryset().filter(ready=True)
|
||||
|
||||
def apply_entry(self, entry: TenantAwareScheduleEntry, producer=None):
|
||||
# https://github.com/maciej-gol/tenant-schemas-celery/blob/master/tenant_schemas_celery/scheduler.py#L85
|
||||
# When (as by default) no tenant schemas are set, the public schema is excluded
|
||||
# so we need to explicitly include it here, otherwise the task is not executed
|
||||
if entry.tenant_schemas is None:
|
||||
entry.tenant_schemas = self.get_queryset().values_list("schema_name", flat=True)
|
||||
return super().apply_entry(entry, producer)
|
||||
|
@ -2,7 +2,7 @@
|
||||
"$schema": "http://json-schema.org/draft-07/schema",
|
||||
"$id": "https://goauthentik.io/blueprints/schema.json",
|
||||
"type": "object",
|
||||
"title": "authentik 2024.4.2 Blueprint schema",
|
||||
"title": "authentik 2024.4.1 Blueprint schema",
|
||||
"required": [
|
||||
"version",
|
||||
"entries"
|
||||
@ -2520,80 +2520,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_providers_google_workspace.googleworkspaceprovider"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"present",
|
||||
"created",
|
||||
"must_created"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_providers_google_workspace.googleworkspaceprovider"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_providers_google_workspace.googleworkspaceprovider"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_providers_google_workspace.googleworkspaceprovidermapping"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"present",
|
||||
"created",
|
||||
"must_created"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_providers_google_workspace.googleworkspaceprovidermapping"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_providers_google_workspace.googleworkspaceprovidermapping"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
@ -3411,7 +3337,6 @@
|
||||
"authentik.core",
|
||||
"authentik.enterprise",
|
||||
"authentik.enterprise.audit",
|
||||
"authentik.enterprise.providers.google_workspace",
|
||||
"authentik.enterprise.providers.rac",
|
||||
"authentik.enterprise.stages.source",
|
||||
"authentik.events"
|
||||
@ -3493,8 +3418,6 @@
|
||||
"authentik_core.application",
|
||||
"authentik_core.token",
|
||||
"authentik_enterprise.license",
|
||||
"authentik_providers_google_workspace.googleworkspaceprovider",
|
||||
"authentik_providers_google_workspace.googleworkspaceprovidermapping",
|
||||
"authentik_providers_rac.racprovider",
|
||||
"authentik_providers_rac.endpoint",
|
||||
"authentik_providers_rac.racpropertymapping",
|
||||
@ -8199,109 +8122,6 @@
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_providers_google_workspace.googleworkspaceprovider": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Name"
|
||||
},
|
||||
"property_mappings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"title": "Property mappings"
|
||||
},
|
||||
"property_mappings_group": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Property mappings used for group creation/updating."
|
||||
},
|
||||
"title": "Property mappings group",
|
||||
"description": "Property mappings used for group creation/updating."
|
||||
},
|
||||
"delegated_subject": {
|
||||
"type": "string",
|
||||
"format": "email",
|
||||
"maxLength": 254,
|
||||
"minLength": 1,
|
||||
"title": "Delegated subject"
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Credentials"
|
||||
},
|
||||
"scopes": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Scopes"
|
||||
},
|
||||
"exclude_users_service_account": {
|
||||
"type": "boolean",
|
||||
"title": "Exclude users service account"
|
||||
},
|
||||
"filter_group": {
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"title": "Filter group"
|
||||
},
|
||||
"user_delete_action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"do_nothing",
|
||||
"delete",
|
||||
"suspend"
|
||||
],
|
||||
"title": "User delete action"
|
||||
},
|
||||
"group_delete_action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"do_nothing",
|
||||
"delete",
|
||||
"suspend"
|
||||
],
|
||||
"title": "Group delete action"
|
||||
},
|
||||
"default_group_email_domain": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Default group email domain"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_providers_google_workspace.googleworkspaceprovidermapping": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"managed": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"minLength": 1,
|
||||
"title": "Managed by authentik",
|
||||
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Name"
|
||||
},
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Expression"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_providers_rac.racprovider": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -1,42 +0,0 @@
|
||||
version: 1
|
||||
metadata:
|
||||
labels:
|
||||
blueprints.goauthentik.io/system: "true"
|
||||
name: System - Google Provider - Mappings
|
||||
entries:
|
||||
- identifiers:
|
||||
managed: goauthentik.io/providers/google_workspace/user
|
||||
model: authentik_providers_google_workspace.googleworkspaceprovidermapping
|
||||
attrs:
|
||||
name: "authentik default Google Workspace Mapping: User"
|
||||
# https://developers.google.com/admin-sdk/directory/reference/rest/v1/users#User
|
||||
expression: |
|
||||
# Google require givenName and familyName to be set
|
||||
givenName, familyName = request.user.name, " "
|
||||
formatted = request.user.name + " "
|
||||
# This default sets givenName to the name before the first space
|
||||
# and the remainder as family name
|
||||
# if the user's name has no space the givenName is the entire name
|
||||
if " " in request.user.name:
|
||||
givenName, _, familyName = request.user.name.partition(" ")
|
||||
formatted = request.user.name
|
||||
return {
|
||||
"name": {
|
||||
"fullName": formatted,
|
||||
"familyName": familyName.strip(),
|
||||
"givenName": givenName.strip(),
|
||||
"displayName": formatted,
|
||||
},
|
||||
"password": request.user.password,
|
||||
"suspended": not request.user.is_active,
|
||||
}
|
||||
- identifiers:
|
||||
managed: goauthentik.io/providers/google_workspace/group
|
||||
model: authentik_providers_google_workspace.googleworkspaceprovidermapping
|
||||
attrs:
|
||||
name: "authentik default Google Workspace Mapping: Group"
|
||||
# https://developers.google.com/admin-sdk/directory/reference/rest/v1/groups#Group
|
||||
expression: |
|
||||
return {
|
||||
"name": group.name,
|
||||
}
|
@ -32,7 +32,7 @@ services:
|
||||
volumes:
|
||||
- redis:/data
|
||||
server:
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.2}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.1}
|
||||
restart: unless-stopped
|
||||
command: server
|
||||
environment:
|
||||
@ -53,7 +53,7 @@ services:
|
||||
- postgresql
|
||||
- redis
|
||||
worker:
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.2}
|
||||
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.1}
|
||||
restart: unless-stopped
|
||||
command: worker
|
||||
environment:
|
||||
|
6
go.mod
6
go.mod
@ -10,7 +10,7 @@ require (
|
||||
github.com/go-ldap/ldap/v3 v3.4.8
|
||||
github.com/go-openapi/runtime v0.28.0
|
||||
github.com/go-openapi/strfmt v0.23.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/handlers v1.5.2
|
||||
github.com/gorilla/mux v1.8.1
|
||||
@ -28,9 +28,9 @@ require (
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/wwt/guac v1.3.2
|
||||
goauthentik.io/api/v3 v3.2024042.2
|
||||
goauthentik.io/api/v3 v3.2024041.2
|
||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||
golang.org/x/oauth2 v0.20.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/sync v0.7.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
layeh.com/radius v0.0.0-20210819152912-ad72663a72ab
|
||||
|
12
go.sum
12
go.sum
@ -111,8 +111,8 @@ github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+Gr
|
||||
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
||||
github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3BumrGD58=
|
||||
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
@ -294,8 +294,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
|
||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
||||
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
|
||||
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
|
||||
goauthentik.io/api/v3 v3.2024042.2 h1:aGfIVrNXEWVuvKH3YDZpGINhnhWNwcVAGTla/Ck4hD8=
|
||||
goauthentik.io/api/v3 v3.2024042.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
goauthentik.io/api/v3 v3.2024041.2 h1:gbquIA8RU+9jJbFdGckQTtJzOfWVp2+QdF4LuNVTAWM=
|
||||
goauthentik.io/api/v3 v3.2024041.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
@ -382,8 +382,8 @@ golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4Iltr
|
||||
golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo=
|
||||
golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
|
||||
golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
@ -29,4 +29,4 @@ func UserAgent() string {
|
||||
return fmt.Sprintf("authentik@%s", FullVersion())
|
||||
}
|
||||
|
||||
const VERSION = "2024.4.2"
|
||||
const VERSION = "2024.4.1"
|
||||
|
@ -192,9 +192,7 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, server Server) (*A
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/outpost.goauthentik.io/start", func(w http.ResponseWriter, r *http.Request) {
|
||||
a.handleAuthStart(w, r, "")
|
||||
})
|
||||
mux.HandleFunc("/outpost.goauthentik.io/start", a.handleAuthStart)
|
||||
mux.HandleFunc("/outpost.goauthentik.io/callback", a.handleAuthCallback)
|
||||
mux.HandleFunc("/outpost.goauthentik.io/sign_out", a.handleSignOut)
|
||||
switch *p.Mode {
|
||||
|
@ -59,11 +59,19 @@ func (a *Application) forwardHandleTraefik(rw http.ResponseWriter, r *http.Reque
|
||||
a.log.Trace("path can be accessed without authentication")
|
||||
return
|
||||
}
|
||||
a.handleAuthStart(rw, r)
|
||||
// set the redirect flag to the current URL we have, since we redirect
|
||||
// to a (possibly) different domain, but we want to be redirected back
|
||||
// to the application
|
||||
// X-Forwarded-Uri is only the path, so we need to build the entire URL
|
||||
a.handleAuthStart(rw, r, fwd.String())
|
||||
s, _ := a.sessions.Get(r, a.SessionName())
|
||||
if _, redirectSet := s.Values[constants.SessionRedirect]; !redirectSet {
|
||||
s.Values[constants.SessionRedirect] = fwd.String()
|
||||
err = s.Save(r, rw)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to save session")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) forwardHandleCaddy(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -102,11 +110,19 @@ func (a *Application) forwardHandleCaddy(rw http.ResponseWriter, r *http.Request
|
||||
a.log.Trace("path can be accessed without authentication")
|
||||
return
|
||||
}
|
||||
a.handleAuthStart(rw, r)
|
||||
// set the redirect flag to the current URL we have, since we redirect
|
||||
// to a (possibly) different domain, but we want to be redirected back
|
||||
// to the application
|
||||
// X-Forwarded-Uri is only the path, so we need to build the entire URL
|
||||
a.handleAuthStart(rw, r, fwd.String())
|
||||
s, _ := a.sessions.Get(r, a.SessionName())
|
||||
if _, redirectSet := s.Values[constants.SessionRedirect]; !redirectSet {
|
||||
s.Values[constants.SessionRedirect] = fwd.String()
|
||||
err = s.Save(r, rw)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to save session")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) forwardHandleNginx(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -169,9 +185,17 @@ func (a *Application) forwardHandleEnvoy(rw http.ResponseWriter, r *http.Request
|
||||
a.log.Trace("path can be accessed without authentication")
|
||||
return
|
||||
}
|
||||
a.handleAuthStart(rw, r)
|
||||
// set the redirect flag to the current URL we have, since we redirect
|
||||
// to a (possibly) different domain, but we want to be redirected back
|
||||
// to the application
|
||||
// X-Forwarded-Uri is only the path, so we need to build the entire URL
|
||||
a.handleAuthStart(rw, r, fwd.String())
|
||||
s, _ := a.sessions.Get(r, a.SessionName())
|
||||
if _, redirectSet := s.Values[constants.SessionRedirect]; !redirectSet {
|
||||
s.Values[constants.SessionRedirect] = fwd.String()
|
||||
err = s.Save(r, rw)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to save session before redirect")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -47,14 +47,16 @@ func TestForwardHandleCaddy_Single_Headers(t *testing.T) {
|
||||
a.forwardHandleCaddy(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
loc, st := a.assertState(t, req, rr)
|
||||
loc, _ := rr.Result().Location()
|
||||
s, _ := a.sessions.Get(req, a.SessionName())
|
||||
shouldUrl := url.Values{
|
||||
"client_id": []string{*a.proxyConfig.ClientId},
|
||||
"redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"},
|
||||
"response_type": []string{"code"},
|
||||
"state": []string{s.Values[constants.SessionOAuthState].(string)},
|
||||
}
|
||||
assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String())
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", st.Redirect)
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect])
|
||||
}
|
||||
|
||||
func TestForwardHandleCaddy_Single_Claims(t *testing.T) {
|
||||
@ -132,12 +134,14 @@ func TestForwardHandleCaddy_Domain_Header(t *testing.T) {
|
||||
a.forwardHandleCaddy(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
loc, st := a.assertState(t, req, rr)
|
||||
loc, _ := rr.Result().Location()
|
||||
s, _ := a.sessions.Get(req, a.SessionName())
|
||||
shouldUrl := url.Values{
|
||||
"client_id": []string{*a.proxyConfig.ClientId},
|
||||
"redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"},
|
||||
"response_type": []string{"code"},
|
||||
"state": []string{s.Values[constants.SessionOAuthState].(string)},
|
||||
}
|
||||
assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String())
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", st.Redirect)
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect])
|
||||
}
|
||||
|
@ -32,14 +32,16 @@ func TestForwardHandleEnvoy_Single_Headers(t *testing.T) {
|
||||
a.forwardHandleEnvoy(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
loc, st := a.assertState(t, req, rr)
|
||||
loc, _ := rr.Result().Location()
|
||||
s, _ := a.sessions.Get(req, a.SessionName())
|
||||
shouldUrl := url.Values{
|
||||
"client_id": []string{*a.proxyConfig.ClientId},
|
||||
"redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"},
|
||||
"response_type": []string{"code"},
|
||||
"state": []string{s.Values[constants.SessionOAuthState].(string)},
|
||||
}
|
||||
assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String())
|
||||
assert.Equal(t, "http://ext.t.goauthentik.io/app", st.Redirect)
|
||||
assert.Equal(t, "http://ext.t.goauthentik.io/app", s.Values[constants.SessionRedirect])
|
||||
}
|
||||
|
||||
func TestForwardHandleEnvoy_Single_Claims(t *testing.T) {
|
||||
@ -100,13 +102,15 @@ func TestForwardHandleEnvoy_Domain_Header(t *testing.T) {
|
||||
a.forwardHandleEnvoy(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
loc, st := a.assertState(t, req, rr)
|
||||
loc, _ := rr.Result().Location()
|
||||
s, _ := a.sessions.Get(req, a.SessionName())
|
||||
|
||||
shouldUrl := url.Values{
|
||||
"client_id": []string{*a.proxyConfig.ClientId},
|
||||
"redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"},
|
||||
"response_type": []string{"code"},
|
||||
"state": []string{s.Values[constants.SessionOAuthState].(string)},
|
||||
}
|
||||
assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String())
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", st.Redirect)
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect])
|
||||
}
|
||||
|
@ -47,14 +47,16 @@ func TestForwardHandleTraefik_Single_Headers(t *testing.T) {
|
||||
a.forwardHandleTraefik(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
loc, st := a.assertState(t, req, rr)
|
||||
loc, _ := rr.Result().Location()
|
||||
s, _ := a.sessions.Get(req, a.SessionName())
|
||||
shouldUrl := url.Values{
|
||||
"client_id": []string{*a.proxyConfig.ClientId},
|
||||
"redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"},
|
||||
"response_type": []string{"code"},
|
||||
"state": []string{s.Values[constants.SessionOAuthState].(string)},
|
||||
}
|
||||
assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String())
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", st.Redirect)
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect])
|
||||
}
|
||||
|
||||
func TestForwardHandleTraefik_Single_Claims(t *testing.T) {
|
||||
@ -132,12 +134,14 @@ func TestForwardHandleTraefik_Domain_Header(t *testing.T) {
|
||||
a.forwardHandleTraefik(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
loc, st := a.assertState(t, req, rr)
|
||||
loc, _ := rr.Result().Location()
|
||||
s, _ := a.sessions.Get(req, a.SessionName())
|
||||
shouldUrl := url.Values{
|
||||
"client_id": []string{*a.proxyConfig.ClientId},
|
||||
"redirect_uri": []string{"https://ext.t.goauthentik.io/outpost.goauthentik.io/callback?X-authentik-auth-callback=true"},
|
||||
"response_type": []string{"code"},
|
||||
"state": []string{s.Values[constants.SessionOAuthState].(string)},
|
||||
}
|
||||
assert.Equal(t, fmt.Sprintf("http://fake-auth.t.goauthentik.io/auth?%s", shouldUrl.Encode()), loc.String())
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", st.Redirect)
|
||||
assert.Equal(t, "http://test.goauthentik.io/app", s.Values[constants.SessionRedirect])
|
||||
}
|
||||
|
@ -1,10 +1,13 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/outpost/proxyv2/constants"
|
||||
)
|
||||
@ -45,59 +48,69 @@ func (a *Application) checkRedirectParam(r *http.Request) (string, bool) {
|
||||
return u.String(), true
|
||||
}
|
||||
|
||||
func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request, fwd string) {
|
||||
state, err := a.createState(r, fwd)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to create state")
|
||||
func (a *Application) handleAuthStart(rw http.ResponseWriter, r *http.Request) {
|
||||
newState := base64.RawURLEncoding.EncodeToString(securecookie.GenerateRandomKey(32))
|
||||
s, _ := a.sessions.Get(r, a.SessionName())
|
||||
// Check if we already have a state in the session,
|
||||
// and if we do we don't do anything here
|
||||
currentState, ok := s.Values[constants.SessionOAuthState].(string)
|
||||
if ok {
|
||||
claims, err := a.checkAuth(rw, r)
|
||||
if err != nil && claims != nil {
|
||||
a.log.Trace("auth start request with existing authenticated session")
|
||||
a.redirect(rw, r)
|
||||
return
|
||||
}
|
||||
a.log.Trace("session already has state, sending redirect to current state")
|
||||
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(currentState), http.StatusFound)
|
||||
return
|
||||
}
|
||||
s, _ := a.sessions.Get(r, a.SessionName())
|
||||
err = s.Save(r, rw)
|
||||
rd, ok := a.checkRedirectParam(r)
|
||||
if ok {
|
||||
s.Values[constants.SessionRedirect] = rd
|
||||
a.log.WithField("rd", rd).Trace("Setting redirect")
|
||||
}
|
||||
s.Values[constants.SessionOAuthState] = newState
|
||||
err := s.Save(r, rw)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to save session")
|
||||
}
|
||||
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(state), http.StatusFound)
|
||||
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(newState), http.StatusFound)
|
||||
}
|
||||
|
||||
func (a *Application) redirectToStart(rw http.ResponseWriter, r *http.Request) {
|
||||
func (a *Application) handleAuthCallback(rw http.ResponseWriter, r *http.Request) {
|
||||
s, err := a.sessions.Get(r, a.SessionName())
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to decode session")
|
||||
a.log.WithError(err).Trace("failed to get session")
|
||||
}
|
||||
if r.Header.Get(constants.HeaderAuthorization) != "" && *a.proxyConfig.InterceptHeaderAuth {
|
||||
rw.WriteHeader(401)
|
||||
er := a.errorTemplates.Execute(rw, ErrorPageData{
|
||||
Title: "Unauthenticated",
|
||||
Message: "Due to 'Receive header authentication' being set, no redirect is performed.",
|
||||
ProxyPrefix: "/outpost.goauthentik.io",
|
||||
})
|
||||
if er != nil {
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
state, ok := s.Values[constants.SessionOAuthState]
|
||||
if !ok {
|
||||
a.log.Warning("No state saved in session")
|
||||
a.redirect(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
redirectUrl := urlJoin(a.proxyConfig.ExternalHost, r.URL.Path)
|
||||
|
||||
if a.Mode() == api.PROXYMODE_FORWARD_DOMAIN {
|
||||
dom := strings.TrimPrefix(*a.proxyConfig.CookieDomain, ".")
|
||||
// In forward_domain we only check that the current URL's host
|
||||
// ends with the cookie domain (remove the leading period if set)
|
||||
if !strings.HasSuffix(r.URL.Hostname(), dom) {
|
||||
a.log.WithField("url", r.URL.String()).WithField("cd", dom).Warning("Invalid redirect found")
|
||||
redirectUrl = a.proxyConfig.ExternalHost
|
||||
}
|
||||
}
|
||||
if _, redirectSet := s.Values[constants.SessionRedirect]; !redirectSet {
|
||||
s.Values[constants.SessionRedirect] = redirectUrl
|
||||
err = s.Save(r, rw)
|
||||
claims, err := a.redeemCallback(state.(string), r.URL, r.Context())
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to redeem code")
|
||||
rw.WriteHeader(400)
|
||||
// To prevent the user from just refreshing and cause more errors, delete
|
||||
// the state from the session
|
||||
delete(s.Values, constants.SessionOAuthState)
|
||||
err := s.Save(r, rw)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to save session before redirect")
|
||||
a.log.WithError(err).Warning("failed to save session")
|
||||
rw.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
urlArgs := url.Values{
|
||||
redirectParam: []string{redirectUrl},
|
||||
s.Options.MaxAge = int(time.Until(time.Unix(int64(claims.Exp), 0)).Seconds())
|
||||
s.Values[constants.SessionClaims] = &claims
|
||||
err = s.Save(r, rw)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to save session")
|
||||
rw.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
authUrl := urlJoin(a.proxyConfig.ExternalHost, "/outpost.goauthentik.io/start")
|
||||
http.Redirect(rw, r, authUrl+"?"+urlArgs.Encode(), http.StatusFound)
|
||||
a.redirect(rw, r)
|
||||
}
|
||||
|
@ -3,43 +3,22 @@ package application
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"goauthentik.io/internal/outpost/proxyv2/constants"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func (a *Application) handleAuthCallback(rw http.ResponseWriter, r *http.Request) {
|
||||
state := a.stateFromRequest(r)
|
||||
if state == nil {
|
||||
a.log.Warning("invalid state")
|
||||
a.redirect(rw, r)
|
||||
return
|
||||
func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Context) (*Claims, error) {
|
||||
state := u.Query().Get("state")
|
||||
a.log.WithFields(log.Fields{
|
||||
"states": savedState,
|
||||
"expected": state,
|
||||
}).Trace("tracing states")
|
||||
if savedState != state {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
claims, err := a.redeemCallback(r.URL, r.Context())
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to redeem code")
|
||||
a.redirect(rw, r)
|
||||
return
|
||||
}
|
||||
s, err := a.sessions.Get(r, a.SessionName())
|
||||
if err != nil {
|
||||
a.log.WithError(err).Trace("failed to get session")
|
||||
}
|
||||
s.Options.MaxAge = int(time.Until(time.Unix(int64(claims.Exp), 0)).Seconds())
|
||||
s.Values[constants.SessionClaims] = &claims
|
||||
err = s.Save(r, rw)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to save session")
|
||||
rw.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
a.redirect(rw, r)
|
||||
}
|
||||
|
||||
func (a *Application) redeemCallback(u *url.URL, c context.Context) (*Claims, error) {
|
||||
code := u.Query().Get("code")
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("blank code")
|
||||
|
@ -1,95 +0,0 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
type OAuthState struct {
|
||||
Issuer string `json:"iss" mapstructure:"iss"`
|
||||
SessionID string `json:"sid" mapstructure:"sid"`
|
||||
State string `json:"state" mapstructure:"state"`
|
||||
Redirect string `json:"redirect" mapstructure:"redirect"`
|
||||
}
|
||||
|
||||
func (oas *OAuthState) GetExpirationTime() (*jwt.NumericDate, error) { return nil, nil }
|
||||
func (oas *OAuthState) GetIssuedAt() (*jwt.NumericDate, error) { return nil, nil }
|
||||
func (oas *OAuthState) GetNotBefore() (*jwt.NumericDate, error) { return nil, nil }
|
||||
func (oas *OAuthState) GetIssuer() (string, error) { return oas.Issuer, nil }
|
||||
func (oas *OAuthState) GetSubject() (string, error) { return oas.State, nil }
|
||||
func (oas *OAuthState) GetAudience() (jwt.ClaimStrings, error) { return nil, nil }
|
||||
|
||||
var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
|
||||
func (a *Application) createState(r *http.Request, fwd string) (string, error) {
|
||||
s, _ := a.sessions.Get(r, a.SessionName())
|
||||
if s.ID == "" {
|
||||
// Ensure session has an ID
|
||||
s.ID = base32RawStdEncoding.EncodeToString(securecookie.GenerateRandomKey(32))
|
||||
}
|
||||
st := &OAuthState{
|
||||
Issuer: fmt.Sprintf("goauthentik.io/outpost/%s", a.proxyConfig.GetClientId()),
|
||||
State: base64.RawURLEncoding.EncodeToString(securecookie.GenerateRandomKey(32)),
|
||||
SessionID: s.ID,
|
||||
Redirect: fwd,
|
||||
}
|
||||
if fwd == "" {
|
||||
// This should only really be hit for nginx forward_auth
|
||||
// as for that the auth start redirect URL is generated by the
|
||||
// reverse proxy, and as such we won't have a request we just
|
||||
// denied to reference for final URL
|
||||
rd, ok := a.checkRedirectParam(r)
|
||||
if ok {
|
||||
a.log.WithField("rd", rd).Trace("Setting redirect")
|
||||
st.Redirect = rd
|
||||
}
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, st)
|
||||
tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
func (a *Application) stateFromRequest(r *http.Request) *OAuthState {
|
||||
stateJwt := r.URL.Query().Get("state")
|
||||
token, err := jwt.Parse(stateJwt, func(token *jwt.Token) (interface{}, error) {
|
||||
// Don't forget to validate the alg is what you expect:
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(a.proxyConfig.GetCookieSecret()), nil
|
||||
})
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to parse state jwt")
|
||||
return nil
|
||||
}
|
||||
iss, err := token.Claims.GetIssuer()
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("state jwt without issuer")
|
||||
return nil
|
||||
}
|
||||
if iss != fmt.Sprintf("goauthentik.io/outpost/%s", a.proxyConfig.GetClientId()) {
|
||||
a.log.WithField("issuer", iss).Warning("invalid state jwt issuer")
|
||||
return nil
|
||||
}
|
||||
claims := &OAuthState{}
|
||||
err = mapstructure.Decode(token.Claims, &claims)
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to mapdecode")
|
||||
return nil
|
||||
}
|
||||
s, _ := a.sessions.Get(r, a.SessionName())
|
||||
if claims.SessionID != s.ID {
|
||||
a.log.WithField("is", claims.SessionID).WithField("should", s.ID).Warning("mismatched session ID")
|
||||
return nil
|
||||
}
|
||||
return claims
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user