Compare commits

..

5 Commits

Author SHA1 Message Date
57a38c93fc cleanup and fix tests
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-04-25 22:57:41 +02:00
93b9dae178 fix logic
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-04-25 20:32:58 +02:00
589c123dc1 update in models too
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-04-25 20:19:31 +02:00
a4d9f08095 also default to entryDN for new sources
(this will also help us with a future migration to better save association with ldap sources)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-04-25 20:01:27 +02:00
d856e403f8 sources/ldap: allow using entryDN as uniqueness field 2024-04-25 19:59:11 +02:00
310 changed files with 5533 additions and 15567 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 2024.4.2
current_version = 2024.4.0
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -12,7 +12,7 @@ should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower()
branch_name = os.environ["GITHUB_REF"]
if os.environ.get("GITHUB_HEAD_REF", "") != "":
branch_name = os.environ["GITHUB_HEAD_REF"]
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-").replace("'", "-")
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
image_names = os.getenv("IMAGE_NAME").split(",")
image_arch = os.getenv("IMAGE_ARCH") or None
@ -54,9 +54,9 @@ image_main_tag = image_tags[0]
image_tags_rendered = ",".join(image_tags)
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print(f"shouldBuild={should_build}", file=_output)
print(f"sha={sha}", file=_output)
print(f"version={version}", file=_output)
print(f"prerelease={prerelease}", file=_output)
print(f"imageTags={image_tags_rendered}", file=_output)
print(f"imageMainTag={image_main_tag}", file=_output)
print("shouldBuild=%s" % should_build, file=_output)
print("sha=%s" % sha, file=_output)
print("version=%s" % version, file=_output)
print("prerelease=%s" % prerelease, file=_output)
print("imageTags=%s" % image_tags_rendered, file=_output)
print("imageMainTag=%s" % image_main_tag, file=_output)

View File

@ -29,7 +29,7 @@ jobs:
- name: Generate API
run: make gen-client-go
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
uses: golangci/golangci-lint-action@v5
with:
version: v1.54.2
args: --timeout 5000s --verbose

View File

@ -155,8 +155,8 @@ jobs:
- uses: actions/checkout@v4
- name: Run test suite in final docker images
run: |
echo "PG_PASS=$(openssl rand 32 | base64)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64)" >> .env
echo "PG_PASS=$(openssl rand -base64 32)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env
docker compose pull -q
docker compose up --no-start
docker compose start postgresql redis

View File

@ -14,8 +14,8 @@ jobs:
- uses: actions/checkout@v4
- name: Pre-release test
run: |
echo "PG_PASS=$(openssl rand 32 | base64)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64)" >> .env
echo "PG_PASS=$(openssl rand -base64 32)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env
docker buildx install
mkdir -p ./gen-ts-api
docker build -t testing:latest .

View File

@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1
# Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/node:22 as website-builder
FROM --platform=${BUILDPLATFORM} docker.io/node:21 as website-builder
ENV NODE_ENV=production
@ -20,7 +20,7 @@ COPY ./SECURITY.md /work/
RUN npm run build-bundled
# Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/node:22 as web-builder
FROM --platform=${BUILDPLATFORM} docker.io/node:21 as web-builder
ENV NODE_ENV=production
@ -38,7 +38,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build
# Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.3-bookworm AS go-builder
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.2-bookworm AS go-builder
ARG TARGETOS
ARG TARGETARCH

View File

@ -19,7 +19,6 @@ pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null)
CODESPELL_ARGS = -D - -D .github/codespell-dictionary.txt \
-I .github/codespell-words.txt \
-S 'web/src/locales/**' \
-S 'website/developer-docs/api/reference/**' \
authentik \
internal \
cmd \
@ -47,8 +46,8 @@ test-go:
go test -timeout 0 -v -race -cover ./...
test-docker: ## Run all tests in a docker-compose
echo "PG_PASS=$(shell openssl rand 32 | base64)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(shell openssl rand 32 | base64)" >> .env
echo "PG_PASS=$(openssl rand -base64 32)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env
docker compose pull -q
docker compose up --no-start
docker compose start postgresql redis

View File

