Compare commits

..

6 Commits

Author SHA1 Message Date
a5379c35aa add to user
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-05-18 18:00:00 +02:00
e4c11a5284 manager for deleted objects
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-05-18 17:59:06 +02:00
a4853a1e09 migrate outpost to soft-delete
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-05-18 17:59:06 +02:00
b65b72d910 core: exclude anonymous user by default
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-05-18 17:59:06 +02:00
cd7be6a1a4 initial soft delete
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-05-18 17:58:03 +02:00
e5cb8ef541 unrelated reorganization
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-05-18 17:58:01 +02:00
329 changed files with 6654 additions and 15401 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2024.6.0-rc1 current_version = 2024.4.2
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -1,3 +1,5 @@
version: "3.7"
services: services:
postgresql: postgresql:
image: docker.io/library/postgres:${PSQL_TAG:-16} image: docker.io/library/postgres:${PSQL_TAG:-16}

View File

@ -4,4 +4,3 @@ hass
warmup warmup
ontext ontext
singed singed
assertIn

View File

@ -50,6 +50,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
psql: psql:
- 12-alpine
- 15-alpine - 15-alpine
- 16-alpine - 16-alpine
steps: steps:
@ -103,6 +104,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
psql: psql:
- 12-alpine
- 15-alpine - 15-alpine
- 16-alpine - 16-alpine
steps: steps:
@ -250,8 +252,8 @@ jobs:
push: ${{ steps.ev.outputs.shouldBuild == 'true' }} push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
build-args: | build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache cache-from: type=gha
cache-to: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache,mode=max cache-to: type=gha,mode=max
platforms: linux/${{ matrix.arch }} platforms: linux/${{ matrix.arch }}
pr-comment: pr-comment:
needs: needs:

View File

@ -105,8 +105,8 @@ jobs:
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
context: . context: .
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache cache-from: type=gha
cache-to: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache,mode=max cache-to: type=gha,mode=max
build-binary: build-binary:
timeout-minutes: 120 timeout-minutes: 120
needs: needs:

View File