@ -2,7 +2,7 @@
from os import environ
__version__ = "2024.4.2"
__version__ = "2024.4.0"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -17,7 +17,6 @@ from rest_framework.fields import CharField, IntegerField, SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError
from rest_framework.validators import UniqueValidator
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.used_by import UsedByMixin
@ -101,10 +100,7 @@ class GroupSerializer(ModelSerializer):
extra_kwargs = {
"users": {
"default": list,
},
# TODO: This field isn't unique on the database which is hard to backport
# hence we just validate the uniqueness here
"name": {"validators": [UniqueValidator(Group.objects.all())]},
}
}
@ -158,18 +154,12 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
pk = IntegerField(required=True)
queryset = Group.objects.none()
queryset = Group.objects.all().select_related("parent").prefetch_related("users")
serializer_class = GroupSerializer
search_fields = ["name", "is_superuser"]
filterset_class = GroupFilter
ordering = ["name"]
def get_queryset(self):
base_qs = Group.objects.all().select_related("parent").prefetch_related("roles")
if self.serializer_class(context={"request": self.request})._should_include_users:
base_qs = base_qs.prefetch_related("users")
return base_qs
@extend_schema(
parameters=[
OpenApiParameter("include_users", bool, default=True),

View File

@ -407,11 +407,8 @@ class UserViewSet(UsedByMixin, ModelViewSet):
search_fields = ["username", "name", "is_active", "email", "uuid"]
filterset_class = UsersFilter
def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous()
if self.serializer_class(context={"request": self.request})._should_include_groups:
base_qs = base_qs.prefetch_related("ak_groups")
return base_qs
def get_queryset(self): # pragma: no cover
return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
@extend_schema(
parameters=[

View File

@ -0,0 +1,7 @@
"""authentik core exceptions"""
from authentik.lib.sentry import SentryIgnoredException
class PropertyMappingExpressionException(SentryIgnoredException):
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated."""

View File

@ -6,7 +6,6 @@ from django.db.models import Model
from django.http import HttpRequest
from prometheus_client import Histogram
from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import User
from authentik.events.models import Event, EventAction
from authentik.lib.expression.evaluator import BaseEvaluator
@ -48,7 +47,6 @@ class PropertyMappingEvaluator(BaseEvaluator):
self._context["request"] = req
req.context.update(**kwargs)
self._context.update(**kwargs)
self._globals["SkipObject"] = SkipObjectException
self.dry_run = dry_run
def handle_error(self, exc: Exception, expression_source: str):

View File

@ -1,13 +0,0 @@
"""authentik core exceptions"""
from authentik.lib.sentry import SentryIgnoredException
class PropertyMappingExpressionException(SentryIgnoredException):
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated."""
class SkipObjectException(PropertyMappingExpressionException):
"""Exception which can be raised in a property mapping to skip syncing an object.
Only applies to Property mappings which sync objects, and not on mappings which transitively
apply to a single user"""

View File

@ -7,10 +7,9 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def backport_is_backchannel(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.scim.models import SCIMProvider
from authentik.core.models import BackchannelProvider
for model in [LDAPProvider, SCIMProvider]:
for model in BackchannelProvider.__subclasses__():
try:
for obj in model.objects.only("is_backchannel"):
obj.is_backchannel = True

View File

@ -22,7 +22,7 @@ from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.blueprints.models import ManagedModel
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.types import UILoginButton, UserSettingSerializer
from authentik.lib.avatars import get_avatar
from authentik.lib.generators import generate_id
@ -632,7 +632,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel):
raise NotImplementedError
def __str__(self) -> str:
return f"User-source connection (user={self.user_id}, source={self.source_id})"
return f"User-source connection (user={self.user.username}, source={self.source.slug})"
class Meta:
unique_together = (("user", "source"),)

View File

@ -100,6 +100,8 @@ class SourceFlowManager:
if self.request.user.is_authenticated:
new_connection.user = self.request.user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection.save()
return Action.LINK, new_connection
existing_connections = self.connection_type.objects.filter(
@ -146,6 +148,7 @@ class SourceFlowManager:
]:
new_connection.user = user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection.save()
return Action.LINK, new_connection
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_DENY,

View File

@ -2,9 +2,7 @@
from datetime import datetime, timedelta
from django.conf import ImproperlyConfigured
from django.contrib.sessions.backends.cache import KEY_PREFIX
from django.contrib.sessions.backends.db import SessionStore as DBSessionStore
from django.core.cache import cache
from django.utils.timezone import now
from structlog.stdlib import get_logger
@ -17,7 +15,6 @@ from authentik.core.models import (
User,
)
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
from authentik.lib.config import CONFIG
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
@ -42,8 +39,6 @@ def clean_expired_models(self: SystemTask):
amount = 0
for session in AuthenticatedSession.objects.all():
match CONFIG.get("session_storage", "cache"):
case "cache":
cache_key = f"{KEY_PREFIX}{session.session_key}"
value = None
try:
@ -54,19 +49,6 @@ def clean_expired_models(self: SystemTask):
if not value:
session.delete()
amount += 1
case "db":
if not (
DBSessionStore.get_model_class()
.objects.filter(session_key=session.session_key, expire_date__gt=now())
.exists()
):
session.delete()
amount += 1
case _:
# Should never happen, as we check for other values in authentik/root/settings.py
raise ImproperlyConfigured(
"Invalid session_storage setting, allowed values are db and cache"
)
LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")

View File

@ -5,7 +5,7 @@ from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_admin_user, create_test_user
from authentik.core.tests.utils import create_test_user
from authentik.lib.generators import generate_id
@ -16,13 +16,6 @@ class TestGroupsAPI(APITestCase):
self.login_user = create_test_user()
self.user = User.objects.create(username="test-user")
def test_list_with_users(self):
"""Test listing with users"""
admin = create_test_admin_user()
self.client.force_login(admin)
response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
self.assertEqual(response.status_code, 200)
def test_add_user(self):
"""Test add_user"""
group = Group.objects.create(name=generate_id())

View File

@ -3,7 +3,7 @@
from django.test import RequestFactory, TestCase
from guardian.shortcuts import get_anonymous_user
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import Event, EventAction
@ -66,11 +66,14 @@ class TestPropertyMappings(TestCase):
expression="return request.http_request.path",
)
http_request = self.factory.get("/")
tmpl = f"""
res = ak_call_policy('{expr.name}')
tmpl = (
"""
res = ak_call_policy('%s')
result = [request.http_request.path, res.raw_result]
return result
"""
% expr.name
)
evaluator = PropertyMapping(expression=tmpl, name=generate_id())
res = evaluator.evaluate(self.user, http_request)
self.assertEqual(res, ["/", "/"])

View File

@ -48,21 +48,15 @@ class TestSourceFlowManager(TestCase):
def test_authenticated_link(self):
"""Test authenticated user linking"""
UserOAuthSourceConnection.objects.create(
user=get_anonymous_user(), source=self.source, identifier=self.identifier
)
user = User.objects.create(username="foo", email="foo@bar.baz")
flow_manager = OAuthSourceFlowManager(
self.source, get_request("/", user=user), self.identifier, {}
)
action, connection = flow_manager.get_action()
action, _ = flow_manager.get_action()
self.assertEqual(action, Action.LINK)
self.assertIsNone(connection.pk)
flow_manager.get_flow()
def test_unauthenticated_link(self):
"""Test un-authenticated user linking"""
flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {})
action, connection = flow_manager.get_action()
self.assertEqual(action, Action.LINK)
self.assertIsNone(connection.pk)
flow_manager.get_flow()
def test_unauthenticated_enroll_email(self):

View File

@ -41,12 +41,6 @@ class TestUsersAPI(APITestCase):
)
self.assertEqual(response.status_code, 200)
def test_list_with_groups(self):
"""Test listing with groups"""
self.client.force_login(self.admin)
response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"})
self.assertEqual(response.status_code, 200)
def test_metrics(self):
"""Test user's metrics"""
self.client.force_login(self.admin)

View File

@ -8,6 +8,7 @@ from rest_framework.test import APITestCase
from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.config import CONFIG
from authentik.tenants.utils import get_current_tenant
@ -24,6 +25,7 @@ class TestUsersAvatars(APITestCase):
tenant.avatars = mode
tenant.save()
@CONFIG.patch("avatars", "none")
def test_avatars_none(self):
"""Test avatars none"""
self.set_avatar_mode("none")

View File

@ -4,7 +4,7 @@ from django.utils.text import slugify
from authentik.brands.models import Brand
from authentik.core.models import Group, User
from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair
from authentik.flows.models import Flow, FlowDesignation
from authentik.lib.generators import generate_id
@ -50,10 +50,12 @@ def create_test_brand(**kwargs) -> Brand:
return Brand.objects.create(domain=uid, default=True, **kwargs)
def create_test_cert(alg=PrivateKeyAlg.RSA) -> CertificateKeyPair:
def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair:
"""Generate a certificate for testing"""
builder = CertificateBuilder(f"{generate_id()}.self-signed.goauthentik.io")
builder.alg = alg
builder = CertificateBuilder(
name=f"{generate_id()}.self-signed.goauthentik.io",
use_ec_private_key=use_ec_private_key,
)
builder.build(
subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"],
validity_days=360,

View File

@ -14,13 +14,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import (
CharField,
ChoiceField,
DateTimeField,
IntegerField,
SerializerMethodField,
)
from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.request import Request
from rest_framework.response import Response
@ -32,7 +26,7 @@ from authentik.api.authorization import SecretKeyFilter
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.crypto.apps import MANAGED_KEY
from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event, EventAction
from authentik.rbac.decorators import permission_required
@ -184,7 +178,6 @@ class CertificateGenerationSerializer(PassiveSerializer):
common_name = CharField()
subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name"))
validity_days = IntegerField(initial=365)
alg = ChoiceField(default=PrivateKeyAlg.RSA, choices=PrivateKeyAlg.choices)
class CertificateKeyPairFilter(FilterSet):
@ -247,7 +240,6 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
raw_san = data.validated_data.get("subject_alt_name", "")
sans = raw_san.split(",") if raw_san != "" else []
builder = CertificateBuilder(data.validated_data["common_name"])
builder.alg = data.validated_data["alg"]
builder.build(
subject_alt_names=sans,
validity_days=int(data.validated_data["validity_days"]),

View File

@ -9,28 +9,20 @@ from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from cryptography.x509.oid import NameOID
from django.db import models
from django.utils.translation import gettext_lazy as _
from authentik import __version__
from authentik.crypto.models import CertificateKeyPair
class PrivateKeyAlg(models.TextChoices):
"""Algorithm to create private key with"""
RSA = "rsa", _("rsa")
ECDSA = "ecdsa", _("ecdsa")
class CertificateBuilder:
"""Build self-signed certificates"""
common_name: str
alg: PrivateKeyAlg
def __init__(self, name: str):
self.alg = PrivateKeyAlg.RSA
_use_ec_private_key: bool
def __init__(self, name: str, use_ec_private_key=False):
self._use_ec_private_key = use_ec_private_key
self.__public_key = None
self.__private_key = None
self.__builder = None
@ -50,13 +42,11 @@ class CertificateBuilder:
def generate_private_key(self) -> PrivateKeyTypes:
"""Generate private key"""
if self.alg == PrivateKeyAlg.ECDSA:
if self._use_ec_private_key:
return ec.generate_private_key(curve=ec.SECP256R1())
if self.alg == PrivateKeyAlg.RSA:
return rsa.generate_private_key(
public_exponent=65537, key_size=4096, backend=default_backend()
)
raise ValueError(f"Invalid alg: {self.alg}")
def build(
self,

View File

@ -2,12 +2,11 @@
from copy import deepcopy
from functools import partial
from typing import Any
from django.apps.registry import apps
from django.core.files import File
from django.db import connection
from django.db.models import ManyToManyRel, Model
from django.db.models import Model
from django.db.models.expressions import BaseExpression, Combinable
from django.db.models.signals import post_init
from django.http import HttpRequest
@ -45,7 +44,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
post_init.disconnect(dispatch_uid=request.request_id)
def serialize_simple(self, model: Model) -> dict:
"""Serialize a model in a very simple way. No ForeignKeys or other relationships are
"""Serialize a model in a very simple way. No ForeginKeys or other relationships are
resolved"""
data = {}
deferred_fields = model.get_deferred_fields()
@ -71,9 +70,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
for key, value in before.items():
if after.get(key) != value:
diff[key] = {"previous_value": value, "new_value": after.get(key)}
for key, value in after.items():
if key not in before and key not in diff and before.get(key) != value:
diff[key] = {"previous_value": before.get(key), "new_value": value}
return sanitize_item(diff)
def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
@ -102,37 +98,8 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
thread_kwargs = {}
if hasattr(instance, "_previous_state") or created:
prev_state = getattr(instance, "_previous_state", {})
if created:
prev_state = {}
# Get current state
new_state = self.serialize_simple(instance)
diff = self.diff(prev_state, new_state)
thread_kwargs["diff"] = diff
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
def m2m_changed_handler( # noqa: PLR0913
self,
request: HttpRequest,
sender,
instance: Model,
action: str,
pk_set: set[Any],
thread_kwargs: dict | None = None,
**_,
):
thread_kwargs = {}
m2m_field = None
# For the audit log we don't care about `pre_` or `post_` so we trim that part off
_, _, action_direction = action.partition("_")
# resolve the "through" model to an actual field
for field in instance._meta.get_fields():
if not isinstance(field, ManyToManyRel):
continue
if field.through == sender:
m2m_field = field
if m2m_field:
# If we're clearing we just set the "flag" to True
if action_direction == "clear":
pk_set = True
thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)

View File

@ -1,22 +1,9 @@
from unittest.mock import PropertyMock, patch
from django.apps import apps
from django.conf import settings
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import Event, EventAction
from authentik.events.utils import sanitize_item
from authentik.lib.generators import generate_id
from django.test import TestCase
class TestEnterpriseAudit(APITestCase):
"""Test audit middleware"""
def setUp(self) -> None:
self.user = create_test_admin_user()
class TestEnterpriseAudit(TestCase):
def test_import(self):
"""Ensure middleware is imported when app.ready is called"""
@ -29,182 +16,3 @@ class TestEnterpriseAudit(APITestCase):
self.assertIn(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE
)
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
def test_create(self):
"""Test create audit log"""
self.client.force_login(self.user)
username = generate_id()
response = self.client.post(
reverse("authentik_api:user-list"),
data={"name": generate_id(), "username": username, "groups": [], "path": "foo"},
)
user = User.objects.get(username=username)
self.assertEqual(response.status_code, 201)
events = Event.objects.filter(
action=EventAction.MODEL_CREATED,
context__model__model_name="user",
context__model__app="authentik_core",
context__model__pk=user.pk,
)
event = events.first()
self.assertIsNotNone(event)
self.assertIsNotNone(event.context["diff"])
diff = event.context["diff"]
self.assertEqual(
diff,
{
"name": {
"new_value": user.name,
"previous_value": None,
},
"path": {"new_value": "foo", "previous_value": None},
"type": {"new_value": "internal", "previous_value": None},
"uuid": {
"new_value": user.uuid.hex,
"previous_value": None,
},
"email": {"new_value": "", "previous_value": None},
"username": {
"new_value": user.username,
"previous_value": None,
},
"is_active": {"new_value": True, "previous_value": None},
"attributes": {"new_value": {}, "previous_value": None},
"date_joined": {
"new_value": sanitize_item(user.date_joined),
"previous_value": None,
},
"first_name": {"new_value": "", "previous_value": None},
"id": {"new_value": user.pk, "previous_value": None},
"last_name": {"new_value": "", "previous_value": None},
"password": {"new_value": "********************", "previous_value": None},
"password_change_date": {
"new_value": sanitize_item(user.password_change_date),
"previous_value": None,
},
},
)
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
def test_update(self):
"""Test update audit log"""
self.client.force_login(self.user)
user = create_test_admin_user()
current_name = user.name
new_name = generate_id()
response = self.client.patch(
reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
data={"name": new_name},
)
user.refresh_from_db()
self.assertEqual(response.status_code, 200)
events = Event.objects.filter(
action=EventAction.MODEL_UPDATED,
context__model__model_name="user",
context__model__app="authentik_core",
context__model__pk=user.pk,
)
event = events.first()
self.assertIsNotNone(event)
self.assertIsNotNone(event.context["diff"])
diff = event.context["diff"]
self.assertEqual(
diff,
{
"name": {
"new_value": new_name,
"previous_value": current_name,
},
},
)
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
def test_delete(self):
"""Test delete audit log"""
self.client.force_login(self.user)
user = create_test_admin_user()
response = self.client.delete(
reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
)
self.assertEqual(response.status_code, 204)
events = Event.objects.filter(
action=EventAction.MODEL_DELETED,
context__model__model_name="user",
context__model__app="authentik_core",
context__model__pk=user.pk,
)
event = events.first()
self.assertIsNotNone(event)
self.assertNotIn("diff", event.context)
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
def test_m2m_add(self):
"""Test m2m add audit log"""
self.client.force_login(self.user)
user = create_test_admin_user()
group = Group.objects.create(name=generate_id())
response = self.client.post(
reverse("authentik_api:group-add-user", kwargs={"pk": group.group_uuid}),
data={
"pk": user.pk,
},
)
self.assertEqual(response.status_code, 204)
events = Event.objects.filter(
action=EventAction.MODEL_UPDATED,
context__model__model_name="group",
context__model__app="authentik_core",
context__model__pk=group.pk.hex,
)
event = events.first()
self.assertIsNotNone(event)
self.assertIsNotNone(event.context["diff"])
diff = event.context["diff"]
self.assertEqual(
diff,
{"users": {"add": [user.pk]}},
)
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
def test_m2m_remove(self):
"""Test m2m remove audit log"""
self.client.force_login(self.user)
user = create_test_admin_user()
group = Group.objects.create(name=generate_id())
response = self.client.post(
reverse("authentik_api:group-remove-user", kwargs={"pk": group.group_uuid}),
data={
"pk": user.pk,
},
)
self.assertEqual(response.status_code, 204)
events = Event.objects.filter(
action=EventAction.MODEL_UPDATED,
context__model__model_name="group",
context__model__app="authentik_core",
context__model__pk=group.pk.hex,
)
event = events.first()
self.assertIsNotNone(event)
self.assertIsNotNone(event.context["diff"])
diff = event.context["diff"]
self.assertEqual(
diff,
{"users": {"remove": [user.pk]}},
)

View File

@ -1,39 +0,0 @@
"""google Property mappings API Views"""
from django_filters.filters import AllValuesMultipleFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderMapping
class GoogleProviderMappingSerializer(PropertyMappingSerializer):
"""GoogleProviderMapping Serializer"""
class Meta:
model = GoogleWorkspaceProviderMapping
fields = PropertyMappingSerializer.Meta.fields
class GoogleProviderMappingFilter(FilterSet):
"""Filter for GoogleProviderMapping"""
managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))
class Meta:
model = GoogleWorkspaceProviderMapping
fields = "__all__"
class GoogleProviderMappingViewSet(UsedByMixin, ModelViewSet):
"""GoogleProviderMapping Viewset"""
queryset = GoogleWorkspaceProviderMapping.objects.all()
serializer_class = GoogleProviderMappingSerializer
filterset_class = GoogleProviderMappingFilter
search_fields = ["name"]
ordering = ["name"]

View File

@ -1,54 +0,0 @@
"""Google Provider API Views"""
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
class GoogleProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
"""GoogleProvider Serializer"""
class Meta:
model = GoogleWorkspaceProvider
fields = [
"pk",
"name",
"property_mappings",
"property_mappings_group",
"component",
"assigned_backchannel_application_slug",
"assigned_backchannel_application_name",
"verbose_name",
"verbose_name_plural",
"meta_model_name",
"delegated_subject",
"credentials",
"scopes",
"exclude_users_service_account",
"filter_group",
"user_delete_action",
"group_delete_action",
"default_group_email_domain",
]
extra_kwargs = {}
class GoogleProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
"""GoogleProvider Viewset"""
queryset = GoogleWorkspaceProvider.objects.all()
serializer_class = GoogleProviderSerializer
filterset_fields = [
"name",
"exclude_users_service_account",
"delegated_subject",
"filter_group",
]
search_fields = ["name"]
ordering = ["name"]
sync_single_task = google_workspace_sync

View File

@ -1,9 +0,0 @@
from authentik.enterprise.apps import EnterpriseConfig
class AuthentikEnterpriseProviderGoogleConfig(EnterpriseConfig):
name = "authentik.enterprise.providers.google_workspace"
label = "authentik_providers_google_workspace"
verbose_name = "authentik Enterprise.Providers.Google Workspace"
default = True

View File

@ -1,71 +0,0 @@
from django.db.models import Model
from django.http import HttpResponseNotFound
from google.auth.exceptions import GoogleAuthError, TransportError
from googleapiclient.discovery import build
from googleapiclient.errors import Error, HttpError
from googleapiclient.http import HttpRequest
from httplib2 import HttpLib2Error, HttpLib2ErrorWithResponse
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.lib.sync.outgoing import HTTP_CONFLICT
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.sync.outgoing.exceptions import (
NotFoundSyncException,
ObjectExistsSyncException,
StopSync,
TransientSyncException,
)
class GoogleWorkspaceSyncClient[TModel: Model, TConnection: Model, TSchema: dict](
BaseOutgoingSyncClient[TModel, TConnection, TSchema, GoogleWorkspaceProvider]
):
"""Base client for syncing to google workspace"""
domains: list
def __init__(self, provider: GoogleWorkspaceProvider) -> None:
super().__init__(provider)
self.directory_service = build(
"admin",
"directory_v1",
cache_discovery=False,
**provider.google_credentials(),
)
self.__prefetch_domains()
def __prefetch_domains(self):
self.domains = []
domains = self._request(self.directory_service.domains().list(customer="my_customer"))
for domain in domains.get("domains", []):
domain_name = domain.get("domainName")
self.domains.append(domain_name)
def _request(self, request: HttpRequest):
try:
response = request.execute()
except GoogleAuthError as exc:
if isinstance(exc, TransportError):
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
raise StopSync(exc) from exc
except HttpLib2Error as exc:
if isinstance(exc, HttpLib2ErrorWithResponse):
self._response_handle_status_code(exc.response.status, exc)
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
except HttpError as exc:
self._response_handle_status_code(exc.status_code, exc)
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
except Error as exc:
raise TransientSyncException(f"Failed to send request: {str(exc)}") from exc
return response
def _response_handle_status_code(self, status_code: int, root_exc: Exception):
if status_code == HttpResponseNotFound.status_code:
raise NotFoundSyncException("Object not found") from root_exc
if status_code == HTTP_CONFLICT:
raise ObjectExistsSyncException("Object exists") from root_exc
def check_email_valid(self, *emails: str):
for email in emails:
if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
raise TransientSyncException(f"Invalid email domain: {email}")

View File

@ -1,245 +0,0 @@
from deepmerge import always_merger
from django.db import transaction
from django.utils.text import slugify
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.models import Group
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
from authentik.enterprise.providers.google_workspace.models import (
GoogleWorkspaceDeleteAction,
GoogleWorkspaceProviderGroup,
GoogleWorkspaceProviderMapping,
GoogleWorkspaceProviderUser,
)
from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import (
NotFoundSyncException,
ObjectExistsSyncException,
StopSync,
TransientSyncException,
)
from authentik.lib.utils.errors import exception_to_string
class GoogleWorkspaceGroupClient(
GoogleWorkspaceSyncClient[Group, GoogleWorkspaceProviderGroup, dict]
):
"""Google client for groups"""
connection_type = GoogleWorkspaceProviderGroup
connection_type_query = "group"
can_discover = True
def to_schema(self, obj: Group) -> dict:
"""Convert authentik group"""
raw_google_group = {
"email": f"{slugify(obj.name)}@{self.provider.default_group_email_domain}"
}
for mapping in (
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
):
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
continue
try:
mapping: GoogleWorkspaceProviderMapping
value = mapping.evaluate(
user=None,
request=None,
group=obj,
provider=self.provider,
)
if value is None:
continue
always_merger.merge(raw_google_group, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=mapping,
).save()
raise StopSync(exc, obj, mapping) from exc
if not raw_google_group:
raise StopSync(ValueError("No group mappings configured"), obj)
return raw_google_group
def delete(self, obj: Group):
"""Delete group"""
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=obj
).first()
if not google_group:
self.logger.debug("Group does not exist in Google, skipping")
return None
with transaction.atomic():
if self.provider.group_delete_action == GoogleWorkspaceDeleteAction.DELETE:
self._request(
self.directory_service.groups().delete(groupKey=google_group.google_id)
)
google_group.delete()
def create(self, group: Group):
"""Create group from scratch and create a connection object"""
google_group = self.to_schema(group)
self.check_email_valid(google_group["email"])
with transaction.atomic():
try:
response = self._request(self.directory_service.groups().insert(body=google_group))
except ObjectExistsSyncException:
# group already exists in google workspace, so we can connect them manually
# for groups we need to fetch the group from google as we connect on
# ID and not group email
group_data = self._request(
self.directory_service.groups().get(groupKey=google_group["email"])
)
GoogleWorkspaceProviderGroup.objects.create(
provider=self.provider, group=group, google_id=group_data["id"]
)
else:
GoogleWorkspaceProviderGroup.objects.create(
provider=self.provider, group=group, google_id=response["id"]
)
def update(self, group: Group, connection: GoogleWorkspaceProviderGroup):
"""Update existing group"""
google_group = self.to_schema(group)
self.check_email_valid(google_group["email"])
try:
return self._request(
self.directory_service.groups().update(
groupKey=connection.google_id,
body=google_group,
)
)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
def write(self, obj: Group):
google_group, created = super().write(obj)
if created:
self.create_sync_members(obj, google_group)
return google_group
def create_sync_members(self, obj: Group, google_group: dict):
"""Sync all members after a group was created"""
users = list(obj.users.order_by("id").values_list("id", flat=True))
connections = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user__pk__in=users
)
for user in connections:
try:
self._request(
self.directory_service.members().insert(
groupKey=google_group["id"],
body={
"email": user.google_id,
},
)
)
except TransientSyncException:
continue
def update_group(self, group: Group, action: Direction, users_set: set[int]):
"""Update a groups members"""
if action == Direction.add:
return self._patch_add_users(group, users_set)
if action == Direction.remove:
return self._patch_remove_users(group, users_set)
def _patch(self, google_group_id: str, direction: Direction, members: list[str]):
for user in members:
try:
if direction == Direction.add:
self._request(
self.directory_service.members().insert(
groupKey=google_group_id, body={"email": user}
)
)
if direction == Direction.remove:
self._request(
self.directory_service.members().delete(
groupKey=google_group_id, memberKey=user
)
)
except ObjectExistsSyncException:
pass
except TransientSyncException:
raise
def _patch_add_users(self, group: Group, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
if not google_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
user_ids = list(
GoogleWorkspaceProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
).values_list("google_id", flat=True)
)
if len(user_ids) < 1:
return
self._patch(google_group.google_id, Direction.add, user_ids)
def _patch_remove_users(self, group: Group, users_set: set[int]):
"""Remove users in users_set from group"""
if len(users_set) < 1:
return
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
if not google_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
user_ids = list(
GoogleWorkspaceProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider
).values_list("google_id", flat=True)
)
if len(user_ids) < 1:
return
self._patch(google_group.google_id, Direction.remove, user_ids)
def discover(self):
"""Iterate through all groups and connect them with authentik groups if possible"""
request = self.directory_service.groups().list(
customer="my_customer", maxResults=500, orderBy="email"
)
while request:
response = request.execute()
for group in response.get("groups", []):
self._discover_single_group(group)
request = self.directory_service.groups().list_next(
previous_request=request, previous_response=response
)
def _discover_single_group(self, group: dict):
"""handle discovery of a single group"""
google_name = group["name"]
google_id = group["id"]
matching_authentik_group = (
self.provider.get_object_qs(Group).filter(name=google_name).first()
)
if not matching_authentik_group:
return
GoogleWorkspaceProviderGroup.objects.get_or_create(
provider=self.provider,
group=matching_authentik_group,
google_id=google_id,
)

View File

@ -1,41 +0,0 @@
from json import dumps
from httplib2 import Response
class MockHTTP:
_recorded_requests = []
_responses = {}
def __init__(
self,
raise_on_unrecorded=True,
) -> None:
self._recorded_requests = []
self._responses = {}
self.raise_on_unrecorded = raise_on_unrecorded
def add_response(self, uri: str, body: str | dict = "", meta: dict | None = None, method="GET"):
if isinstance(body, dict):
body = dumps(body)
self._responses[(uri, method.upper())] = (body, meta or {"status": "200"})
def requests(self):
return self._recorded_requests
def request(
self,
uri,
method="GET",
body=None,
headers=None,
redirections=1,
connection_type=None,
):
key = (uri, method.upper())
self._recorded_requests.append((uri, method, body, headers))
if key not in self._responses and self.raise_on_unrecorded:
raise AssertionError(key)
body, meta = self._responses[key]
return Response(meta), body.encode("utf-8")

View File

@ -1,141 +0,0 @@
from deepmerge import always_merger
from django.db import transaction
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.models import User
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
from authentik.enterprise.providers.google_workspace.models import (
GoogleWorkspaceDeleteAction,
GoogleWorkspaceProviderMapping,
GoogleWorkspaceProviderUser,
)
from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.exceptions import (
ObjectExistsSyncException,
StopSync,
TransientSyncException,
)
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.utils import delete_none_values
class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceProviderUser, dict]):
"""Sync authentik users into google workspace"""
connection_type = GoogleWorkspaceProviderUser
connection_type_query = "user"
can_discover = True
def to_schema(self, obj: User) -> dict:
"""Convert authentik user"""
raw_google_user = {}
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
continue
try:
mapping: GoogleWorkspaceProviderMapping
value = mapping.evaluate(
user=obj,
request=None,
provider=self.provider,
)
if value is None:
continue
always_merger.merge(raw_google_user, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=mapping,
).save()
raise StopSync(exc, obj, mapping) from exc
if not raw_google_user:
raise StopSync(ValueError("No user mappings configured"), obj)
if "primaryEmail" not in raw_google_user:
raw_google_user["primaryEmail"] = str(obj.email)
return delete_none_values(raw_google_user)
def delete(self, obj: User):
"""Delete user"""
google_user = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user=obj
).first()
if not google_user:
self.logger.debug("User does not exist in Google, skipping")
return None
with transaction.atomic():
response = None
if self.provider.user_delete_action == GoogleWorkspaceDeleteAction.DELETE:
response = self._request(
self.directory_service.users().delete(userKey=google_user.google_id)
)
elif self.provider.user_delete_action == GoogleWorkspaceDeleteAction.SUSPEND:
response = self._request(
self.directory_service.users().update(
userKey=google_user.google_id, body={"suspended": True}
)
)
google_user.delete()
return response
def create(self, user: User):
"""Create user from scratch and create a connection object"""
google_user = self.to_schema(user)
self.check_email_valid(
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
)
with transaction.atomic():
try:
response = self._request(self.directory_service.users().insert(body=google_user))
except ObjectExistsSyncException:
# user already exists in google workspace, so we can connect them manually
GoogleWorkspaceProviderUser.objects.create(
provider=self.provider, user=user, google_id=user.email
)
except TransientSyncException as exc:
raise exc
else:
GoogleWorkspaceProviderUser.objects.create(
provider=self.provider, user=user, google_id=response["primaryEmail"]
)
def update(self, user: User, connection: GoogleWorkspaceProviderUser):
"""Update existing user"""
google_user = self.to_schema(user)
self.check_email_valid(
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
)
self._request(
self.directory_service.users().update(userKey=connection.google_id, body=google_user)
)
def discover(self):
"""Iterate through all users and connect them with authentik users if possible"""
request = self.directory_service.users().list(
customer="my_customer", maxResults=500, orderBy="email"
)
while request:
response = request.execute()
for user in response.get("users", []):
self._discover_single_user(user)
request = self.directory_service.users().list_next(
previous_request=request, previous_response=response
)
def _discover_single_user(self, user: dict):
"""handle discovery of a single user"""
email = user["primaryEmail"]
matching_authentik_user = self.provider.get_object_qs(User).filter(email=email).first()
if not matching_authentik_user:
return
GoogleWorkspaceProviderUser.objects.get_or_create(
provider=self.provider,
user=matching_authentik_user,
google_id=email,
)

View File

@ -1,167 +0,0 @@
# Generated by Django 5.0.4 on 2024-05-07 16:03
import django.db.models.deletion
import uuid
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_core", "0035_alter_group_options_and_more"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="GoogleWorkspaceProviderMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
],
options={
"verbose_name": "Google Workspace Provider Mapping",
"verbose_name_plural": "Google Workspace Provider Mappings",
},
bases=("authentik_core.propertymapping",),
),
migrations.CreateModel(
name="GoogleWorkspaceProvider",
fields=[
(
"provider_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.provider",
),
),
("delegated_subject", models.EmailField(max_length=254)),
("credentials", models.JSONField()),
(
"scopes",
models.TextField(
default="https://www.googleapis.com/auth/admin.directory.user,https://www.googleapis.com/auth/admin.directory.group,https://www.googleapis.com/auth/admin.directory.group.member,https://www.googleapis.com/auth/admin.directory.domain.readonly"
),
),
("default_group_email_domain", models.TextField()),
("exclude_users_service_account", models.BooleanField(default=False)),
(
"user_delete_action",
models.TextField(
choices=[
("do_nothing", "Do Nothing"),
("delete", "Delete"),
("suspend", "Suspend"),
],
default="delete",
),
),
(
"group_delete_action",
models.TextField(
choices=[
("do_nothing", "Do Nothing"),
("delete", "Delete"),
("suspend", "Suspend"),
],
default="delete",
),
),
(
"filter_group",
models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.group",
),
),
(
"property_mappings_group",
models.ManyToManyField(
blank=True,
default=None,
help_text="Property mappings used for group creation/updating.",
to="authentik_core.propertymapping",
),
),
],
options={
"verbose_name": "Google Workspace Provider",
"verbose_name_plural": "Google Workspace Providers",
},
bases=("authentik_core.provider", models.Model),
),
migrations.CreateModel(
name="GoogleWorkspaceProviderGroup",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
("google_id", models.TextField()),
(
"group",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_core.group"
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_google_workspace.googleworkspaceprovider",
),
),
],
options={
"unique_together": {("google_id", "group", "provider")},
},
),
migrations.CreateModel(
name="GoogleWorkspaceProviderUser",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
("google_id", models.TextField()),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_google_workspace.googleworkspaceprovider",
),
),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
),
],
options={
"unique_together": {("google_id", "user", "provider")},
},
),
]

View File

@ -1,179 +0,0 @@
"""Google workspace sync provider"""
from typing import Any, Self
from uuid import uuid4
from django.db import models
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from google.oauth2.service_account import Credentials
from rest_framework.serializers import Serializer
from authentik.core.models import (
BackchannelProvider,
Group,
PropertyMapping,
User,
UserTypes,
)
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
def default_scopes() -> list[str]:
return [
"https://www.googleapis.com/auth/admin.directory.user",
"https://www.googleapis.com/auth/admin.directory.group",
"https://www.googleapis.com/auth/admin.directory.group.member",
"https://www.googleapis.com/auth/admin.directory.domain.readonly",
]
class GoogleWorkspaceDeleteAction(models.TextChoices):
"""Action taken when a user/group is deleted in authentik. Suspend is not available for groups,
and will be treated as `do_nothing`"""
DO_NOTHING = "do_nothing"
DELETE = "delete"
SUSPEND = "suspend"
class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
"""Sync users from authentik into Google Workspace."""
delegated_subject = models.EmailField()
credentials = models.JSONField()
scopes = models.TextField(default=",".join(default_scopes()))
default_group_email_domain = models.TextField()
exclude_users_service_account = models.BooleanField(default=False)
user_delete_action = models.TextField(
choices=GoogleWorkspaceDeleteAction.choices, default=GoogleWorkspaceDeleteAction.DELETE
)
group_delete_action = models.TextField(
choices=GoogleWorkspaceDeleteAction.choices, default=GoogleWorkspaceDeleteAction.DELETE
)
filter_group = models.ForeignKey(
"authentik_core.group", on_delete=models.SET_DEFAULT, default=None, null=True
)
property_mappings_group = models.ManyToManyField(
PropertyMapping,
default=None,
blank=True,
help_text=_("Property mappings used for group creation/updating."),
)
def client_for_model(
self, model: type[User | Group]
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
if issubclass(model, User):
from authentik.enterprise.providers.google_workspace.clients.users import (
GoogleWorkspaceUserClient,
)
return GoogleWorkspaceUserClient(self)
if issubclass(model, Group):
from authentik.enterprise.providers.google_workspace.clients.groups import (
GoogleWorkspaceGroupClient,
)
return GoogleWorkspaceGroupClient(self)
raise ValueError(f"Invalid model {model}")
def get_object_qs(self, type: type[User | Group]) -> QuerySet[User | Group]:
if type == User:
# Get queryset of all users with consistent ordering
# according to the provider's settings
base = User.objects.all().exclude_anonymous()
if self.exclude_users_service_account:
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
type=UserTypes.INTERNAL_SERVICE_ACCOUNT
)
if self.filter_group:
base = base.filter(ak_groups__in=[self.filter_group])
return base.order_by("pk")
if type == Group:
# Get queryset of all groups with consistent ordering
return Group.objects.all().order_by("pk")
raise ValueError(f"Invalid type {type}")
def google_credentials(self):
return {
"credentials": Credentials.from_service_account_info(
self.credentials, scopes=self.scopes.split(",")
).with_subject(self.delegated_subject),
}
@property
def component(self) -> str:
return "ak-provider-google-workspace-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.google_workspace.api.providers import (
GoogleProviderSerializer,
)
return GoogleProviderSerializer
def __str__(self):
return f"Google Workspace Provider {self.name}"
class Meta:
verbose_name = _("Google Workspace Provider")
verbose_name_plural = _("Google Workspace Providers")
class GoogleWorkspaceProviderMapping(PropertyMapping):
"""Map authentik data to outgoing Google requests"""
@property
def component(self) -> str:
return "ak-property-mapping-google-workspace-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.google_workspace.api.property_mappings import (
GoogleProviderMappingSerializer,
)
return GoogleProviderMappingSerializer
def __str__(self):
return f"Google Workspace Provider Mapping {self.name}"
class Meta:
verbose_name = _("Google Workspace Provider Mapping")
verbose_name_plural = _("Google Workspace Provider Mappings")
class GoogleWorkspaceProviderUser(models.Model):
"""Mapping of a user and provider to a Google user ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
google_id = models.TextField()
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
class Meta:
unique_together = (("google_id", "user", "provider"),)
def __str__(self) -> str:
return f"Google Workspace User {self.user_id} to {self.provider_id}"
class GoogleWorkspaceProviderGroup(models.Model):
"""Mapping of a group and provider to a Google group ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
google_id = models.TextField()
group = models.ForeignKey(Group, on_delete=models.CASCADE)
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
class Meta:
unique_together = (("google_id", "group", "provider"),)
def __str__(self) -> str:
return f"Google Workspace Group {self.group_id} to {self.provider_id}"

View File

@ -1,13 +0,0 @@
"""Google workspace provider task Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"providers_google_workspace_sync": {
"task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all",
"schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,16 +0,0 @@
"""Google provider signals"""
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import (
google_workspace_sync,
google_workspace_sync_direct,
google_workspace_sync_m2m,
)
from authentik.lib.sync.outgoing.signals import register_signals
register_signals(
GoogleWorkspaceProvider,
task_sync_single=google_workspace_sync,
task_sync_direct=google_workspace_sync_direct,
task_sync_m2m=google_workspace_sync_m2m,
)

View File

@ -1,34 +0,0 @@
"""Google Provider tasks"""
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.events.system_tasks import SystemTask
from authentik.lib.sync.outgoing.tasks import SyncTasks
from authentik.root.celery import CELERY_APP
sync_tasks = SyncTasks(GoogleWorkspaceProvider)
@CELERY_APP.task()
def google_workspace_sync_objects(*args, **kwargs):
return sync_tasks.sync_objects(*args, **kwargs)
@CELERY_APP.task(base=SystemTask, bind=True)
def google_workspace_sync(self, provider_pk: int, *args, **kwargs):
"""Run full sync for Google Workspace provider"""
return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects)
@CELERY_APP.task()
def google_workspace_sync_all():
return sync_tasks.sync_all(google_workspace_sync)
@CELERY_APP.task()
def google_workspace_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs)
@CELERY_APP.task()
def google_workspace_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs)

View File

@ -1,14 +0,0 @@
{
"kind": "admin#directory#domains",
"etag": "\"a1kA7zE2sFLsHiFwgXN9G3effoc9grR2OwUu8_95xD4/uvC5HsKHylhnUtnRV6ZxINODtV0\"",
"domains": [
{
"kind": "admin#directory#domain",
"etag": "\"a1kA7zE2sFLsHiFwgXN9G3effoc9grR2OwUu8_95xD4/V4koSPWBFIWuIpAmUamO96QhTLo\"",
"domainName": "goauthentik.io",
"isPrimary": true,
"verified": true,
"creationTime": "1543048869840"
}
]
}

View File

@ -1,313 +0,0 @@
"""Google Workspace Group tests"""
from unittest.mock import MagicMock, patch
from django.test import TestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.core.tests.utils import create_test_user
from authentik.enterprise.providers.google_workspace.clients.test_http import MockHTTP
from authentik.enterprise.providers.google_workspace.models import (
GoogleWorkspaceDeleteAction,
GoogleWorkspaceProvider,
GoogleWorkspaceProviderGroup,
GoogleWorkspaceProviderMapping,
)
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture
from authentik.tenants.models import Tenant
domains_list_v1_mock = load_fixture("fixtures/domains_list_v1.json")
class GoogleWorkspaceGroupTests(TestCase):
"""Google workspace Group tests"""
@apply_blueprint("system/providers-google-workspace.yaml")
def setUp(self) -> None:
# Delete all groups and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple groups
Tenant.objects.update(avatars="none")
User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete()
self.provider: GoogleWorkspaceProvider = GoogleWorkspaceProvider.objects.create(
name=generate_id(),
credentials={},
delegated_subject="",
exclude_users_service_account=True,
default_group_email_domain="goauthentik.io",
)
self.app: Application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
)
self.app.backchannel_providers.add(self.provider)
self.provider.property_mappings.add(
GoogleWorkspaceProviderMapping.objects.get(
managed="goauthentik.io/providers/google_workspace/user"
)
)
self.provider.property_mappings_group.add(
GoogleWorkspaceProviderMapping.objects.get(
managed="goauthentik.io/providers/google_workspace/group"
)
)
self.api_key = generate_id()
def test_group_create(self):
"""Test group creation"""
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
method="POST",
body={"id": generate_id()},
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
group = Group.objects.create(name=uid)
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNotNone(google_group)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 2)
def test_group_create_update(self):
"""Test group updating"""
uid = generate_id()
ext_id = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
method="POST",
body={"id": ext_id},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}?key={self.api_key}&alt=json",
method="PUT",
body={"id": ext_id},
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
group = Group.objects.create(name=uid)
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNotNone(google_group)
group.name = "new name"
group.save()
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 4)
def test_group_create_delete(self):
"""Test group deletion"""
uid = generate_id()
ext_id = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
method="POST",
body={"id": ext_id},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}?key={self.api_key}",
method="DELETE",
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
group = Group.objects.create(name=uid)
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNotNone(google_group)
group.delete()
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 4)
def test_group_create_member_add(self):
"""Test group creation"""
uid = generate_id()
ext_id = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
method="POST",
body={"id": ext_id},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
method="POST",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
method="PUT",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}/members?key={self.api_key}&alt=json",
method="POST",
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = create_test_user(uid)
group = Group.objects.create(name=uid)
group.users.add(user)
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNotNone(google_group)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 8)
def test_group_create_member_remove(self):
"""Test group creation"""
uid = generate_id()
ext_id = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
method="POST",
body={"id": ext_id},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
method="POST",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
method="PUT",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}/members/{uid}%40goauthentik.io?key={self.api_key}",
method="DELETE",
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups/{ext_id}/members?key={self.api_key}&alt=json",
method="POST",
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = create_test_user(uid)
group = Group.objects.create(name=uid)
group.users.add(user)
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNotNone(google_group)
group.users.remove(user)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 10)
def test_group_create_delete_do_nothing(self):
"""Test group deletion (delete action = do nothing)"""
self.provider.group_delete_action = GoogleWorkspaceDeleteAction.DO_NOTHING
self.provider.save()
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?key={self.api_key}&alt=json",
method="POST",
body={"id": uid},
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
group = Group.objects.create(name=uid)
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNotNone(google_group)
group.delete()
self.assertEqual(len(http.requests()), 3)
self.assertFalse(
GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group__name=uid
).exists()
)
def test_sync_task(self):
"""Test group discovery"""
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
method="GET",
body={"users": []},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
method="GET",
body={"groups": [{"id": uid, "name": uid}]},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups/{uid}?key={self.api_key}&alt=json",
method="PUT",
body={"id": uid},
)
self.app.backchannel_providers.remove(self.provider)
different_group = Group.objects.create(
name=uid,
)
self.app.backchannel_providers.add(self.provider)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
google_workspace_sync.delay(self.provider.pk).get()
self.assertTrue(
GoogleWorkspaceProviderGroup.objects.filter(
group=different_group, provider=self.provider
).exists()
)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 5)

View File

@ -1,287 +0,0 @@
"""Google Workspace User tests"""
from json import loads
from unittest.mock import MagicMock, patch
from django.test import TestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.enterprise.providers.google_workspace.clients.test_http import MockHTTP
from authentik.enterprise.providers.google_workspace.models import (
GoogleWorkspaceDeleteAction,
GoogleWorkspaceProvider,
GoogleWorkspaceProviderMapping,
GoogleWorkspaceProviderUser,
)
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture
from authentik.tenants.models import Tenant
domains_list_v1_mock = load_fixture("fixtures/domains_list_v1.json")
class GoogleWorkspaceUserTests(TestCase):
"""Google workspace User tests"""
@apply_blueprint("system/providers-google-workspace.yaml")
def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users
Tenant.objects.update(avatars="none")
User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete()
self.provider: GoogleWorkspaceProvider = GoogleWorkspaceProvider.objects.create(
name=generate_id(),
credentials={},
delegated_subject="",
exclude_users_service_account=True,
default_group_email_domain="goauthentik.io",
)
self.app: Application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
)
self.app.backchannel_providers.add(self.provider)
self.provider.property_mappings.add(
GoogleWorkspaceProviderMapping.objects.get(
managed="goauthentik.io/providers/google_workspace/user"
)
)
self.provider.property_mappings_group.add(
GoogleWorkspaceProviderMapping.objects.get(
managed="goauthentik.io/providers/google_workspace/group"
)
)
self.api_key = generate_id()
def test_user_create(self):
"""Test user creation"""
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
method="POST",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
google_user = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user=user
).first()
self.assertIsNotNone(google_user)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 2)
def test_user_create_update(self):
"""Test user updating"""
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
method="POST",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
method="PUT",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
google_user = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user=user
).first()
self.assertIsNotNone(google_user)
user.name = "new name"
user.save()
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 4)
def test_user_create_delete(self):
"""Test user deletion"""
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
method="POST",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}",
method="DELETE",
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
google_user = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user=user
).first()
self.assertIsNotNone(google_user)
user.delete()
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 4)
def test_user_create_delete_suspend(self):
"""Test user deletion (delete action = Suspend)"""
self.provider.user_delete_action = GoogleWorkspaceDeleteAction.SUSPEND
self.provider.save()
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
method="POST",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
method="PUT",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
google_user = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user=user
).first()
self.assertIsNotNone(google_user)
user.delete()
self.assertEqual(len(http.requests()), 4)
_, _, body, _ = http.requests()[3]
self.assertEqual(
loads(body),
{
"suspended": True,
},
)
self.assertFalse(
GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user__username=uid
).exists()
)
def test_user_create_delete_do_nothing(self):
"""Test user deletion (delete action = do nothing)"""
self.provider.user_delete_action = GoogleWorkspaceDeleteAction.DO_NOTHING
self.provider.save()
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?key={self.api_key}&alt=json",
method="POST",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
google_user = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user=user
).first()
self.assertIsNotNone(google_user)
user.delete()
self.assertEqual(len(http.requests()), 3)
self.assertFalse(
GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user__username=uid
).exists()
)
def test_sync_task(self):
"""Test user discovery"""
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
method="GET",
body={"users": [{"primaryEmail": f"{uid}@goauthentik.io"}]},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/groups?customer=my_customer&maxResults=500&orderBy=email&key={self.api_key}&alt=json",
method="GET",
body={"groups": []},
)
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/users/{uid}%40goauthentik.io?key={self.api_key}&alt=json",
method="PUT",
body={"primaryEmail": f"{uid}@goauthentik.io"},
)
self.app.backchannel_providers.remove(self.provider)
different_user = User.objects.create(
username=uid,
email=f"{uid}@goauthentik.io",
)
self.app.backchannel_providers.add(self.provider)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
google_workspace_sync.delay(self.provider.pk).get()
self.assertTrue(
GoogleWorkspaceProviderUser.objects.filter(
user=different_user, provider=self.provider
).exists()
)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 5)