@ -38,7 +38,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build RUN npm run build
# Stage 3: Build go proxy # Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.3-bookworm AS go-builder
ARG TARGETOS ARG TARGETOS
ARG TARGETARCH ARG TARGETARCH
@ -49,11 +49,6 @@ ARG GOARCH=$TARGETARCH
WORKDIR /go/src/goauthentik.io WORKDIR /go/src/goauthentik.io
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
dpkg --add-architecture arm64 && \
apt-get update && \
apt-get install -y --no-install-recommends crossbuild-essential-arm64 gcc-aarch64-linux-gnu
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \ RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \ --mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
--mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/go/pkg/mod \
@ -68,11 +63,11 @@ COPY ./internal /go/src/goauthentik.io/internal
COPY ./go.mod /go/src/goauthentik.io/go.mod COPY ./go.mod /go/src/goauthentik.io/go.mod
COPY ./go.sum /go/src/goauthentik.io/go.sum COPY ./go.sum /go/src/goauthentik.io/go.sum
ENV CGO_ENABLED=0
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \ --mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
if [ "$TARGETARCH" = "arm64" ]; then export CC=aarch64-linux-gnu-gcc && export CC_FOR_TARGET=gcc-aarch64-linux-gnu; fi && \ GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server
CGO_ENABLED=1 GOEXPERIMENT="systemcrypto" GOFLAGS="-tags=requirefips" GOARM="${TARGETVARIANT#v}" \
go build -o /go/authentik ./cmd/server
# Stage 4: MaxMind GeoIP # Stage 4: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
@ -89,7 +84,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 5: Python dependencies # Stage 5: Python dependencies
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps FROM docker.io/python:3.12.3-slim-bookworm AS python-deps
WORKDIR /ak-root/poetry WORKDIR /ak-root/poetry
@ -102,7 +97,7 @@ RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloa
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \ RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
apt-get update && \ apt-get update && \
# Required for installing pip packages # Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libpq-dev apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev
RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
--mount=type=bind,target=./poetry.lock,src=./poetry.lock \ --mount=type=bind,target=./poetry.lock,src=./poetry.lock \
@ -110,13 +105,12 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
--mount=type=cache,target=/root/.cache/pypoetry \ --mount=type=cache,target=/root/.cache/pypoetry \
python -m venv /ak-root/venv/ && \ python -m venv /ak-root/venv/ && \
bash -c "source ${VENV_PATH}/bin/activate && \ bash -c "source ${VENV_PATH}/bin/activate && \
pip3 install --upgrade pip && \ pip3 install --upgrade pip && \
pip3 install poetry && \ pip3 install poetry && \
poetry install --only=main --no-ansi --no-interaction --no-root && \ poetry install --only=main --no-ansi --no-interaction --no-root"
pip install --force-reinstall /wheels/*"
# Stage 6: Run # Stage 6: Run
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image FROM docker.io/python:3.12.3-slim-bookworm AS final-image
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ARG VERSION ARG VERSION
@ -133,7 +127,7 @@ WORKDIR /
# We cannot cache this layer otherwise we'll end up with a bigger image # We cannot cache this layer otherwise we'll end up with a bigger image
RUN apt-get update && \ RUN apt-get update && \
# Required for runtime # Required for runtime
apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates && \ apt-get install -y --no-install-recommends libpq5 openssl libxmlsec1-openssl libmaxminddb0 ca-certificates && \
# Required for bootstrap & healtcheck # Required for bootstrap & healtcheck
apt-get install -y --no-install-recommends runit && \ apt-get install -y --no-install-recommends runit && \
apt-get clean && \ apt-get clean && \
@ -169,8 +163,6 @@ ENV TMPDIR=/dev/shm/ \
VENV_PATH="/ak-root/venv" \ VENV_PATH="/ak-root/venv" \
POETRY_VIRTUALENVS_CREATE=false POETRY_VIRTUALENVS_CREATE=false
ENV GOFIPS=1
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 CMD [ "ak", "healthcheck" ] HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 CMD [ "ak", "healthcheck" ]
ENTRYPOINT [ "dumb-init", "--", "ak" ] ENTRYPOINT [ "dumb-init", "--", "ak" ]

View File

@ -253,7 +253,6 @@ website-watch: ## Build and watch the documentation website, updating automatic
######################### #########################
docker: ## Build a docker image of the current source tree docker: ## Build a docker image of the current source tree
mkdir -p ${GEN_API_TS}
DOCKER_BUILDKIT=1 docker build . --progress plain --tag ${DOCKER_IMAGE} DOCKER_BUILDKIT=1 docker build . --progress plain --tag ${DOCKER_IMAGE}
######################### #########################

View File

@ -2,7 +2,7 @@
from os import environ from os import environ
__version__ = "2024.6.0" __version__ = "2024.4.2"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -2,21 +2,18 @@
import platform import platform
from datetime import datetime from datetime import datetime
from ssl import OPENSSL_VERSION
from sys import version as python_version from sys import version as python_version
from typing import TypedDict from typing import TypedDict
from cryptography.hazmat.backends.openssl.backend import backend
from django.utils.timezone import now from django.utils.timezone import now
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema
from gunicorn import version_info as gunicorn_version
from rest_framework.fields import SerializerMethodField from rest_framework.fields import SerializerMethodField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from authentik import get_full_version
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.enterprise.license import LicenseKey
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.reflection import get_env from authentik.lib.utils.reflection import get_env
from authentik.outposts.apps import MANAGED_OUTPOST from authentik.outposts.apps import MANAGED_OUTPOST
@ -28,13 +25,11 @@ class RuntimeDict(TypedDict):
"""Runtime information""" """Runtime information"""
python_version: str python_version: str
gunicorn_version: str
environment: str environment: str
architecture: str architecture: str
platform: str platform: str
uname: str uname: str
openssl_version: str
openssl_fips_enabled: bool | None
authentik_version: str
class SystemInfoSerializer(PassiveSerializer): class SystemInfoSerializer(PassiveSerializer):
@ -69,15 +64,11 @@ class SystemInfoSerializer(PassiveSerializer):
def get_runtime(self, request: Request) -> RuntimeDict: def get_runtime(self, request: Request) -> RuntimeDict:
"""Get versions""" """Get versions"""
return { return {
"architecture": platform.machine(),
"authentik_version": get_full_version(),
"environment": get_env(),
"openssl_fips_enabled": (
backend._fips_enabled if LicenseKey.get_total().is_valid() else None
),
"openssl_version": OPENSSL_VERSION,
"platform": platform.platform(),
"python_version": python_version, "python_version": python_version,
"gunicorn_version": ".".join(str(x) for x in gunicorn_version),
"environment": get_env(),
"architecture": platform.machine(),
"platform": platform.platform(),
"uname": " ".join(platform.uname()), "uname": " ".join(platform.uname()),
} }

View File

@ -75,7 +75,7 @@ class BlueprintEntry:
_state: BlueprintEntryState = field(default_factory=BlueprintEntryState) _state: BlueprintEntryState = field(default_factory=BlueprintEntryState)
def __post_init__(self, *args, **kwargs) -> None: def __post_init__(self, *args, **kwargs) -> None:
self.__tag_contexts: list[YAMLTagContext] = [] self.__tag_contexts: list["YAMLTagContext"] = []
@staticmethod @staticmethod
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry": def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry":

View File

@ -4,7 +4,6 @@ from collections.abc import Iterable
from uuid import UUID from uuid import UUID
from django.apps import apps from django.apps import apps
from django.contrib.auth import get_user_model
from django.db.models import Model, Q, QuerySet from django.db.models import Model, Q, QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@ -47,8 +46,6 @@ class Exporter:
def get_model_instances(self, model: type[Model]) -> QuerySet: def get_model_instances(self, model: type[Model]) -> QuerySet:
"""Return a queryset for `model`. Can be used to filter some """Return a queryset for `model`. Can be used to filter some
objects on some models""" objects on some models"""
if model == get_user_model():
return model.objects.exclude_anonymous()
return model.objects.all() return model.objects.all()
def _pre_export(self, blueprint: Blueprint): def _pre_export(self, blueprint: Blueprint):

View File

@ -58,7 +58,7 @@ from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel from authentik.policies.models import Policy, PolicyBindingModel
from authentik.policies.reputation.models import Reputation from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser from authentik.providers.scim.models import SCIMGroup, SCIMUser
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -97,8 +97,8 @@ def excluded_models() -> list[type[Model]]:
# FIXME: these shouldn't need to be explicitly listed, but rather based off of a mixin # FIXME: these shouldn't need to be explicitly listed, but rather based off of a mixin
FlowToken, FlowToken,
LicenseUsage, LicenseUsage,
SCIMProviderGroup, SCIMGroup,
SCIMProviderUser, SCIMUser,
Tenant, Tenant,
SystemTask, SystemTask,
ConnectionToken, ConnectionToken,

View File

@ -2,7 +2,6 @@
from json import loads from json import loads
from django.db.models import Prefetch
from django.http import Http404 from django.http import Http404
from django_filters.filters import CharFilter, ModelMultipleChoiceFilter from django_filters.filters import CharFilter, ModelMultipleChoiceFilter
from django_filters.filterset import FilterSet from django_filters.filterset import FilterSet
@ -167,14 +166,8 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
def get_queryset(self): def get_queryset(self):
base_qs = Group.objects.all().select_related("parent").prefetch_related("roles") base_qs = Group.objects.all().select_related("parent").prefetch_related("roles")
if self.serializer_class(context={"request": self.request})._should_include_users: if self.serializer_class(context={"request": self.request})._should_include_users:
base_qs = base_qs.prefetch_related("users") base_qs = base_qs.prefetch_related("users")
else:
base_qs = base_qs.prefetch_related(
Prefetch("users", queryset=User.objects.all().only("id"))
)
return base_qs return base_qs
@extend_schema( @extend_schema(
@ -185,14 +178,6 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs) return super().list(request, *args, **kwargs)
@extend_schema(
parameters=[
OpenApiParameter("include_users", bool, default=True),
]
)
def retrieve(self, request, *args, **kwargs):
return super().retrieve(request, *args, **kwargs)
@permission_required("authentik_core.add_user_to_group") @permission_required("authentik_core.add_user_to_group")
@extend_schema( @extend_schema(
request=UserAccountSerializer, request=UserAccountSerializer,

View File

@ -1,79 +0,0 @@
"""API Utilities"""
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
from rest_framework.fields import (
BooleanField,
CharField,
)
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.core.api.utils import PassiveSerializer
from authentik.enterprise.apps import EnterpriseConfig
from authentik.lib.utils.reflection import all_subclasses
class TypeCreateSerializer(PassiveSerializer):
"""Types of an object that can be created"""
name = CharField(required=True)
description = CharField(required=True)
component = CharField(required=True)
model_name = CharField(required=True)
icon_url = CharField(required=False)
requires_enterprise = BooleanField(default=False)
class CreatableType:
"""Class to inherit from to mark a model as creatable, even if the model itself is marked
as abstract"""
class NonCreatableType:
"""Class to inherit from to mark a model as non-creatable even if it is not abstract"""
class TypesMixin:
"""Mixin which adds an API endpoint to list all possible types that can be created"""
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[])
def types(self, request: Request, additional: list[dict] | None = None) -> Response:
"""Get all creatable types"""
data = []
for subclass in all_subclasses(self.queryset.model):
instance = None
if subclass._meta.abstract:
if not issubclass(subclass, CreatableType):
continue
# Circumvent the django protection for not being able to instantiate
# abstract models. We need a model instance to access .component
# and further down .icon_url
instance = subclass.__new__(subclass)
# Django re-sets abstract = False so we need to override that
instance.Meta.abstract = True
else:
if issubclass(subclass, NonCreatableType):
continue
instance = subclass()
try:
data.append(
{
"name": subclass._meta.verbose_name,
"description": subclass.__doc__,
"component": instance.component,
"model_name": subclass._meta.model_name,
"icon_url": getattr(instance, "icon_url", None),
"requires_enterprise": isinstance(
subclass._meta.app_config, EnterpriseConfig
),
}
)
except NotImplementedError:
continue
if additional:
data.extend(additional)
data = sorted(data, key=lambda x: x["name"])
return Response(TypeCreateSerializer(data, many=True).data)

View File

@ -9,22 +9,18 @@ from rest_framework import mixins
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import PermissionDenied from rest_framework.exceptions import PermissionDenied
from rest_framework.fields import BooleanField, CharField from rest_framework.fields import BooleanField, CharField
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer, SerializerMethodField from rest_framework.serializers import ModelSerializer, SerializerMethodField
from rest_framework.viewsets import GenericViewSet from rest_framework.viewsets import GenericViewSet
from authentik.blueprints.api import ManagedSerializer from authentik.blueprints.api import ManagedSerializer
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ( from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
MetaNameSerializer,
PassiveSerializer,
)
from authentik.core.expression.evaluator import PropertyMappingEvaluator from authentik.core.expression.evaluator import PropertyMappingEvaluator
from authentik.core.models import Group, PropertyMapping, User from authentik.core.models import PropertyMapping
from authentik.events.utils import sanitize_item from authentik.events.utils import sanitize_item
from authentik.lib.utils.reflection import all_subclasses
from authentik.policies.api.exec import PolicyTestSerializer from authentik.policies.api.exec import PolicyTestSerializer
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
@ -68,7 +64,6 @@ class PropertyMappingSerializer(ManagedSerializer, ModelSerializer, MetaNameSeri
class PropertyMappingViewSet( class PropertyMappingViewSet(
TypesMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.DestroyModelMixin,
UsedByMixin, UsedByMixin,
@ -77,15 +72,7 @@ class PropertyMappingViewSet(
): ):
"""PropertyMapping Viewset""" """PropertyMapping Viewset"""
class PropertyMappingTestSerializer(PolicyTestSerializer): queryset = PropertyMapping.objects.none()
"""Test property mapping execution for a user/group with context"""
user = PrimaryKeyRelatedField(queryset=User.objects.all(), required=False, allow_null=True)
group = PrimaryKeyRelatedField(
queryset=Group.objects.all(), required=False, allow_null=True
)
queryset = PropertyMapping.objects.select_subclasses()
serializer_class = PropertyMappingSerializer serializer_class = PropertyMappingSerializer
search_fields = [ search_fields = [
"name", "name",
@ -93,9 +80,29 @@ class PropertyMappingViewSet(
filterset_fields = {"managed": ["isnull"]} filterset_fields = {"managed": ["isnull"]}
ordering = ["name"] ordering = ["name"]
def get_queryset(self): # pragma: no cover
return PropertyMapping.objects.select_subclasses()
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[])
def types(self, request: Request) -> Response:
"""Get all creatable property-mapping types"""
data = []
for subclass in all_subclasses(self.queryset.model):
subclass: PropertyMapping
data.append(
{
"name": subclass._meta.verbose_name,
"description": subclass.__doc__,
"component": subclass().component,
"model_name": subclass._meta.model_name,
}
)
return Response(TypeCreateSerializer(data, many=True).data)
@permission_required("authentik_core.view_propertymapping") @permission_required("authentik_core.view_propertymapping")
@extend_schema( @extend_schema(
request=PropertyMappingTestSerializer(), request=PolicyTestSerializer(),
responses={ responses={
200: PropertyMappingTestResultSerializer, 200: PropertyMappingTestResultSerializer,
400: OpenApiResponse(description="Invalid parameters"), 400: OpenApiResponse(description="Invalid parameters"),
@ -113,39 +120,29 @@ class PropertyMappingViewSet(
"""Test Property Mapping""" """Test Property Mapping"""
_mapping: PropertyMapping = self.get_object() _mapping: PropertyMapping = self.get_object()
# Use `get_subclass` to get correct class and correct `.evaluate` implementation # Use `get_subclass` to get correct class and correct `.evaluate` implementation
mapping: PropertyMapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk) mapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk)
# FIXME: when we separate policy mappings between ones for sources # FIXME: when we separate policy mappings between ones for sources
# and ones for providers, we need to make the user field optional for the source mapping # and ones for providers, we need to make the user field optional for the source mapping
test_params = self.PropertyMappingTestSerializer(data=request.data) test_params = PolicyTestSerializer(data=request.data)
if not test_params.is_valid(): if not test_params.is_valid():
return Response(test_params.errors, status=400) return Response(test_params.errors, status=400)
format_result = str(request.GET.get("format_result", "false")).lower() == "true" format_result = str(request.GET.get("format_result", "false")).lower() == "true"
context: dict = test_params.validated_data.get("context", {}) # User permission check, only allow mapping testing for users that are readable
context.setdefault("user", None) users = get_objects_for_user(request.user, "authentik_core.view_user").filter(
pk=test_params.validated_data["user"].pk
if user := test_params.validated_data.get("user"): )
# User permission check, only allow mapping testing for users that are readable if not users.exists():
users = get_objects_for_user(request.user, "authentik_core.view_user").filter( raise PermissionDenied()
pk=user.pk
)
if not users.exists():
raise PermissionDenied()
context["user"] = user
if group := test_params.validated_data.get("group"):
# Group permission check, only allow mapping testing for groups that are readable
groups = get_objects_for_user(request.user, "authentik_core.view_group").filter(
pk=group.pk
)
if not groups.exists():
raise PermissionDenied()
context["group"] = group
context["request"] = self.request
response_data = {"successful": True, "result": ""} response_data = {"successful": True, "result": ""}
try: try:
result = mapping.evaluate(**context) result = mapping.evaluate(
users.first(),
self.request,
**test_params.validated_data.get("context", {}),
)
response_data["result"] = dumps( response_data["result"] = dumps(
sanitize_item(result), indent=(4 if format_result else None) sanitize_item(result), indent=(4 if format_result else None)
) )

View File

@ -5,15 +5,20 @@ from django.db.models.query import Q
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_filters.filters import BooleanFilter from django_filters.filters import BooleanFilter
from django_filters.filterset import FilterSet from django_filters.filterset import FilterSet
from drf_spectacular.utils import extend_schema
from rest_framework import mixins from rest_framework import mixins
from rest_framework.decorators import action
from rest_framework.fields import ReadOnlyField from rest_framework.fields import ReadOnlyField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer, SerializerMethodField from rest_framework.serializers import ModelSerializer, SerializerMethodField
from rest_framework.viewsets import GenericViewSet from rest_framework.viewsets import GenericViewSet
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import MetaNameSerializer from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
from authentik.core.models import Provider from authentik.core.models import Provider
from authentik.enterprise.apps import EnterpriseConfig
from authentik.lib.utils.reflection import all_subclasses
class ProviderSerializer(ModelSerializer, MetaNameSerializer): class ProviderSerializer(ModelSerializer, MetaNameSerializer):
@ -81,7 +86,6 @@ class ProviderFilter(FilterSet):
class ProviderViewSet( class ProviderViewSet(
TypesMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.DestroyModelMixin,
UsedByMixin, UsedByMixin,
@ -100,3 +104,31 @@ class ProviderViewSet(
def get_queryset(self): # pragma: no cover def get_queryset(self): # pragma: no cover
return Provider.objects.select_subclasses() return Provider.objects.select_subclasses()
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[])
def types(self, request: Request) -> Response:
"""Get all creatable provider types"""
data = []
for subclass in all_subclasses(self.queryset.model):
subclass: Provider
if subclass._meta.abstract:
continue
data.append(
{
"name": subclass._meta.verbose_name,
"description": subclass.__doc__,
"component": subclass().component,
"model_name": subclass._meta.model_name,
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
}
)
data.append(
{
"name": _("SAML Provider from Metadata"),
"description": _("Create a SAML Provider by importing its Metadata."),
"component": "ak-provider-saml-import-form",
"model_name": "",
}
)
return Response(TypeCreateSerializer(data, many=True).data)

View File

@ -17,9 +17,8 @@ from structlog.stdlib import get_logger
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import MetaNameSerializer from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
from authentik.core.models import Source, UserSourceConnection from authentik.core.models import Source, UserSourceConnection
from authentik.core.types import UserSettingSerializer from authentik.core.types import UserSettingSerializer
from authentik.lib.utils.file import ( from authentik.lib.utils.file import (
@ -28,6 +27,7 @@ from authentik.lib.utils.file import (
set_file, set_file,
set_file_url, set_file_url,
) )
from authentik.lib.utils.reflection import all_subclasses
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
@ -74,7 +74,6 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
class SourceViewSet( class SourceViewSet(
TypesMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.DestroyModelMixin,
UsedByMixin, UsedByMixin,
@ -133,6 +132,30 @@ class SourceViewSet(
source: Source = self.get_object() source: Source = self.get_object()
return set_file_url(request, source, "icon") return set_file_url(request, source, "icon")
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[])
def types(self, request: Request) -> Response:
"""Get all creatable source types"""
data = []
for subclass in all_subclasses(self.queryset.model):
subclass: Source
component = ""
if len(subclass.__subclasses__()) > 0:
continue
if subclass._meta.abstract:
component = subclass.__bases__[0]().component
else:
component = subclass().component
data.append(
{
"name": subclass._meta.verbose_name,
"description": subclass.__doc__,
"component": component,
"model_name": subclass._meta.model_name,
}
)
return Response(TypeCreateSerializer(data, many=True).data)
@extend_schema(responses={200: UserSettingSerializer(many=True)}) @extend_schema(responses={200: UserSettingSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[]) @action(detail=False, pagination_class=None, filter_backends=[])
def user_settings(self, request: Request) -> Response: def user_settings(self, request: Request) -> Response:

View File

@ -39,12 +39,12 @@ def get_delete_action(manager: Manager) -> str:
"""Get the delete action from the Foreign key, falls back to cascade""" """Get the delete action from the Foreign key, falls back to cascade"""
if hasattr(manager, "field"): if hasattr(manager, "field"):
if manager.field.remote_field.on_delete.__name__ == SET_NULL.__name__: if manager.field.remote_field.on_delete.__name__ == SET_NULL.__name__:
return DeleteAction.SET_NULL.value return DeleteAction.SET_NULL.name
if manager.field.remote_field.on_delete.__name__ == SET_DEFAULT.__name__: if manager.field.remote_field.on_delete.__name__ == SET_DEFAULT.__name__:
return DeleteAction.SET_DEFAULT.value return DeleteAction.SET_DEFAULT.name
if hasattr(manager, "source_field"): if hasattr(manager, "source_field"):
return DeleteAction.CASCADE_MANY.value return DeleteAction.CASCADE_MANY.name
return DeleteAction.CASCADE.value return DeleteAction.CASCADE.name
class UsedByMixin: class UsedByMixin:

View File

@ -408,7 +408,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
filterset_class = UsersFilter filterset_class = UsersFilter
def get_queryset(self): def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous() base_qs = User.objects.all()
if self.serializer_class(context={"request": self.request})._should_include_groups: if self.serializer_class(context={"request": self.request})._should_include_groups:
base_qs = base_qs.prefetch_related("ak_groups") base_qs = base_qs.prefetch_related("ak_groups")
return base_qs return base_qs

View File

@ -6,16 +6,8 @@ from django.db.models import Model
from drf_spectacular.extensions import OpenApiSerializerFieldExtension from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import build_basic_type from drf_spectacular.plumbing import build_basic_type
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from rest_framework.fields import ( from rest_framework.fields import BooleanField, CharField, IntegerField, JSONField
CharField, from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
IntegerField,
JSONField,
SerializerMethodField,
)
from rest_framework.serializers import (
Serializer,
ValidationError,
)
def is_dict(value: Any): def is_dict(value: Any):
@ -76,6 +68,16 @@ class MetaNameSerializer(PassiveSerializer):
return f"{obj._meta.app_label}.{obj._meta.model_name}" return f"{obj._meta.app_label}.{obj._meta.model_name}"
class TypeCreateSerializer(PassiveSerializer):
"""Types of an object that can be created"""
name = CharField(required=True)
description = CharField(required=True)
component = CharField(required=True)
model_name = CharField(required=True)
requires_enterprise = BooleanField(default=False)
class CacheSerializer(PassiveSerializer): class CacheSerializer(PassiveSerializer):
"""Generic cache stats for an object""" """Generic cache stats for an object"""

View File

@ -31,9 +31,8 @@ class InbuiltBackend(ModelBackend):
# Since we can't directly pass other variables to signals, and we want to log the method # Since we can't directly pass other variables to signals, and we want to log the method
# and the token used, we assume we're running in a flow and set a variable in the context # and the token used, we assume we're running in a flow and set a variable in the context
flow_plan: FlowPlan = request.session.get(SESSION_KEY_PLAN, FlowPlan("")) flow_plan: FlowPlan = request.session.get(SESSION_KEY_PLAN, FlowPlan(""))
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, method) flow_plan.context[PLAN_CONTEXT_METHOD] = method
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {}) flow_plan.context[PLAN_CONTEXT_METHOD_ARGS] = cleanse_dict(sanitize_dict(kwargs))
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].update(cleanse_dict(sanitize_dict(kwargs)))
request.session[SESSION_KEY_PLAN] = flow_plan request.session[SESSION_KEY_PLAN] = flow_plan

View File

@ -1,6 +1,5 @@
"""Property Mapping Evaluator""" """Property Mapping Evaluator"""
from types import CodeType
from typing import Any from typing import Any
from django.db.models import Model from django.db.models import Model
@ -25,8 +24,6 @@ class PropertyMappingEvaluator(BaseEvaluator):
"""Custom Evaluator that adds some different context variables.""" """Custom Evaluator that adds some different context variables."""
dry_run: bool dry_run: bool
model: Model
_compiled: CodeType | None = None
def __init__( def __init__(
self, self,
@ -36,32 +33,23 @@ class PropertyMappingEvaluator(BaseEvaluator):
dry_run: bool | None = False, dry_run: bool | None = False,
**kwargs, **kwargs,
): ):
self.model = model
if hasattr(model, "name"): if hasattr(model, "name"):
_filename = model.name _filename = model.name
else: else:
_filename = str(model) _filename = str(model)
super().__init__(filename=_filename) super().__init__(filename=_filename)
self.dry_run = dry_run
self.set_context(user, request, **kwargs)
def set_context(
self,
user: User | None = None,
request: HttpRequest | None = None,
**kwargs,
):
req = PolicyRequest(user=User()) req = PolicyRequest(user=User())
req.obj = self.model req.obj = model
if user: if user:
req.user = user req.user = user
self._context["user"] = user self._context["user"] = user
if request: if request:
req.http_request = request req.http_request = request
req.context.update(**kwargs)
self._context["request"] = req self._context["request"] = req
req.context.update(**kwargs)
self._context.update(**kwargs) self._context.update(**kwargs)
self._globals["SkipObject"] = SkipObjectException self._globals["SkipObject"] = SkipObjectException
self.dry_run = dry_run
def handle_error(self, exc: Exception, expression_source: str): def handle_error(self, exc: Exception, expression_source: str):
"""Exception Handler""" """Exception Handler"""
@ -83,9 +71,3 @@ class PropertyMappingEvaluator(BaseEvaluator):
def evaluate(self, *args, **kwargs) -> Any: def evaluate(self, *args, **kwargs) -> Any:
with PROPERTY_MAPPING_TIME.labels(mapping_name=self._filename).time(): with PROPERTY_MAPPING_TIME.labels(mapping_name=self._filename).time():
return super().evaluate(*args, **kwargs) return super().evaluate(*args, **kwargs)
def compile(self, expression: str | None = None) -> Any:
if not self._compiled:
compiled = super().compile(expression or self.model.expression)
self._compiled = compiled
return self._compiled

View File

@ -6,11 +6,6 @@ from authentik.lib.sentry import SentryIgnoredException
class PropertyMappingExpressionException(SentryIgnoredException): class PropertyMappingExpressionException(SentryIgnoredException):
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated.""" """Error when a PropertyMapping Exception expression could not be parsed or evaluated."""
def __init__(self, exc: Exception, mapping) -> None:
super().__init__()
self.exc = exc
self.mapping = mapping
class SkipObjectException(PropertyMappingExpressionException): class SkipObjectException(PropertyMappingExpressionException):
"""Exception which can be raised in a property mapping to skip syncing an object. """Exception which can be raised in a property mapping to skip syncing an object.

View File

@ -10,7 +10,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import Count from django.db.models import Count
import authentik.core.models import authentik.core.models
import authentik.lib.models import authentik.lib.validators
def migrate_sessions(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): def migrate_sessions(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
@ -160,7 +160,7 @@ class Migration(migrations.Migration):
field=models.TextField( field=models.TextField(
blank=True, blank=True,
default="", default="",
validators=[authentik.lib.models.DomainlessFormattedURLValidator()], validators=[authentik.lib.validators.DomainlessFormattedURLValidator()],
), ),
), ),
migrations.RunPython( migrations.RunPython(

View File

@ -0,0 +1,23 @@
# Generated by Django 5.0.4 on 2024-04-23 16:59
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0035_alter_group_options_and_more"),
]
operations = [
migrations.AddField(
model_name="group",
name="deleted_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="user",
name="deleted_at",
field=models.DateTimeField(blank=True, null=True),
),
]

View File

@ -15,7 +15,6 @@ from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_cte import CTEQuerySet, With
from guardian.conf import settings from guardian.conf import settings
from guardian.mixins import GuardianUserMixin from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager from model_utils.managers import InheritanceManager
@ -29,10 +28,12 @@ from authentik.lib.avatars import get_avatar
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.models import ( from authentik.lib.models import (
CreatedUpdatedModel, CreatedUpdatedModel,
DomainlessFormattedURLValidator,
SerializerModel, SerializerModel,
SoftDeleteModel,
SoftDeleteQuerySet,
) )
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.lib.validators import DomainlessFormattedURLValidator
from authentik.policies.models import PolicyBindingModel from authentik.policies.models import PolicyBindingModel
from authentik.tenants.models import DEFAULT_TOKEN_DURATION, DEFAULT_TOKEN_LENGTH from authentik.tenants.models import DEFAULT_TOKEN_DURATION, DEFAULT_TOKEN_LENGTH
from authentik.tenants.utils import get_current_tenant, get_unique_identifier from authentik.tenants.utils import get_current_tenant, get_unique_identifier
@ -57,8 +58,6 @@ options.DEFAULT_NAMES = options.DEFAULT_NAMES + (
"authentik_used_by_shadows", "authentik_used_by_shadows",
) )
GROUP_RECURSION_LIMIT = 20
def default_token_duration() -> datetime: def default_token_duration() -> datetime:
"""Default duration a Token is valid""" """Default duration a Token is valid"""
@ -99,41 +98,7 @@ class UserTypes(models.TextChoices):
INTERNAL_SERVICE_ACCOUNT = "internal_service_account" INTERNAL_SERVICE_ACCOUNT = "internal_service_account"
class GroupQuerySet(CTEQuerySet): class Group(SoftDeleteModel, SerializerModel):
def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents
or are indirectly related."""
def make_cte(cte):
"""Build the query that ends up in WITH RECURSIVE"""
# Start from self, aka the current query
# Add a depth attribute to limit the recursion
return self.annotate(
relative_depth=models.Value(0, output_field=models.IntegerField())
).union(
# Here is the recursive part of the query. cte refers to the previous iteration
# Only select groups for which the parent is part of the previous iteration
# and increase the depth
# Finally, limit the depth
cte.join(Group, group_uuid=cte.col.parent_id)
.annotate(
relative_depth=models.ExpressionWrapper(
cte.col.relative_depth
+ models.Value(1, output_field=models.IntegerField()),
output_field=models.IntegerField(),
)
)
.filter(relative_depth__lt=GROUP_RECURSION_LIMIT),
all=True,
)
# Build the recursive query, see above
cte = With.recursive(make_cte)
# Return the result, as a usable queryset for Group.
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
class Group(SerializerModel):
"""Group model which supports a basic hierarchy and has attributes""" """Group model which supports a basic hierarchy and has attributes"""
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
@ -155,8 +120,6 @@ class Group(SerializerModel):
) )
attributes = models.JSONField(default=dict, blank=True) attributes = models.JSONField(default=dict, blank=True)
objects = GroupQuerySet.as_manager()
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:
from authentik.core.api.groups import GroupSerializer from authentik.core.api.groups import GroupSerializer
@ -175,11 +138,36 @@ class Group(SerializerModel):
return user.all_groups().filter(group_uuid=self.group_uuid).exists() return user.all_groups().filter(group_uuid=self.group_uuid).exists()
def children_recursive(self: Self | QuerySet["Group"]) -> QuerySet["Group"]: def children_recursive(self: Self | QuerySet["Group"]) -> QuerySet["Group"]:
"""Compatibility layer for Group.objects.with_children_recursive()""" """Recursively get all groups that have this as parent or are indirectly related"""
qs = self direct_groups = []
if not isinstance(self, QuerySet): if isinstance(self, QuerySet):
qs = Group.objects.filter(group_uuid=self.group_uuid) direct_groups = list(x for x in self.all().values_list("pk", flat=True).iterator())
return qs.with_children_recursive() else:
direct_groups = [self.pk]
if len(direct_groups) < 1:
return Group.objects.none()
query = """
WITH RECURSIVE parents AS (
SELECT authentik_core_group.*, 0 AS relative_depth
FROM authentik_core_group
WHERE authentik_core_group.group_uuid = ANY(%s)
UNION ALL
SELECT authentik_core_group.*, parents.relative_depth + 1
FROM authentik_core_group, parents
WHERE (
authentik_core_group.group_uuid = parents.parent_id and
parents.relative_depth < 20
)
)
SELECT group_uuid
FROM parents
GROUP BY group_uuid, name
ORDER BY name;
"""
group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()]
return Group.objects.filter(pk__in=group_pks)
def __str__(self): def __str__(self):
return f"Group {self.name}" return f"Group {self.name}"
@ -200,31 +188,21 @@ class Group(SerializerModel):
] ]
class UserQuerySet(models.QuerySet):
"""User queryset"""
def exclude_anonymous(self):
"""Exclude anonymous user"""
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
class UserManager(DjangoUserManager): class UserManager(DjangoUserManager):
"""User manager that doesn't assign is_superuser and is_staff""" """User manager that doesn't assign is_superuser and is_staff"""
def get_queryset(self): def get_queryset(self):
"""Create special user queryset""" """Create special user queryset"""
return UserQuerySet(self.model, using=self._db) return SoftDeleteQuerySet(self.model, using=self._db).exclude(
**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME}
)
def create_user(self, username, email=None, password=None, **extra_fields): def create_user(self, username, email=None, password=None, **extra_fields):
"""User manager that doesn't assign is_superuser and is_staff""" """User manager that doesn't assign is_superuser and is_staff"""
return self._create_user(username, email, password, **extra_fields) return self._create_user(username, email, password, **extra_fields)
def exclude_anonymous(self) -> QuerySet:
"""Exclude anonymous user"""
return self.get_queryset().exclude_anonymous()
class User(SoftDeleteModel, SerializerModel, GuardianUserMixin, AbstractUser):
class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model.""" """authentik User model, based on django's contrib auth user model."""
uuid = models.UUIDField(default=uuid4, editable=False, unique=True) uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
@ -246,8 +224,10 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
return User._meta.get_field("path").default return User._meta.get_field("path").default
def all_groups(self) -> QuerySet[Group]: def all_groups(self) -> QuerySet[Group]:
"""Recursively get all groups this user is a member of.""" """Recursively get all groups this user is a member of.
return self.ak_groups.all().with_children_recursive() At least one query is done to get the direct groups of the user, with groups
there are at most 3 queries done"""
return Group.children_recursive(self.ak_groups.all())
def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]: def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]:
"""Get a dictionary containing the attributes from all groups the user belongs to, """Get a dictionary containing the attributes from all groups the user belongs to,
@ -389,10 +369,6 @@ class Provider(SerializerModel):
Can return None for providers that are not URL-based""" Can return None for providers that are not URL-based"""
return None return None
@property
def icon_url(self) -> str | None:
return None
@property @property
def component(self) -> str: def component(self) -> str:
"""Return component used to edit this object""" """Return component used to edit this object"""
@ -784,7 +760,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
try: try:
return evaluator.evaluate(self.expression) return evaluator.evaluate(self.expression)
except Exception as exc: except Exception as exc:
raise PropertyMappingExpressionException(self, exc) from exc raise PropertyMappingExpressionException(exc) from exc
def __str__(self): def __str__(self):
return f"Property Mapping {self.name}" return f"Property Mapping {self.name}"

View File

@ -23,17 +23,6 @@ class TestGroupsAPI(APITestCase):
response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"}) response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_retrieve_with_users(self):
"""Test retrieve with users"""
admin = create_test_admin_user()
group = Group.objects.create(name=generate_id())
self.client.force_login(admin)
response = self.client.get(
reverse("authentik_api:group-detail", kwargs={"pk": group.pk}),
{"include_users": "true"},
)
self.assertEqual(response.status_code, 200)
def test_add_user(self): def test_add_user(self):
"""Test add_user""" """Test add_user"""
group = Group.objects.create(name=generate_id()) group = Group.objects.create(name=generate_id())

View File

@ -1,14 +1,14 @@
"""authentik core models tests""" """authentik core models tests"""
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from time import sleep
from django.test import RequestFactory, TestCase from django.test import RequestFactory, TestCase
from django.utils.timezone import now from django.utils.timezone import now
from freezegun import freeze_time
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from authentik.core.models import Provider, Source, Token from authentik.core.models import Provider, Source, Token
from authentik.flows.models import Stage
from authentik.lib.utils.reflection import all_subclasses from authentik.lib.utils.reflection import all_subclasses
@ -17,20 +17,18 @@ class TestModels(TestCase):
def test_token_expire(self): def test_token_expire(self):
"""Test token expiring""" """Test token expiring"""
with freeze_time() as freeze: token = Token.objects.create(expires=now(), user=get_anonymous_user())
token = Token.objects.create(expires=now(), user=get_anonymous_user()) sleep(0.5)
freeze.tick(timedelta(seconds=1)) self.assertTrue(token.is_expired)
self.assertTrue(token.is_expired)
def test_token_expire_no_expire(self): def test_token_expire_no_expire(self):
"""Test token expiring with "expiring" set""" """Test token expiring with "expiring" set"""
with freeze_time() as freeze: token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False) sleep(0.5)
freeze.tick(timedelta(seconds=1)) self.assertFalse(token.is_expired)
self.assertFalse(token.is_expired)
def source_tester_factory(test_model: type[Source]) -> Callable: def source_tester_factory(test_model: type[Stage]) -> Callable:
"""Test source""" """Test source"""
factory = RequestFactory() factory = RequestFactory()
@ -38,19 +36,19 @@ def source_tester_factory(test_model: type[Source]) -> Callable:
def tester(self: TestModels): def tester(self: TestModels):
model_class = None model_class = None
if test_model._meta.abstract: if test_model._meta.abstract: # pragma: no cover
model_class = [x for x in test_model.__bases__ if issubclass(x, Source)][0]() model_class = test_model.__bases__[0]()
else: else:
model_class = test_model() model_class = test_model()
model_class.slug = "test" model_class.slug = "test"
self.assertIsNotNone(model_class.component) self.assertIsNotNone(model_class.component)
model_class.ui_login_button(request) _ = model_class.ui_login_button(request)
model_class.ui_user_settings() _ = model_class.ui_user_settings()
return tester return tester
def provider_tester_factory(test_model: type[Provider]) -> Callable: def provider_tester_factory(test_model: type[Stage]) -> Callable:
"""Test provider""" """Test provider"""
def tester(self: TestModels): def tester(self: TestModels):

View File

@ -6,10 +6,9 @@ from django.urls import reverse
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.core.api.property_mappings import PropertyMappingSerializer from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.models import Group, PropertyMapping from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
class TestPropertyMappingAPI(APITestCase): class TestPropertyMappingAPI(APITestCase):
@ -17,40 +16,23 @@ class TestPropertyMappingAPI(APITestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.mapping = PropertyMapping.objects.create(
name="dummy", expression="""return {'foo': 'bar'}"""
)
self.user = create_test_admin_user() self.user = create_test_admin_user()
self.client.force_login(self.user) self.client.force_login(self.user)
def test_test_call(self): def test_test_call(self):
"""Test PropertyMappings's test endpoint""" """Test PropertMappings's test endpoint"""
mapping = PropertyMapping.objects.create(
name="dummy", expression="""return {'foo': 'bar', 'baz': user.username}"""
)
response = self.client.post( response = self.client.post(
reverse("authentik_api:propertymapping-test", kwargs={"pk": mapping.pk}), reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}),
data={ data={
"user": self.user.pk, "user": self.user.pk,
}, },
) )
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{"result": dumps({"foo": "bar", "baz": self.user.username}), "successful": True}, {"result": dumps({"foo": "bar"}), "successful": True},
)
def test_test_call_group(self):
"""Test PropertyMappings's test endpoint"""
mapping = PropertyMapping.objects.create(
name="dummy", expression="""return {'foo': 'bar', 'baz': group.name}"""
)
group = Group.objects.create(name=generate_id())
response = self.client.post(
reverse("authentik_api:propertymapping-test", kwargs={"pk": mapping.pk}),
data={
"group": group.pk,
},
)
self.assertJSONEqual(
response.content.decode(),
{"result": dumps({"foo": "bar", "baz": group.name}), "successful": True},
) )
def test_validate(self): def test_validate(self):

View File

@ -42,8 +42,8 @@ class TestUsersAvatars(APITestCase):
with Mocker() as mocker: with Mocker() as mocker:
mocker.head( mocker.head(
( (
"https://www.gravatar.com/avatar/76eb3c74c8beb6faa037f1b6e2ecb3e252bdac" "https://secure.gravatar.com/avatar/84730f9c1851d1ea03f1a"
"6cf71fb567ae36025a9d4ea86b?size=158&rating=g&default=404" "a9ed85bd1ea?size=158&rating=g&default=404"
), ),
text="foo", text="foo",
) )

View File

@ -12,7 +12,7 @@ from authentik.core.api.applications import ApplicationViewSet
from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet
from authentik.core.api.devices import AdminDeviceViewSet, DeviceViewSet from authentik.core.api.devices import AdminDeviceViewSet, DeviceViewSet
from authentik.core.api.groups import GroupViewSet from authentik.core.api.groups import GroupViewSet
from authentik.core.api.property_mappings import PropertyMappingViewSet from authentik.core.api.propertymappings import PropertyMappingViewSet
from authentik.core.api.providers import ProviderViewSet from authentik.core.api.providers import ProviderViewSet
from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSet from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSet
from authentik.core.api.tokens import TokenViewSet from authentik.core.api.tokens import TokenViewSet

View File

@ -92,11 +92,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
@property @property
def kid(self): def kid(self):
"""Get Key ID used for JWKS""" """Get Key ID used for JWKS"""
return ( return md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else "" # nosec
md5(self.key_data.encode("utf-8"), usedforsecurity=False).hexdigest()
if self.key_data
else ""
) # nosec
def __str__(self) -> str: def __str__(self) -> str:
return f"Certificate-Key Pair {self.name}" return f"Certificate-Key Pair {self.name}"

View File

@ -241,7 +241,7 @@ class TestCrypto(APITestCase):
"model_name": "oauth2provider", "model_name": "oauth2provider",
"pk": str(provider.pk), "pk": str(provider.pk),
"name": str(provider), "name": str(provider),
"action": DeleteAction.SET_NULL.value, "action": DeleteAction.SET_NULL.name,
} }
], ],
) )

View File

@ -132,7 +132,7 @@ class LicenseKey:
@staticmethod @staticmethod
def base_user_qs() -> QuerySet: def base_user_qs() -> QuerySet:
"""Base query set for all users""" """Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False) return User.objects.all().exclude(is_active=False)
@staticmethod @staticmethod
def get_default_user_count(): def get_default_user_count():

View File

@ -1,16 +1,14 @@
"""GoogleWorkspaceProviderGroup API Views""" """GoogleWorkspaceProviderGroup API Views"""
from rest_framework import mixins from rest_framework.viewsets import ModelViewSet
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.users import UserGroupSerializer from authentik.core.api.users import UserGroupSerializer
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderGroup from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderGroup
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
class GoogleWorkspaceProviderGroupSerializer(ModelSerializer): class GoogleWorkspaceProviderGroupSerializer(SourceSerializer):
"""GoogleWorkspaceProviderGroup Serializer""" """GoogleWorkspaceProviderGroup Serializer"""
group_obj = UserGroupSerializer(source="group", read_only=True) group_obj = UserGroupSerializer(source="group", read_only=True)
@ -20,24 +18,12 @@ class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
model = GoogleWorkspaceProviderGroup model = GoogleWorkspaceProviderGroup
fields = [ fields = [
"id", "id",
"google_id",
"group", "group",
"group_obj", "group_obj",
"provider",
"attributes",
] ]
extra_kwargs = {"attributes": {"read_only": True}}
class GoogleWorkspaceProviderGroupViewSet( class GoogleWorkspaceProviderGroupViewSet(UsedByMixin, ModelViewSet):
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""GoogleWorkspaceProviderGroup Viewset""" """GoogleWorkspaceProviderGroup Viewset"""
queryset = GoogleWorkspaceProviderGroup.objects.all().select_related("group") queryset = GoogleWorkspaceProviderGroup.objects.all().select_related("group")

View File

@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field from drf_spectacular.utils import extend_schema_field
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingSerializer from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderMapping from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderMapping

View File

@ -1,16 +1,14 @@
"""GoogleWorkspaceProviderUser API Views""" """GoogleWorkspaceProviderUser API Views"""
from rest_framework import mixins from rest_framework.viewsets import ModelViewSet
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.groups import GroupMemberSerializer from authentik.core.api.groups import GroupMemberSerializer
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderUser from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderUser
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
class GoogleWorkspaceProviderUserSerializer(ModelSerializer): class GoogleWorkspaceProviderUserSerializer(SourceSerializer):
"""GoogleWorkspaceProviderUser Serializer""" """GoogleWorkspaceProviderUser Serializer"""
user_obj = GroupMemberSerializer(source="user", read_only=True) user_obj = GroupMemberSerializer(source="user", read_only=True)
@ -20,24 +18,12 @@ class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
model = GoogleWorkspaceProviderUser model = GoogleWorkspaceProviderUser
fields = [ fields = [
"id", "id",
"google_id",
"user", "user",
"user_obj", "user_obj",
"provider",
"attributes",
] ]
extra_kwargs = {"attributes": {"read_only": True}}
class GoogleWorkspaceProviderUserViewSet( class GoogleWorkspaceProviderUserViewSet(UsedByMixin, ModelViewSet):
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""GoogleWorkspaceProviderUser Viewset""" """GoogleWorkspaceProviderUser Viewset"""
queryset = GoogleWorkspaceProviderUser.objects.all().select_related("user") queryset = GoogleWorkspaceProviderUser.objects.all().select_related("user")

View File

@ -1,22 +1,28 @@
from deepmerge import always_merger
from django.db import transaction from django.db import transaction
from django.utils.text import slugify from django.utils.text import slugify
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.models import Group from authentik.core.models import Group
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
from authentik.enterprise.providers.google_workspace.models import ( from authentik.enterprise.providers.google_workspace.models import (
GoogleWorkspaceProvider,
GoogleWorkspaceProviderGroup, GoogleWorkspaceProviderGroup,
GoogleWorkspaceProviderMapping, GoogleWorkspaceProviderMapping,
GoogleWorkspaceProviderUser, GoogleWorkspaceProviderUser,
) )
from authentik.lib.sync.mapper import PropertyMappingManager from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import ( from authentik.lib.sync.outgoing.exceptions import (
NotFoundSyncException, NotFoundSyncException,
ObjectExistsSyncException, ObjectExistsSyncException,
StopSync,
TransientSyncException, TransientSyncException,
) )
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
from authentik.lib.utils.errors import exception_to_string
class GoogleWorkspaceGroupClient( class GoogleWorkspaceGroupClient(
@ -28,21 +34,41 @@ class GoogleWorkspaceGroupClient(
connection_type_query = "group" connection_type_query = "group"
can_discover = True can_discover = True
def __init__(self, provider: GoogleWorkspaceProvider) -> None: def to_schema(self, obj: Group, creating: bool) -> dict:
super().__init__(provider)
self.mapper = PropertyMappingManager(
self.provider.property_mappings_group.all().order_by("name").select_subclasses(),
GoogleWorkspaceProviderMapping,
["group", "provider", "connection"],
)
def to_schema(self, obj: Group, connection: GoogleWorkspaceProviderGroup) -> dict:
"""Convert authentik group""" """Convert authentik group"""
return super().to_schema( raw_google_group = {
obj, "email": f"{slugify(obj.name)}@{self.provider.default_group_email_domain}"
connection=connection, }
email=f"{slugify(obj.name)}@{self.provider.default_group_email_domain}", for mapping in (
) self.provider.property_mappings_group.all().order_by("name").select_subclasses()
):
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
continue
try:
value = mapping.evaluate(
user=None,
request=None,
group=obj,
provider=self.provider,
creating=creating,
)
if value is None:
continue
always_merger.merge(raw_google_group, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=mapping,
).save()
raise StopSync(exc, obj, mapping) from exc
if not raw_google_group:
raise StopSync(ValueError("No group mappings configured"), obj)
return raw_google_group
def delete(self, obj: Group): def delete(self, obj: Group):
"""Delete group""" """Delete group"""
@ -61,7 +87,7 @@ class GoogleWorkspaceGroupClient(
def create(self, group: Group): def create(self, group: Group):
"""Create group from scratch and create a connection object""" """Create group from scratch and create a connection object"""
google_group = self.to_schema(group, None) google_group = self.to_schema(group, True)
self.check_email_valid(google_group["email"]) self.check_email_valid(google_group["email"])
with transaction.atomic(): with transaction.atomic():
try: try:
@ -74,32 +100,24 @@ class GoogleWorkspaceGroupClient(
self.directory_service.groups().get(groupKey=google_group["email"]) self.directory_service.groups().get(groupKey=google_group["email"])
) )
return GoogleWorkspaceProviderGroup.objects.create( return GoogleWorkspaceProviderGroup.objects.create(
provider=self.provider, provider=self.provider, group=group, google_id=group_data["id"]
group=group,
google_id=group_data["id"],
attributes=group_data,
) )
else: else:
return GoogleWorkspaceProviderGroup.objects.create( return GoogleWorkspaceProviderGroup.objects.create(
provider=self.provider, provider=self.provider, group=group, google_id=response["id"]
group=group,
google_id=response["id"],
attributes=response,
) )
def update(self, group: Group, connection: GoogleWorkspaceProviderGroup): def update(self, group: Group, connection: GoogleWorkspaceProviderGroup):
"""Update existing group""" """Update existing group"""
google_group = self.to_schema(group, connection) google_group = self.to_schema(group, False)
self.check_email_valid(google_group["email"]) self.check_email_valid(google_group["email"])
try: try:
response = self._request( return self._request(
self.directory_service.groups().update( self.directory_service.groups().update(
groupKey=connection.google_id, groupKey=connection.google_id,
body=google_group, body=google_group,
) )
) )
connection.attributes = response
connection.save()
except NotFoundSyncException: except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group # Resource missing is handled by self.write, which will re-create the group
raise raise
@ -212,9 +230,4 @@ class GoogleWorkspaceGroupClient(
provider=self.provider, provider=self.provider,
group=matching_authentik_group, group=matching_authentik_group,
google_id=google_id, google_id=google_id,
attributes=group,
) )
def update_single_attribute(self, connection: GoogleWorkspaceProviderUser):
group = self.directory_service.groups().get(connection.google_id)
connection.attributes = group

View File

@ -1,18 +1,24 @@
from deepmerge import always_merger
from django.db import transaction from django.db import transaction
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.models import User from authentik.core.models import User
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
from authentik.enterprise.providers.google_workspace.models import ( from authentik.enterprise.providers.google_workspace.models import (
GoogleWorkspaceProvider,
GoogleWorkspaceProviderMapping, GoogleWorkspaceProviderMapping,
GoogleWorkspaceProviderUser, GoogleWorkspaceProviderUser,
) )
from authentik.lib.sync.mapper import PropertyMappingManager from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.exceptions import ( from authentik.lib.sync.outgoing.exceptions import (
ObjectExistsSyncException, ObjectExistsSyncException,
StopSync,
TransientSyncException, TransientSyncException,
) )
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.utils import delete_none_values from authentik.policies.utils import delete_none_values
@ -23,17 +29,37 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
connection_type_query = "user" connection_type_query = "user"
can_discover = True can_discover = True
def __init__(self, provider: GoogleWorkspaceProvider) -> None: def to_schema(self, obj: User, creating: bool) -> dict:
super().__init__(provider)
self.mapper = PropertyMappingManager(
self.provider.property_mappings.all().order_by("name").select_subclasses(),
GoogleWorkspaceProviderMapping,
["provider", "connection"],
)
def to_schema(self, obj: User, connection: GoogleWorkspaceProviderUser) -> dict:
"""Convert authentik user""" """Convert authentik user"""
return delete_none_values(super().to_schema(obj, connection, primaryEmail=obj.email)) raw_google_user = {}
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
continue
try:
value = mapping.evaluate(
user=obj,
request=None,
provider=self.provider,
creating=creating,
)
if value is None:
continue
always_merger.merge(raw_google_user, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=mapping,
).save()
raise StopSync(exc, obj, mapping) from exc
if not raw_google_user:
raise StopSync(ValueError("No user mappings configured"), obj)
if "primaryEmail" not in raw_google_user:
raw_google_user["primaryEmail"] = str(obj.email)
return delete_none_values(raw_google_user)
def delete(self, obj: User): def delete(self, obj: User):
"""Delete user""" """Delete user"""
@ -60,7 +86,7 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
def create(self, user: User): def create(self, user: User):
"""Create user from scratch and create a connection object""" """Create user from scratch and create a connection object"""
google_user = self.to_schema(user, None) google_user = self.to_schema(user, True)
self.check_email_valid( self.check_email_valid(
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])] google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
) )
@ -70,29 +96,24 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
except ObjectExistsSyncException: except ObjectExistsSyncException:
# user already exists in google workspace, so we can connect them manually # user already exists in google workspace, so we can connect them manually
return GoogleWorkspaceProviderUser.objects.create( return GoogleWorkspaceProviderUser.objects.create(
provider=self.provider, user=user, google_id=user.email, attributes={} provider=self.provider, user=user, google_id=user.email
) )
except TransientSyncException as exc: except TransientSyncException as exc:
raise exc raise exc
else: else:
return GoogleWorkspaceProviderUser.objects.create( return GoogleWorkspaceProviderUser.objects.create(
provider=self.provider, provider=self.provider, user=user, google_id=response["primaryEmail"]
user=user,
google_id=response["primaryEmail"],
attributes=response,
) )
def update(self, user: User, connection: GoogleWorkspaceProviderUser): def update(self, user: User, connection: GoogleWorkspaceProviderUser):
"""Update existing user""" """Update existing user"""
google_user = self.to_schema(user, connection) google_user = self.to_schema(user, False)
self.check_email_valid( self.check_email_valid(
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])] google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
) )
response = self._request( self._request(
self.directory_service.users().update(userKey=connection.google_id, body=google_user) self.directory_service.users().update(userKey=connection.google_id, body=google_user)
) )
connection.attributes = response
connection.save()
def discover(self): def discover(self):
"""Iterate through all users and connect them with authentik users if possible""" """Iterate through all users and connect them with authentik users if possible"""
@ -117,9 +138,4 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
provider=self.provider, provider=self.provider,
user=matching_authentik_user, user=matching_authentik_user,
google_id=email, google_id=email,
attributes=user,
) )
def update_single_attribute(self, connection: GoogleWorkspaceProviderUser):
user = self.directory_service.users().get(connection.google_id)
connection.attributes = user

View File

@ -1,26 +0,0 @@
# Generated by Django 5.0.6 on 2024-05-23 20:48
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
(
"authentik_providers_google_workspace",
"0001_squashed_0002_alter_googleworkspaceprovidergroup_options_and_more",
),
]
operations = [
migrations.AddField(
model_name="googleworkspaceprovidergroup",
name="attributes",
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name="googleworkspaceprovideruser",
name="attributes",
field=models.JSONField(default=dict),
),
]

View File

@ -5,7 +5,6 @@ from uuid import uuid4
from django.db import models from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from google.oauth2.service_account import Credentials from google.oauth2.service_account import Credentials
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -31,58 +30,6 @@ def default_scopes() -> list[str]:
] ]
class GoogleWorkspaceProviderUser(SerializerModel):
"""Mapping of a user and provider to a Google user ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
google_id = models.TextField()
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.ForeignKey("GoogleWorkspaceProvider", on_delete=models.CASCADE)
attributes = models.JSONField(default=dict)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.google_workspace.api.users import (
GoogleWorkspaceProviderUserSerializer,
)
return GoogleWorkspaceProviderUserSerializer
class Meta:
verbose_name = _("Google Workspace Provider User")
verbose_name_plural = _("Google Workspace Provider Users")
unique_together = (("google_id", "user", "provider"),)
def __str__(self) -> str:
return f"Google Workspace Provider User {self.user_id} to {self.provider_id}"
class GoogleWorkspaceProviderGroup(SerializerModel):
"""Mapping of a group and provider to a Google group ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
google_id = models.TextField()
group = models.ForeignKey(Group, on_delete=models.CASCADE)
provider = models.ForeignKey("GoogleWorkspaceProvider", on_delete=models.CASCADE)
attributes = models.JSONField(default=dict)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.google_workspace.api.groups import (
GoogleWorkspaceProviderGroupSerializer,
)
return GoogleWorkspaceProviderGroupSerializer
class Meta:
verbose_name = _("Google Workspace Provider Group")
verbose_name_plural = _("Google Workspace Provider Groups")
unique_together = (("google_id", "group", "provider"),)
def __str__(self) -> str:
return f"Google Workspace Provider Group {self.group_id} to {self.provider_id}"
class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider): class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
"""Sync users from authentik into Google Workspace.""" """Sync users from authentik into Google Workspace."""
@ -111,16 +58,15 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
) )
def client_for_model( def client_for_model(
self, self, model: type[User | Group]
model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup],
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
if issubclass(model, User | GoogleWorkspaceProviderUser): if issubclass(model, User):
from authentik.enterprise.providers.google_workspace.clients.users import ( from authentik.enterprise.providers.google_workspace.clients.users import (
GoogleWorkspaceUserClient, GoogleWorkspaceUserClient,
) )
return GoogleWorkspaceUserClient(self) return GoogleWorkspaceUserClient(self)
if issubclass(model, Group | GoogleWorkspaceProviderGroup): if issubclass(model, Group):
from authentik.enterprise.providers.google_workspace.clients.groups import ( from authentik.enterprise.providers.google_workspace.clients.groups import (
GoogleWorkspaceGroupClient, GoogleWorkspaceGroupClient,
) )
@ -152,10 +98,6 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
).with_subject(self.delegated_subject), ).with_subject(self.delegated_subject),
} }
@property
def icon_url(self) -> str | None:
return static("authentik/sources/google.svg")
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-provider-google-workspace-form" return "ak-provider-google-workspace-form"
@ -197,3 +139,53 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
class Meta: class Meta:
verbose_name = _("Google Workspace Provider Mapping") verbose_name = _("Google Workspace Provider Mapping")
verbose_name_plural = _("Google Workspace Provider Mappings") verbose_name_plural = _("Google Workspace Provider Mappings")
class GoogleWorkspaceProviderUser(SerializerModel):
"""Mapping of a user and provider to a Google user ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
google_id = models.TextField()
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.google_workspace.api.users import (
GoogleWorkspaceProviderUserSerializer,
)
return GoogleWorkspaceProviderUserSerializer
class Meta:
verbose_name = _("Google Workspace Provider User")
verbose_name_plural = _("Google Workspace Provider Users")
unique_together = (("google_id", "user", "provider"),)
def __str__(self) -> str:
return f"Google Workspace Provider User {self.user_id} to {self.provider_id}"
class GoogleWorkspaceProviderGroup(SerializerModel):
"""Mapping of a group and provider to a Google group ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
google_id = models.TextField()
group = models.ForeignKey(Group, on_delete=models.CASCADE)
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.google_workspace.api.groups import (
GoogleWorkspaceProviderGroupSerializer,
)
return GoogleWorkspaceProviderGroupSerializer
class Meta:
verbose_name = _("Google Workspace Provider Group")
verbose_name_plural = _("Google Workspace Provider Groups")
unique_together = (("google_id", "group", "provider"),)
def __str__(self) -> str:
return f"Google Workspace Provider Group {self.group_id} to {self.provider_id}"

View File

@ -82,27 +82,6 @@ class GoogleWorkspaceGroupTests(TestCase):
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists()) self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 2) self.assertEqual(len(http.requests()), 2)
def test_group_not_created(self):
"""Test without group property mappings, no group is created"""
self.provider.property_mappings_group.clear()
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
group = Group.objects.create(name=uid)
google_group = GoogleWorkspaceProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNone(google_group)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 1)
def test_group_create_update(self): def test_group_create_update(self):
"""Test group updating""" """Test group updating"""
uid = generate_id() uid = generate_id()

View File

@ -86,31 +86,6 @@ class GoogleWorkspaceUserTests(TestCase):
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists()) self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 2) self.assertEqual(len(http.requests()), 2)
def test_user_not_created(self):
"""Test without property mappings, no group is created"""
self.provider.property_mappings.clear()
uid = generate_id()
http = MockHTTP()
http.add_response(
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
domains_list_v1_mock,
)
with patch(
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
):
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
google_user = GoogleWorkspaceProviderUser.objects.filter(
provider=self.provider, user=user
).first()
self.assertIsNone(google_user)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
self.assertEqual(len(http.requests()), 1)
def test_user_create_update(self): def test_user_create_update(self):
"""Test user updating""" """Test user updating"""
uid = generate_id() uid = generate_id()

View File

@ -1,16 +1,14 @@
"""MicrosoftEntraProviderGroup API Views""" """MicrosoftEntraProviderGroup API Views"""
from rest_framework import mixins from rest_framework.viewsets import ModelViewSet
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.users import UserGroupSerializer from authentik.core.api.users import UserGroupSerializer
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderGroup from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderGroup
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
class MicrosoftEntraProviderGroupSerializer(ModelSerializer): class MicrosoftEntraProviderGroupSerializer(SourceSerializer):
"""MicrosoftEntraProviderGroup Serializer""" """MicrosoftEntraProviderGroup Serializer"""
group_obj = UserGroupSerializer(source="group", read_only=True) group_obj = UserGroupSerializer(source="group", read_only=True)
@ -20,24 +18,12 @@ class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
model = MicrosoftEntraProviderGroup model = MicrosoftEntraProviderGroup
fields = [ fields = [
"id", "id",
"microsoft_id",
"group", "group",
"group_obj", "group_obj",
"provider",
"attributes",
] ]
extra_kwargs = {"attributes": {"read_only": True}}
class MicrosoftEntraProviderGroupViewSet( class MicrosoftEntraProviderGroupViewSet(UsedByMixin, ModelViewSet):
mixins.CreateModelMixin,
OutgoingSyncConnectionCreateMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""MicrosoftEntraProviderGroup Viewset""" """MicrosoftEntraProviderGroup Viewset"""
queryset = MicrosoftEntraProviderGroup.objects.all().select_related("group") queryset = MicrosoftEntraProviderGroup.objects.all().select_related("group")

View File

@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field from drf_spectacular.utils import extend_schema_field
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingSerializer from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderMapping from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderMapping

View File

@ -1,16 +1,14 @@
"""MicrosoftEntraProviderUser API Views""" """MicrosoftEntraProviderUser API Views"""
from rest_framework import mixins from rest_framework.viewsets import ModelViewSet
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from authentik.core.api.groups import GroupMemberSerializer from authentik.core.api.groups import GroupMemberSerializer
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderUser from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderUser
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
class MicrosoftEntraProviderUserSerializer(ModelSerializer): class MicrosoftEntraProviderUserSerializer(SourceSerializer):
"""MicrosoftEntraProviderUser Serializer""" """MicrosoftEntraProviderUser Serializer"""
user_obj = GroupMemberSerializer(source="user", read_only=True) user_obj = GroupMemberSerializer(source="user", read_only=True)
@ -20,24 +18,12 @@ class MicrosoftEntraProviderUserSerializer(ModelSerializer):
model = MicrosoftEntraProviderUser model = MicrosoftEntraProviderUser
fields = [ fields = [
"id", "id",
"microsoft_id",
"user", "user",
"user_obj", "user_obj",
"provider",
"attributes",
] ]
extra_kwargs = {"attributes": {"read_only": True}}
class MicrosoftEntraProviderUserViewSet( class MicrosoftEntraProviderUserViewSet(UsedByMixin, ModelViewSet):
OutgoingSyncConnectionCreateMixin,
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""MicrosoftEntraProviderUser Viewset""" """MicrosoftEntraProviderUser Viewset"""
queryset = MicrosoftEntraProviderUser.objects.all().select_related("user") queryset = MicrosoftEntraProviderUser.objects.all().select_related("user")

View File

@ -1,6 +1,5 @@
from asyncio import run from asyncio import run
from collections.abc import Coroutine from collections.abc import Coroutine
from dataclasses import asdict
from typing import Any from typing import Any
from azure.core.exceptions import ( from azure.core.exceptions import (
@ -16,14 +15,12 @@ from kiota_authentication_azure.azure_identity_authentication_provider import (
AzureIdentityAuthenticationProvider, AzureIdentityAuthenticationProvider,
) )
from kiota_http.kiota_client_factory import KiotaClientFactory from kiota_http.kiota_client_factory import KiotaClientFactory
from msgraph.generated.models.entity import Entity
from msgraph.generated.models.o_data_errors.o_data_error import ODataError from msgraph.generated.models.o_data_errors.o_data_error import ODataError
from msgraph.graph_request_adapter import GraphRequestAdapter, options from msgraph.graph_request_adapter import GraphRequestAdapter, options
from msgraph.graph_service_client import GraphServiceClient from msgraph.graph_service_client import GraphServiceClient
from msgraph_core import GraphClientFactory from msgraph_core import GraphClientFactory
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.events.utils import sanitize_item
from authentik.lib.sync.outgoing import HTTP_CONFLICT from authentik.lib.sync.outgoing import HTTP_CONFLICT
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.sync.outgoing.exceptions import ( from authentik.lib.sync.outgoing.exceptions import (
@ -101,10 +98,3 @@ class MicrosoftEntraSyncClient[TModel: Model, TConnection: Model, TSchema: dict]
for email in emails: for email in emails:
if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains): if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
raise BadRequestSyncException(f"Invalid email domain: {email}") raise BadRequestSyncException(f"Invalid email domain: {email}")
def entity_as_dict(self, entity: Entity) -> dict:
"""Create a dictionary of a model instance, making sure to remove (known) things
we can't JSON serialize"""
raw_data = asdict(entity)
raw_data.pop("backing_store", None)
return sanitize_item(raw_data)

View File

@ -4,15 +4,18 @@ from msgraph.generated.groups.groups_request_builder import GroupsRequestBuilder
from msgraph.generated.models.group import Group as MSGroup from msgraph.generated.models.group import Group as MSGroup
from msgraph.generated.models.reference_create import ReferenceCreate from msgraph.generated.models.reference_create import ReferenceCreate
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.models import Group from authentik.core.models import Group
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
from authentik.enterprise.providers.microsoft_entra.models import ( from authentik.enterprise.providers.microsoft_entra.models import (
MicrosoftEntraProvider,
MicrosoftEntraProviderGroup, MicrosoftEntraProviderGroup,
MicrosoftEntraProviderMapping, MicrosoftEntraProviderMapping,
MicrosoftEntraProviderUser, MicrosoftEntraProviderUser,
) )
from authentik.lib.sync.mapper import PropertyMappingManager from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import ( from authentik.lib.sync.outgoing.exceptions import (
NotFoundSyncException, NotFoundSyncException,
@ -21,6 +24,7 @@ from authentik.lib.sync.outgoing.exceptions import (
TransientSyncException, TransientSyncException,
) )
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
from authentik.lib.utils.errors import exception_to_string
class MicrosoftEntraGroupClient( class MicrosoftEntraGroupClient(
@ -32,17 +36,37 @@ class MicrosoftEntraGroupClient(
connection_type_query = "group" connection_type_query = "group"
can_discover = True can_discover = True
def __init__(self, provider: MicrosoftEntraProvider) -> None: def to_schema(self, obj: Group, creating: bool) -> MSGroup:
super().__init__(provider)
self.mapper = PropertyMappingManager(
self.provider.property_mappings_group.all().order_by("name").select_subclasses(),
MicrosoftEntraProviderMapping,
["group", "provider", "connection"],
)
def to_schema(self, obj: Group, connection: MicrosoftEntraProviderGroup) -> MSGroup:
"""Convert authentik group""" """Convert authentik group"""
raw_microsoft_group = super().to_schema(obj, connection) raw_microsoft_group = {}
for mapping in (
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
):
if not isinstance(mapping, MicrosoftEntraProviderMapping):
continue
try:
value = mapping.evaluate(
user=None,
request=None,
group=obj,
provider=self.provider,
creating=creating,
)
if value is None:
continue
always_merger.merge(raw_microsoft_group, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=mapping,
).save()
raise StopSync(exc, obj, mapping) from exc
if not raw_microsoft_group:
raise StopSync(ValueError("No group mappings configured"), obj)
try: try:
return MSGroup(**raw_microsoft_group) return MSGroup(**raw_microsoft_group)
except TypeError as exc: except TypeError as exc:
@ -63,7 +87,7 @@ class MicrosoftEntraGroupClient(
def create(self, group: Group): def create(self, group: Group):
"""Create group from scratch and create a connection object""" """Create group from scratch and create a connection object"""
microsoft_group = self.to_schema(group, None) microsoft_group = self.to_schema(group, True)
with transaction.atomic(): with transaction.atomic():
try: try:
response = self._request(self.client.groups.post(microsoft_group)) response = self._request(self.client.groups.post(microsoft_group))
@ -80,37 +104,27 @@ class MicrosoftEntraGroupClient(
) )
) )
group_data = self._request(self.client.groups.get(request_configuration)) group_data = self._request(self.client.groups.get(request_configuration))
if group_data.odata_count < 1 or len(group_data.value) < 1: if group_data.odata_count < 1:
self.logger.warning( self.logger.warning(
"Group which could not be created also does not exist", group=group "Group which could not be created also does not exist", group=group
) )
return return
ms_group = group_data.value[0]
return MicrosoftEntraProviderGroup.objects.create( return MicrosoftEntraProviderGroup.objects.create(
provider=self.provider, provider=self.provider, group=group, microsoft_id=group_data.value[0].id
group=group,
microsoft_id=ms_group.id,
attributes=self.entity_as_dict(ms_group),
) )
else: else:
return MicrosoftEntraProviderGroup.objects.create( return MicrosoftEntraProviderGroup.objects.create(
provider=self.provider, provider=self.provider, group=group, microsoft_id=response.id
group=group,
microsoft_id=response.id,
attributes=self.entity_as_dict(response),
) )
def update(self, group: Group, connection: MicrosoftEntraProviderGroup): def update(self, group: Group, connection: MicrosoftEntraProviderGroup):
"""Update existing group""" """Update existing group"""
microsoft_group = self.to_schema(group, connection) microsoft_group = self.to_schema(group, False)
microsoft_group.id = connection.microsoft_id microsoft_group.id = connection.microsoft_id
try: try:
response = self._request( return self._request(
self.client.groups.by_group_id(connection.microsoft_id).patch(microsoft_group) self.client.groups.by_group_id(connection.microsoft_id).patch(microsoft_group)
) )
if response:
always_merger.merge(connection.attributes, self.entity_as_dict(response))
connection.save()
except NotFoundSyncException: except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group # Resource missing is handled by self.write, which will re-create the group
raise raise
@ -224,9 +238,4 @@ class MicrosoftEntraGroupClient(
provider=self.provider, provider=self.provider,
group=matching_authentik_group, group=matching_authentik_group,
microsoft_id=group.id, microsoft_id=group.id,
attributes=self.entity_as_dict(group),
) )
def update_single_attribute(self, connection: MicrosoftEntraProviderGroup):
data = self._request(self.client.groups.by_group_id(connection.microsoft_id).get())
connection.attributes = self.entity_as_dict(data)

View File

@ -3,20 +3,24 @@ from django.db import transaction
from msgraph.generated.models.user import User as MSUser from msgraph.generated.models.user import User as MSUser
from msgraph.generated.users.users_request_builder import UsersRequestBuilder from msgraph.generated.users.users_request_builder import UsersRequestBuilder
from authentik.core.expression.exceptions import (
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.core.models import User from authentik.core.models import User
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
from authentik.enterprise.providers.microsoft_entra.models import ( from authentik.enterprise.providers.microsoft_entra.models import (
MicrosoftEntraProvider,
MicrosoftEntraProviderMapping, MicrosoftEntraProviderMapping,
MicrosoftEntraProviderUser, MicrosoftEntraProviderUser,
) )
from authentik.lib.sync.mapper import PropertyMappingManager from authentik.events.models import Event, EventAction
from authentik.lib.sync.outgoing.exceptions import ( from authentik.lib.sync.outgoing.exceptions import (
ObjectExistsSyncException, ObjectExistsSyncException,
StopSync, StopSync,
TransientSyncException, TransientSyncException,
) )
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.utils import delete_none_values from authentik.policies.utils import delete_none_values
@ -27,17 +31,34 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
connection_type_query = "user" connection_type_query = "user"
can_discover = True can_discover = True
def __init__(self, provider: MicrosoftEntraProvider) -> None: def to_schema(self, obj: User, creating: bool) -> MSUser:
super().__init__(provider)
self.mapper = PropertyMappingManager(
self.provider.property_mappings.all().order_by("name").select_subclasses(),
MicrosoftEntraProviderMapping,
["provider", "connection"],
)
def to_schema(self, obj: User, connection: MicrosoftEntraProviderUser) -> MSUser:
"""Convert authentik user""" """Convert authentik user"""
raw_microsoft_user = super().to_schema(obj, connection) raw_microsoft_user = {}
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
if not isinstance(mapping, MicrosoftEntraProviderMapping):
continue
try:
value = mapping.evaluate(
user=obj,
request=None,
provider=self.provider,
creating=creating,
)
if value is None:
continue
always_merger.merge(raw_microsoft_user, value)
except SkipObjectException as exc:
raise exc from exc
except (PropertyMappingExpressionException, ValueError) as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=mapping,
).save()
raise StopSync(exc, obj, mapping) from exc
if not raw_microsoft_user:
raise StopSync(ValueError("No user mappings configured"), obj)
try: try:
return MSUser(**delete_none_values(raw_microsoft_user)) return MSUser(**delete_none_values(raw_microsoft_user))
except TypeError as exc: except TypeError as exc:
@ -66,85 +87,48 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
microsoft_user.delete() microsoft_user.delete()
return response return response
def get_select_fields(self) -> list[str]:
"""All fields that should be selected when we fetch user data."""
# TODO: Make this customizable in the future
return [
# Default fields
"businessPhones",
"displayName",
"givenName",
"jobTitle",
"mail",
"mobilePhone",
"officeLocation",
"preferredLanguage",
"surname",
"userPrincipalName",
"id",
# Required for logging into M365 using authentik
"onPremisesImmutableId",
]
def create(self, user: User): def create(self, user: User):
"""Create user from scratch and create a connection object""" """Create user from scratch and create a connection object"""
microsoft_user = self.to_schema(user, None) microsoft_user = self.to_schema(user, True)
self.check_email_valid(microsoft_user.user_principal_name) self.check_email_valid(microsoft_user.user_principal_name)
with transaction.atomic(): with transaction.atomic():
try: try:
response = self._request(self.client.users.post(microsoft_user)) response = self._request(self.client.users.post(microsoft_user))
except ObjectExistsSyncException: except ObjectExistsSyncException:
# user already exists in microsoft entra, so we can connect them manually # user already exists in microsoft entra, so we can connect them manually
query_params = UsersRequestBuilder.UsersRequestBuilderGetQueryParameters()(
filter=f"mail eq '{microsoft_user.mail}'",
)
request_configuration = ( request_configuration = (
UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration( UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration(
query_parameters=UsersRequestBuilder.UsersRequestBuilderGetQueryParameters( query_parameters=query_params,
filter=f"mail eq '{microsoft_user.mail}'",
select=self.get_select_fields(),
),
) )
) )
user_data = self._request(self.client.users.get(request_configuration)) user_data = self._request(self.client.users.get(request_configuration))
if user_data.odata_count < 1 or len(user_data.value) < 1: if user_data.odata_count < 1:
self.logger.warning( self.logger.warning(
"User which could not be created also does not exist", user=user "User which could not be created also does not exist", user=user
) )
return return
ms_user = user_data.value[0]
return MicrosoftEntraProviderUser.objects.create( return MicrosoftEntraProviderUser.objects.create(
provider=self.provider, provider=self.provider, user=user, microsoft_id=user_data.value[0].id
user=user,
microsoft_id=ms_user.id,
attributes=self.entity_as_dict(ms_user),
) )
except TransientSyncException as exc: except TransientSyncException as exc:
raise exc raise exc
else: else:
return MicrosoftEntraProviderUser.objects.create( return MicrosoftEntraProviderUser.objects.create(
provider=self.provider, provider=self.provider, user=user, microsoft_id=response.id
user=user,
microsoft_id=response.id,
attributes=self.entity_as_dict(response),
) )
def update(self, user: User, connection: MicrosoftEntraProviderUser): def update(self, user: User, connection: MicrosoftEntraProviderUser):
"""Update existing user""" """Update existing user"""
microsoft_user = self.to_schema(user, connection) microsoft_user = self.to_schema(user, False)
self.check_email_valid(microsoft_user.user_principal_name) self.check_email_valid(microsoft_user.user_principal_name)
response = self._request( self._request(self.client.users.by_user_id(connection.microsoft_id).patch(microsoft_user))
self.client.users.by_user_id(connection.microsoft_id).patch(microsoft_user)
)
if response:
always_merger.merge(connection.attributes, self.entity_as_dict(response))
connection.save()
def discover(self): def discover(self):
"""Iterate through all users and connect them with authentik users if possible""" """Iterate through all users and connect them with authentik users if possible"""
request_configuration = UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration( users = self._request(self.client.users.get())
query_parameters=UsersRequestBuilder.UsersRequestBuilderGetQueryParameters(
select=self.get_select_fields(),
),
)
users = self._request(self.client.users.get(request_configuration))
next_link = True next_link = True
while next_link: while next_link:
for user in users.value: for user in users.value:
@ -163,16 +147,4 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
provider=self.provider, provider=self.provider,
user=matching_authentik_user, user=matching_authentik_user,
microsoft_id=user.id, microsoft_id=user.id,
attributes=self.entity_as_dict(user),
) )
def update_single_attribute(self, connection: MicrosoftEntraProviderUser):
request_configuration = UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration(
query_parameters=UsersRequestBuilder.UsersRequestBuilderGetQueryParameters(
select=self.get_select_fields(),
),
)
data = self._request(
self.client.users.by_user_id(connection.microsoft_id).get(request_configuration)
)
connection.attributes = self.entity_as_dict(data)

View File

@ -1,23 +0,0 @@
# Generated by Django 5.0.6 on 2024-05-23 20:48
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_microsoft_entra", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="microsoftentraprovidergroup",
name="attributes",
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name="microsoftentraprovideruser",
name="attributes",
field=models.JSONField(default=dict),
),
]

View File

@ -6,7 +6,6 @@ from uuid import uuid4
from azure.identity.aio import ClientSecretCredential from azure.identity.aio import ClientSecretCredential
from django.db import models from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -22,58 +21,6 @@ from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction, OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction, OutgoingSyncProvider
class MicrosoftEntraProviderUser(SerializerModel):
"""Mapping of a user and provider to a Microsoft user ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
microsoft_id = models.TextField()
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.ForeignKey("MicrosoftEntraProvider", on_delete=models.CASCADE)
attributes = models.JSONField(default=dict)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.microsoft_entra.api.users import (
MicrosoftEntraProviderUserSerializer,
)
return MicrosoftEntraProviderUserSerializer
class Meta:
verbose_name = _("Microsoft Entra Provider User")
verbose_name_plural = _("Microsoft Entra Provider User")
unique_together = (("microsoft_id", "user", "provider"),)
def __str__(self) -> str:
return f"Microsoft Entra Provider User {self.user_id} to {self.provider_id}"
class MicrosoftEntraProviderGroup(SerializerModel):
"""Mapping of a group and provider to a Microsoft group ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
microsoft_id = models.TextField()
group = models.ForeignKey(Group, on_delete=models.CASCADE)
provider = models.ForeignKey("MicrosoftEntraProvider", on_delete=models.CASCADE)
attributes = models.JSONField(default=dict)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.microsoft_entra.api.groups import (
MicrosoftEntraProviderGroupSerializer,
)
return MicrosoftEntraProviderGroupSerializer
class Meta:
verbose_name = _("Microsoft Entra Provider Group")
verbose_name_plural = _("Microsoft Entra Provider Groups")
unique_together = (("microsoft_id", "group", "provider"),)
def __str__(self) -> str:
return f"Microsoft Entra Provider Group {self.group_id} to {self.provider_id}"
class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider): class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
"""Sync users from authentik into Microsoft Entra.""" """Sync users from authentik into Microsoft Entra."""
@ -100,16 +47,15 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
) )
def client_for_model( def client_for_model(
self, self, model: type[User | Group]
model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup],
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
if issubclass(model, User | MicrosoftEntraProviderUser): if issubclass(model, User):
from authentik.enterprise.providers.microsoft_entra.clients.users import ( from authentik.enterprise.providers.microsoft_entra.clients.users import (
MicrosoftEntraUserClient, MicrosoftEntraUserClient,
) )
return MicrosoftEntraUserClient(self) return MicrosoftEntraUserClient(self)
if issubclass(model, Group | MicrosoftEntraProviderGroup): if issubclass(model, Group):
from authentik.enterprise.providers.microsoft_entra.clients.groups import ( from authentik.enterprise.providers.microsoft_entra.clients.groups import (
MicrosoftEntraGroupClient, MicrosoftEntraGroupClient,
) )
@ -141,10 +87,6 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
) )
} }
@property
def icon_url(self) -> str | None:
return static("authentik/sources/azuread.svg")
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-provider-microsoft-entra-form" return "ak-provider-microsoft-entra-form"
@ -186,3 +128,53 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
class Meta: class Meta:
verbose_name = _("Microsoft Entra Provider Mapping") verbose_name = _("Microsoft Entra Provider Mapping")
verbose_name_plural = _("Microsoft Entra Provider Mappings") verbose_name_plural = _("Microsoft Entra Provider Mappings")
class MicrosoftEntraProviderUser(SerializerModel):
"""Mapping of a user and provider to a Microsoft user ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
microsoft_id = models.TextField()
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.ForeignKey(MicrosoftEntraProvider, on_delete=models.CASCADE)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.microsoft_entra.api.users import (
MicrosoftEntraProviderUserSerializer,
)
return MicrosoftEntraProviderUserSerializer
class Meta:
verbose_name = _("Microsoft Entra Provider User")
verbose_name_plural = _("Microsoft Entra Provider User")
unique_together = (("microsoft_id", "user", "provider"),)
def __str__(self) -> str:
return f"Microsoft Entra Provider User {self.user_id} to {self.provider_id}"
class MicrosoftEntraProviderGroup(SerializerModel):
"""Mapping of a group and provider to a Microsoft group ID"""
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
microsoft_id = models.TextField()
group = models.ForeignKey(Group, on_delete=models.CASCADE)
provider = models.ForeignKey(MicrosoftEntraProvider, on_delete=models.CASCADE)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.microsoft_entra.api.groups import (
MicrosoftEntraProviderGroupSerializer,
)
return MicrosoftEntraProviderGroupSerializer
class Meta:
verbose_name = _("Microsoft Entra Provider Group")
verbose_name_plural = _("Microsoft Entra Provider Groups")
unique_together = (("microsoft_id", "group", "provider"),)
def __str__(self) -> str:
return f"Microsoft Entra Provider Group {self.group_id} to {self.provider_id}"

View File

@ -93,38 +93,6 @@ class MicrosoftEntraGroupTests(TestCase):
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists()) self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
group_create.assert_called_once() group_create.assert_called_once()
def test_group_not_created(self):
"""Test without group property mappings, no group is created"""
self.provider.property_mappings_group.clear()
uid = generate_id()
with (
patch(
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
MagicMock(return_value={"credentials": self.creds}),
),
patch(
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
AsyncMock(
return_value=OrganizationCollectionResponse(
value=[
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
]
)
),
),
patch(
"msgraph.generated.groups.groups_request_builder.GroupsRequestBuilder.post",
AsyncMock(return_value=MSGroup(id=generate_id())),
) as group_create,
):
group = Group.objects.create(name=uid)
microsoft_group = MicrosoftEntraProviderGroup.objects.filter(
provider=self.provider, group=group
).first()
self.assertIsNone(microsoft_group)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
group_create.assert_not_called()
def test_group_create_update(self): def test_group_create_update(self):
"""Test group updating""" """Test group updating"""
uid = generate_id() uid = generate_id()

View File

@ -3,18 +3,16 @@
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from azure.identity.aio import ClientSecretCredential from azure.identity.aio import ClientSecretCredential
from django.urls import reverse from django.test import TestCase
from msgraph.generated.models.group_collection_response import GroupCollectionResponse from msgraph.generated.models.group_collection_response import GroupCollectionResponse
from msgraph.generated.models.organization import Organization from msgraph.generated.models.organization import Organization
from msgraph.generated.models.organization_collection_response import OrganizationCollectionResponse from msgraph.generated.models.organization_collection_response import OrganizationCollectionResponse
from msgraph.generated.models.user import User as MSUser from msgraph.generated.models.user import User as MSUser
from msgraph.generated.models.user_collection_response import UserCollectionResponse from msgraph.generated.models.user_collection_response import UserCollectionResponse
from msgraph.generated.models.verified_domain import VerifiedDomain from msgraph.generated.models.verified_domain import VerifiedDomain
from rest_framework.test import APITestCase
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User from authentik.core.models import Application, Group, User
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.providers.microsoft_entra.models import ( from authentik.enterprise.providers.microsoft_entra.models import (
MicrosoftEntraProvider, MicrosoftEntraProvider,
MicrosoftEntraProviderMapping, MicrosoftEntraProviderMapping,
@ -27,12 +25,11 @@ from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
class MicrosoftEntraUserTests(APITestCase): class MicrosoftEntraUserTests(TestCase):
"""Microsoft Entra User tests""" """Microsoft Entra User tests"""
@apply_blueprint("system/providers-microsoft-entra.yaml") @apply_blueprint("system/providers-microsoft-entra.yaml")
def setUp(self) -> None: def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID # Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users # which will cause errors with multiple users
Tenant.objects.update(avatars="none") Tenant.objects.update(avatars="none")
@ -97,42 +94,6 @@ class MicrosoftEntraUserTests(APITestCase):
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists()) self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
user_create.assert_called_once() user_create.assert_called_once()
def test_user_not_created(self):
"""Test without property mappings, no group is created"""
self.provider.property_mappings.clear()
uid = generate_id()
with (
patch(
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
MagicMock(return_value={"credentials": self.creds}),
),
patch(
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
AsyncMock(
return_value=OrganizationCollectionResponse(
value=[
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
]
)
),
),
patch(
"msgraph.generated.users.users_request_builder.UsersRequestBuilder.post",
AsyncMock(return_value=MSUser(id=generate_id())),
) as user_create,
):
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
microsoft_user = MicrosoftEntraProviderUser.objects.filter(
provider=self.provider, user=user
).first()
self.assertIsNone(microsoft_user)
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
user_create.assert_not_called()
def test_user_create_update(self): def test_user_create_update(self):
"""Test user updating""" """Test user updating"""
uid = generate_id() uid = generate_id()
@ -374,45 +335,3 @@ class MicrosoftEntraUserTests(APITestCase):
) )
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists()) self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
user_list.assert_called_once() user_list.assert_called_once()
def test_connect_manual(self):
"""test manual user connection"""
uid = generate_id()
self.app.backchannel_providers.remove(self.provider)
admin = create_test_admin_user()
different_user = User.objects.create(
username=uid,
email=f"{uid}@goauthentik.io",
)
self.app.backchannel_providers.add(self.provider)
with (
patch(
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
MagicMock(return_value={"credentials": self.creds}),
),
patch(
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
AsyncMock(
return_value=OrganizationCollectionResponse(
value=[
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
]
)
),
),
patch(
"authentik.enterprise.providers.microsoft_entra.clients.users.MicrosoftEntraUserClient.update_single_attribute",
MagicMock(),
) as user_get,
):
self.client.force_login(admin)
response = self.client.post(
reverse("authentik_api:microsoftentraprovideruser-list"),
data={
"microsoft_id": generate_id(),
"user": different_user.pk,
"provider": self.provider.pk,
},
)
self.assertEqual(response.status_code, 201)
user_get.assert_called_once()

View File

@ -7,7 +7,7 @@ from drf_spectacular.utils import extend_schema_field
from rest_framework.fields import CharField from rest_framework.fields import CharField
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingSerializer from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField from authentik.core.api.utils import JSONDictField
from authentik.enterprise.providers.rac.models import RACPropertyMapping from authentik.enterprise.providers.rac.models import RACPropertyMapping

View File

@ -7,7 +7,6 @@ from deepmerge import always_merger
from django.db import models from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.http import HttpRequest from django.http import HttpRequest
from django.templatetags.static import static
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -64,10 +63,6 @@ class RACProvider(Provider):
Can return None for providers that are not URL-based""" Can return None for providers that are not URL-based"""
return "goauthentik.io://providers/rac/launch" return "goauthentik.io://providers/rac/launch"
@property
def icon_url(self) -> str | None:
return static("authentik/sources/rac.svg")
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-provider-rac-form" return "ak-provider-rac-form"

View File

@ -18,12 +18,9 @@ class SourceStageSerializer(EnterpriseRequiredMixin, StageSerializer):
source = Source.objects.filter(pk=_source.pk).select_subclasses().first() source = Source.objects.filter(pk=_source.pk).select_subclasses().first()
if not source: if not source:
raise ValidationError("Invalid source") raise ValidationError("Invalid source")
if "request" in self.context: login_button = source.ui_login_button(self.context["request"])
login_button = source.ui_login_button(self.context["request"]) if not login_button:
if not login_button: raise ValidationError("Invalid source selected, only web-based sources are supported.")
raise ValidationError(
"Invalid source selected, only web-based sources are supported."
)
return source return source
class Meta: class Meta:

View File

@ -54,7 +54,7 @@ class SourceStageView(ChallengeStageView):
def create_flow_token(self) -> FlowToken: def create_flow_token(self) -> FlowToken:
"""Save the current flow state in a token that can be used to resume this flow""" """Save the current flow state in a token that can be used to resume this flow"""
pending_user: User = self.get_pending_user() pending_user: User = self.get_pending_user()
if pending_user.is_anonymous or not pending_user.pk: if pending_user.is_anonymous:
pending_user = get_anonymous_user() pending_user = get_anonymous_user()
current_stage: SourceStage = self.executor.current_stage current_stage: SourceStage = self.executor.current_stage
identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}") identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}")

View File

@ -19,8 +19,7 @@ from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.admin.api.metrics import CoordinateSerializer from authentik.admin.api.metrics import CoordinateSerializer
from authentik.core.api.object_types import TypeCreateSerializer from authentik.core.api.utils import PassiveSerializer, TypeCreateSerializer
from authentik.core.api.utils import PassiveSerializer
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction

View File

@ -45,7 +45,7 @@ class GeoIPContextProcessor(MMDBContextProcessor):
def enrich_context(self, request: HttpRequest) -> dict: def enrich_context(self, request: HttpRequest) -> dict:
# Different key `geoip` vs `geo` for legacy reasons # Different key `geoip` vs `geo` for legacy reasons
return {"geoip": self.city_dict(ClientIPMiddleware.get_client_ip(request))} return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))}
def city(self, ip_address: str) -> City | None: def city(self, ip_address: str) -> City | None:
"""Wrapper for Reader.city""" """Wrapper for Reader.city"""

View File

@ -10,7 +10,7 @@ from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.base.schema import BaseDatabaseSchemaEditor
import authentik.events.models import authentik.events.models
import authentik.lib.models import authentik.lib.validators
from authentik.lib.migrations import progress_bar from authentik.lib.migrations import progress_bar
@ -377,7 +377,7 @@ class Migration(migrations.Migration):
model_name="notificationtransport", model_name="notificationtransport",
name="webhook_url", name="webhook_url",
field=models.TextField( field=models.TextField(
blank=True, validators=[authentik.lib.models.DomainlessURLValidator()] blank=True, validators=[authentik.lib.validators.DomainlessURLValidator()]
), ),
), ),
] ]

View File

@ -41,10 +41,11 @@ from authentik.events.utils import (
sanitize_dict, sanitize_dict,
sanitize_item, sanitize_item,
) )
from authentik.lib.models import DomainlessURLValidator, SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.lib.validators import DomainlessURLValidator
from authentik.policies.models import PolicyBindingModel from authentik.policies.models import PolicyBindingModel
from authentik.root.middleware import ClientIPMiddleware from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage

View File

@ -10,10 +10,10 @@ from rest_framework.serializers import ModelSerializer, SerializerMethodField
from rest_framework.viewsets import GenericViewSet from rest_framework.viewsets import GenericViewSet
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import MetaNameSerializer from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
from authentik.core.types import UserSettingSerializer from authentik.core.types import UserSettingSerializer
from authentik.enterprise.apps import EnterpriseConfig
from authentik.flows.api.flows import FlowSetSerializer from authentik.flows.api.flows import FlowSetSerializer
from authentik.flows.models import ConfigurableStage, Stage from authentik.flows.models import ConfigurableStage, Stage
from authentik.lib.utils.reflection import all_subclasses from authentik.lib.utils.reflection import all_subclasses
@ -47,7 +47,6 @@ class StageSerializer(ModelSerializer, MetaNameSerializer):
class StageViewSet( class StageViewSet(
TypesMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.DestroyModelMixin,
UsedByMixin, UsedByMixin,
@ -64,6 +63,25 @@ class StageViewSet(
def get_queryset(self): # pragma: no cover def get_queryset(self): # pragma: no cover
return Stage.objects.select_subclasses().prefetch_related("flow_set") return Stage.objects.select_subclasses().prefetch_related("flow_set")
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[])
def types(self, request: Request) -> Response:
"""Get all creatable stage types"""
data = []
for subclass in all_subclasses(self.queryset.model, False):
subclass: Stage
data.append(
{
"name": subclass._meta.verbose_name,
"description": subclass.__doc__,
"component": subclass().component,
"model_name": subclass._meta.model_name,
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
}
)
data = sorted(data, key=lambda x: x["name"])
return Response(TypeCreateSerializer(data, many=True).data)
@extend_schema(responses={200: UserSettingSerializer(many=True)}) @extend_schema(responses={200: UserSettingSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[]) @action(detail=False, pagination_class=None, filter_backends=[])
def user_settings(self, request: Request) -> Response: def user_settings(self, request: Request) -> Response:

View File

@ -2,7 +2,7 @@
from base64 import b64encode from base64 import b64encode
from functools import cache as funccache from functools import cache as funccache
from hashlib import md5, sha256 from hashlib import md5
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from urllib.parse import urlencode from urllib.parse import urlencode
@ -20,7 +20,7 @@ from authentik.tenants.utils import get_current_tenant
if TYPE_CHECKING: if TYPE_CHECKING:
from authentik.core.models import User from authentik.core.models import User
GRAVATAR_URL = "https://www.gravatar.com" GRAVATAR_URL = "https://secure.gravatar.com"
DEFAULT_AVATAR = static("dist/assets/images/user_default.png") DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/" CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available" CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available"
@ -55,9 +55,10 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
if not cache.get(CACHE_KEY_GRAVATAR_AVAILABLE, True): if not cache.get(CACHE_KEY_GRAVATAR_AVAILABLE, True):
return None return None
mail_hash = sha256(user.email.lower().encode("utf-8")).hexdigest() # nosec # gravatar uses md5 for their URLs, so md5 can't be avoided
parameters = {"size": "158", "rating": "g", "default": "404"} mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters)}" parameters = [("size", "158"), ("rating", "g"), ("default", "404")]
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
full_key = CACHE_KEY_GRAVATAR + mail_hash full_key = CACHE_KEY_GRAVATAR + mail_hash
if cache.has_key(full_key): if cache.has_key(full_key):
@ -83,9 +84,7 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
def generate_colors(text: str) -> tuple[str, str]: def generate_colors(text: str) -> tuple[str, str]:
"""Generate colours based on `text`""" """Generate colours based on `text`"""
color = ( color = int(md5(text.lower().encode("utf-8")).hexdigest(), 16) % 0xFFFFFF # nosec
int(md5(text.lower().encode("utf-8"), usedforsecurity=False).hexdigest(), 16) % 0xFFFFFF
) # nosec
# Get a (somewhat arbitrarily) reduced scope of colors # Get a (somewhat arbitrarily) reduced scope of colors
# to avoid too dark or light backgrounds # to avoid too dark or light backgrounds
@ -180,7 +179,7 @@ def avatar_mode_generated(user: "User", mode: str) -> str | None:
def avatar_mode_url(user: "User", mode: str) -> str | None: def avatar_mode_url(user: "User", mode: str) -> str | None:
"""Format url""" """Format url"""
mail_hash = md5(user.email.lower().encode("utf-8"), usedforsecurity=False).hexdigest() # nosec mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec
return mode % { return mode % {
"username": user.username, "username": user.username,
"mail_hash": mail_hash, "mail_hash": mail_hash,

View File

@ -304,12 +304,6 @@ class ConfigLoader:
"""Wrapper for get that converts value into boolean""" """Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true" return str(self.get(path, default)).lower() == "true"
def get_keys(self, path: str, sep=".") -> list[str]:
"""List attribute keys by using yaml path"""
root = self.raw
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr({}))
return attr.keys()
def get_dict_from_b64_json(self, path: str, default=None) -> dict: def get_dict_from_b64_json(self, path: str, default=None) -> dict:
"""Wrapper for get that converts value from Base64 encoded string into dictionary""" """Wrapper for get that converts value from Base64 encoded string into dictionary"""
config_value = self.get(path) config_value = self.get(path)

View File

@ -10,10 +10,6 @@ postgresql:
use_pgpool: false use_pgpool: false
test: test:
name: test_authentik name: test_authentik
read_replicas: {}
# For example
# 0:
# host: replica1.example.com
listen: listen:
listen_http: 0.0.0.0:9000 listen_http: 0.0.0.0:9000
@ -50,6 +46,7 @@ cache:
timeout: 300 timeout: 300
timeout_flows: 300 timeout_flows: 300
timeout_policies: 300 timeout_policies: 300
timeout_reputation: 300
# channel: # channel:
# url: "" # url: ""
@ -115,9 +112,6 @@ events:
context_processors: context_processors:
geoip: "/geoip/GeoLite2-City.mmdb" geoip: "/geoip/GeoLite2-City.mmdb"
asn: "/geoip/GeoLite2-ASN.mmdb" asn: "/geoip/GeoLite2-ASN.mmdb"
compliance:
fips:
enabled: false
cert_discovery_dir: /certs cert_discovery_dir: /certs

View File

@ -5,7 +5,6 @@ import socket
from collections.abc import Iterable from collections.abc import Iterable
from ipaddress import ip_address, ip_network from ipaddress import ip_address, ip_network
from textwrap import indent from textwrap import indent
from types import CodeType
from typing import Any from typing import Any
from cachetools import TLRUCache, cached from cachetools import TLRUCache, cached
@ -185,7 +184,7 @@ class BaseEvaluator:
full_expression += f"\nresult = handler({handler_signature})" full_expression += f"\nresult = handler({handler_signature})"
return full_expression return full_expression
def compile(self, expression: str) -> CodeType: def compile(self, expression: str) -> Any:
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect.""" """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
param_keys = self._context.keys() param_keys = self._context.keys()
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec") return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")

View File

@ -102,8 +102,6 @@ def get_logger_config():
"gunicorn": "INFO", "gunicorn": "INFO",
"requests_mock": "WARNING", "requests_mock": "WARNING",
"hpack": "WARNING", "hpack": "WARNING",
"httpx": "WARNING",
"azure": "WARNING",
} }
for handler_name, level in handler_level_map.items(): for handler_name, level in handler_level_map.items():
base_config["loggers"][handler_name] = { base_config["loggers"][handler_name] = {

View File

@ -1,13 +1,16 @@
"""Generic models""" """Generic models"""
import re from typing import Any
from django.core.validators import URLValidator
from django.db import models from django.db import models
from django.utils.regex_helper import _lazy_re_compile from django.dispatch import Signal
from django.utils import timezone
from model_utils.managers import InheritanceManager from model_utils.managers import InheritanceManager
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
pre_soft_delete = Signal()
post_soft_delete = Signal()
class SerializerModel(models.Model): class SerializerModel(models.Model):
"""Base Abstract Model which has a serializer""" """Base Abstract Model which has a serializer"""
@ -51,46 +54,57 @@ class InheritanceForeignKey(models.ForeignKey):
forward_related_accessor_class = InheritanceForwardManyToOneDescriptor forward_related_accessor_class = InheritanceForwardManyToOneDescriptor
class DomainlessURLValidator(URLValidator): class SoftDeleteQuerySet(models.QuerySet):
"""Subclass of URLValidator which doesn't check the domain
(to allow hostnames without domain)"""
def __init__(self, *args, **kwargs) -> None: def delete(self):
super().__init__(*args, **kwargs) for obj in self.all():
self.host_re = "(" + self.hostname_re + self.domain_re + "|localhost)" obj.delete()
self.regex = _lazy_re_compile(
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately def hard_delete(self):
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication return super().delete()
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
r"(?::\d{2,5})?" # port
r"(?:[/?#][^\s]*)?" # resource path class SoftDeleteManager(models.Manager):
r"\Z",
re.IGNORECASE, def get_queryset(self):
return SoftDeleteQuerySet(self.model, using=self._db).filter(deleted_at__isnull=True)
class DeletedSoftDeleteManager(models.Manager):
def get_queryset(self):
return super().get_queryset().exclude(deleted_at__isnull=True)
class SoftDeleteModel(models.Model):
"""Model which doesn't fully delete itself, but rather saved the delete status
so cleanup events can run."""
deleted_at = models.DateTimeField(blank=True, null=True)
objects = SoftDeleteManager()
deleted = DeletedSoftDeleteManager()
class Meta:
abstract = True
@property
def is_deleted(self):
return self.deleted_at is not None
def delete(self, using: Any = ..., keep_parents: bool = ...) -> tuple[int, dict[str, int]]:
pre_soft_delete.send(sender=self.__class__, instance=self)
now = timezone.now()
self.deleted_at = now
self.save(
update_fields=[
"deleted_at",
]
) )
self.schemes = ["http", "https", "blank"] + list(self.schemes) post_soft_delete.send(sender=self.__class__, instance=self)
return tuple()
def __call__(self, value: str): def force_delete(self, using: Any = ...):
# Check if the scheme is valid. if not self.deleted_at:
scheme = value.split("://")[0].lower() raise models.ProtectedError("Refusing to force delete non-deleted model", {self})
if scheme not in self.schemes: return super().delete(using=using)
value = "default" + value
super().__call__(value)
class DomainlessFormattedURLValidator(DomainlessURLValidator):
"""URL validator which allows for python format strings"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.formatter_re = r"([%\(\)a-zA-Z])*"
self.host_re = "(" + self.formatter_re + self.hostname_re + self.domain_re + "|localhost)"
self.regex = _lazy_re_compile(
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
r"(?::\d{2,5})?" # port
r"(?:[/?#][^\s]*)?" # resource path
r"\Z",
re.IGNORECASE,
)
self.schemes = ["http", "https", "blank"] + list(self.schemes)

View File

@ -1,69 +0,0 @@
from collections.abc import Generator
from django.db.models import QuerySet
from django.http import HttpRequest
from authentik.core.expression.evaluator import PropertyMappingEvaluator
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import PropertyMapping, User
class PropertyMappingManager:
"""Pre-compile and cache property mappings when an identical
set is used multiple times"""
query_set: QuerySet[PropertyMapping]
mapping_subclass: type[PropertyMapping]
_evaluators: list[PropertyMappingEvaluator]
def __init__(
self,
qs: QuerySet[PropertyMapping],
# Expected subclass of PropertyMappings, any objects in the queryset
# that are not an instance of this class will be discarded
mapping_subclass: type[PropertyMapping],
# As they keys of parameters are part of the compilation,
# we need a list of all parameter names that will be used during evaluation
context_keys: list[str],
) -> None:
self.query_set = qs
self.mapping_subclass = mapping_subclass
self.context_keys = context_keys
self.compile()
def compile(self):
self._evaluators = []
for mapping in self.query_set:
if not isinstance(mapping, self.mapping_subclass):
continue
evaluator = PropertyMappingEvaluator(
mapping, **{key: None for key in self.context_keys}
)
# Compile and cache expression
evaluator.compile()
self._evaluators.append(evaluator)
def iter_eval(
self,
user: User | None,
request: HttpRequest | None,
return_mapping: bool = False,
**kwargs,
) -> Generator[tuple[dict, PropertyMapping], None]:
"""Iterate over all mappings that were pre-compiled and
execute all of them with the given context"""
for mapping in self._evaluators:
mapping.set_context(user, request, **kwargs)
try:
value = mapping.evaluate(mapping.model.expression)
except PropertyMappingExpressionException as exc:
raise exc from exc
except Exception as exc:
raise PropertyMappingExpressionException(exc, mapping.model) from exc
if value is None:
continue
if return_mapping:
yield value, mapping.model
else:
yield value

View File

@ -3,6 +3,3 @@
PAGE_SIZE = 100 PAGE_SIZE = 100
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
HTTP_CONFLICT = 409 HTTP_CONFLICT = 409
HTTP_NO_CONTENT = 204
HTTP_SERVICE_UNAVAILABLE = 503
HTTP_TOO_MANY_REQUESTS = 429

View File

@ -7,7 +7,6 @@ from rest_framework.decorators import action
from rest_framework.fields import BooleanField from rest_framework.fields import BooleanField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.events.api.tasks import SystemTaskSerializer from authentik.events.api.tasks import SystemTaskSerializer
@ -48,24 +47,8 @@ class OutgoingSyncProviderStatusMixin:
uid=slugify(provider.name), uid=slugify(provider.name),
) )
) )
with provider.sync_lock as lock_acquired: status = {
status = { "tasks": tasks,
"tasks": tasks, "is_running": provider.sync_lock.locked(),
# If we could not acquire the lock, it means a task is using it, and thus is running }
"is_running": not lock_acquired,
}
return Response(SyncStatusSerializer(status).data) return Response(SyncStatusSerializer(status).data)
class OutgoingSyncConnectionCreateMixin:
"""Mixin for connection objects that fetches remote data upon creation"""
def perform_create(self, serializer: ModelSerializer):
super().perform_create(serializer)
try:
instance = serializer.instance
client = instance.provider.client_for_model(instance.__class__)
client.update_single_attribute(instance)
instance.save()
except NotImplementedError:
pass

View File

@ -3,18 +3,10 @@
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from deepmerge import always_merger
from django.db import DatabaseError from django.db import DatabaseError
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import ( from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException
PropertyMappingExpressionException,
SkipObjectException,
)
from authentik.events.models import Event, EventAction
from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
from authentik.lib.utils.errors import exception_to_string
if TYPE_CHECKING: if TYPE_CHECKING:
from django.db.models import Model from django.db.models import Model
@ -36,7 +28,6 @@ class BaseOutgoingSyncClient[
provider: TProvider provider: TProvider
connection_type: type[TConnection] connection_type: type[TConnection]
connection_type_query: str connection_type_query: str
mapper: PropertyMappingManager
can_discover = False can_discover = False
@ -79,34 +70,9 @@ class BaseOutgoingSyncClient[
"""Delete object from destination""" """Delete object from destination"""
raise NotImplementedError() raise NotImplementedError()
def to_schema(self, obj: TModel, connection: TConnection | None, **defaults) -> TSchema: def to_schema(self, obj: TModel, creating: bool) -> TSchema:
"""Convert object to destination schema""" """Convert object to destination schema"""
raw_final_object = {} raise NotImplementedError()
try:
eval_kwargs = {
"request": None,
"provider": self.provider,
"connection": connection,
obj._meta.model_name: obj,
}
eval_kwargs.setdefault("user", None)
for value in self.mapper.iter_eval(**eval_kwargs):
always_merger.merge(raw_final_object, value)
except SkipObjectException as exc:
raise exc from exc
except PropertyMappingExpressionException as exc:
# Value error can be raised when assigning invalid data to an attribute
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
mapping=exc.mapping,
).save()
raise StopSync(exc, obj, exc.mapping) from exc
if not raw_final_object:
raise StopSync(ValueError("No mappings configured"), obj)
for key, value in defaults.items():
raw_final_object.setdefault(key, value)
return raw_final_object
def discover(self): def discover(self):
"""Optional method. Can be used to implement a "discovery" where """Optional method. Can be used to implement a "discovery" where
@ -114,8 +80,3 @@ class BaseOutgoingSyncClient[
pre-link any users/groups in the remote system with the respective pre-link any users/groups in the remote system with the respective
object in authentik based on a common identifier""" object in authentik based on a common identifier"""
raise NotImplementedError() raise NotImplementedError()
def update_single_attribute(self, connection: TConnection):
"""Update connection attributes on a connection object, when the connection
is manually created"""
raise NotImplementedError

View File

@ -1,10 +1,11 @@
from typing import Any, Self from typing import Any, Self
import pglock from django.core.cache import cache
from django.db import connection
from django.db.models import Model, QuerySet, TextChoices from django.db.models import Model, QuerySet, TextChoices
from redis.lock import Lock
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.lib.sync.outgoing import PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
@ -31,10 +32,10 @@ class OutgoingSyncProvider(Model):
raise NotImplementedError raise NotImplementedError
@property @property
def sync_lock(self) -> pglock.advisory: def sync_lock(self) -> Lock:
"""Postgres lock for syncing SCIM to prevent multiple parallel syncs happening""" """Redis lock to prevent multiple parallel syncs happening"""
return pglock.advisory( return Lock(
lock_id=f"goauthentik.io/{connection.schema_name}/providers/outgoing-sync/{str(self.pk)}", cache.client.get_client(),
timeout=0, name=f"goauthentik.io/providers/outgoing-sync/{str(self.pk)}",
side_effect=pglock.Return, timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
) )

View File

@ -1,5 +1,4 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict
from celery.exceptions import Retry from celery.exceptions import Retry
from celery.result import allow_join_result from celery.result import allow_join_result
@ -14,7 +13,6 @@ from authentik.core.models import Group, User
from authentik.events.logs import LogEvent from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask from authentik.events.system_tasks import SystemTask
from authentik.events.utils import sanitize_item
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import ( from authentik.lib.sync.outgoing.exceptions import (
@ -66,16 +64,17 @@ class SyncTasks:
).first() ).first()
if not provider: if not provider:
return return
lock = provider.sync_lock
if lock.locked():
self.logger.debug("Sync locked, skipping task", source=provider.name)
return
task.set_uid(slugify(provider.name)) task.set_uid(slugify(provider.name))
messages = [] messages = []
messages.append(_("Starting full provider sync")) messages.append(_("Starting full provider sync"))
self.logger.debug("Starting provider sync") self.logger.debug("Starting provider sync")
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
with allow_join_result(), provider.sync_lock as lock_acquired: with allow_join_result(), lock:
if not lock_acquired:
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
return
try: try:
for page in users_paginator.page_range: for page in users_paginator.page_range:
messages.append(_("Syncing page %(page)d of users" % {"page": page})) messages.append(_("Syncing page %(page)d of users" % {"page": page}))
@ -84,7 +83,7 @@ class SyncTasks:
time_limit=PAGE_TIMEOUT, time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT, soft_time_limit=PAGE_TIMEOUT,
).get(): ).get():
messages.append(LogEvent(**msg)) messages.append(msg)
for page in groups_paginator.page_range: for page in groups_paginator.page_range:
messages.append(_("Syncing page %(page)d of groups" % {"page": page})) messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
for msg in sync_objects.apply_async( for msg in sync_objects.apply_async(
@ -92,7 +91,7 @@ class SyncTasks:
time_limit=PAGE_TIMEOUT, time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT, soft_time_limit=PAGE_TIMEOUT,
).get(): ).get():
messages.append(LogEvent(**msg)) messages.append(msg)
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc) self.logger.warning("transient sync exception", exc=exc)
raise task.retry(exc=exc) from exc raise task.retry(exc=exc) from exc
@ -126,70 +125,61 @@ class SyncTasks:
try: try:
client.write(obj) client.write(obj)
except SkipObjectException: except SkipObjectException:
self.logger.debug("skipping object due to SkipObject", obj=obj)
continue continue
except BadRequestSyncException as exc: except BadRequestSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, obj=obj) self.logger.warning("failed to sync object", exc=exc, obj=obj)
messages.append( messages.append(
asdict( LogEvent(
LogEvent( _(
_( (
( "Failed to sync {object_type} {object_name} "
"Failed to sync {object_type} {object_name} " "due to error: {error}"
"due to error: {error}" ).format_map(
).format_map( {
{ "object_type": obj._meta.verbose_name,
"object_type": obj._meta.verbose_name, "object_name": str(obj),
"object_name": str(obj), "error": str(exc),
"error": str(exc), }
} )
) ),
), log_level="warning",
log_level="warning", logger="",
logger=f"{provider._meta.verbose_name}@{object_type}", attributes={"arguments": exc.args[1:]},
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
)
) )
) )
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj) self.logger.warning("failed to sync object", exc=exc, user=obj)
messages.append( messages.append(
asdict( LogEvent(
LogEvent( _(
_( (
( "Failed to sync {object_type} {object_name} "
"Failed to sync {object_type} {object_name} " "due to transient error: {error}"
"due to transient error: {error}" ).format_map(
).format_map( {
{ "object_type": obj._meta.verbose_name,
"object_type": obj._meta.verbose_name, "object_name": str(obj),
"object_name": str(obj), "error": str(exc),
"error": str(exc), }
} )
) ),
), log_level="warning",
log_level="warning", logger="",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
) )
) )
except StopSync as exc: except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc) self.logger.warning("Stopping sync", exc=exc)
messages.append( messages.append(
asdict( LogEvent(
LogEvent( _(
_( "Stopping sync due to error: {error}".format_map(
"Stopping sync due to error: {error}".format_map( {
{ "error": exc.detail(),
"error": exc.detail(), }
} )
) ),
), log_level="warning",
log_level="warning", logger="",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
) )
) )
break break

View File

@ -169,9 +169,3 @@ class TestConfig(TestCase):
self.assertEqual(config.get("cache.timeout_flows"), "32m") self.assertEqual(config.get("cache.timeout_flows"), "32m")
self.assertEqual(config.get("cache.timeout_policies"), "3920ns") self.assertEqual(config.get("cache.timeout_policies"), "3920ns")
self.assertEqual(config.get("cache.timeout_reputation"), "298382us") self.assertEqual(config.get("cache.timeout_reputation"), "298382us")
def test_get_keys(self):
"""Test get_keys"""
config = ConfigLoader()
config.set("foo.bar", "baz")
self.assertEqual(list(config.get_keys("foo")), ["bar"])

View File

@ -12,7 +12,7 @@ from authentik.lib.config import CONFIG
SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST"
def all_subclasses[T](cls: T, sort=True) -> list[T] | set[T]: def all_subclasses(cls, sort=True):
"""Recursively return all subclassess of cls""" """Recursively return all subclassess of cls"""
classes = set(cls.__subclasses__()).union( classes = set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c, sort=sort)] [s for c in cls.__subclasses__() for s in all_subclasses(c, sort=sort)]

View File

@ -1,5 +1,9 @@
"""Serializer validators""" """Serializer validators"""
import re
from django.core.validators import URLValidator
from django.utils.regex_helper import _lazy_re_compile
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -29,3 +33,48 @@ class RequiredTogetherValidator:
def __repr__(self): def __repr__(self):
return f"<{self.__class__.__name__}(fields={smart_repr(self.fields)})>" return f"<{self.__class__.__name__}(fields={smart_repr(self.fields)})>"
class DomainlessURLValidator(URLValidator):
"""Subclass of URLValidator which doesn't check the domain
(to allow hostnames without domain)"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.host_re = "(" + self.hostname_re + self.domain_re + "|localhost)"
self.regex = _lazy_re_compile(
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
r"(?::\d{2,5})?" # port
r"(?:[/?#][^\s]*)?" # resource path
r"\Z",
re.IGNORECASE,
)
self.schemes = ["http", "https", "blank"] + list(self.schemes)
def __call__(self, value: str):
# Check if the scheme is valid.
scheme = value.split("://")[0].lower()
if scheme not in self.schemes:
value = "default" + value
super().__call__(value)
class DomainlessFormattedURLValidator(DomainlessURLValidator):
"""URL validator which allows for python format strings"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.formatter_re = r"([%\(\)a-zA-Z])*"
self.host_re = "(" + self.formatter_re + self.hostname_re + self.domain_re + "|localhost)"
self.regex = _lazy_re_compile(
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
r"(?::\d{2,5})?" # port
r"(?:[/?#][^\s]*)?" # resource path
r"\Z",
re.IGNORECASE,
)
self.schemes = ["http", "https", "blank"] + list(self.schemes)

View File

@ -6,7 +6,7 @@ from django_filters.filters import ModelMultipleChoiceFilter
from django_filters.filterset import FilterSet from django_filters.filterset import FilterSet
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, DateTimeField, SerializerMethodField from rest_framework.fields import BooleanField, CharField, DateTimeField
from rest_framework.relations import PrimaryKeyRelatedField from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
@ -18,7 +18,6 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField, PassiveSerializer from authentik.core.api.utils import JSONDictField, PassiveSerializer
from authentik.core.models import Provider from authentik.core.models import Provider
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.providers.rac.models import RACProvider from authentik.enterprise.providers.rac.models import RACProvider
from authentik.outposts.api.service_connections import ServiceConnectionSerializer from authentik.outposts.api.service_connections import ServiceConnectionSerializer
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
@ -118,12 +117,8 @@ class OutpostHealthSerializer(PassiveSerializer):
uid = CharField(read_only=True) uid = CharField(read_only=True)
last_seen = DateTimeField(read_only=True) last_seen = DateTimeField(read_only=True)
version = CharField(read_only=True) version = CharField(read_only=True)
golang_version = CharField(read_only=True)
openssl_enabled = BooleanField(read_only=True)
openssl_version = CharField(read_only=True)
fips_enabled = SerializerMethodField()
version_should = CharField(read_only=True) version_should = CharField(read_only=True)
version_outdated = BooleanField(read_only=True) version_outdated = BooleanField(read_only=True)
build_hash = CharField(read_only=True, required=False) build_hash = CharField(read_only=True, required=False)
@ -131,12 +126,6 @@ class OutpostHealthSerializer(PassiveSerializer):
hostname = CharField(read_only=True, required=False) hostname = CharField(read_only=True, required=False)
def get_fips_enabled(self, obj: dict) -> bool | None:
"""Get FIPS enabled"""
if not LicenseKey.get_total().is_valid():
return None
return obj["fips_enabled"]
class OutpostFilter(FilterSet): class OutpostFilter(FilterSet):
"""Filter for Outposts""" """Filter for Outposts"""
@ -184,10 +173,6 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
"version_should": state.version_should, "version_should": state.version_should,
"version_outdated": state.version_outdated, "version_outdated": state.version_outdated,
"build_hash": state.build_hash, "build_hash": state.build_hash,
"golang_version": state.golang_version,
"openssl_enabled": state.openssl_enabled,
"openssl_version": state.openssl_version,
"fips_enabled": state.fips_enabled,
"hostname": state.hostname, "hostname": state.hostname,
"build_hash_should": get_build_hash(), "build_hash_should": get_build_hash(),
} }

View File

@ -15,12 +15,9 @@ from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ( from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
MetaNameSerializer, from authentik.lib.utils.reflection import all_subclasses
PassiveSerializer,
)
from authentik.outposts.models import ( from authentik.outposts.models import (
DockerServiceConnection, DockerServiceConnection,
KubernetesServiceConnection, KubernetesServiceConnection,
@ -60,7 +57,6 @@ class ServiceConnectionStateSerializer(PassiveSerializer):
class ServiceConnectionViewSet( class ServiceConnectionViewSet(
TypesMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.DestroyModelMixin,
UsedByMixin, UsedByMixin,
@ -74,6 +70,23 @@ class ServiceConnectionViewSet(
search_fields = ["name"] search_fields = ["name"]
filterset_fields = ["name"] filterset_fields = ["name"]
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[])
def types(self, request: Request) -> Response:
"""Get all creatable service connection types"""
data = []
for subclass in all_subclasses(self.queryset.model):
subclass: OutpostServiceConnection
data.append(
{
"name": subclass._meta.verbose_name,
"description": subclass.__doc__,
"component": subclass().component,
"model_name": subclass._meta.model_name,
}
)
return Response(TypeCreateSerializer(data, many=True).data)
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)}) @extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
@action(detail=True, pagination_class=None, filter_backends=[]) @action(detail=True, pagination_class=None, filter_backends=[])
def state(self, request: Request, pk: str) -> Response: def state(self, request: Request, pk: str) -> Response:

View File

@ -121,10 +121,6 @@ class OutpostConsumer(JsonWebsocketConsumer):
if msg.instruction == WebsocketMessageInstruction.HELLO: if msg.instruction == WebsocketMessageInstruction.HELLO:
state.version = msg.args.pop("version", None) state.version = msg.args.pop("version", None)
state.build_hash = msg.args.pop("buildHash", "") state.build_hash = msg.args.pop("buildHash", "")
state.golang_version = msg.args.pop("golangVersion", "")
state.openssl_enabled = msg.args.pop("opensslEnabled", False)
state.openssl_version = msg.args.pop("opensslVersion", "")
state.fips_enabled = msg.args.pop("fipsEnabled", False)
state.args.update(msg.args) state.args.update(msg.args)
elif msg.instruction == WebsocketMessageInstruction.ACK: elif msg.instruction == WebsocketMessageInstruction.ACK:
return return

View File

@ -124,6 +124,7 @@ class KubernetesObjectReconciler(Generic[T]):
self.update(current, reference) self.update(current, reference)
self.logger.debug("Updating") self.logger.debug("Updating")
except (OpenApiException, HTTPError) as exc: except (OpenApiException, HTTPError) as exc:
if isinstance(exc, ApiException) and exc.status == 422: # noqa: PLR2004 if isinstance(exc, ApiException) and exc.status == 422: # noqa: PLR2004
self.logger.debug("Failed to update current, triggering re-create") self.logger.debug("Failed to update current, triggering re-create")
self._recreate(current=current, reference=reference) self._recreate(current=current, reference=reference)

View File

@ -0,0 +1,18 @@
# Generated by Django 5.0.4 on 2024-04-23 21:00
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_outposts", "0021_alter_outpost_type"),
]
operations = [
migrations.AddField(
model_name="outpost",
name="deleted_at",
field=models.DateTimeField(blank=True, null=True),
),
]

View File

@ -33,7 +33,7 @@ from authentik.core.models import (
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.models import InheritanceForeignKey, SerializerModel from authentik.lib.models import InheritanceForeignKey, SerializerModel, SoftDeleteModel
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.outposts.controllers.k8s.utils import get_namespace from authentik.outposts.controllers.k8s.utils import get_namespace
@ -131,7 +131,7 @@ class OutpostServiceConnection(models.Model):
verbose_name = _("Outpost Service-Connection") verbose_name = _("Outpost Service-Connection")
verbose_name_plural = _("Outpost Service-Connections") verbose_name_plural = _("Outpost Service-Connections")
def __str__(self) -> str: def __str__(self):
return f"Outpost service connection {self.name}" return f"Outpost service connection {self.name}"
@property @property
@ -241,7 +241,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection):
return "ak-service-connection-kubernetes-form" return "ak-service-connection-kubernetes-form"
class Outpost(SerializerModel, ManagedModel): class Outpost(SoftDeleteModel, SerializerModel, ManagedModel):
"""Outpost instance which manages a service user and token""" """Outpost instance which manages a service user and token"""
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
@ -434,10 +434,6 @@ class OutpostState:
version: str | None = field(default=None) version: str | None = field(default=None)
version_should: Version = field(default=OUR_VERSION) version_should: Version = field(default=OUR_VERSION)
build_hash: str = field(default="") build_hash: str = field(default="")
golang_version: str = field(default="")
openssl_enabled: bool = field(default=False)
openssl_version: str = field(default="")
fips_enabled: bool = field(default=False)
hostname: str = field(default="") hostname: str = field(default="")
args: dict = field(default_factory=dict) args: dict = field(default_factory=dict)

View File

@ -2,13 +2,14 @@
from django.core.cache import cache from django.core.cache import cache
from django.db.models import Model from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save from django.db.models.signals import m2m_changed, post_save, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.core.models import Provider from authentik.core.models import Provider
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.lib.models import post_soft_delete
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection from authentik.outposts.models import Outpost, OutpostServiceConnection
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
@ -67,9 +68,7 @@ def post_save_update(sender, instance: Model, created: bool, **_):
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
@receiver(pre_delete, sender=Outpost) @receiver(post_soft_delete, sender=Outpost)
def pre_delete_cleanup(sender, instance: Outpost, **_): def outpost_cleanup(sender, instance: Outpost, **_):
"""Ensure that Outpost's user is deleted (which will delete the token through cascade)""" """Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
instance.user.delete() outpost_controller.delay(instance.pk.hex, action="down")
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)

View File

@ -129,17 +129,14 @@ def outpost_controller_all():
@CELERY_APP.task(bind=True, base=SystemTask) @CELERY_APP.task(bind=True, base=SystemTask)
def outpost_controller( def outpost_controller(self: SystemTask, outpost_pk: str, action: str = "up"):
self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
):
"""Create/update/monitor/delete the deployment of an Outpost""" """Create/update/monitor/delete the deployment of an Outpost"""
logs = [] logs = []
if from_cache: outpost: Outpost = None
outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) if action == "up":
LOGGER.debug("Getting outpost from cache to delete") outpost = Outpost.objects.filter(pk=outpost_pk).first()
else: elif action == "down":
outpost: Outpost = Outpost.objects.filter(pk=outpost_pk).first() outpost = Outpost.deleted.filter(pk=outpost_pk).first()
LOGGER.debug("Getting outpost from DB")
if not outpost: if not outpost:
LOGGER.warning("No outpost") LOGGER.warning("No outpost")
return return
@ -155,9 +152,10 @@ def outpost_controller(
except (ControllerException, ServiceConnectionInvalid) as exc: except (ControllerException, ServiceConnectionInvalid) as exc:
self.set_error(exc) self.set_error(exc)
else: else:
if from_cache:
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
self.set_status(TaskStatus.SUCCESSFUL, *logs) self.set_status(TaskStatus.SUCCESSFUL, *logs)
finally:
if outpost.deleted_at:
outpost.force_delete()
@CELERY_APP.task(bind=True, base=SystemTask) @CELERY_APP.task(bind=True, base=SystemTask)

View File

@ -13,13 +13,10 @@ from rest_framework.viewsets import GenericViewSet
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.api.applications import user_app_cache_key from authentik.core.api.applications import user_app_cache_key
from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ( from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer
CacheSerializer,
MetaNameSerializer,
)
from authentik.events.logs import LogEventSerializer, capture_logs from authentik.events.logs import LogEventSerializer, capture_logs
from authentik.lib.utils.reflection import all_subclasses
from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer
from authentik.policies.models import Policy, PolicyBinding from authentik.policies.models import Policy, PolicyBinding
from authentik.policies.process import PolicyProcess from authentik.policies.process import PolicyProcess
@ -72,7 +69,6 @@ class PolicySerializer(ModelSerializer, MetaNameSerializer):
class PolicyViewSet( class PolicyViewSet(
TypesMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.DestroyModelMixin,
UsedByMixin, UsedByMixin,
@ -93,6 +89,23 @@ class PolicyViewSet(
def get_queryset(self): # pragma: no cover def get_queryset(self): # pragma: no cover
return Policy.objects.select_subclasses().prefetch_related("bindings", "promptstage_set") return Policy.objects.select_subclasses().prefetch_related("bindings", "promptstage_set")
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[])
def types(self, request: Request) -> Response:
"""Get all creatable policy types"""
data = []
for subclass in all_subclasses(self.queryset.model):
subclass: Policy
data.append(
{
"name": subclass._meta.verbose_name,
"description": subclass.__doc__,
"component": subclass().component,
"model_name": subclass._meta.model_name,
}
)
return Response(TypeCreateSerializer(data, many=True).data)
@permission_required(None, ["authentik_policies.view_policy_cache"]) @permission_required(None, ["authentik_policies.view_policy_cache"])
@extend_schema(responses={200: CacheSerializer(many=False)}) @extend_schema(responses={200: CacheSerializer(many=False)})
@action(detail=False, pagination_class=None, filter_backends=[]) @action(detail=False, pagination_class=None, filter_backends=[])

View File

@ -102,7 +102,7 @@ class EventMatcherPolicy(Policy):
result = checker(request, event) result = checker(request, event)
if result is None: if result is None:
continue continue
LOGGER.debug( LOGGER.info(
"Event matcher check result", "Event matcher check result",
checker=checker.__name__, checker=checker.__name__,
result=result, result=result,

View File

@ -96,42 +96,16 @@ class TestEvaluator(TestCase):
execution_logging=True, execution_logging=True,
expression="ak_message(request.http_request.path)\nreturn True", expression="ak_message(request.http_request.path)\nreturn True",
) )
expr2 = ExpressionPolicy.objects.create( tmpl = f"""
name=generate_id(), ak_message(request.http_request.path)
execution_logging=True, res = ak_call_policy('{expr.name}')
expression=f""" ak_message(request.http_request.path)
ak_message(request.http_request.path) for msg in res.messages:
res = ak_call_policy('{expr.name}') ak_message(msg)
ak_message(request.http_request.path) """
for msg in res.messages: evaluator = PolicyEvaluator("test")
ak_message(msg) evaluator.set_policy_request(self.request)
""", res = evaluator.evaluate(tmpl)
)
proc = PolicyProcess(PolicyBinding(policy=expr2), request=self.request, connection=None)
res = proc.profiling_wrapper()
self.assertEqual(res.messages, ("/", "/", "/"))
def test_call_policy_test_like(self):
"""test ak_call_policy without `obj` set, as if it was when testing policies"""
expr = ExpressionPolicy.objects.create(
name=generate_id(),
execution_logging=True,
expression="ak_message(request.http_request.path)\nreturn True",
)
expr2 = ExpressionPolicy.objects.create(
name=generate_id(),
execution_logging=True,
expression=f"""
ak_message(request.http_request.path)
res = ak_call_policy('{expr.name}')
ak_message(request.http_request.path)
for msg in res.messages:
ak_message(msg)
""",
)
self.request.obj = None
proc = PolicyProcess(PolicyBinding(policy=expr2), request=self.request, connection=None)
res = proc.profiling_wrapper()
self.assertEqual(res.messages, ("/", "/", "/")) self.assertEqual(res.messages, ("/", "/", "/"))

View File

@ -128,8 +128,8 @@ class PolicyProcess(PROCESS_CLASS):
binding_order=self.binding.order, binding_order=self.binding.order,
binding_target_type=self.binding.target_type, binding_target_type=self.binding.target_type,
binding_target_name=self.binding.target_name, binding_target_name=self.binding.target_name,
object_pk=str(self.request.obj.pk) if self.request.obj else "", object_pk=str(self.request.obj.pk),
object_type=class_to_path(self.request.obj.__class__) if self.request.obj else "", object_type=class_to_path(self.request.obj.__class__),
mode="execute_process", mode="execute_process",
).time(), ).time(),
): ):

View File

@ -2,6 +2,8 @@
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
class AuthentikPolicyReputationConfig(ManagedAppConfig): class AuthentikPolicyReputationConfig(ManagedAppConfig):
"""Authentik reputation app config""" """Authentik reputation app config"""

View File

@ -1,25 +0,0 @@
# Generated by Django 5.0.6 on 2024-06-11 08:50
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_reputation", "0006_reputation_ip_asn_data"),
]
operations = [
migrations.AddIndex(
model_name="reputation",
index=models.Index(fields=["identifier"], name="authentik_p_identif_9434d7_idx"),
),
migrations.AddIndex(
model_name="reputation",
index=models.Index(fields=["ip"], name="authentik_p_ip_7ad0df_idx"),
),
migrations.AddIndex(
model_name="reputation",
index=models.Index(fields=["ip", "identifier"], name="authentik_p_ip_d779aa_idx"),
),
]

View File

@ -96,8 +96,3 @@ class Reputation(ExpiringModel, SerializerModel):
verbose_name = _("Reputation Score") verbose_name = _("Reputation Score")
verbose_name_plural = _("Reputation Scores") verbose_name_plural = _("Reputation Scores")
unique_together = ("identifier", "ip") unique_together = ("identifier", "ip")
indexes = [
models.Index(fields=["identifier"]),
models.Index(fields=["ip"]),
models.Index(fields=["ip", "identifier"]),
]

View File

@ -0,0 +1,11 @@
"""Reputation Settings"""
from celery.schedules import crontab
CELERY_BEAT_SCHEDULE = {
"policies_reputation_save": {
"task": "authentik.policies.reputation.tasks.save_reputation",
"schedule": crontab(minute="1-59/5"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,35 +1,40 @@
"""authentik reputation request signals""" """authentik reputation request signals"""
from django.contrib.auth.signals import user_logged_in from django.contrib.auth.signals import user_logged_in
from django.core.cache import cache
from django.dispatch import receiver from django.dispatch import receiver
from django.http import HttpRequest from django.http import HttpRequest
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.signals import login_failed from authentik.core.signals import login_failed
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR from authentik.lib.config import CONFIG
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
from authentik.policies.reputation.models import Reputation, reputation_expiry from authentik.policies.reputation.tasks import save_reputation
from authentik.root.middleware import ClientIPMiddleware from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.identification.signals import identification_failed from authentik.stages.identification.signals import identification_failed
LOGGER = get_logger() LOGGER = get_logger()
CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_reputation")
def update_score(request: HttpRequest, identifier: str, amount: int): def update_score(request: HttpRequest, identifier: str, amount: int):
"""Update score for IP and User""" """Update score for IP and User"""
remote_ip = ClientIPMiddleware.get_client_ip(request) remote_ip = ClientIPMiddleware.get_client_ip(request)
Reputation.objects.update_or_create( try:
ip=remote_ip, # We only update the cache here, as its faster than writing to the DB
identifier=identifier, score = cache.get_or_set(
defaults={ CACHE_KEY_PREFIX + remote_ip + "/" + identifier,
"score": amount, {"ip": remote_ip, "identifier": identifier, "score": 0},
"ip_geo_data": GEOIP_CONTEXT_PROCESSOR.city_dict(remote_ip) or {}, CACHE_TIMEOUT,
"ip_asn_data": ASN_CONTEXT_PROCESSOR.asn_dict(remote_ip) or {}, )
"expires": reputation_expiry(), score["score"] += amount
}, cache.set(CACHE_KEY_PREFIX + remote_ip + "/" + identifier, score)
) except ValueError as exc:
LOGGER.warning("failed to set reputation", exc=exc)
LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip) LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
save_reputation.delay()
@receiver(login_failed) @receiver(login_failed)

View File

@ -0,0 +1,32 @@
"""Reputation tasks"""
from django.core.cache import cache
from structlog.stdlib import get_logger
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
from authentik.policies.reputation.models import Reputation
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def save_reputation(self: SystemTask):
"""Save currently cached reputation to database"""
objects_to_update = []
for _, score in cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")).items():
rep, _ = Reputation.objects.get_or_create(
ip=score["ip"],
identifier=score["identifier"],
)
rep.ip_geo_data = GEOIP_CONTEXT_PROCESSOR.city_dict(score["ip"]) or {}
rep.ip_asn_data = ASN_CONTEXT_PROCESSOR.asn_dict(score["ip"]) or {}
rep.score = score["score"]
objects_to_update.append(rep)
Reputation.objects.bulk_update(objects_to_update, ["score", "ip_geo_data"])
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated Reputation")

View File

@ -1,11 +1,14 @@
"""test reputation signals and policy""" """test reputation signals and policy"""
from django.core.cache import cache
from django.test import RequestFactory, TestCase from django.test import RequestFactory, TestCase
from authentik.core.models import User from authentik.core.models import User
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.reputation.api import ReputationPolicySerializer from authentik.policies.reputation.api import ReputationPolicySerializer
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
from authentik.policies.reputation.models import Reputation, ReputationPolicy from authentik.policies.reputation.models import Reputation, ReputationPolicy
from authentik.policies.reputation.tasks import save_reputation
from authentik.policies.types import PolicyRequest from authentik.policies.types import PolicyRequest
from authentik.stages.password import BACKEND_INBUILT from authentik.stages.password import BACKEND_INBUILT
from authentik.stages.password.stage import authenticate from authentik.stages.password.stage import authenticate
@ -19,6 +22,8 @@ class TestReputationPolicy(TestCase):
self.request = self.request_factory.get("/") self.request = self.request_factory.get("/")
self.test_ip = "127.0.0.1" self.test_ip = "127.0.0.1"
self.test_username = "test" self.test_username = "test"
keys = cache.keys(CACHE_KEY_PREFIX + "*")
cache.delete_many(keys)
# We need a user for the one-to-one in userreputation # We need a user for the one-to-one in userreputation
self.user = User.objects.create(username=self.test_username) self.user = User.objects.create(username=self.test_username)
self.backends = [BACKEND_INBUILT] self.backends = [BACKEND_INBUILT]
@ -29,6 +34,13 @@ class TestReputationPolicy(TestCase):
authenticate( authenticate(
self.request, self.backends, username=self.test_username, password=self.test_username self.request, self.backends, username=self.test_username, password=self.test_username
) )
# Test value in cache
self.assertEqual(
cache.get(CACHE_KEY_PREFIX + self.test_ip + "/" + self.test_username),
{"ip": "127.0.0.1", "identifier": "test", "score": -1},
)
# Save cache and check db values
save_reputation.delay().get()
self.assertEqual(Reputation.objects.get(ip=self.test_ip).score, -1) self.assertEqual(Reputation.objects.get(ip=self.test_ip).score, -1)
def test_user_reputation(self): def test_user_reputation(self):
@ -37,6 +49,13 @@ class TestReputationPolicy(TestCase):
authenticate( authenticate(
self.request, self.backends, username=self.test_username, password=self.test_username self.request, self.backends, username=self.test_username, password=self.test_username
) )
# Test value in cache
self.assertEqual(
cache.get(CACHE_KEY_PREFIX + self.test_ip + "/" + self.test_username),
{"ip": "127.0.0.1", "identifier": "test", "score": -1},
)
# Save cache and check db values
save_reputation.delay().get()
self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, -1) self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, -1)
def test_policy(self): def test_policy(self):

View File

@ -3,7 +3,6 @@
from collections.abc import Iterable from collections.abc import Iterable
from django.db import models from django.db import models
from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -91,10 +90,6 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
def component(self) -> str: def component(self) -> str:
return "ak-provider-ldap-form" return "ak-provider-ldap-form"
@property
def icon_url(self) -> str | None:
return static("authentik/sources/ldap.png")
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
from authentik.providers.ldap.api import LDAPProviderSerializer from authentik.providers.ldap.api import LDAPProviderSerializer

View File

@ -8,7 +8,7 @@ from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.core.api.property_mappings import PropertyMappingSerializer from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.providers.oauth2.models import ScopeMapping from authentik.providers.oauth2.models import ScopeMapping

View File

@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from dacite.core import from_dict from dacite.core import from_dict
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
from django.templatetags.static import static
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from jwt import encode from jwt import encode
@ -263,10 +262,6 @@ class OAuth2Provider(Provider):
LOGGER.warning("Failed to format launch url", exc=exc) LOGGER.warning("Failed to format launch url", exc=exc)
return None return None
@property
def icon_url(self) -> str | None:
return static("authentik/sources/openidconnect.svg")
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-provider-oauth2-form" return "ak-provider-oauth2-form"

View File

@ -15,6 +15,7 @@ from authentik.core.expression.exceptions import PropertyMappingExpressionExcept
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.flows.challenge import PermissionDict from authentik.flows.challenge import PermissionDict
from authentik.providers.oauth2.constants import ( from authentik.providers.oauth2.constants import (
SCOPE_AUTHENTIK_API,
SCOPE_GITHUB_ORG_READ, SCOPE_GITHUB_ORG_READ,
SCOPE_GITHUB_USER, SCOPE_GITHUB_USER,
SCOPE_GITHUB_USER_EMAIL, SCOPE_GITHUB_USER_EMAIL,
@ -56,6 +57,7 @@ class UserInfoView(View):
SCOPE_GITHUB_USER_READ: _("GitHub Compatibility: Access your User Information"), SCOPE_GITHUB_USER_READ: _("GitHub Compatibility: Access your User Information"),
SCOPE_GITHUB_USER_EMAIL: _("GitHub Compatibility: Access you Email addresses"), SCOPE_GITHUB_USER_EMAIL: _("GitHub Compatibility: Access you Email addresses"),
SCOPE_GITHUB_ORG_READ: _("GitHub Compatibility: Access your Groups"), SCOPE_GITHUB_ORG_READ: _("GitHub Compatibility: Access your Groups"),
SCOPE_AUTHENTIK_API: _("authentik API Access on behalf of your user"),
} }
for scope in scopes: for scope in scopes:
if scope in special_scope_map: if scope in special_scope_map:

Some files were not shown because too many files have changed in this diff Show More