View File

@ -1,11 +0,0 @@
"""google provider urls"""
from authentik.enterprise.providers.google_workspace.api.property_mappings import (
GoogleProviderMappingViewSet,
)
from authentik.enterprise.providers.google_workspace.api.providers import GoogleProviderViewSet
api_urlpatterns = [
("providers/google_workspace", GoogleProviderViewSet),
("propertymappings/provider/google_workspace", GoogleProviderMappingViewSet),
]

View File

@ -11,7 +11,7 @@ from django.utils.translation import gettext as _
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User, default_token_key
from authentik.events.models import Event, EventAction
from authentik.lib.models import SerializerModel
@ -201,7 +201,10 @@ class ConnectionToken(ExpiringModel):
return settings
def __str__(self):
return f"RAC Connection token {self.session_id} to {self.provider_id}/{self.endpoint_id}"
return (
f"RAC Connection token {self.session.user} to "
f"{self.endpoint.provider.name}/{self.endpoint.name}"
)
class Meta:
verbose_name = _("RAC Connection token")

View File

@ -14,7 +14,6 @@ CELERY_BEAT_SCHEDULE = {
TENANT_APPS = [
"authentik.enterprise.audit",
"authentik.enterprise.providers.google_workspace",
"authentik.enterprise.providers.rac",
"authentik.enterprise.stages.source",
]

View File

@ -60,8 +60,6 @@ class SystemTaskSerializer(ModelSerializer):
"duration",
"status",
"messages",
"expires",
"expiring",
]

View File

@ -116,12 +116,12 @@ class AuditMiddleware:
return user
user = getattr(request, "user", self.anonymous_user)
if not user.is_authenticated:
self._ensure_fallback_user()
return self.anonymous_user
return user
def connect(self, request: HttpRequest):
"""Connect signal for automatic logging"""
self._ensure_fallback_user()
if not hasattr(request, "request_id"):
return
post_save.connect(
@ -214,15 +214,7 @@ class AuditMiddleware:
model=model_to_dict(instance),
).run()
def m2m_changed_handler(
self,
request: HttpRequest,
sender,
instance: Model,
action: str,
thread_kwargs: dict | None = None,
**_,
):
def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_):
"""Signal handler for all object's m2m_changed"""
if action not in ["pre_add", "pre_remove", "post_clear"]:
return
@ -237,5 +229,4 @@ class AuditMiddleware:
request,
user=user,
model=model_to_dict(instance),
**thread_kwargs,
).run()

View File

@ -556,7 +556,7 @@ class Notification(SerializerModel):
if len(self.body) > NOTIFICATION_SUMMARY_LENGTH
else self.body
)
return f"Notification for user {self.user_id}: {body_trunc}"
return f"Notification for user {self.user}: {body_trunc}"
class Meta:
verbose_name = _("Notification")

View File

@ -6,7 +6,7 @@ from typing import Any
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import BoundLogger, get_logger
from structlog.stdlib import get_logger
from tenant_schemas_celery.task import TenantTask
from authentik.events.logs import LogEvent
@ -15,12 +15,12 @@ from authentik.events.models import SystemTask as DBSystemTask
from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
LOGGER = get_logger()
class SystemTask(TenantTask):
"""Task which can save its state to the cache"""
logger: BoundLogger
# For tasks that should only be listed if they failed, set this to False
save_on_success: bool
@ -63,7 +63,6 @@ class SystemTask(TenantTask):
def before_start(self, task_id, args, kwargs):
self._start_precise = perf_counter()
self._start = now()
self.logger = get_logger().bind(task_id=task_id)
return super().before_start(task_id, args, kwargs)
def db(self) -> DBSystemTask | None:
@ -120,7 +119,7 @@ class SystemTask(TenantTask):
"task_call_kwargs": sanitize_item(kwargs),
"status": self._status,
"messages": sanitize_item(self._messages),
"expires": now() + timedelta(hours=self.result_timeout_hours + 3),
"expires": now() + timedelta(hours=self.result_timeout_hours),
"expiring": True,
},
)

View File

@ -4,7 +4,7 @@ from django.db.models.query_utils import Q
from guardian.shortcuts import get_anonymous_user
from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import User
from authentik.events.models import (
Event,

View File

@ -1,35 +0,0 @@
"""authentik event models tests"""
from collections.abc import Callable
from django.db.models import Model
from django.test import TestCase
from authentik.core.models import default_token_key
from authentik.lib.utils.reflection import get_apps
class TestModels(TestCase):
"""Test Models"""
def model_tester_factory(test_model: type[Model]) -> Callable:
"""Test models' __str__ and __repr__"""
def tester(self: TestModels):
allowed = 0
# Token-like objects need to lookup the current tenant to get the default token length
for field in test_model._meta.fields:
if field.default == default_token_key:
allowed += 1
with self.assertNumQueries(allowed):
str(test_model())
with self.assertNumQueries(allowed):
repr(test_model())
return tester
for app in get_apps():
for model in app.get_models():
setattr(TestModels, f"test_{app.label}_{model.__name__}", model_tester_factory(model))

View File

@ -278,7 +278,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
},
)
@action(detail=True, pagination_class=None, filter_backends=[])
def execute(self, request: Request, slug: str):
def execute(self, request: Request, _slug: str):
"""Execute flow for current user"""
# Because we pre-plan the flow here, and not in the planner, we need to manually clear
# the history of the inspector

View File

@ -6,7 +6,6 @@ from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user
from authentik.flows.api.stages import StageSerializer, StageViewSet
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
from authentik.stages.dummy.models import DummyStage
@ -102,21 +101,3 @@ class TestFlowsAPI(APITestCase):
reverse("authentik_api:stage-types"),
)
self.assertEqual(response.status_code, 200)
def test_execute(self):
"""Test execute endpoint"""
user = create_test_admin_user()
self.client.force_login(user)
flow = Flow.objects.create(
name=generate_id(),
slug=generate_id(),
designation=FlowDesignation.AUTHENTICATION,
)
FlowStageBinding.objects.create(
target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
)
response = self.client.get(
reverse("authentik_api:flow-execute", kwargs={"slug": flow.slug})
)
self.assertEqual(response.status_code, 200)

View File

@ -53,7 +53,6 @@ cache:
# result_backend:
# url: ""
# transport_options: ""
debug: false
remote_debug: false

View File

@ -9,7 +9,6 @@ from typing import Any
from cachetools import TLRUCache, cached
from django.core.exceptions import FieldError
from django.utils.text import slugify
from guardian.shortcuts import get_anonymous_user
from rest_framework.serializers import ValidationError
from sentry_sdk.hub import Hub
@ -57,7 +56,6 @@ class BaseEvaluator:
"requests": get_http_session(),
"resolve_dns": BaseEvaluator.expr_resolve_dns,
"reverse_dns": BaseEvaluator.expr_reverse_dns,
"slugify": slugify,
}
self._context = {}

View File

@ -100,7 +100,6 @@ def get_logger_config():
"fsevents": "WARNING",
"uvicorn": "WARNING",
"gunicorn": "INFO",
"requests_mock": "WARNING",
}
for handler_name, level in handler_level_map.items():
base_config["loggers"][handler_name] = {

View File

@ -1,5 +0,0 @@
"""Sync constants"""
PAGE_SIZE = 100
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
HTTP_CONFLICT = 409

View File

@ -1,54 +0,0 @@
from collections.abc import Callable
from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.fields import BooleanField
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.core.api.utils import PassiveSerializer
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
class SyncStatusSerializer(PassiveSerializer):
"""Provider sync status"""
is_running = BooleanField(read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers"""
sync_single_task: Callable = None
@extend_schema(
responses={
200: SyncStatusSerializer(),
404: OpenApiResponse(description="Task not found"),
}
)
@action(
methods=["GET"],
detail=True,
pagination_class=None,
url_path="sync/status",
filter_backends=[],
)
def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status"""
provider: OutgoingSyncProvider = self.get_object()
tasks = list(
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
name=self.sync_single_task.__name__,
uid=slugify(provider.name),
)
)
status = {
"tasks": tasks,
"is_running": provider.sync_lock.locked(),
}
return Response(SyncStatusSerializer(status).data)

View File

@ -1,83 +0,0 @@
"""Basic outgoing sync Client"""
from enum import StrEnum
from typing import TYPE_CHECKING
from django.db import DatabaseError
from structlog.stdlib import get_logger
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException
if TYPE_CHECKING:
from django.db.models import Model
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
class Direction(StrEnum):
add = "add"
remove = "remove"
class BaseOutgoingSyncClient[
TModel: "Model", TConnection: "Model", TSchema: dict, TProvider: "OutgoingSyncProvider"
]:
"""Basic Outgoing sync client Client"""
provider: TProvider
connection_type: type[TConnection]
connection_type_query: str
can_discover = False
def __init__(self, provider: TProvider):
self.logger = get_logger().bind(provider=provider.name)
self.provider = provider
def create(self, obj: TModel) -> TConnection:
"""Create object in remote destination"""
raise NotImplementedError()
def update(self, obj: TModel, connection: object):
"""Update object in remote destination"""
raise NotImplementedError()
def write(self, obj: TModel) -> tuple[TConnection, bool]:
"""Write object to destination. Uses self.create and self.update, but
can be overwritten for further logic"""
remote_obj = self.connection_type.objects.filter(
provider=self.provider, **{self.connection_type_query: obj}
).first()
connection: TConnection | None = None
try:
if not remote_obj:
connection = self.create(obj)
return connection, True
try:
self.update(obj, remote_obj)
return remote_obj, False
except NotFoundSyncException:
remote_obj.delete()
connection = self.create(obj)
return connection, True
except DatabaseError as exc:
self.logger.warning("Failed to write object", obj=obj, exc=exc)
if connection:
connection.delete()
return None, False
def delete(self, obj: TModel):
"""Delete object from destination"""
raise NotImplementedError()
def to_schema(self, obj: TModel) -> TSchema:
"""Convert object to destination schema"""
raise NotImplementedError()
def discover(self):
"""Optional method. Can be used to implement a "discovery" where
upon creation of this provider, this function will be called and can
pre-link any users/groups in the remote system with the respective
object in authentik based on a common identifier"""
raise NotImplementedError()

View File

@ -1,37 +0,0 @@
from authentik.lib.sentry import SentryIgnoredException
class BaseSyncException(SentryIgnoredException):
"""Base class for all sync exceptions"""
class TransientSyncException(BaseSyncException):
"""Transient sync exception which may be caused by network blips, etc"""
class NotFoundSyncException(BaseSyncException):
"""Exception when an object was not found in the remote system"""
class ObjectExistsSyncException(BaseSyncException):
"""Exception when an object already exists in the remote system"""
class StopSync(BaseSyncException):
"""Exception raised when a configuration error should stop the sync process"""
def __init__(
self, exc: Exception, obj: object | None = None, mapping: object | None = None
) -> None:
self.exc = exc
self.obj = obj
self.mapping = mapping
def detail(self) -> str:
"""Get human readable details of this error"""
msg = f"Error {str(self.exc)}"
if self.obj:
msg += f", caused by {self.obj}"
if self.mapping:
msg += f" (mapping {self.mapping})"
return msg

View File

@ -1,32 +0,0 @@
from typing import Any, Self
from django.core.cache import cache
from django.db.models import Model, QuerySet
from redis.lock import Lock
from authentik.core.models import Group, User
from authentik.lib.sync.outgoing import PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
class OutgoingSyncProvider(Model):
class Meta:
abstract = True
def client_for_model[
T: User | Group
](self, model: type[T]) -> BaseOutgoingSyncClient[T, Any, Any, Self]:
raise NotImplementedError
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
raise NotImplementedError
@property
def sync_lock(self) -> Lock:
"""Redis lock to prevent multiple parallel syncs happening"""
return Lock(
cache.client.get_client(),
name=f"goauthentik.io/providers/outgoing-sync/{str(self.pk)}",
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
)

View File

@ -1,71 +0,0 @@
from collections.abc import Callable
from django.core.paginator import Paginator
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete
from authentik.core.models import Group, User
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
def register_signals(
provider_type: type[OutgoingSyncProvider],
task_sync_single: Callable[[int], None],
task_sync_direct: Callable[[int], None],
task_sync_m2m: Callable[[int], None],
):
"""Register sync signals"""
uid = class_to_path(provider_type)
def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_):
"""Trigger sync when Provider is saved"""
users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
task_sync_single.apply_async(
(instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False)
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
return
task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
return
task_sync_direct.delay(
class_to_path(instance.__class__), instance.pk, Direction.remove.value
)
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
def model_m2m_changed(
sender: type[Model], instance, action: str, pk_set: set, reverse: bool, **kwargs
):
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_m2m.delay(str(instance.pk), action, list(pk_set))
else:
for group_pk in pk_set:
task_sync_m2m.delay(group_pk, action, [instance.pk])
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)

View File

@ -1,215 +0,0 @@
from collections.abc import Callable
from celery.result import allow_join_result
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import Group, User
from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import StopSync, TransientSyncException
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class
class SyncTasks:
"""Container for all sync 'tasks' (this class doesn't actually contain celery
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
logger: BoundLogger
def __init__(self, provider_model: type[OutgoingSyncProvider]) -> None:
super().__init__()
self._provider_model = provider_model
def sync_all(self, single_sync: Callable[[int], None]):
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
self.trigger_single_task(provider, single_sync)
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
"""Wrapper single sync task that correctly sets time limits based
on the amount of objects that will be synced"""
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
return sync_task.apply_async(
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
def sync_single(
self,
task: SystemTask,
provider_pk: int,
sync_objects: Callable[[int, int], list[str]],
):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
)
provider = self._provider_model.objects.filter(
pk=provider_pk, backchannel_application__isnull=False
).first()
if not provider:
return
lock = provider.sync_lock
if lock.locked():
self.logger.debug("Sync locked, skipping task", source=provider.name)
return
task.set_uid(slugify(provider.name))
messages = []
messages.append(_("Starting full provider sync"))
self.logger.debug("Starting provider sync")
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
with allow_join_result(), lock:
try:
for page in users_paginator.page_range:
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
for msg in sync_objects.apply_async(
args=(class_to_path(User), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
).get():
messages.append(msg)
for page in groups_paginator.page_range:
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
for msg in sync_objects.apply_async(
args=(class_to_path(Group), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
).get():
messages.append(msg)
except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc)
raise task.retry(exc=exc) from exc
except StopSync as exc:
task.set_error(exc)
return
task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects(self, object_type: str, page: int, provider_pk: int):
_object_type = path_to_class(object_type)
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk,
object_type=object_type,
)
messages = []
provider = self._provider_model.objects.filter(pk=provider_pk).first()
if not provider:
return messages
try:
client = provider.client_for_model(_object_type)
except TransientSyncException:
return messages
paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
if client.can_discover:
self.logger.debug("starting discover")
client.discover()
self.logger.debug("starting sync for page", page=page)
for obj in paginator.page(page).object_list:
obj: Model
try:
client.write(obj)
except SkipObjectException:
continue
except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj)
messages.append(
LogEvent(
_(
(
"Failed to sync {object_type} {object_name} "
"due to transient error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger="",
)
)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc)
messages.append(
LogEvent(
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
),
log_level="warning",
logger="",
)
)
break
return messages
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
model_class: type[Model] = path_to_class(model)
instance = model_class.objects.filter(pk=pk).first()
if not instance:
return
operation = Direction(raw_op)
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
if not queryset:
continue
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists():
continue
try:
if operation == Direction.add:
client.write(instance)
if operation == Direction.remove:
client.delete(instance)
except (StopSync, TransientSyncException) as exc:
self.logger.warning(exc, provider_pk=provider.pk)
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists():
continue
client = provider.client_for_model(Group)
try:
operation = None
if action == "post_add":
operation = Direction.add
if action == "post_remove":
operation = Direction.remove
client.update_group(group, operation, pk_set)
except (StopSync, TransientSyncException) as exc:
self.logger.warning(exc, provider_pk=provider.pk)

View File

@ -24,7 +24,7 @@ def load_fixture(path: str, **kwargs) -> str:
fixture = _fixture.read()
try:
return fixture % kwargs
except (TypeError, ValueError):
except TypeError:
return fixture

View File

@ -96,13 +96,16 @@ class TestEvaluator(TestCase):
execution_logging=True,
expression="ak_message(request.http_request.path)\nreturn True",
)
tmpl = f"""
tmpl = (
"""
ak_message(request.http_request.path)
res = ak_call_policy('{expr.name}')
res = ak_call_policy('%s')
ak_message(request.http_request.path)
for msg in res.messages:
ak_message(msg)
"""
% expr.name
)
evaluator = PolicyEvaluator("test")
evaluator.set_policy_request(self.request)
res = evaluator.evaluate(tmpl)

View File

@ -326,7 +326,7 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel):
verbose_name_plural = _("Authorization Codes")
def __str__(self):
return f"Authorization code for {self.provider_id} for user {self.user_id}"
return f"Authorization code for {self.provider} for user {self.user}"
@property
def serializer(self) -> Serializer:
@ -356,7 +356,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel):
verbose_name_plural = _("OAuth2 Access Tokens")
def __str__(self):
return f"Access Token for {self.provider_id} for user {self.user_id}"
return f"Access Token for {self.provider} for user {self.user}"
@property
def id_token(self) -> IDToken:
@ -399,7 +399,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
verbose_name_plural = _("OAuth2 Refresh Tokens")
def __str__(self):
return f"Refresh Token for {self.provider_id} for user {self.user_id}"
return f"Refresh Token for {self.provider} for user {self.user}"
@property
def id_token(self) -> IDToken:
@ -443,4 +443,4 @@ class DeviceToken(ExpiringModel):
verbose_name_plural = _("Device Tokens")
def __str__(self):
return f"Device Token for {self.provider_id}"
return f"Device Token for {self.provider}"

View File

@ -10,7 +10,6 @@ from jwt import PyJWKSet
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_cert, create_test_flow
from authentik.crypto.builder import PrivateKeyAlg
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import OAuth2Provider
@ -83,7 +82,7 @@ class TestJWKS(OAuthTestCase):
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
signing_key=create_test_cert(use_ec_private_key=True),
)
app = Application.objects.create(name="test", slug="test", provider=provider)
response = self.client.get(

View File

@ -8,7 +8,7 @@ from django.views import View
from guardian.shortcuts import get_anonymous_user
from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import Application
from authentik.providers.oauth2.constants import (
ACR_AUTHENTIK_DEFAULT,

View File

@ -11,7 +11,7 @@ from django.views import View
from django.views.decorators.csrf import csrf_exempt
from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.events.models import Event, EventAction
from authentik.flows.challenge import PermissionDict
from authentik.providers.oauth2.constants import (

View File

@ -1,44 +0,0 @@
# Generated by Django 5.0.4 on 2024-05-01 15:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_saml", "0013_samlprovider_default_relay_state"),
]
operations = [
migrations.AlterField(
model_name="samlprovider",
name="digest_algorithm",
field=models.TextField(
choices=[
("http://www.w3.org/2000/09/xmldsig#sha1", "SHA1"),
("http://www.w3.org/2001/04/xmlenc#sha256", "SHA256"),
("http://www.w3.org/2001/04/xmldsig-more#sha384", "SHA384"),
("http://www.w3.org/2001/04/xmlenc#sha512", "SHA512"),
],
default="http://www.w3.org/2001/04/xmlenc#sha256",
),
),
migrations.AlterField(
model_name="samlprovider",
name="signature_algorithm",
field=models.TextField(
choices=[
("http://www.w3.org/2000/09/xmldsig#rsa-sha1", "RSA-SHA1"),
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", "RSA-SHA256"),
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", "RSA-SHA384"),
("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", "RSA-SHA512"),
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", "ECDSA-SHA1"),
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", "ECDSA-SHA256"),
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", "ECDSA-SHA384"),
("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", "ECDSA-SHA512"),
("http://www.w3.org/2000/09/xmldsig#dsa-sha1", "DSA-SHA1"),
],
default="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
),
),
]

View File

@ -11,10 +11,6 @@ from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.time import timedelta_string_validator
from authentik.sources.saml.processors.constants import (
DSA_SHA1,
ECDSA_SHA1,
ECDSA_SHA256,
ECDSA_SHA384,
ECDSA_SHA512,
RSA_SHA1,
RSA_SHA256,
RSA_SHA384,
@ -96,7 +92,8 @@ class SAMLProvider(Provider):
),
)
digest_algorithm = models.TextField(
digest_algorithm = models.CharField(
max_length=50,
choices=(
(SHA1, _("SHA1")),
(SHA256, _("SHA256")),
@ -105,16 +102,13 @@ class SAMLProvider(Provider):
),
default=SHA256,
)
signature_algorithm = models.TextField(
signature_algorithm = models.CharField(
max_length=50,
choices=(
(RSA_SHA1, _("RSA-SHA1")),
(RSA_SHA256, _("RSA-SHA256")),
(RSA_SHA384, _("RSA-SHA384")),
(RSA_SHA512, _("RSA-SHA512")),
(ECDSA_SHA1, _("ECDSA-SHA1")),
(ECDSA_SHA256, _("ECDSA-SHA256")),
(ECDSA_SHA384, _("ECDSA-SHA384")),
(ECDSA_SHA512, _("ECDSA-SHA512")),
(DSA_SHA1, _("DSA-SHA1")),
),
default=RSA_SHA256,

View File

@ -9,7 +9,7 @@ from lxml import etree # nosec
from lxml.etree import Element, SubElement # nosec
from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.events.models import Event, EventAction
from authentik.events.signals import get_login_event
from authentik.lib.utils.time import timedelta_from_string

View File

@ -7,14 +7,13 @@ from lxml import etree # nosec
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_cert, create_test_flow
from authentik.crypto.builder import PrivateKeyAlg
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture
from authentik.lib.xml import lxml_from_string
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.processors.metadata import MetadataProcessor
from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser
from authentik.sources.saml.processors.constants import ECDSA_SHA256, NS_MAP, NS_SAML_METADATA
from authentik.sources.saml.processors.constants import NS_MAP, NS_SAML_METADATA
class TestServiceProviderMetadataParser(TestCase):
@ -108,41 +107,12 @@ class TestServiceProviderMetadataParser(TestCase):
load_fixture("fixtures/cert.xml").replace("/apps/user_saml", "")
)
def test_signature_rsa(self):
"""Test signature validation (RSA)"""
def test_signature(self):
"""Test signature validation"""
provider = SAMLProvider.objects.create(
name=generate_id(),
authorization_flow=self.flow,
signing_kp=create_test_cert(PrivateKeyAlg.RSA),
)
Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=provider,
)
request = self.factory.get("/")
metadata = MetadataProcessor(provider, request).build_entity_descriptor()
root = fromstring(metadata.encode())
xmlsec.tree.add_ids(root, ["ID"])
signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP)
signature_node = signature_nodes[0]
ctx = xmlsec.SignatureContext()
key = xmlsec.Key.from_memory(
provider.signing_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
None,
)
ctx.key = key
ctx.verify(signature_node)
def test_signature_ecdsa(self):
"""Test signature validation (ECDSA)"""
provider = SAMLProvider.objects.create(
name=generate_id(),
authorization_flow=self.flow,
signing_kp=create_test_cert(PrivateKeyAlg.ECDSA),
signature_algorithm=ECDSA_SHA256,
signing_kp=create_test_cert(),
)
Application.objects.create(
name=generate_id(),

View File

@ -1,12 +1,19 @@
"""SCIM Provider API Views"""
from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.fields import BooleanField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync
class SCIMProviderSerializer(ProviderSerializer):
@ -33,7 +40,14 @@ class SCIMProviderSerializer(ProviderSerializer):
extra_kwargs = {}
class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelViewSet):
class SCIMSyncStatusSerializer(PassiveSerializer):
"""SCIM Provider sync status"""
is_running = BooleanField(read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class SCIMProviderViewSet(UsedByMixin, ModelViewSet):
"""SCIMProvider Viewset"""
queryset = SCIMProvider.objects.all()
@ -41,4 +55,25 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"]
search_fields = ["name", "url"]
ordering = ["name", "url"]
sync_single_task = scim_sync
@extend_schema(
responses={
200: SCIMSyncStatusSerializer(),
404: OpenApiResponse(description="Task not found"),
}
)
@action(methods=["GET"], detail=True, pagination_class=None, filter_backends=[])
def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status"""
provider: SCIMProvider = self.get_object()
tasks = list(
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
name="scim_sync",
uid=slugify(provider.name),
)
)
status = {
"tasks": tasks,
"is_running": provider.sync_lock.locked(),
}
return Response(SCIMSyncStatusSerializer(status).data)

View File

@ -0,0 +1,4 @@
"""SCIM constants"""
PAGE_SIZE = 100
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour

View File

@ -1,37 +1,33 @@
"""SCIM Client"""
from typing import TYPE_CHECKING
from typing import Generic, TypeVar
from django.http import HttpResponseBadRequest, HttpResponseNotFound
from pydantic import ValidationError
from requests import RequestException, Session
from structlog.stdlib import get_logger
from authentik.lib.sync.outgoing import HTTP_CONFLICT
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, ObjectExistsSyncException
from authentik.lib.utils.http import get_http_session
from authentik.providers.scim.clients.exceptions import SCIMRequestException
from authentik.providers.scim.clients.exceptions import ResourceMissing, SCIMRequestException
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMProvider
if TYPE_CHECKING:
from django.db.models import Model
from pydantic import BaseModel
T = TypeVar("T")
SchemaType = TypeVar("SchemaType")
class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
BaseOutgoingSyncClient[TModel, TConnection, TSchema, SCIMProvider]
):
class SCIMClient(Generic[T, SchemaType]):
"""SCIM Client"""
base_url: str
token: str
provider: SCIMProvider
_session: Session
_config: ServiceProviderConfiguration
def __init__(self, provider: SCIMProvider):
super().__init__(provider)
self._session = get_http_session()
self.provider = provider
# Remove trailing slashes as we assume the URL doesn't have any
@ -40,6 +36,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
base_url = base_url[:-1]
self.base_url = base_url
self.token = provider.token
self.logger = get_logger().bind(provider=provider.name)
self._config = self.get_service_provider_config()
def _request(self, method: str, path: str, **kwargs) -> dict:
@ -60,9 +57,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
self.logger.debug("scim request", path=path, method=method, **kwargs)
if response.status_code >= HttpResponseBadRequest.status_code:
if response.status_code == HttpResponseNotFound.status_code:
raise NotFoundSyncException(response)
if response.status_code == HTTP_CONFLICT:
raise ObjectExistsSyncException(response)
raise ResourceMissing(response)
self.logger.warning(
"Failed to send SCIM request", path=path, method=method, response=response.text
)
@ -81,3 +76,15 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
except (ValidationError, SCIMRequestException) as exc:
self.logger.warning("failed to get ServiceProviderConfig", exc=exc)
return default_config
def write(self, obj: T):
"""Write object to SCIM"""
raise NotImplementedError()
def delete(self, obj: T):
"""Delete object from SCIM"""
raise NotImplementedError()
def to_scim(self, obj: T) -> SchemaType:
"""Convert object to scim"""
raise NotImplementedError()

View File

@ -3,11 +3,28 @@
from pydantic import ValidationError
from requests import Response
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
from authentik.lib.sentry import SentryIgnoredException
from authentik.providers.scim.clients.schema import SCIMError
class SCIMRequestException(TransientSyncException):
class StopSync(SentryIgnoredException):
"""Exception raised when a configuration error should stop the sync process"""
def __init__(self, exc: Exception, obj: object, mapping: object | None = None) -> None:
self.exc = exc
self.obj = obj
self.mapping = mapping
def detail(self) -> str:
"""Get human readable details of this error"""
msg = f"Error {str(self.exc)}, caused by {self.obj}"
if self.mapping:
msg += f" (mapping {self.mapping})"
return msg
class SCIMRequestException(SentryIgnoredException):
"""Exception raised when an SCIM request fails"""
_response: Response | None
@ -22,8 +39,13 @@ class SCIMRequestException(TransientSyncException):
if not self._response:
return self._message
try:
error = SCIMError.model_validate_json(self._response.text)
error = SCIMError.parse_raw(self._response.text)
return error.detail
except ValidationError:
pass
return self._message
class ResourceMissing(SCIMRequestException):
"""Error raised when the provider raises a 404, meaning that we
should delete our internal ID and re-create the object"""

View File

@ -5,36 +5,47 @@ from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp, PatchOperation
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group
from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import (
NotFoundSyncException,
ObjectExistsSyncException,
StopSync,
)
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.utils import delete_none_values
from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
ResourceMissing,
SCIMRequestException,
StopSync,
)
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.clients.schema import PatchRequest
from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMUser
class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
"""SCIM client for groups"""
connection_type = SCIMGroup
connection_type_query = "group"
def write(self, obj: Group):
"""Write a group"""
scim_group = SCIMGroup.objects.filter(provider=self.provider, group=obj).first()
if not scim_group:
return self._create(obj)
try:
return self._update(obj, scim_group)
except ResourceMissing:
scim_group.delete()
return self._create(obj)
def to_schema(self, obj: Group) -> SCIMGroupSchema:
def delete(self, obj: Group):
"""Delete group"""
scim_group = SCIMGroup.objects.filter(provider=self.provider, group=obj).first()
if not scim_group:
self.logger.debug("Group does not exist in SCIM, skipping")
return None
response = self._request("DELETE", f"/Groups/{scim_group.id}")
scim_group.delete()
return response
def to_scim(self, obj: Group) -> SCIMGroupSchema:
"""Convert authentik user into SCIM"""
raw_scim_group = {
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:Group",),
@ -55,8 +66,6 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
if value is None:
continue
always_merger.merge(raw_scim_group, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
@ -80,26 +89,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
for user in connections:
members.append(
GroupMember(
value=user.scim_id,
value=user.id,
)
)
if members:
scim_group.members = members
return scim_group
def delete(self, obj: Group):
"""Delete group"""
scim_group = SCIMGroup.objects.filter(provider=self.provider, group=obj).first()
if not scim_group:
self.logger.debug("Group does not exist in SCIM, skipping")
return None
response = self._request("DELETE", f"/Groups/{scim_group.scim_id}")
scim_group.delete()
return response
def create(self, group: Group):
def _create(self, group: Group):
"""Create group from scratch and create a connection object"""
scim_group = self.to_schema(group)
scim_group = self.to_scim(group)
response = self._request(
"POST",
"/Groups",
@ -108,28 +107,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
exclude_unset=True,
),
)
scim_id = response.get("id")
if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`")
SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id)
SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"])
def update(self, group: Group, connection: SCIMGroup):
def _update(self, group: Group, connection: SCIMGroup):
"""Update existing group"""
scim_group = self.to_schema(group)
scim_group.id = connection.scim_id
scim_group = self.to_scim(group)
scim_group.id = connection.id
try:
return self._request(
"PUT",
f"/Groups/{connection.scim_id}",
f"/Groups/{scim_group.id}",
json=scim_group.model_dump(
mode="json",
exclude_unset=True,
),
)
except NotFoundSyncException:
except ResourceMissing:
# Resource missing is handled by self.write, which will re-create the group
raise
except (SCIMRequestException, ObjectExistsSyncException):
except SCIMRequestException:
# Some providers don't support PUT on groups, so this is mainly a fix for the initial
# sync, send patch add requests for all the users the group currently has
users = list(group.users.order_by("id").values_list("id", flat=True))
@ -144,12 +140,12 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
),
)
def update_group(self, group: Group, action: Direction, users_set: set[int]):
def update_group(self, group: Group, action: PatchOp, users_set: set[int]):
"""Update a group, either using PUT to replace it or PATCH if supported"""
if self._config.patch.supported:
if action == Direction.add:
if action == PatchOp.add:
return self._patch_add_users(group, users_set)
if action == Direction.remove:
if action == PatchOp.remove:
return self._patch_remove_users(group, users_set)
try:
return self.write(group)
@ -157,9 +153,9 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
if self._config.is_fallback:
# Assume that provider does not support PUT and also doesn't support
# ServiceProviderConfig, so try PATCH as a fallback
if action == Direction.add:
if action == PatchOp.add:
return self._patch_add_users(group, users_set)
if action == Direction.remove:
if action == PatchOp.remove:
return self._patch_remove_users(group, users_set)
raise exc
@ -189,13 +185,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
return
user_ids = list(
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
"scim_id", flat=True
"id", flat=True
)
)
if len(user_ids) < 1:
return
self._patch(
scim_group.scim_id,
scim_group.id,
PatchOperation(
op=PatchOp.add,
path="members",
@ -215,13 +211,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
return
user_ids = list(
SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
"scim_id", flat=True
"id", flat=True
)
)
if len(user_ids) < 1:
return
self._patch(
scim_group.scim_id,
scim_group.id,
PatchOperation(
op=PatchOp.remove,
path="members",

View File

@ -9,14 +9,13 @@ from pydanticscim.service_provider import (
)
from pydanticscim.user import User as BaseUser
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
class User(BaseUser):
"""Modified User schema with added externalId field"""
schemas: list[str] = [SCIM_USER_SCHEMA]
schemas: list[str] = [
"urn:ietf:params:scim:schemas:core:2.0:User",
]
externalId: str | None = None
meta: dict | None = None
@ -24,7 +23,9 @@ class User(BaseUser):
class Group(BaseGroup):
"""Modified Group schema with added externalId field"""
schemas: list[str] = [SCIM_GROUP_SCHEMA]
schemas: list[str] = [
"urn:ietf:params:scim:schemas:core:2.0:Group",
]
externalId: str | None = None
meta: dict | None = None

View File

@ -3,27 +3,42 @@
from deepmerge import always_merger
from pydantic import ValidationError
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import User
from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.exceptions import StopSync
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.utils import delete_none_values
from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import ResourceMissing, StopSync
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
from authentik.providers.scim.models import SCIMMapping, SCIMUser
class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
"""SCIM client for users"""
connection_type = SCIMUser
connection_type_query = "user"
def write(self, obj: User):
"""Write a user"""
scim_user = SCIMUser.objects.filter(provider=self.provider, user=obj).first()
if not scim_user:
return self._create(obj)
try:
return self._update(obj, scim_user)
except ResourceMissing:
scim_user.delete()
return self._create(obj)
def to_schema(self, obj: User) -> SCIMUserSchema:
def delete(self, obj: User):
"""Delete user"""
scim_user = SCIMUser.objects.filter(provider=self.provider, user=obj).first()
if not scim_user:
self.logger.debug("User does not exist in SCIM, skipping")
return None
response = self._request("DELETE", f"/Users/{scim_user.id}")
scim_user.delete()
return response
def to_scim(self, obj: User) -> SCIMUserSchema:
"""Convert authentik user into SCIM"""
raw_scim_user = {
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:User",),
@ -41,8 +56,6 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
if value is None:
continue
always_merger.merge(raw_scim_user, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
@ -61,19 +74,9 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
scim_user.externalId = str(obj.uid)
return scim_user
def delete(self, obj: User):
"""Delete user"""
scim_user = SCIMUser.objects.filter(provider=self.provider, user=obj).first()
if not scim_user:
self.logger.debug("User does not exist in SCIM, skipping")
return None
response = self._request("DELETE", f"/Users/{scim_user.scim_id}")
scim_user.delete()
return response
def create(self, user: User):
def _create(self, user: User):
"""Create user from scratch and create a connection object"""
scim_user = self.to_schema(user)
scim_user = self.to_scim(user)
response = self._request(
"POST",
"/Users",
@ -82,18 +85,15 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
exclude_unset=True,
),
)
scim_id = response.get("id")
if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`")
SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"])
def update(self, user: User, connection: SCIMUser):
def _update(self, user: User, connection: SCIMUser):
"""Update existing user"""
scim_user = self.to_schema(user)
scim_user.id = connection.scim_id
scim_user = self.to_scim(user)
scim_user.id = connection.id
self._request(
"PUT",
f"/Users/{connection.scim_id}",
f"/Users/{connection.id}",
json=scim_user.model_dump(
mode="json",
exclude_unset=True,

View File

@ -3,7 +3,7 @@
from structlog.stdlib import get_logger
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync, sync_tasks
from authentik.providers.scim.tasks import scim_sync
from authentik.tenants.management import TenantCommand
LOGGER = get_logger()
@ -21,4 +21,4 @@ class Command(TenantCommand):
if not provider:
LOGGER.warning("Provider does not exist", name=provider_name)
continue
sync_tasks.trigger_single_task(provider, scim_sync).get()
scim_sync.delay(provider.pk).get()

View File

@ -1,76 +0,0 @@
# Generated by Django 5.0.4 on 2024-05-03 12:38
import uuid
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from authentik.lib.migrations import progress_bar
def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser")
SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup")
db_alias = schema_editor.connection.alias
print("\nFixing primary key for SCIM users, this might take a couple of minutes...")
for user in progress_bar(SCIMUser.objects.using(db_alias).all()):
SCIMUser.objects.using(db_alias).filter(
pk=user.pk, user=user.user_id, provider=user.provider_id
).update(scim_id=user.pk, id=uuid.uuid4())
print("\nFixing primary key for SCIM groups, this might take a couple of minutes...")
for group in progress_bar(SCIMGroup.objects.using(db_alias).all()):
SCIMGroup.objects.using(db_alias).filter(
pk=group.pk, group=group.group_id, provider=group.provider_id
).update(scim_id=group.pk, id=uuid.uuid4())
class Migration(migrations.Migration):
dependencies = [
(
"authentik_providers_scim",
"0001_squashed_0006_rename_parent_group_scimprovider_filter_group",
),
]
operations = [
migrations.AddField(
model_name="scimgroup",
name="scim_id",
field=models.TextField(default="temp"),
preserve_default=False,
),
migrations.AddField(
model_name="scimuser",
name="scim_id",
field=models.TextField(default="temp"),
preserve_default=False,
),
migrations.RunPython(fix_scim_user_group_pk),
migrations.AlterField(
model_name="scimgroup",
name="id",
field=models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
migrations.AlterField(
model_name="scimuser",
name="id",
field=models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()),
migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()),
migrations.AlterUniqueTogether(
name="scimgroup",
unique_together={("scim_id", "group", "provider")},
),
migrations.AlterUniqueTogether(
name="scimuser",
unique_together={("scim_id", "user", "provider")},
),
]

View File

@ -1,19 +1,17 @@
"""SCIM Provider models"""
from typing import Any, Self
from uuid import uuid4
from django.core.cache import cache
from django.db import models
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from redis.lock import Lock
from rest_framework.serializers import Serializer
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.providers.scim.clients import PAGE_TIMEOUT
class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
class SCIMProvider(BackchannelProvider):
"""SCIM 2.0 provider to create users and groups in external applications"""
exclude_users_service_account = models.BooleanField(default=False)
@ -32,23 +30,18 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
help_text=_("Property mappings used for group creation/updating."),
)
def client_for_model(
self, model: type[User | Group]
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
if issubclass(model, User):
from authentik.providers.scim.clients.users import SCIMUserClient
@property
def sync_lock(self) -> Lock:
"""Redis lock for syncing SCIM to prevent multiple parallel syncs happening"""
return Lock(
cache.client.get_client(),
name=f"goauthentik.io/providers/scim/sync-{str(self.pk)}",
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
)
return SCIMUserClient(self)
if issubclass(model, Group):
from authentik.providers.scim.clients.groups import SCIMGroupClient
return SCIMGroupClient(self)
raise ValueError(f"Invalid model {model}")
def get_object_qs(self, type: type[User | Group]) -> QuerySet[User | Group]:
if type == User:
# Get queryset of all users with consistent ordering
# according to the provider's settings
def get_user_qs(self) -> QuerySet[User]:
"""Get queryset of all users with consistent ordering
according to the provider's settings"""
base = User.objects.all().exclude_anonymous()
if self.exclude_users_service_account:
base = base.exclude(type=UserTypes.SERVICE_ACCOUNT).exclude(
@ -57,10 +50,10 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
if self.filter_group:
base = base.filter(ak_groups__in=[self.filter_group])
return base.order_by("pk")
if type == Group:
# Get queryset of all groups with consistent ordering
def get_group_qs(self) -> QuerySet[Group]:
"""Get queryset of all groups with consistent ordering"""
return Group.objects.all().order_by("pk")
raise ValueError(f"Invalid type {type}")
@property
def component(self) -> str:
@ -89,7 +82,7 @@ class SCIMMapping(PropertyMapping):
@property
def serializer(self) -> type[Serializer]:
from authentik.providers.scim.api.property_mappings import SCIMMappingSerializer
from authentik.providers.scim.api.property_mapping import SCIMMappingSerializer
return SCIMMappingSerializer
@ -104,28 +97,26 @@ class SCIMMapping(PropertyMapping):
class SCIMUser(models.Model):
"""Mapping of a user and provider to a SCIM user ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
scim_id = models.TextField()
id = models.TextField(primary_key=True)
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
class Meta:
unique_together = (("scim_id", "user", "provider"),)
unique_together = (("id", "user", "provider"),)
def __str__(self) -> str:
return f"SCIM User {self.user_id} to {self.provider_id}"
return f"SCIM User {self.user.username} to {self.provider.name}"
class SCIMGroup(models.Model):
"""Mapping of a group and provider to a SCIM user ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
scim_id = models.TextField()
id = models.TextField(primary_key=True)
group = models.ForeignKey(Group, on_delete=models.CASCADE)
provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
class Meta:
unique_together = (("scim_id", "group", "provider"),)
unique_together = (("id", "group", "provider"),)
def __str__(self) -> str:
return f"SCIM Group {self.group_id} to {self.provider_id}"
return f"SCIM Group {self.group.name} to {self.provider.name}"

View File

@ -7,7 +7,7 @@ from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"providers_scim_sync": {
"task": "authentik.providers.scim.tasks.scim_sync_all",
"schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*/4"),
"schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,12 +1,56 @@
"""SCIM provider signals"""
from authentik.lib.sync.outgoing.signals import register_signals
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync, scim_sync_direct, scim_sync_m2m
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete
from django.dispatch import receiver
from pydanticscim.responses import PatchOp
from structlog.stdlib import get_logger
register_signals(
SCIMProvider,
task_sync_single=scim_sync,
task_sync_direct=scim_sync_direct,
task_sync_m2m=scim_sync_m2m,
)
from authentik.core.models import Group, User
from authentik.lib.utils.reflection import class_to_path
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_sync
LOGGER = get_logger()
@receiver(post_save, sender=SCIMProvider)
def post_save_provider(sender: type[Model], instance, created: bool, **_):
"""Trigger sync when SCIM provider is saved"""
scim_sync.delay(instance.pk)
@receiver(post_save, sender=User)
@receiver(post_save, sender=Group)
def post_save_scim(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler"""
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.add.value)
@receiver(pre_delete, sender=User)
@receiver(pre_delete, sender=Group)
def pre_delete_scim(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler"""
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
scim_signal_direct.delay(class_to_path(instance.__class__), instance.pk, PatchOp.remove.value)
@receiver(m2m_changed, sender=User.ak_groups.through)
def m2m_changed_scim(
sender: type[Model], instance, action: str, pk_set: set, reverse: bool, **kwargs
):
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return
if not SCIMProvider.objects.filter(backchannel_application__isnull=False).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
scim_signal_m2m.delay(str(instance.pk), action, list(pk_set))
else:
for group_pk in pk_set:
scim_signal_m2m.delay(group_pk, action, [instance.pk])

View File

@ -1,34 +1,226 @@
"""SCIM Provider tasks"""
from typing import Any
from celery.result import allow_join_result
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from pydanticscim.responses import PatchOp
from structlog.stdlib import get_logger
from authentik.core.models import Group, User
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.sync.outgoing.tasks import SyncTasks
from authentik.lib.utils.reflection import path_to_class
from authentik.providers.scim.clients import PAGE_SIZE, PAGE_TIMEOUT
from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import SCIMRequestException, StopSync
from authentik.providers.scim.clients.group import SCIMGroupClient
from authentik.providers.scim.clients.user import SCIMUserClient
from authentik.providers.scim.models import SCIMProvider
from authentik.root.celery import CELERY_APP
sync_tasks = SyncTasks(SCIMProvider)
LOGGER = get_logger(__name__)
@CELERY_APP.task()
def scim_sync_objects(*args, **kwargs):
return sync_tasks.sync_objects(*args, **kwargs)
@CELERY_APP.task(base=SystemTask, bind=True)
def scim_sync(self, provider_pk: int, *args, **kwargs):
"""Run full sync for SCIM provider"""
return sync_tasks.sync_single(self, provider_pk, scim_sync_objects)
def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient:
"""Get SCIM client for model"""
if isinstance(model, User):
return SCIMUserClient(provider)
if isinstance(model, Group):
return SCIMGroupClient(provider)
raise ValueError(f"Invalid model {model}")
@CELERY_APP.task()
def scim_sync_all():
return sync_tasks.sync_all(scim_sync)
"""Run sync for all providers"""
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
scim_sync.delay(provider.pk)
@CELERY_APP.task(bind=True, base=SystemTask)
def scim_sync(self: SystemTask, provider_pk: int) -> None:
"""Run SCIM full sync for provider"""
provider: SCIMProvider = SCIMProvider.objects.filter(
pk=provider_pk, backchannel_application__isnull=False
).first()
if not provider:
return
lock = provider.sync_lock
if lock.locked():
LOGGER.debug("SCIM sync locked, skipping task", source=provider.name)
return
self.set_uid(slugify(provider.name))
messages = []
messages.append(_("Starting full SCIM sync"))
LOGGER.debug("Starting SCIM sync")
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
self.soft_time_limit = self.time_limit = (
users_paginator.count + groups_paginator.count
) * PAGE_TIMEOUT
with allow_join_result():
try:
for page in users_paginator.page_range:
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
for msg in scim_sync_users.delay(page, provider_pk).get():
messages.append(msg)
for page in groups_paginator.page_range:
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
for msg in scim_sync_group.delay(page, provider_pk).get():
messages.append(msg)
except StopSync as exc:
self.set_error(exc)
return
self.set_status(TaskStatus.SUCCESSFUL, *messages)
@CELERY_APP.task(
soft_time_limit=PAGE_TIMEOUT,
task_time_limit=PAGE_TIMEOUT,
)
def scim_sync_users(page: int, provider_pk: int):
"""Sync single or multiple users to SCIM"""
messages = []
provider: SCIMProvider = SCIMProvider.objects.filter(pk=provider_pk).first()
if not provider:
return messages
try:
client = SCIMUserClient(provider)
except SCIMRequestException:
return messages
paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
LOGGER.debug("starting user sync for page", page=page)
for user in paginator.page(page).object_list:
try:
client.write(user)
except SCIMRequestException as exc:
LOGGER.warning("failed to sync user", exc=exc, user=user)
messages.append(
_(
"Failed to sync user {user_name} due to remote error: {error}".format_map(
{
"user_name": user.username,
"error": exc.detail(),
}
)
)
)
except StopSync as exc:
LOGGER.warning("Stopping sync", exc=exc)
messages.append(
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
)
)
break
return messages
@CELERY_APP.task()
def scim_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs)
def scim_sync_group(page: int, provider_pk: int):
"""Sync single or multiple groups to SCIM"""
messages = []
provider: SCIMProvider = SCIMProvider.objects.filter(pk=provider_pk).first()
if not provider:
return messages
try:
client = SCIMGroupClient(provider)
except SCIMRequestException:
return messages
paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
LOGGER.debug("starting group sync for page", page=page)
for group in paginator.page(page).object_list:
try:
client.write(group)
except SCIMRequestException as exc:
LOGGER.warning("failed to sync group", exc=exc, group=group)
messages.append(
_(
"Failed to sync group {group_name} due to remote error: {error}".format_map(
{
"group_name": group.name,
"error": exc.detail(),
}
)
)
)
except StopSync as exc:
LOGGER.warning("Stopping sync", exc=exc)
messages.append(
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
)
)
break
return messages
@CELERY_APP.task()
def scim_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs)
def scim_signal_direct(model: str, pk: Any, raw_op: str):
"""Handler for post_save and pre_delete signal"""
model_class: type[Model] = path_to_class(model)
instance = model_class.objects.filter(pk=pk).first()
if not instance:
return
operation = PatchOp(raw_op)
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
client = client_for_model(provider, instance)
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet | None = None
if isinstance(instance, User):
queryset = provider.get_user_qs()
if isinstance(instance, Group):
queryset = provider.get_group_qs()
if not queryset:
continue
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists():
continue
try:
if operation == PatchOp.add:
client.write(instance)
if operation == PatchOp.remove:
client.delete(instance)
except (StopSync, SCIMRequestException) as exc:
LOGGER.warning(exc)
@CELERY_APP.task()
def scim_signal_m2m(group_pk: str, action: str, pk_set: list[int]):
"""Update m2m (group membership)"""
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_group_qs()
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists():
continue
client = SCIMGroupClient(provider)
try:
operation = None
if action == "post_add":
operation = PatchOp.add
if action == "post_remove":
operation = PatchOp.remove
client.update_group(group, operation, pk_set)
except (StopSync, SCIMRequestException) as exc:
LOGGER.warning(exc)

View File

@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync, sync_tasks
from authentik.providers.scim.tasks import scim_sync
from authentik.tenants.models import Tenant
@ -79,7 +79,7 @@ class SCIMMembershipTests(TestCase):
)
self.configure()
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
scim_sync.delay(self.provider.pk).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")
@ -169,7 +169,7 @@ class SCIMMembershipTests(TestCase):
)
self.configure()
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
scim_sync.delay(self.provider.pk).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")

View File

@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync, sync_tasks
from authentik.providers.scim.tasks import scim_sync
from authentik.tenants.models import Tenant
@ -88,72 +88,6 @@ class SCIMUserTests(TestCase):
},
)
@Mocker()
def test_user_create_different_provider_same_id(self, mock: Mocker):
"""Test user creation with multiple providers that happen
to return the same object ID"""
# Create duplicate provider
provider: SCIMProvider = SCIMProvider.objects.create(
name=generate_id(),
url="https://localhost",
token=generate_id(),
exclude_users_service_account=True,
)
app: Application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
)
app.backchannel_providers.add(provider)
provider.property_mappings.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
)
provider.property_mappings_group.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
)
scim_id = generate_id()
mock.get(
"https://localhost/ServiceProviderConfig",
json={},
)
mock.post(
"https://localhost/Users",
json={
"id": scim_id,
},
)
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 4)
self.assertEqual(mock.request_history[0].method, "GET")
self.assertEqual(mock.request_history[1].method, "POST")
self.assertJSONEqual(
mock.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
"primary": True,
"type": "other",
"value": f"{uid}@goauthentik.io",
}
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"givenName": uid,
},
"displayName": f"{uid} {uid}",
"userName": uid,
},
)
@Mocker()
def test_user_create_update(self, mock: Mocker):
"""Test user creation and update"""
@ -302,7 +236,7 @@ class SCIMUserTests(TestCase):
email=f"{uid}@goauthentik.io",
)
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
scim_sync.delay(self.provider.pk).get()
self.assertEqual(mock.call_count, 5)
self.assertEqual(mock.request_history[0].method, "GET")

View File

@ -1,6 +1,6 @@
"""API URLs"""
from authentik.providers.scim.api.property_mappings import SCIMMappingViewSet
from authentik.providers.scim.api.property_mapping import SCIMMappingViewSet
from authentik.providers.scim.api.providers import SCIMProviderViewSet
api_urlpatterns = [

View File

@ -155,9 +155,6 @@ SPECTACULAR_SETTINGS = {
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
"UserTypeEnum": "authentik.core.models.UserTypes",
"GoogleWorkspaceDeleteAction": (
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceDeleteAction"
),
},
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
@ -379,13 +376,7 @@ CELERY = {
"task_default_queue": "authentik",
"broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")),
"result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")),
"broker_transport_options": CONFIG.get_dict_from_b64_json(
"broker.transport_options", {"retry_policy": {"timeout": 5.0}}
),
"result_backend_transport_options": CONFIG.get_dict_from_b64_json(
"result_backend.transport_options", {"retry_policy": {"timeout": 5.0}}
),
"redis_retry_on_timeout": True,
"broker_transport_options": CONFIG.get_dict_from_b64_json("broker.transport_options"),
}
# Sentry integration

View File

@ -76,7 +76,7 @@ class S3Storage(BaseS3Storage):
return safe_join(self.location, connection.schema_name, name)
except ValueError:
raise SuspiciousOperation(f"Attempted access to '{name}' denied.") from None
raise SuspiciousOperation("Attempted access to '%s' denied." % name) from None
# This is a fix for https://github.com/jschneier/django-storages/pull/839
def url(self, name, parameters=None, expire=None, http_method=None):

View File

@ -10,7 +10,7 @@ from drf_spectacular.utils import extend_schema, extend_schema_field, inline_ser
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import DictField, ListField, SerializerMethodField
from rest_framework.fields import BooleanField, DictField, ListField, SerializerMethodField
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.request import Request
from rest_framework.response import Response
@ -19,8 +19,9 @@ from rest_framework.viewsets import ModelViewSet
from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.sync.outgoing.api import SyncStatusSerializer
from authentik.events.api.tasks import SystemTaskSerializer
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
from authentik.sources.ldap.tasks import CACHE_KEY_STATUS, SYNC_CLASSES
@ -88,6 +89,13 @@ class LDAPSourceSerializer(SourceSerializer):
extra_kwargs = {"bind_password": {"write_only": True}}
class LDAPSyncStatusSerializer(PassiveSerializer):
"""LDAP Source sync status"""
is_running = BooleanField(read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
"""LDAP Source Viewset"""
@ -124,16 +132,10 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
@extend_schema(
responses={
200: SyncStatusSerializer(),
200: LDAPSyncStatusSerializer(),
}
)
@action(
methods=["GET"],
detail=True,
pagination_class=None,
url_path="sync/status",
filter_backends=[],
)
@action(methods=["GET"], detail=True, pagination_class=None, filter_backends=[])
def sync_status(self, request: Request, slug: str) -> Response:
"""Get source's sync status"""
source: LDAPSource = self.get_object()
@ -147,7 +149,7 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
"tasks": tasks,
"is_running": source.sync_lock.locked(),
}
return Response(SyncStatusSerializer(status).data)
return Response(LDAPSyncStatusSerializer(status).data)
@extend_schema(
responses={

View File

@ -153,7 +153,7 @@ class Migration(migrations.Migration):
(
"object_uniqueness_field",
models.TextField(
default="objectSid", help_text="Field which contains a unique Identifier."
default="entryDN", help_text="Field which contains a unique Identifier."
),
),
("sync_groups", models.BooleanField(default=True)),

View File

@ -88,7 +88,7 @@ class LDAPSource(Source):
help_text=_("Consider Objects matching this filter to be Groups."),
)
object_uniqueness_field = models.TextField(
default="objectSid", help_text=_("Field which contains a unique Identifier.")
default="entryDN", help_text=_("Field which contains a unique Identifier.")
)
property_mappings_group = models.ManyToManyField(

Some files were not shown because too many files have changed in this diff Show More