Compare commits
132 Commits
core/soft-
...
healthchec
Author | SHA1 | Date | |
---|---|---|---|
93b960a1f1 | |||
578ff13868 | |||
d12acb5bcd | |||
0e8b9a6409 | |||
6171443e61 | |||
5fedd616d9 | |||
5dd6498694 | |||
cf5102ed20 | |||
d3b2032c33 | |||
1e5df1c405 | |||
96eabe269c | |||
3e869a0ec7 | |||
7276a416f6 | |||
a989390533 | |||
562c52a48b | |||
c3cb9bc778 | |||
5f65a7c6cc | |||
95d26563e7 | |||
1cac1492d7 | |||
6c1ac48bd9 | |||
97f11f7aa8 | |||
6db763f7dc | |||
16b5f692ee | |||
80c1bd690c | |||
040dcaa9d6 | |||
66a16752e4 | |||
70c0e1be99 | |||
5beea4624f | |||
50fffa72cc | |||
dae4bf0d6b | |||
823851652e | |||
ae7f7c9930 | |||
5ce4ed4dd3 | |||
5582cc7745 | |||
c384ed5f52 | |||
02e2ba8971 | |||
925d5c80df | |||
1de69a7bd6 | |||
c6979a48e0 | |||
6e73d60305 | |||
f388cac07c | |||
cf593e5cb9 | |||
c3a98e5d5f | |||
1048729599 | |||
72442b37e5 | |||
211cdb3a21 | |||
4cca16750e | |||
b2d261dd1c | |||
0663100429 | |||
66c3261eeb | |||
bf7570bc36 | |||
20b52d0dbd | |||
a1f5e284c4 | |||
0e4737d38f | |||
609b10f7f8 | |||
2cff3d15e7 | |||
4f1d49417c | |||
0766a47b4f | |||
bd1ddfebd6 | |||
a841743c74 | |||
0974456ac8 | |||
d44d5a44a1 | |||
edf5c8686a | |||
70ace8b209 | |||
c3509e63af | |||
89b8206176 | |||
908d87c142 | |||
4ab4e81fb0 | |||
6dae1a4fe7 | |||
d11de73e95 | |||
b08fb5fdf1 | |||
3c9e8c7287 | |||
691d0be41e | |||
dfbaccbab6 | |||
f3bdb189f6 | |||
85b3523639 | |||
9ff61a7120 | |||
f742b986a7 | |||
177bdfa689 | |||
c3445374c2 | |||
c2da6822dc | |||
493294ef9f | |||
17f807e8b0 | |||
96eb98500c | |||
ddd75f6d09 | |||
fbad02fac1 | |||
fbab822db1 | |||
d8316eea9b | |||
8182c9f7c2 | |||
5d94b97e97 | |||
35ddbb6d75 | |||
2b8bc38fc3 | |||
9b0b504531 | |||
c312430007 | |||
4e65c205e3 | |||
372a66c876 | |||
3630349388 | |||
347746cbcd | |||
ef2e1ad27b | |||
8a6b34eb5c | |||
26f72bcac4 | |||
f04466b3be | |||
4ba53d2f08 | |||
7a13046a27 | |||
939e2c1edd | |||
cf06b4177a | |||
f8079d63fa | |||
576a56c562 | |||
cf9b14213e | |||
73cbdb77ed | |||
fd66be9fa2 | |||
96bf9ee898 | |||
6c4c535d57 | |||
0ed4bba5a5 | |||
6e31e5b889 | |||
a5467c6e19 | |||
09832355e3 | |||
6ffef878f0 | |||
644090dc58 | |||
d07508b9a4 | |||
44d7e81a93 | |||
2e91b9d035 | |||
964c6a1050 | |||
90a1c5ab85 | |||
8162c1ec86 | |||
ab46610d9b | |||
6909b58279 | |||
6d7a06227f | |||
1459a13991 | |||
1921ce39f6 | |||
263cff6393 | |||
5a61688472 |
2
.github/actions/setup/docker-compose.yml
vendored
2
.github/actions/setup/docker-compose.yml
vendored
@ -1,5 +1,3 @@
|
||||
version: "3.7"
|
||||
|
||||
services:
|
||||
postgresql:
|
||||
image: docker.io/library/postgres:${PSQL_TAG:-16}
|
||||
|
1
.github/codespell-words.txt
vendored
1
.github/codespell-words.txt
vendored
@ -4,3 +4,4 @@ hass
|
||||
warmup
|
||||
ontext
|
||||
singed
|
||||
assertIn
|
||||
|
6
.github/workflows/ci-main.yml
vendored
6
.github/workflows/ci-main.yml
vendored
@ -50,7 +50,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
psql:
|
||||
- 12-alpine
|
||||
- 15-alpine
|
||||
- 16-alpine
|
||||
steps:
|
||||
@ -104,7 +103,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
psql:
|
||||
- 12-alpine
|
||||
- 15-alpine
|
||||
- 16-alpine
|
||||
steps:
|
||||
@ -252,8 +250,8 @@ jobs:
|
||||
push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
|
||||
build-args: |
|
||||
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache
|
||||
cache-to: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache,mode=max
|
||||
platforms: linux/${{ matrix.arch }}
|
||||
pr-comment:
|
||||
needs:
|
||||
|
4
.github/workflows/ci-outpost.yml
vendored
4
.github/workflows/ci-outpost.yml
vendored
@ -105,8 +105,8 @@ jobs:
|
||||
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
context: .
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache
|
||||
cache-to: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache,mode=max
|
||||
build-binary:
|
||||
timeout-minutes: 120
|
||||
needs:
|
||||
|
30
Dockerfile
30
Dockerfile
@ -38,7 +38,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
|
||||
RUN npm run build
|
||||
|
||||
# Stage 3: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.3-bookworm AS go-builder
|
||||
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@ -49,6 +49,11 @@ ARG GOARCH=$TARGETARCH
|
||||
|
||||
WORKDIR /go/src/goauthentik.io
|
||||
|
||||
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
|
||||
dpkg --add-architecture arm64 && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends crossbuild-essential-arm64 gcc-aarch64-linux-gnu
|
||||
|
||||
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
|
||||
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
@ -63,11 +68,11 @@ COPY ./internal /go/src/goauthentik.io/internal
|
||||
COPY ./go.mod /go/src/goauthentik.io/go.mod
|
||||
COPY ./go.sum /go/src/goauthentik.io/go.sum
|
||||
|
||||
ENV CGO_ENABLED=0
|
||||
|
||||
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
|
||||
GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server
|
||||
if [ "$TARGETARCH" = "arm64" ]; then export CC=aarch64-linux-gnu-gcc && export CC_FOR_TARGET=gcc-aarch64-linux-gnu; fi && \
|
||||
CGO_ENABLED=1 GOEXPERIMENT="systemcrypto" GOFLAGS="-tags=requirefips" GOARM="${TARGETVARIANT#v}" \
|
||||
go build -o /go/authentik ./cmd/server
|
||||
|
||||
# Stage 4: MaxMind GeoIP
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
|
||||
@ -84,7 +89,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
|
||||
# Stage 5: Python dependencies
|
||||
FROM docker.io/python:3.12.3-slim-bookworm AS python-deps
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps
|
||||
|
||||
WORKDIR /ak-root/poetry
|
||||
|
||||
@ -97,7 +102,7 @@ RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloa
|
||||
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
|
||||
apt-get update && \
|
||||
# Required for installing pip packages
|
||||
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev
|
||||
apt-get install -y --no-install-recommends build-essential pkg-config libpq-dev
|
||||
|
||||
RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
|
||||
--mount=type=bind,target=./poetry.lock,src=./poetry.lock \
|
||||
@ -105,12 +110,13 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
|
||||
--mount=type=cache,target=/root/.cache/pypoetry \
|
||||
python -m venv /ak-root/venv/ && \
|
||||
bash -c "source ${VENV_PATH}/bin/activate && \
|
||||
pip3 install --upgrade pip && \
|
||||
pip3 install poetry && \
|
||||
poetry install --only=main --no-ansi --no-interaction --no-root"
|
||||
pip3 install --upgrade pip && \
|
||||
pip3 install poetry && \
|
||||
poetry install --only=main --no-ansi --no-interaction --no-root && \
|
||||
pip install --force-reinstall /wheels/*"
|
||||
|
||||
# Stage 6: Run
|
||||
FROM docker.io/python:3.12.3-slim-bookworm AS final-image
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image
|
||||
|
||||
ARG GIT_BUILD_HASH
|
||||
ARG VERSION
|
||||
@ -127,7 +133,7 @@ WORKDIR /
|
||||
# We cannot cache this layer otherwise we'll end up with a bigger image
|
||||
RUN apt-get update && \
|
||||
# Required for runtime
|
||||
apt-get install -y --no-install-recommends libpq5 openssl libxmlsec1-openssl libmaxminddb0 ca-certificates && \
|
||||
apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates && \
|
||||
# Required for bootstrap & healtcheck
|
||||
apt-get install -y --no-install-recommends runit && \
|
||||
apt-get clean && \
|
||||
@ -163,6 +169,8 @@ ENV TMPDIR=/dev/shm/ \
|
||||
VENV_PATH="/ak-root/venv" \
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
|
||||
ENV GOFIPS=1
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 CMD [ "ak", "healthcheck" ]
|
||||
|
||||
ENTRYPOINT [ "dumb-init", "--", "ak" ]
|
||||
|
1
Makefile
1
Makefile
@ -253,6 +253,7 @@ website-watch: ## Build and watch the documentation website, updating automatic
|
||||
#########################
|
||||
|
||||
docker: ## Build a docker image of the current source tree
|
||||
mkdir -p ${GEN_API_TS}
|
||||
DOCKER_BUILDKIT=1 docker build . --progress plain --tag ${DOCKER_IMAGE}
|
||||
|
||||
#########################
|
||||
|
@ -2,17 +2,19 @@
|
||||
|
||||
import platform
|
||||
from datetime import datetime
|
||||
from ssl import OPENSSL_VERSION
|
||||
from sys import version as python_version
|
||||
from typing import TypedDict
|
||||
|
||||
from cryptography.hazmat.backends.openssl.backend import backend
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from gunicorn import version_info as gunicorn_version
|
||||
from rest_framework.fields import SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from authentik import get_full_version
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.reflection import get_env
|
||||
@ -25,11 +27,13 @@ class RuntimeDict(TypedDict):
|
||||
"""Runtime information"""
|
||||
|
||||
python_version: str
|
||||
gunicorn_version: str
|
||||
environment: str
|
||||
architecture: str
|
||||
platform: str
|
||||
uname: str
|
||||
openssl_version: str
|
||||
openssl_fips_mode: bool
|
||||
authentik_version: str
|
||||
|
||||
|
||||
class SystemInfoSerializer(PassiveSerializer):
|
||||
@ -64,11 +68,13 @@ class SystemInfoSerializer(PassiveSerializer):
|
||||
def get_runtime(self, request: Request) -> RuntimeDict:
|
||||
"""Get versions"""
|
||||
return {
|
||||
"python_version": python_version,
|
||||
"gunicorn_version": ".".join(str(x) for x in gunicorn_version),
|
||||
"environment": get_env(),
|
||||
"architecture": platform.machine(),
|
||||
"authentik_version": get_full_version(),
|
||||
"environment": get_env(),
|
||||
"openssl_fips_enabled": backend._fips_enabled,
|
||||
"openssl_version": OPENSSL_VERSION,
|
||||
"platform": platform.platform(),
|
||||
"python_version": python_version,
|
||||
"uname": " ".join(platform.uname()),
|
||||
}
|
||||
|
||||
|
@ -75,7 +75,7 @@ class BlueprintEntry:
|
||||
_state: BlueprintEntryState = field(default_factory=BlueprintEntryState)
|
||||
|
||||
def __post_init__(self, *args, **kwargs) -> None:
|
||||
self.__tag_contexts: list["YAMLTagContext"] = []
|
||||
self.__tag_contexts: list[YAMLTagContext] = []
|
||||
|
||||
@staticmethod
|
||||
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry":
|
||||
|
@ -178,6 +178,14 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
def list(self, request, *args, **kwargs):
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
@extend_schema(
|
||||
parameters=[
|
||||
OpenApiParameter("include_users", bool, default=True),
|
||||
]
|
||||
)
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
return super().retrieve(request, *args, **kwargs)
|
||||
|
||||
@permission_required("authentik_core.add_user_to_group")
|
||||
@extend_schema(
|
||||
request=UserAccountSerializer,
|
||||
|
79
authentik/core/api/object_types.py
Normal file
79
authentik/core/api/object_types.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""API Utilities"""
|
||||
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
)
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
|
||||
|
||||
class TypeCreateSerializer(PassiveSerializer):
|
||||
"""Types of an object that can be created"""
|
||||
|
||||
name = CharField(required=True)
|
||||
description = CharField(required=True)
|
||||
component = CharField(required=True)
|
||||
model_name = CharField(required=True)
|
||||
|
||||
icon_url = CharField(required=False)
|
||||
requires_enterprise = BooleanField(default=False)
|
||||
|
||||
|
||||
class CreatableType:
|
||||
"""Class to inherit from to mark a model as creatable, even if the model itself is marked
|
||||
as abstract"""
|
||||
|
||||
|
||||
class NonCreatableType:
|
||||
"""Class to inherit from to mark a model as non-creatable even if it is not abstract"""
|
||||
|
||||
|
||||
class TypesMixin:
|
||||
"""Mixin which adds an API endpoint to list all possible types that can be created"""
|
||||
|
||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def types(self, request: Request, additional: list[dict] | None = None) -> Response:
|
||||
"""Get all creatable types"""
|
||||
data = []
|
||||
for subclass in all_subclasses(self.queryset.model):
|
||||
instance = None
|
||||
if subclass._meta.abstract:
|
||||
if not issubclass(subclass, CreatableType):
|
||||
continue
|
||||
# Circumvent the django protection for not being able to instantiate
|
||||
# abstract models. We need a model instance to access .component
|
||||
# and further down .icon_url
|
||||
instance = subclass.__new__(subclass)
|
||||
# Django re-sets abstract = False so we need to override that
|
||||
instance.Meta.abstract = True
|
||||
else:
|
||||
if issubclass(subclass, NonCreatableType):
|
||||
continue
|
||||
instance = subclass()
|
||||
try:
|
||||
data.append(
|
||||
{
|
||||
"name": subclass._meta.verbose_name,
|
||||
"description": subclass.__doc__,
|
||||
"component": instance.component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
"icon_url": getattr(instance, "icon_url", None),
|
||||
"requires_enterprise": isinstance(
|
||||
subclass._meta.app_config, EnterpriseConfig
|
||||
),
|
||||
}
|
||||
)
|
||||
except NotImplementedError:
|
||||
continue
|
||||
if additional:
|
||||
data.extend(additional)
|
||||
data = sorted(data, key=lambda x: x["name"])
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
@ -9,18 +9,22 @@ from rest_framework import mixins
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.fields import BooleanField, CharField
|
||||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.blueprints.api import ManagedSerializer
|
||||
from authentik.core.api.object_types import TypesMixin
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||
from authentik.core.api.utils import (
|
||||
MetaNameSerializer,
|
||||
PassiveSerializer,
|
||||
)
|
||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||
from authentik.core.models import PropertyMapping
|
||||
from authentik.core.models import Group, PropertyMapping, User
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
from authentik.policies.api.exec import PolicyTestSerializer
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
@ -64,6 +68,7 @@ class PropertyMappingSerializer(ManagedSerializer, ModelSerializer, MetaNameSeri
|
||||
|
||||
|
||||
class PropertyMappingViewSet(
|
||||
TypesMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
@ -72,7 +77,13 @@ class PropertyMappingViewSet(
|
||||
):
|
||||
"""PropertyMapping Viewset"""
|
||||
|
||||
queryset = PropertyMapping.objects.none()
|
||||
class PropertyMappingTestSerializer(PolicyTestSerializer):
|
||||
"""Test property mapping execution for a user/group with context"""
|
||||
|
||||
user = PrimaryKeyRelatedField(queryset=User.objects.all(), required=False)
|
||||
group = PrimaryKeyRelatedField(queryset=Group.objects.all(), required=False)
|
||||
|
||||
queryset = PropertyMapping.objects.select_subclasses()
|
||||
serializer_class = PropertyMappingSerializer
|
||||
search_fields = [
|
||||
"name",
|
||||
@ -80,29 +91,9 @@ class PropertyMappingViewSet(
|
||||
filterset_fields = {"managed": ["isnull"]}
|
||||
ordering = ["name"]
|
||||
|
||||
def get_queryset(self): # pragma: no cover
|
||||
return PropertyMapping.objects.select_subclasses()
|
||||
|
||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def types(self, request: Request) -> Response:
|
||||
"""Get all creatable property-mapping types"""
|
||||
data = []
|
||||
for subclass in all_subclasses(self.queryset.model):
|
||||
subclass: PropertyMapping
|
||||
data.append(
|
||||
{
|
||||
"name": subclass._meta.verbose_name,
|
||||
"description": subclass.__doc__,
|
||||
"component": subclass().component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
}
|
||||
)
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
||||
|
||||
@permission_required("authentik_core.view_propertymapping")
|
||||
@extend_schema(
|
||||
request=PolicyTestSerializer(),
|
||||
request=PropertyMappingTestSerializer(),
|
||||
responses={
|
||||
200: PropertyMappingTestResultSerializer,
|
||||
400: OpenApiResponse(description="Invalid parameters"),
|
||||
@ -120,29 +111,39 @@ class PropertyMappingViewSet(
|
||||
"""Test Property Mapping"""
|
||||
_mapping: PropertyMapping = self.get_object()
|
||||
# Use `get_subclass` to get correct class and correct `.evaluate` implementation
|
||||
mapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk)
|
||||
mapping: PropertyMapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk)
|
||||
# FIXME: when we separate policy mappings between ones for sources
|
||||
# and ones for providers, we need to make the user field optional for the source mapping
|
||||
test_params = PolicyTestSerializer(data=request.data)
|
||||
test_params = self.PropertyMappingTestSerializer(data=request.data)
|
||||
if not test_params.is_valid():
|
||||
return Response(test_params.errors, status=400)
|
||||
|
||||
format_result = str(request.GET.get("format_result", "false")).lower() == "true"
|
||||
|
||||
# User permission check, only allow mapping testing for users that are readable
|
||||
users = get_objects_for_user(request.user, "authentik_core.view_user").filter(
|
||||
pk=test_params.validated_data["user"].pk
|
||||
)
|
||||
if not users.exists():
|
||||
raise PermissionDenied()
|
||||
context: dict = test_params.validated_data.get("context", {})
|
||||
context.setdefault("user", None)
|
||||
|
||||
if user := test_params.validated_data.get("user"):
|
||||
# User permission check, only allow mapping testing for users that are readable
|
||||
users = get_objects_for_user(request.user, "authentik_core.view_user").filter(
|
||||
pk=user.pk
|
||||
)
|
||||
if not users.exists():
|
||||
raise PermissionDenied()
|
||||
context["user"] = user
|
||||
if group := test_params.validated_data.get("group"):
|
||||
# Group permission check, only allow mapping testing for groups that are readable
|
||||
groups = get_objects_for_user(request.user, "authentik_core.view_group").filter(
|
||||
pk=group.pk
|
||||
)
|
||||
if not groups.exists():
|
||||
raise PermissionDenied()
|
||||
context["group"] = group
|
||||
context["request"] = self.request
|
||||
|
||||
response_data = {"successful": True, "result": ""}
|
||||
try:
|
||||
result = mapping.evaluate(
|
||||
users.first(),
|
||||
self.request,
|
||||
**test_params.validated_data.get("context", {}),
|
||||
)
|
||||
result = mapping.evaluate(**context)
|
||||
response_data["result"] = dumps(
|
||||
sanitize_item(result), indent=(4 if format_result else None)
|
||||
)
|
@ -5,20 +5,15 @@ from django.db.models.query import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters.filters import BooleanFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import mixins
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import ReadOnlyField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.object_types import TypesMixin
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||
from authentik.core.api.utils import MetaNameSerializer
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
|
||||
|
||||
class ProviderSerializer(ModelSerializer, MetaNameSerializer):
|
||||
@ -86,6 +81,7 @@ class ProviderFilter(FilterSet):
|
||||
|
||||
|
||||
class ProviderViewSet(
|
||||
TypesMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
@ -104,31 +100,3 @@ class ProviderViewSet(
|
||||
|
||||
def get_queryset(self): # pragma: no cover
|
||||
return Provider.objects.select_subclasses()
|
||||
|
||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def types(self, request: Request) -> Response:
|
||||
"""Get all creatable provider types"""
|
||||
data = []
|
||||
for subclass in all_subclasses(self.queryset.model):
|
||||
subclass: Provider
|
||||
if subclass._meta.abstract:
|
||||
continue
|
||||
data.append(
|
||||
{
|
||||
"name": subclass._meta.verbose_name,
|
||||
"description": subclass.__doc__,
|
||||
"component": subclass().component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||
}
|
||||
)
|
||||
data.append(
|
||||
{
|
||||
"name": _("SAML Provider from Metadata"),
|
||||
"description": _("Create a SAML Provider by importing its Metadata."),
|
||||
"component": "ak-provider-saml-import-form",
|
||||
"model_name": "",
|
||||
}
|
||||
)
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
||||
|
@ -17,8 +17,9 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
|
||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||
from authentik.core.api.object_types import TypesMixin
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||
from authentik.core.api.utils import MetaNameSerializer
|
||||
from authentik.core.models import Source, UserSourceConnection
|
||||
from authentik.core.types import UserSettingSerializer
|
||||
from authentik.lib.utils.file import (
|
||||
@ -27,7 +28,6 @@ from authentik.lib.utils.file import (
|
||||
set_file,
|
||||
set_file_url,
|
||||
)
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
@ -74,6 +74,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
|
||||
|
||||
|
||||
class SourceViewSet(
|
||||
TypesMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
@ -132,30 +133,6 @@ class SourceViewSet(
|
||||
source: Source = self.get_object()
|
||||
return set_file_url(request, source, "icon")
|
||||
|
||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def types(self, request: Request) -> Response:
|
||||
"""Get all creatable source types"""
|
||||
data = []
|
||||
for subclass in all_subclasses(self.queryset.model):
|
||||
subclass: Source
|
||||
component = ""
|
||||
if len(subclass.__subclasses__()) > 0:
|
||||
continue
|
||||
if subclass._meta.abstract:
|
||||
component = subclass.__bases__[0]().component
|
||||
else:
|
||||
component = subclass().component
|
||||
data.append(
|
||||
{
|
||||
"name": subclass._meta.verbose_name,
|
||||
"description": subclass.__doc__,
|
||||
"component": component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
}
|
||||
)
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
||||
|
||||
@extend_schema(responses={200: UserSettingSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def user_settings(self, request: Request) -> Response:
|
||||
|
@ -6,8 +6,16 @@ from django.db.models import Model
|
||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.plumbing import build_basic_type
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from rest_framework.fields import BooleanField, CharField, IntegerField, JSONField
|
||||
from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
|
||||
from rest_framework.fields import (
|
||||
CharField,
|
||||
IntegerField,
|
||||
JSONField,
|
||||
SerializerMethodField,
|
||||
)
|
||||
from rest_framework.serializers import (
|
||||
Serializer,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
|
||||
def is_dict(value: Any):
|
||||
@ -68,16 +76,6 @@ class MetaNameSerializer(PassiveSerializer):
|
||||
return f"{obj._meta.app_label}.{obj._meta.model_name}"
|
||||
|
||||
|
||||
class TypeCreateSerializer(PassiveSerializer):
|
||||
"""Types of an object that can be created"""
|
||||
|
||||
name = CharField(required=True)
|
||||
description = CharField(required=True)
|
||||
component = CharField(required=True)
|
||||
model_name = CharField(required=True)
|
||||
requires_enterprise = BooleanField(default=False)
|
||||
|
||||
|
||||
class CacheSerializer(PassiveSerializer):
|
||||
"""Generic cache stats for an object"""
|
||||
|
||||
|
@ -31,8 +31,9 @@ class InbuiltBackend(ModelBackend):
|
||||
# Since we can't directly pass other variables to signals, and we want to log the method
|
||||
# and the token used, we assume we're running in a flow and set a variable in the context
|
||||
flow_plan: FlowPlan = request.session.get(SESSION_KEY_PLAN, FlowPlan(""))
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD] = method
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS] = cleanse_dict(sanitize_dict(kwargs))
|
||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, method)
|
||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].update(cleanse_dict(sanitize_dict(kwargs)))
|
||||
request.session[SESSION_KEY_PLAN] = flow_plan
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Property Mapping Evaluator"""
|
||||
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import Model
|
||||
@ -24,6 +25,8 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
||||
"""Custom Evaluator that adds some different context variables."""
|
||||
|
||||
dry_run: bool
|
||||
model: Model
|
||||
_compiled: CodeType | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -33,23 +36,32 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
||||
dry_run: bool | None = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
if hasattr(model, "name"):
|
||||
_filename = model.name
|
||||
else:
|
||||
_filename = str(model)
|
||||
super().__init__(filename=_filename)
|
||||
self.dry_run = dry_run
|
||||
self.set_context(user, request, **kwargs)
|
||||
|
||||
def set_context(
|
||||
self,
|
||||
user: User | None = None,
|
||||
request: HttpRequest | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
req = PolicyRequest(user=User())
|
||||
req.obj = model
|
||||
req.obj = self.model
|
||||
if user:
|
||||
req.user = user
|
||||
self._context["user"] = user
|
||||
if request:
|
||||
req.http_request = request
|
||||
self._context["request"] = req
|
||||
req.context.update(**kwargs)
|
||||
self._context["request"] = req
|
||||
self._context.update(**kwargs)
|
||||
self._globals["SkipObject"] = SkipObjectException
|
||||
self.dry_run = dry_run
|
||||
|
||||
def handle_error(self, exc: Exception, expression_source: str):
|
||||
"""Exception Handler"""
|
||||
@ -71,3 +83,9 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
||||
def evaluate(self, *args, **kwargs) -> Any:
|
||||
with PROPERTY_MAPPING_TIME.labels(mapping_name=self._filename).time():
|
||||
return super().evaluate(*args, **kwargs)
|
||||
|
||||
def compile(self, expression: str | None = None) -> Any:
|
||||
if not self._compiled:
|
||||
compiled = super().compile(expression or self.model.expression)
|
||||
self._compiled = compiled
|
||||
return self._compiled
|
||||
|
@ -6,6 +6,11 @@ from authentik.lib.sentry import SentryIgnoredException
|
||||
class PropertyMappingExpressionException(SentryIgnoredException):
|
||||
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated."""
|
||||
|
||||
def __init__(self, exc: Exception, mapping) -> None:
|
||||
super().__init__()
|
||||
self.exc = exc
|
||||
self.mapping = mapping
|
||||
|
||||
|
||||
class SkipObjectException(PropertyMappingExpressionException):
|
||||
"""Exception which can be raised in a property mapping to skip syncing an object.
|
||||
|
@ -15,6 +15,7 @@ from django.http import HttpRequest
|
||||
from django.utils.functional import SimpleLazyObject, cached_property
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_cte import CTEQuerySet, With
|
||||
from guardian.conf import settings
|
||||
from guardian.mixins import GuardianUserMixin
|
||||
from model_utils.managers import InheritanceManager
|
||||
@ -56,6 +57,8 @@ options.DEFAULT_NAMES = options.DEFAULT_NAMES + (
|
||||
"authentik_used_by_shadows",
|
||||
)
|
||||
|
||||
GROUP_RECURSION_LIMIT = 20
|
||||
|
||||
|
||||
def default_token_duration() -> datetime:
|
||||
"""Default duration a Token is valid"""
|
||||
@ -96,6 +99,40 @@ class UserTypes(models.TextChoices):
|
||||
INTERNAL_SERVICE_ACCOUNT = "internal_service_account"
|
||||
|
||||
|
||||
class GroupQuerySet(CTEQuerySet):
|
||||
def with_children_recursive(self):
|
||||
"""Recursively get all groups that have the current queryset as parents
|
||||
or are indirectly related."""
|
||||
|
||||
def make_cte(cte):
|
||||
"""Build the query that ends up in WITH RECURSIVE"""
|
||||
# Start from self, aka the current query
|
||||
# Add a depth attribute to limit the recursion
|
||||
return self.annotate(
|
||||
relative_depth=models.Value(0, output_field=models.IntegerField())
|
||||
).union(
|
||||
# Here is the recursive part of the query. cte refers to the previous iteration
|
||||
# Only select groups for which the parent is part of the previous iteration
|
||||
# and increase the depth
|
||||
# Finally, limit the depth
|
||||
cte.join(Group, group_uuid=cte.col.parent_id)
|
||||
.annotate(
|
||||
relative_depth=models.ExpressionWrapper(
|
||||
cte.col.relative_depth
|
||||
+ models.Value(1, output_field=models.IntegerField()),
|
||||
output_field=models.IntegerField(),
|
||||
)
|
||||
)
|
||||
.filter(relative_depth__lt=GROUP_RECURSION_LIMIT),
|
||||
all=True,
|
||||
)
|
||||
|
||||
# Build the recursive query, see above
|
||||
cte = With.recursive(make_cte)
|
||||
# Return the result, as a usable queryset for Group.
|
||||
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
|
||||
|
||||
|
||||
class Group(SerializerModel):
|
||||
"""Group model which supports a basic hierarchy and has attributes"""
|
||||
|
||||
@ -118,6 +155,8 @@ class Group(SerializerModel):
|
||||
)
|
||||
attributes = models.JSONField(default=dict, blank=True)
|
||||
|
||||
objects = GroupQuerySet.as_manager()
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
from authentik.core.api.groups import GroupSerializer
|
||||
@ -136,36 +175,11 @@ class Group(SerializerModel):
|
||||
return user.all_groups().filter(group_uuid=self.group_uuid).exists()
|
||||
|
||||
def children_recursive(self: Self | QuerySet["Group"]) -> QuerySet["Group"]:
|
||||
"""Recursively get all groups that have this as parent or are indirectly related"""
|
||||
direct_groups = []
|
||||
if isinstance(self, QuerySet):
|
||||
direct_groups = list(x for x in self.all().values_list("pk", flat=True).iterator())
|
||||
else:
|
||||
direct_groups = [self.pk]
|
||||
if len(direct_groups) < 1:
|
||||
return Group.objects.none()
|
||||
query = """
|
||||
WITH RECURSIVE parents AS (
|
||||
SELECT authentik_core_group.*, 0 AS relative_depth
|
||||
FROM authentik_core_group
|
||||
WHERE authentik_core_group.group_uuid = ANY(%s)
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT authentik_core_group.*, parents.relative_depth + 1
|
||||
FROM authentik_core_group, parents
|
||||
WHERE (
|
||||
authentik_core_group.group_uuid = parents.parent_id and
|
||||
parents.relative_depth < 20
|
||||
)
|
||||
)
|
||||
SELECT group_uuid
|
||||
FROM parents
|
||||
GROUP BY group_uuid, name
|
||||
ORDER BY name;
|
||||
"""
|
||||
group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()]
|
||||
return Group.objects.filter(pk__in=group_pks)
|
||||
"""Compatibility layer for Group.objects.with_children_recursive()"""
|
||||
qs = self
|
||||
if not isinstance(self, QuerySet):
|
||||
qs = Group.objects.filter(group_uuid=self.group_uuid)
|
||||
return qs.with_children_recursive()
|
||||
|
||||
def __str__(self):
|
||||
return f"Group {self.name}"
|
||||
@ -232,10 +246,8 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
||||
return User._meta.get_field("path").default
|
||||
|
||||
def all_groups(self) -> QuerySet[Group]:
|
||||
"""Recursively get all groups this user is a member of.
|
||||
At least one query is done to get the direct groups of the user, with groups
|
||||
there are at most 3 queries done"""
|
||||
return Group.children_recursive(self.ak_groups.all())
|
||||
"""Recursively get all groups this user is a member of."""
|
||||
return self.ak_groups.all().with_children_recursive()
|
||||
|
||||
def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]:
|
||||
"""Get a dictionary containing the attributes from all groups the user belongs to,
|
||||
@ -377,6 +389,10 @@ class Provider(SerializerModel):
|
||||
Can return None for providers that are not URL-based"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
"""Return component used to edit this object"""
|
||||
@ -768,7 +784,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
||||
try:
|
||||
return evaluator.evaluate(self.expression)
|
||||
except Exception as exc:
|
||||
raise PropertyMappingExpressionException(exc) from exc
|
||||
raise PropertyMappingExpressionException(self, exc) from exc
|
||||
|
||||
def __str__(self):
|
||||
return f"Property Mapping {self.name}"
|
||||
|
@ -23,6 +23,17 @@ class TestGroupsAPI(APITestCase):
|
||||
response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_retrieve_with_users(self):
|
||||
"""Test retrieve with users"""
|
||||
admin = create_test_admin_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
self.client.force_login(admin)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:group-detail", kwargs={"pk": group.pk}),
|
||||
{"include_users": "true"},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_add_user(self):
|
||||
"""Test add_user"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
|
@ -1,14 +1,14 @@
|
||||
"""authentik core models tests"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from time import sleep
|
||||
from datetime import timedelta
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.utils.timezone import now
|
||||
from freezegun import freeze_time
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
|
||||
from authentik.core.models import Provider, Source, Token
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
|
||||
|
||||
@ -17,18 +17,20 @@ class TestModels(TestCase):
|
||||
|
||||
def test_token_expire(self):
|
||||
"""Test token expiring"""
|
||||
token = Token.objects.create(expires=now(), user=get_anonymous_user())
|
||||
sleep(0.5)
|
||||
self.assertTrue(token.is_expired)
|
||||
with freeze_time() as freeze:
|
||||
token = Token.objects.create(expires=now(), user=get_anonymous_user())
|
||||
freeze.tick(timedelta(seconds=1))
|
||||
self.assertTrue(token.is_expired)
|
||||
|
||||
def test_token_expire_no_expire(self):
|
||||
"""Test token expiring with "expiring" set"""
|
||||
token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
|
||||
sleep(0.5)
|
||||
self.assertFalse(token.is_expired)
|
||||
with freeze_time() as freeze:
|
||||
token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
|
||||
freeze.tick(timedelta(seconds=1))
|
||||
self.assertFalse(token.is_expired)
|
||||
|
||||
|
||||
def source_tester_factory(test_model: type[Stage]) -> Callable:
|
||||
def source_tester_factory(test_model: type[Source]) -> Callable:
|
||||
"""Test source"""
|
||||
|
||||
factory = RequestFactory()
|
||||
@ -36,19 +38,19 @@ def source_tester_factory(test_model: type[Stage]) -> Callable:
|
||||
|
||||
def tester(self: TestModels):
|
||||
model_class = None
|
||||
if test_model._meta.abstract: # pragma: no cover
|
||||
model_class = test_model.__bases__[0]()
|
||||
if test_model._meta.abstract:
|
||||
model_class = [x for x in test_model.__bases__ if issubclass(x, Source)][0]()
|
||||
else:
|
||||
model_class = test_model()
|
||||
model_class.slug = "test"
|
||||
self.assertIsNotNone(model_class.component)
|
||||
_ = model_class.ui_login_button(request)
|
||||
_ = model_class.ui_user_settings()
|
||||
model_class.ui_login_button(request)
|
||||
model_class.ui_user_settings()
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def provider_tester_factory(test_model: type[Stage]) -> Callable:
|
||||
def provider_tester_factory(test_model: type[Provider]) -> Callable:
|
||||
"""Test provider"""
|
||||
|
||||
def tester(self: TestModels):
|
||||
|
@ -6,9 +6,10 @@ from django.urls import reverse
|
||||
from rest_framework.serializers import ValidationError
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.models import PropertyMapping
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.models import Group, PropertyMapping
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestPropertyMappingAPI(APITestCase):
|
||||
@ -16,23 +17,40 @@ class TestPropertyMappingAPI(APITestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.mapping = PropertyMapping.objects.create(
|
||||
name="dummy", expression="""return {'foo': 'bar'}"""
|
||||
)
|
||||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_test_call(self):
|
||||
"""Test PropertMappings's test endpoint"""
|
||||
"""Test PropertyMappings's test endpoint"""
|
||||
mapping = PropertyMapping.objects.create(
|
||||
name="dummy", expression="""return {'foo': 'bar', 'baz': user.username}"""
|
||||
)
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}),
|
||||
reverse("authentik_api:propertymapping-test", kwargs={"pk": mapping.pk}),
|
||||
data={
|
||||
"user": self.user.pk,
|
||||
},
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"result": dumps({"foo": "bar"}), "successful": True},
|
||||
{"result": dumps({"foo": "bar", "baz": self.user.username}), "successful": True},
|
||||
)
|
||||
|
||||
def test_test_call_group(self):
|
||||
"""Test PropertyMappings's test endpoint"""
|
||||
mapping = PropertyMapping.objects.create(
|
||||
name="dummy", expression="""return {'foo': 'bar', 'baz': group.name}"""
|
||||
)
|
||||
group = Group.objects.create(name=generate_id())
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:propertymapping-test", kwargs={"pk": mapping.pk}),
|
||||
data={
|
||||
"group": group.pk,
|
||||
},
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{"result": dumps({"foo": "bar", "baz": group.name}), "successful": True},
|
||||
)
|
||||
|
||||
def test_validate(self):
|
||||
|
@ -42,8 +42,8 @@ class TestUsersAvatars(APITestCase):
|
||||
with Mocker() as mocker:
|
||||
mocker.head(
|
||||
(
|
||||
"https://secure.gravatar.com/avatar/84730f9c1851d1ea03f1a"
|
||||
"a9ed85bd1ea?size=158&rating=g&default=404"
|
||||
"https://www.gravatar.com/avatar/76eb3c74c8beb6faa037f1b6e2ecb3e252bdac"
|
||||
"6cf71fb567ae36025a9d4ea86b?size=158&rating=g&default=404"
|
||||
),
|
||||
text="foo",
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ from authentik.core.api.applications import ApplicationViewSet
|
||||
from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet
|
||||
from authentik.core.api.devices import AdminDeviceViewSet, DeviceViewSet
|
||||
from authentik.core.api.groups import GroupViewSet
|
||||
from authentik.core.api.propertymappings import PropertyMappingViewSet
|
||||
from authentik.core.api.property_mappings import PropertyMappingViewSet
|
||||
from authentik.core.api.providers import ProviderViewSet
|
||||
from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSet
|
||||
from authentik.core.api.tokens import TokenViewSet
|
||||
|
@ -92,7 +92,11 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
||||
@property
|
||||
def kid(self):
|
||||
"""Get Key ID used for JWKS"""
|
||||
return md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else "" # nosec
|
||||
return (
|
||||
md5(self.key_data.encode("utf-8"), usedforsecurity=False).hexdigest()
|
||||
if self.key_data
|
||||
else ""
|
||||
) # nosec
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Certificate-Key Pair {self.name}"
|
||||
|
@ -1,14 +1,15 @@
|
||||
"""GoogleWorkspaceProviderGroup API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework import mixins
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.users import UserGroupSerializer
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderGroup
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderGroupSerializer(SourceSerializer):
|
||||
class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
|
||||
"""GoogleWorkspaceProviderGroup Serializer"""
|
||||
|
||||
group_obj = UserGroupSerializer(source="group", read_only=True)
|
||||
@ -20,10 +21,20 @@ class GoogleWorkspaceProviderGroupSerializer(SourceSerializer):
|
||||
"id",
|
||||
"group",
|
||||
"group_obj",
|
||||
"provider",
|
||||
"attributes",
|
||||
]
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderGroupViewSet(UsedByMixin, ModelViewSet):
|
||||
class GoogleWorkspaceProviderGroupViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
"""GoogleWorkspaceProviderGroup Viewset"""
|
||||
|
||||
queryset = GoogleWorkspaceProviderGroup.objects.all().select_related("group")
|
||||
|
@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderMapping
|
||||
|
||||
|
@ -1,14 +1,15 @@
|
||||
"""GoogleWorkspaceProviderUser API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework import mixins
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.groups import GroupMemberSerializer
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderUser
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderUserSerializer(SourceSerializer):
|
||||
class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
|
||||
"""GoogleWorkspaceProviderUser Serializer"""
|
||||
|
||||
user_obj = GroupMemberSerializer(source="user", read_only=True)
|
||||
@ -20,10 +21,20 @@ class GoogleWorkspaceProviderUserSerializer(SourceSerializer):
|
||||
"id",
|
||||
"user",
|
||||
"user_obj",
|
||||
"provider",
|
||||
"attributes",
|
||||
]
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class GoogleWorkspaceProviderUserViewSet(UsedByMixin, ModelViewSet):
|
||||
class GoogleWorkspaceProviderUserViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
"""GoogleWorkspaceProviderUser Viewset"""
|
||||
|
||||
queryset = GoogleWorkspaceProviderUser.objects.all().select_related("user")
|
||||
|
@ -1,28 +1,22 @@
|
||||
from deepmerge import always_merger
|
||||
from django.db import transaction
|
||||
from django.utils.text import slugify
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import Group
|
||||
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
||||
from authentik.enterprise.providers.google_workspace.models import (
|
||||
GoogleWorkspaceProvider,
|
||||
GoogleWorkspaceProviderGroup,
|
||||
GoogleWorkspaceProviderMapping,
|
||||
GoogleWorkspaceProviderUser,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
NotFoundSyncException,
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
TransientSyncException,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
|
||||
class GoogleWorkspaceGroupClient(
|
||||
@ -34,41 +28,21 @@ class GoogleWorkspaceGroupClient(
|
||||
connection_type_query = "group"
|
||||
can_discover = True
|
||||
|
||||
def to_schema(self, obj: Group, creating: bool) -> dict:
|
||||
"""Convert authentik group"""
|
||||
raw_google_group = {
|
||||
"email": f"{slugify(obj.name)}@{self.provider.default_group_email_domain}"
|
||||
}
|
||||
for mapping in (
|
||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
|
||||
):
|
||||
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
|
||||
continue
|
||||
try:
|
||||
value = mapping.evaluate(
|
||||
user=None,
|
||||
request=None,
|
||||
group=obj,
|
||||
provider=self.provider,
|
||||
creating=creating,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_google_group, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_google_group:
|
||||
raise StopSync(ValueError("No group mappings configured"), obj)
|
||||
def __init__(self, provider: GoogleWorkspaceProvider) -> None:
|
||||
super().__init__(provider)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses(),
|
||||
GoogleWorkspaceProviderMapping,
|
||||
["group", "provider", "connection"],
|
||||
)
|
||||
|
||||
return raw_google_group
|
||||
def to_schema(self, obj: Group, connection: GoogleWorkspaceProviderGroup) -> dict:
|
||||
"""Convert authentik group"""
|
||||
return super().to_schema(
|
||||
obj,
|
||||
connection=connection,
|
||||
email=f"{slugify(obj.name)}@{self.provider.default_group_email_domain}",
|
||||
)
|
||||
|
||||
def delete(self, obj: Group):
|
||||
"""Delete group"""
|
||||
@ -87,7 +61,7 @@ class GoogleWorkspaceGroupClient(
|
||||
|
||||
def create(self, group: Group):
|
||||
"""Create group from scratch and create a connection object"""
|
||||
google_group = self.to_schema(group, True)
|
||||
google_group = self.to_schema(group, None)
|
||||
self.check_email_valid(google_group["email"])
|
||||
with transaction.atomic():
|
||||
try:
|
||||
@ -100,16 +74,22 @@ class GoogleWorkspaceGroupClient(
|
||||
self.directory_service.groups().get(groupKey=google_group["email"])
|
||||
)
|
||||
return GoogleWorkspaceProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, google_id=group_data["id"]
|
||||
provider=self.provider,
|
||||
group=group,
|
||||
google_id=group_data["id"],
|
||||
attributes=group_data,
|
||||
)
|
||||
else:
|
||||
return GoogleWorkspaceProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, google_id=response["id"]
|
||||
provider=self.provider,
|
||||
group=group,
|
||||
google_id=response["id"],
|
||||
attributes=response,
|
||||
)
|
||||
|
||||
def update(self, group: Group, connection: GoogleWorkspaceProviderGroup):
|
||||
"""Update existing group"""
|
||||
google_group = self.to_schema(group, False)
|
||||
google_group = self.to_schema(group, connection)
|
||||
self.check_email_valid(google_group["email"])
|
||||
try:
|
||||
return self._request(
|
||||
@ -230,4 +210,5 @@ class GoogleWorkspaceGroupClient(
|
||||
provider=self.provider,
|
||||
group=matching_authentik_group,
|
||||
google_id=google_id,
|
||||
attributes=group,
|
||||
)
|
||||
|
@ -1,24 +1,18 @@
|
||||
from deepmerge import always_merger
|
||||
from django.db import transaction
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
||||
from authentik.enterprise.providers.google_workspace.models import (
|
||||
GoogleWorkspaceProvider,
|
||||
GoogleWorkspaceProviderMapping,
|
||||
GoogleWorkspaceProviderUser,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
TransientSyncException,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.utils import delete_none_values
|
||||
|
||||
|
||||
@ -29,37 +23,17 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
||||
connection_type_query = "user"
|
||||
can_discover = True
|
||||
|
||||
def to_schema(self, obj: User, creating: bool) -> dict:
|
||||
def __init__(self, provider: GoogleWorkspaceProvider) -> None:
|
||||
super().__init__(provider)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self.provider.property_mappings.all().order_by("name").select_subclasses(),
|
||||
GoogleWorkspaceProviderMapping,
|
||||
["provider", "connection"],
|
||||
)
|
||||
|
||||
def to_schema(self, obj: User, connection: GoogleWorkspaceProviderUser) -> dict:
|
||||
"""Convert authentik user"""
|
||||
raw_google_user = {}
|
||||
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
|
||||
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
|
||||
continue
|
||||
try:
|
||||
value = mapping.evaluate(
|
||||
user=obj,
|
||||
request=None,
|
||||
provider=self.provider,
|
||||
creating=creating,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_google_user, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_google_user:
|
||||
raise StopSync(ValueError("No user mappings configured"), obj)
|
||||
if "primaryEmail" not in raw_google_user:
|
||||
raw_google_user["primaryEmail"] = str(obj.email)
|
||||
return delete_none_values(raw_google_user)
|
||||
return delete_none_values(super().to_schema(obj, connection, primaryEmail=obj.email))
|
||||
|
||||
def delete(self, obj: User):
|
||||
"""Delete user"""
|
||||
@ -86,7 +60,7 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
||||
|
||||
def create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
google_user = self.to_schema(user, True)
|
||||
google_user = self.to_schema(user, None)
|
||||
self.check_email_valid(
|
||||
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
||||
)
|
||||
@ -96,18 +70,21 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
||||
except ObjectExistsSyncException:
|
||||
# user already exists in google workspace, so we can connect them manually
|
||||
return GoogleWorkspaceProviderUser.objects.create(
|
||||
provider=self.provider, user=user, google_id=user.email
|
||||
provider=self.provider, user=user, google_id=user.email, attributes={}
|
||||
)
|
||||
except TransientSyncException as exc:
|
||||
raise exc
|
||||
else:
|
||||
return GoogleWorkspaceProviderUser.objects.create(
|
||||
provider=self.provider, user=user, google_id=response["primaryEmail"]
|
||||
provider=self.provider,
|
||||
user=user,
|
||||
google_id=response["primaryEmail"],
|
||||
attributes=response,
|
||||
)
|
||||
|
||||
def update(self, user: User, connection: GoogleWorkspaceProviderUser):
|
||||
"""Update existing user"""
|
||||
google_user = self.to_schema(user, False)
|
||||
google_user = self.to_schema(user, connection)
|
||||
self.check_email_valid(
|
||||
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
||||
)
|
||||
@ -138,4 +115,5 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
||||
provider=self.provider,
|
||||
user=matching_authentik_user,
|
||||
google_id=email,
|
||||
attributes=user,
|
||||
)
|
||||
|
@ -0,0 +1,26 @@
|
||||
# Generated by Django 5.0.6 on 2024-05-23 20:48
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_providers_google_workspace",
|
||||
"0001_squashed_0002_alter_googleworkspaceprovidergroup_options_and_more",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="googleworkspaceprovidergroup",
|
||||
name="attributes",
|
||||
field=models.JSONField(default=dict),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="googleworkspaceprovideruser",
|
||||
name="attributes",
|
||||
field=models.JSONField(default=dict),
|
||||
),
|
||||
]
|
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from google.oauth2.service_account import Credentials
|
||||
from rest_framework.serializers import Serializer
|
||||
@ -98,6 +99,10 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
).with_subject(self.delegated_subject),
|
||||
}
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/google.svg")
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-provider-google-workspace-form"
|
||||
@ -148,6 +153,7 @@ class GoogleWorkspaceProviderUser(SerializerModel):
|
||||
google_id = models.TextField()
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -173,6 +179,7 @@ class GoogleWorkspaceProviderGroup(SerializerModel):
|
||||
google_id = models.TextField()
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -82,6 +82,27 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 2)
|
||||
|
||||
def test_group_not_created(self):
|
||||
"""Test without group property mappings, no group is created"""
|
||||
self.provider.property_mappings_group.clear()
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
group = Group.objects.create(name=uid)
|
||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNone(google_group)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 1)
|
||||
|
||||
def test_group_create_update(self):
|
||||
"""Test group updating"""
|
||||
uid = generate_id()
|
||||
|
@ -86,6 +86,31 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 2)
|
||||
|
||||
def test_user_not_created(self):
|
||||
"""Test without property mappings, no group is created"""
|
||||
self.provider.property_mappings.clear()
|
||||
uid = generate_id()
|
||||
http = MockHTTP()
|
||||
http.add_response(
|
||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
||||
domains_list_v1_mock,
|
||||
)
|
||||
with patch(
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
||||
provider=self.provider, user=user
|
||||
).first()
|
||||
self.assertIsNone(google_user)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
self.assertEqual(len(http.requests()), 1)
|
||||
|
||||
def test_user_create_update(self):
|
||||
"""Test user updating"""
|
||||
uid = generate_id()
|
||||
|
@ -1,14 +1,15 @@
|
||||
"""MicrosoftEntraProviderGroup API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework import mixins
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.users import UserGroupSerializer
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderGroup
|
||||
|
||||
|
||||
class MicrosoftEntraProviderGroupSerializer(SourceSerializer):
|
||||
class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
|
||||
"""MicrosoftEntraProviderGroup Serializer"""
|
||||
|
||||
group_obj = UserGroupSerializer(source="group", read_only=True)
|
||||
@ -20,10 +21,20 @@ class MicrosoftEntraProviderGroupSerializer(SourceSerializer):
|
||||
"id",
|
||||
"group",
|
||||
"group_obj",
|
||||
"provider",
|
||||
"attributes",
|
||||
]
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class MicrosoftEntraProviderGroupViewSet(UsedByMixin, ModelViewSet):
|
||||
class MicrosoftEntraProviderGroupViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
"""MicrosoftEntraProviderGroup Viewset"""
|
||||
|
||||
queryset = MicrosoftEntraProviderGroup.objects.all().select_related("group")
|
||||
|
@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderMapping
|
||||
|
||||
|
@ -1,14 +1,15 @@
|
||||
"""MicrosoftEntraProviderUser API Views"""
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework import mixins
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from authentik.core.api.groups import GroupMemberSerializer
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderUser
|
||||
|
||||
|
||||
class MicrosoftEntraProviderUserSerializer(SourceSerializer):
|
||||
class MicrosoftEntraProviderUserSerializer(ModelSerializer):
|
||||
"""MicrosoftEntraProviderUser Serializer"""
|
||||
|
||||
user_obj = GroupMemberSerializer(source="user", read_only=True)
|
||||
@ -20,10 +21,20 @@ class MicrosoftEntraProviderUserSerializer(SourceSerializer):
|
||||
"id",
|
||||
"user",
|
||||
"user_obj",
|
||||
"provider",
|
||||
"attributes",
|
||||
]
|
||||
extra_kwargs = {"attributes": {"read_only": True}}
|
||||
|
||||
|
||||
class MicrosoftEntraProviderUserViewSet(UsedByMixin, ModelViewSet):
|
||||
class MicrosoftEntraProviderUserViewSet(
|
||||
mixins.CreateModelMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
"""MicrosoftEntraProviderUser Viewset"""
|
||||
|
||||
queryset = MicrosoftEntraProviderUser.objects.all().select_related("user")
|
||||
|
@ -1,5 +1,6 @@
|
||||
from asyncio import run
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
from azure.core.exceptions import (
|
||||
@ -15,6 +16,7 @@ from kiota_authentication_azure.azure_identity_authentication_provider import (
|
||||
AzureIdentityAuthenticationProvider,
|
||||
)
|
||||
from kiota_http.kiota_client_factory import KiotaClientFactory
|
||||
from msgraph.generated.models.entity import Entity
|
||||
from msgraph.generated.models.o_data_errors.o_data_error import ODataError
|
||||
from msgraph.graph_request_adapter import GraphRequestAdapter, options
|
||||
from msgraph.graph_service_client import GraphServiceClient
|
||||
@ -98,3 +100,10 @@ class MicrosoftEntraSyncClient[TModel: Model, TConnection: Model, TSchema: dict]
|
||||
for email in emails:
|
||||
if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
|
||||
raise BadRequestSyncException(f"Invalid email domain: {email}")
|
||||
|
||||
def entity_as_dict(self, entity: Entity) -> dict:
|
||||
"""Create a dictionary of a model instance, making sure to remove (known) things
|
||||
we can't JSON serialize"""
|
||||
raw_data = asdict(entity)
|
||||
raw_data.pop("backing_store", None)
|
||||
return raw_data
|
||||
|
@ -1,21 +1,17 @@
|
||||
from deepmerge import always_merger
|
||||
from django.db import transaction
|
||||
from msgraph.generated.groups.groups_request_builder import GroupsRequestBuilder
|
||||
from msgraph.generated.models.group import Group as MSGroup
|
||||
from msgraph.generated.models.reference_create import ReferenceCreate
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import Group
|
||||
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
|
||||
from authentik.enterprise.providers.microsoft_entra.models import (
|
||||
MicrosoftEntraProvider,
|
||||
MicrosoftEntraProviderGroup,
|
||||
MicrosoftEntraProviderMapping,
|
||||
MicrosoftEntraProviderUser,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
NotFoundSyncException,
|
||||
@ -24,7 +20,6 @@ from authentik.lib.sync.outgoing.exceptions import (
|
||||
TransientSyncException,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
|
||||
class MicrosoftEntraGroupClient(
|
||||
@ -36,37 +31,17 @@ class MicrosoftEntraGroupClient(
|
||||
connection_type_query = "group"
|
||||
can_discover = True
|
||||
|
||||
def to_schema(self, obj: Group, creating: bool) -> MSGroup:
|
||||
def __init__(self, provider: MicrosoftEntraProvider) -> None:
|
||||
super().__init__(provider)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses(),
|
||||
MicrosoftEntraProviderMapping,
|
||||
["group", "provider", "connection"],
|
||||
)
|
||||
|
||||
def to_schema(self, obj: Group, connection: MicrosoftEntraProviderGroup) -> MSGroup:
|
||||
"""Convert authentik group"""
|
||||
raw_microsoft_group = {}
|
||||
for mapping in (
|
||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
|
||||
):
|
||||
if not isinstance(mapping, MicrosoftEntraProviderMapping):
|
||||
continue
|
||||
try:
|
||||
value = mapping.evaluate(
|
||||
user=None,
|
||||
request=None,
|
||||
group=obj,
|
||||
provider=self.provider,
|
||||
creating=creating,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_microsoft_group, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_microsoft_group:
|
||||
raise StopSync(ValueError("No group mappings configured"), obj)
|
||||
raw_microsoft_group = super().to_schema(obj, connection)
|
||||
try:
|
||||
return MSGroup(**raw_microsoft_group)
|
||||
except TypeError as exc:
|
||||
@ -87,7 +62,7 @@ class MicrosoftEntraGroupClient(
|
||||
|
||||
def create(self, group: Group):
|
||||
"""Create group from scratch and create a connection object"""
|
||||
microsoft_group = self.to_schema(group, True)
|
||||
microsoft_group = self.to_schema(group, None)
|
||||
with transaction.atomic():
|
||||
try:
|
||||
response = self._request(self.client.groups.post(microsoft_group))
|
||||
@ -104,22 +79,29 @@ class MicrosoftEntraGroupClient(
|
||||
)
|
||||
)
|
||||
group_data = self._request(self.client.groups.get(request_configuration))
|
||||
if group_data.odata_count < 1:
|
||||
if group_data.odata_count < 1 or len(group_data.value) < 1:
|
||||
self.logger.warning(
|
||||
"Group which could not be created also does not exist", group=group
|
||||
)
|
||||
return
|
||||
ms_group = group_data.value[0]
|
||||
return MicrosoftEntraProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, microsoft_id=group_data.value[0].id
|
||||
provider=self.provider,
|
||||
group=group,
|
||||
microsoft_id=ms_group.id,
|
||||
attributes=self.entity_as_dict(ms_group),
|
||||
)
|
||||
else:
|
||||
return MicrosoftEntraProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, microsoft_id=response.id
|
||||
provider=self.provider,
|
||||
group=group,
|
||||
microsoft_id=response.id,
|
||||
attributes=self.entity_as_dict(response),
|
||||
)
|
||||
|
||||
def update(self, group: Group, connection: MicrosoftEntraProviderGroup):
|
||||
"""Update existing group"""
|
||||
microsoft_group = self.to_schema(group, False)
|
||||
microsoft_group = self.to_schema(group, connection)
|
||||
microsoft_group.id = connection.microsoft_id
|
||||
try:
|
||||
return self._request(
|
||||
@ -238,4 +220,5 @@ class MicrosoftEntraGroupClient(
|
||||
provider=self.provider,
|
||||
group=matching_authentik_group,
|
||||
microsoft_id=group.id,
|
||||
attributes=self.entity_as_dict(group),
|
||||
)
|
||||
|
@ -1,26 +1,21 @@
|
||||
from deepmerge import always_merger
|
||||
from django.db import transaction
|
||||
from msgraph.generated.models.user import User as MSUser
|
||||
from msgraph.generated.users.users_request_builder import UsersRequestBuilder
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
|
||||
from authentik.enterprise.providers.microsoft_entra.models import (
|
||||
MicrosoftEntraProvider,
|
||||
MicrosoftEntraProviderMapping,
|
||||
MicrosoftEntraProviderUser,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
TransientSyncException,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.utils import delete_none_values
|
||||
|
||||
|
||||
@ -31,34 +26,17 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
connection_type_query = "user"
|
||||
can_discover = True
|
||||
|
||||
def to_schema(self, obj: User, creating: bool) -> MSUser:
|
||||
def __init__(self, provider: MicrosoftEntraProvider) -> None:
|
||||
super().__init__(provider)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self.provider.property_mappings.all().order_by("name").select_subclasses(),
|
||||
MicrosoftEntraProviderMapping,
|
||||
["provider", "connection"],
|
||||
)
|
||||
|
||||
def to_schema(self, obj: User, connection: MicrosoftEntraProviderUser) -> MSUser:
|
||||
"""Convert authentik user"""
|
||||
raw_microsoft_user = {}
|
||||
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
|
||||
if not isinstance(mapping, MicrosoftEntraProviderMapping):
|
||||
continue
|
||||
try:
|
||||
value = mapping.evaluate(
|
||||
user=obj,
|
||||
request=None,
|
||||
provider=self.provider,
|
||||
creating=creating,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_microsoft_user, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_microsoft_user:
|
||||
raise StopSync(ValueError("No user mappings configured"), obj)
|
||||
raw_microsoft_user = super().to_schema(obj, connection)
|
||||
try:
|
||||
return MSUser(**delete_none_values(raw_microsoft_user))
|
||||
except TypeError as exc:
|
||||
@ -89,7 +67,7 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
|
||||
def create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
microsoft_user = self.to_schema(user, True)
|
||||
microsoft_user = self.to_schema(user, None)
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
with transaction.atomic():
|
||||
try:
|
||||
@ -105,24 +83,32 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
)
|
||||
)
|
||||
user_data = self._request(self.client.users.get(request_configuration))
|
||||
if user_data.odata_count < 1:
|
||||
if user_data.odata_count < 1 or len(user_data.value) < 1:
|
||||
self.logger.warning(
|
||||
"User which could not be created also does not exist", user=user
|
||||
)
|
||||
return
|
||||
ms_user = user_data.value[0]
|
||||
return MicrosoftEntraProviderUser.objects.create(
|
||||
provider=self.provider, user=user, microsoft_id=user_data.value[0].id
|
||||
provider=self.provider,
|
||||
user=user,
|
||||
microsoft_id=ms_user.id,
|
||||
attributes=self.entity_as_dict(ms_user),
|
||||
)
|
||||
except TransientSyncException as exc:
|
||||
raise exc
|
||||
else:
|
||||
print(self.entity_as_dict(response))
|
||||
return MicrosoftEntraProviderUser.objects.create(
|
||||
provider=self.provider, user=user, microsoft_id=response.id
|
||||
provider=self.provider,
|
||||
user=user,
|
||||
microsoft_id=response.id,
|
||||
attributes=self.entity_as_dict(response),
|
||||
)
|
||||
|
||||
def update(self, user: User, connection: MicrosoftEntraProviderUser):
|
||||
"""Update existing user"""
|
||||
microsoft_user = self.to_schema(user, False)
|
||||
microsoft_user = self.to_schema(user, connection)
|
||||
self.check_email_valid(microsoft_user.user_principal_name)
|
||||
self._request(self.client.users.by_user_id(connection.microsoft_id).patch(microsoft_user))
|
||||
|
||||
@ -147,4 +133,5 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
||||
provider=self.provider,
|
||||
user=matching_authentik_user,
|
||||
microsoft_id=user.id,
|
||||
attributes=self.entity_as_dict(user),
|
||||
)
|
||||
|
@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.0.6 on 2024-05-23 20:48
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_microsoft_entra", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="microsoftentraprovidergroup",
|
||||
name="attributes",
|
||||
field=models.JSONField(default=dict),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="microsoftentraprovideruser",
|
||||
name="attributes",
|
||||
field=models.JSONField(default=dict),
|
||||
),
|
||||
]
|
@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
from azure.identity.aio import ClientSecretCredential
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
@ -87,6 +88,10 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/azuread.svg")
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-provider-microsoft-entra-form"
|
||||
@ -137,6 +142,7 @@ class MicrosoftEntraProviderUser(SerializerModel):
|
||||
microsoft_id = models.TextField()
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(MicrosoftEntraProvider, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -162,6 +168,7 @@ class MicrosoftEntraProviderGroup(SerializerModel):
|
||||
microsoft_id = models.TextField()
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(MicrosoftEntraProvider, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -93,6 +93,38 @@ class MicrosoftEntraGroupTests(TestCase):
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
group_create.assert_called_once()
|
||||
|
||||
def test_group_not_created(self):
|
||||
"""Test without group property mappings, no group is created"""
|
||||
self.provider.property_mappings_group.clear()
|
||||
uid = generate_id()
|
||||
with (
|
||||
patch(
|
||||
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
|
||||
MagicMock(return_value={"credentials": self.creds}),
|
||||
),
|
||||
patch(
|
||||
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
|
||||
AsyncMock(
|
||||
return_value=OrganizationCollectionResponse(
|
||||
value=[
|
||||
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
|
||||
]
|
||||
)
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"msgraph.generated.groups.groups_request_builder.GroupsRequestBuilder.post",
|
||||
AsyncMock(return_value=MSGroup(id=generate_id())),
|
||||
) as group_create,
|
||||
):
|
||||
group = Group.objects.create(name=uid)
|
||||
microsoft_group = MicrosoftEntraProviderGroup.objects.filter(
|
||||
provider=self.provider, group=group
|
||||
).first()
|
||||
self.assertIsNone(microsoft_group)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
group_create.assert_not_called()
|
||||
|
||||
def test_group_create_update(self):
|
||||
"""Test group updating"""
|
||||
uid = generate_id()
|
||||
|
@ -94,6 +94,42 @@ class MicrosoftEntraUserTests(TestCase):
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
user_create.assert_called_once()
|
||||
|
||||
def test_user_not_created(self):
|
||||
"""Test without property mappings, no group is created"""
|
||||
self.provider.property_mappings.clear()
|
||||
uid = generate_id()
|
||||
with (
|
||||
patch(
|
||||
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
|
||||
MagicMock(return_value={"credentials": self.creds}),
|
||||
),
|
||||
patch(
|
||||
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
|
||||
AsyncMock(
|
||||
return_value=OrganizationCollectionResponse(
|
||||
value=[
|
||||
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
|
||||
]
|
||||
)
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"msgraph.generated.users.users_request_builder.UsersRequestBuilder.post",
|
||||
AsyncMock(return_value=MSUser(id=generate_id())),
|
||||
) as user_create,
|
||||
):
|
||||
user = User.objects.create(
|
||||
username=uid,
|
||||
name=f"{uid} {uid}",
|
||||
email=f"{uid}@goauthentik.io",
|
||||
)
|
||||
microsoft_user = MicrosoftEntraProviderUser.objects.filter(
|
||||
provider=self.provider, user=user
|
||||
).first()
|
||||
self.assertIsNone(microsoft_user)
|
||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||
user_create.assert_not_called()
|
||||
|
||||
def test_user_create_update(self):
|
||||
"""Test user updating"""
|
||||
uid = generate_id()
|
||||
|
@ -7,7 +7,7 @@ from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import JSONDictField
|
||||
from authentik.enterprise.providers.rac.models import RACPropertyMapping
|
||||
|
@ -7,6 +7,7 @@ from deepmerge import always_merger
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.http import HttpRequest
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
@ -63,6 +64,10 @@ class RACProvider(Provider):
|
||||
Can return None for providers that are not URL-based"""
|
||||
return "goauthentik.io://providers/rac/launch"
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/rac.svg")
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-provider-rac-form"
|
||||
|
@ -18,9 +18,12 @@ class SourceStageSerializer(EnterpriseRequiredMixin, StageSerializer):
|
||||
source = Source.objects.filter(pk=_source.pk).select_subclasses().first()
|
||||
if not source:
|
||||
raise ValidationError("Invalid source")
|
||||
login_button = source.ui_login_button(self.context["request"])
|
||||
if not login_button:
|
||||
raise ValidationError("Invalid source selected, only web-based sources are supported.")
|
||||
if "request" in self.context:
|
||||
login_button = source.ui_login_button(self.context["request"])
|
||||
if not login_button:
|
||||
raise ValidationError(
|
||||
"Invalid source selected, only web-based sources are supported."
|
||||
)
|
||||
return source
|
||||
|
||||
class Meta:
|
||||
|
@ -54,7 +54,7 @@ class SourceStageView(ChallengeStageView):
|
||||
def create_flow_token(self) -> FlowToken:
|
||||
"""Save the current flow state in a token that can be used to resume this flow"""
|
||||
pending_user: User = self.get_pending_user()
|
||||
if pending_user.is_anonymous:
|
||||
if pending_user.is_anonymous or not pending_user.pk:
|
||||
pending_user = get_anonymous_user()
|
||||
current_stage: SourceStage = self.executor.current_stage
|
||||
identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}")
|
||||
|
@ -19,7 +19,8 @@ from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.admin.api.metrics import CoordinateSerializer
|
||||
from authentik.core.api.utils import PassiveSerializer, TypeCreateSerializer
|
||||
from authentik.core.api.object_types import TypeCreateSerializer
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.events.models import Event, EventAction
|
||||
|
||||
|
||||
|
@ -45,7 +45,7 @@ class GeoIPContextProcessor(MMDBContextProcessor):
|
||||
|
||||
def enrich_context(self, request: HttpRequest) -> dict:
|
||||
# Different key `geoip` vs `geo` for legacy reasons
|
||||
return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))}
|
||||
return {"geoip": self.city_dict(ClientIPMiddleware.get_client_ip(request))}
|
||||
|
||||
def city(self, ip_address: str) -> City | None:
|
||||
"""Wrapper for Reader.city"""
|
||||
|
@ -10,10 +10,10 @@ from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.object_types import TypesMixin
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||
from authentik.core.api.utils import MetaNameSerializer
|
||||
from authentik.core.types import UserSettingSerializer
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
from authentik.flows.api.flows import FlowSetSerializer
|
||||
from authentik.flows.models import ConfigurableStage, Stage
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
@ -47,6 +47,7 @@ class StageSerializer(ModelSerializer, MetaNameSerializer):
|
||||
|
||||
|
||||
class StageViewSet(
|
||||
TypesMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
@ -63,25 +64,6 @@ class StageViewSet(
|
||||
def get_queryset(self): # pragma: no cover
|
||||
return Stage.objects.select_subclasses().prefetch_related("flow_set")
|
||||
|
||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def types(self, request: Request) -> Response:
|
||||
"""Get all creatable stage types"""
|
||||
data = []
|
||||
for subclass in all_subclasses(self.queryset.model, False):
|
||||
subclass: Stage
|
||||
data.append(
|
||||
{
|
||||
"name": subclass._meta.verbose_name,
|
||||
"description": subclass.__doc__,
|
||||
"component": subclass().component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||
}
|
||||
)
|
||||
data = sorted(data, key=lambda x: x["name"])
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
||||
|
||||
@extend_schema(responses={200: UserSettingSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def user_settings(self, request: Request) -> Response:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from base64 import b64encode
|
||||
from functools import cache as funccache
|
||||
from hashlib import md5
|
||||
from hashlib import md5, sha256
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@ -20,7 +20,7 @@ from authentik.tenants.utils import get_current_tenant
|
||||
if TYPE_CHECKING:
|
||||
from authentik.core.models import User
|
||||
|
||||
GRAVATAR_URL = "https://secure.gravatar.com"
|
||||
GRAVATAR_URL = "https://www.gravatar.com"
|
||||
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
||||
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
|
||||
CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available"
|
||||
@ -55,10 +55,9 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
|
||||
if not cache.get(CACHE_KEY_GRAVATAR_AVAILABLE, True):
|
||||
return None
|
||||
|
||||
# gravatar uses md5 for their URLs, so md5 can't be avoided
|
||||
mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||
parameters = [("size", "158"), ("rating", "g"), ("default", "404")]
|
||||
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
|
||||
mail_hash = sha256(user.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||
parameters = {"size": "158", "rating": "g", "default": "404"}
|
||||
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters)}"
|
||||
|
||||
full_key = CACHE_KEY_GRAVATAR + mail_hash
|
||||
if cache.has_key(full_key):
|
||||
@ -84,7 +83,9 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
|
||||
|
||||
def generate_colors(text: str) -> tuple[str, str]:
|
||||
"""Generate colours based on `text`"""
|
||||
color = int(md5(text.lower().encode("utf-8")).hexdigest(), 16) % 0xFFFFFF # nosec
|
||||
color = (
|
||||
int(md5(text.lower().encode("utf-8"), usedforsecurity=False).hexdigest(), 16) % 0xFFFFFF
|
||||
) # nosec
|
||||
|
||||
# Get a (somewhat arbitrarily) reduced scope of colors
|
||||
# to avoid too dark or light backgrounds
|
||||
@ -179,7 +180,7 @@ def avatar_mode_generated(user: "User", mode: str) -> str | None:
|
||||
|
||||
def avatar_mode_url(user: "User", mode: str) -> str | None:
|
||||
"""Format url"""
|
||||
mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||
mail_hash = md5(user.email.lower().encode("utf-8"), usedforsecurity=False).hexdigest() # nosec
|
||||
return mode % {
|
||||
"username": user.username,
|
||||
"mail_hash": mail_hash,
|
||||
|
@ -304,6 +304,12 @@ class ConfigLoader:
|
||||
"""Wrapper for get that converts value into boolean"""
|
||||
return str(self.get(path, default)).lower() == "true"
|
||||
|
||||
def get_keys(self, path: str, sep=".") -> list[str]:
|
||||
"""List attribute keys by using yaml path"""
|
||||
root = self.raw
|
||||
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr({}))
|
||||
return attr.keys()
|
||||
|
||||
def get_dict_from_b64_json(self, path: str, default=None) -> dict:
|
||||
"""Wrapper for get that converts value from Base64 encoded string into dictionary"""
|
||||
config_value = self.get(path)
|
||||
|
@ -10,6 +10,10 @@ postgresql:
|
||||
use_pgpool: false
|
||||
test:
|
||||
name: test_authentik
|
||||
read_replicas: {}
|
||||
# For example
|
||||
# 0:
|
||||
# host: replica1.example.com
|
||||
|
||||
listen:
|
||||
listen_http: 0.0.0.0:9000
|
||||
|
@ -5,6 +5,7 @@ import socket
|
||||
from collections.abc import Iterable
|
||||
from ipaddress import ip_address, ip_network
|
||||
from textwrap import indent
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
|
||||
from cachetools import TLRUCache, cached
|
||||
@ -184,7 +185,7 @@ class BaseEvaluator:
|
||||
full_expression += f"\nresult = handler({handler_signature})"
|
||||
return full_expression
|
||||
|
||||
def compile(self, expression: str) -> Any:
|
||||
def compile(self, expression: str) -> CodeType:
|
||||
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
|
||||
param_keys = self._context.keys()
|
||||
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
|
||||
|
67
authentik/lib/sync/mapper.py
Normal file
67
authentik/lib/sync/mapper.py
Normal file
@ -0,0 +1,67 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.http import HttpRequest
|
||||
|
||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import PropertyMapping, User
|
||||
|
||||
|
||||
class PropertyMappingManager:
|
||||
"""Pre-compile and cache property mappings when an identical
|
||||
set is used multiple times"""
|
||||
|
||||
query_set: QuerySet[PropertyMapping]
|
||||
mapping_subclass: type[PropertyMapping]
|
||||
|
||||
_evaluators: list[PropertyMappingEvaluator]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qs: QuerySet[PropertyMapping],
|
||||
# Expected subclass of PropertyMappings, any objects in the queryset
|
||||
# that are not an instance of this class will be discarded
|
||||
mapping_subclass: type[PropertyMapping],
|
||||
# As they keys of parameters are part of the compilation,
|
||||
# we need a list of all parameter names that will be used during evaluation
|
||||
context_keys: list[str],
|
||||
) -> None:
|
||||
self.query_set = qs
|
||||
self.mapping_subclass = mapping_subclass
|
||||
self.context_keys = context_keys
|
||||
self.compile()
|
||||
|
||||
def compile(self):
|
||||
self._evaluators = []
|
||||
for mapping in self.query_set:
|
||||
if not isinstance(mapping, self.mapping_subclass):
|
||||
continue
|
||||
evaluator = PropertyMappingEvaluator(
|
||||
mapping, **{key: None for key in self.context_keys}
|
||||
)
|
||||
# Compile and cache expression
|
||||
evaluator.compile()
|
||||
self._evaluators.append(evaluator)
|
||||
|
||||
def iter_eval(
|
||||
self,
|
||||
user: User | None,
|
||||
request: HttpRequest | None,
|
||||
return_mapping: bool = False,
|
||||
**kwargs,
|
||||
) -> Generator[tuple[dict, PropertyMapping], None]:
|
||||
"""Iterate over all mappings that were pre-compiled and
|
||||
execute all of them with the given context"""
|
||||
for mapping in self._evaluators:
|
||||
mapping.set_context(user, request, **kwargs)
|
||||
try:
|
||||
value = mapping.evaluate(mapping.model.expression)
|
||||
except Exception as exc:
|
||||
raise PropertyMappingExpressionException(exc, mapping.model) from exc
|
||||
if value is None:
|
||||
continue
|
||||
if return_mapping:
|
||||
yield value, mapping.model
|
||||
else:
|
||||
yield value
|
@ -3,3 +3,6 @@
|
||||
PAGE_SIZE = 100
|
||||
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
|
||||
HTTP_CONFLICT = 409
|
||||
HTTP_NO_CONTENT = 204
|
||||
HTTP_SERVICE_UNAVAILABLE = 503
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
|
@ -47,8 +47,10 @@ class OutgoingSyncProviderStatusMixin:
|
||||
uid=slugify(provider.name),
|
||||
)
|
||||
)
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
"is_running": provider.sync_lock.locked(),
|
||||
}
|
||||
with provider.sync_lock as lock_acquired:
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
# If we could not acquire the lock, it means a task is using it, and thus is running
|
||||
"is_running": not lock_acquired,
|
||||
}
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
|
@ -3,10 +3,18 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deepmerge import always_merger
|
||||
from django.db import DatabaseError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models import Model
|
||||
@ -28,6 +36,7 @@ class BaseOutgoingSyncClient[
|
||||
provider: TProvider
|
||||
connection_type: type[TConnection]
|
||||
connection_type_query: str
|
||||
mapper: PropertyMappingManager
|
||||
|
||||
can_discover = False
|
||||
|
||||
@ -70,9 +79,35 @@ class BaseOutgoingSyncClient[
|
||||
"""Delete object from destination"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_schema(self, obj: TModel, creating: bool) -> TSchema:
|
||||
def to_schema(self, obj: TModel, connection: TConnection | None, **defaults) -> TSchema:
|
||||
"""Convert object to destination schema"""
|
||||
raise NotImplementedError()
|
||||
raw_final_object = {}
|
||||
try:
|
||||
eval_kwargs = {
|
||||
"request": None,
|
||||
"provider": self.provider,
|
||||
"connection": connection,
|
||||
obj._meta.model_name: obj,
|
||||
}
|
||||
eval_kwargs.setdefault("user", None)
|
||||
for value in self.mapper.iter_eval(**eval_kwargs):
|
||||
try:
|
||||
always_merger.merge(raw_final_object, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except PropertyMappingExpressionException as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=exc.mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, exc.mapping) from exc
|
||||
if not raw_final_object:
|
||||
raise StopSync(ValueError("No user mappings configured"), obj)
|
||||
for key, value in defaults.items():
|
||||
raw_final_object.setdefault(key, value)
|
||||
return raw_final_object
|
||||
|
||||
def discover(self):
|
||||
"""Optional method. Can be used to implement a "discovery" where
|
||||
|
@ -1,11 +1,10 @@
|
||||
from typing import Any, Self
|
||||
|
||||
from django.core.cache import cache
|
||||
import pglock
|
||||
from django.db import connection
|
||||
from django.db.models import Model, QuerySet, TextChoices
|
||||
from redis.lock import Lock
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.sync.outgoing import PAGE_TIMEOUT
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
|
||||
|
||||
@ -32,10 +31,10 @@ class OutgoingSyncProvider(Model):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def sync_lock(self) -> Lock:
|
||||
"""Redis lock to prevent multiple parallel syncs happening"""
|
||||
return Lock(
|
||||
cache.client.get_client(),
|
||||
name=f"goauthentik.io/providers/outgoing-sync/{str(self.pk)}",
|
||||
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
|
||||
def sync_lock(self) -> pglock.advisory:
|
||||
"""Postgres lock for syncing SCIM to prevent multiple parallel syncs happening"""
|
||||
return pglock.advisory(
|
||||
lock_id=f"goauthentik.io/{connection.schema_name}/providers/outgoing-sync/{str(self.pk)}",
|
||||
timeout=0,
|
||||
side_effect=pglock.Return,
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
|
||||
from celery.exceptions import Retry
|
||||
from celery.result import allow_join_result
|
||||
@ -64,17 +65,16 @@ class SyncTasks:
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
lock = provider.sync_lock
|
||||
if lock.locked():
|
||||
self.logger.debug("Sync locked, skipping task", source=provider.name)
|
||||
return
|
||||
task.set_uid(slugify(provider.name))
|
||||
messages = []
|
||||
messages.append(_("Starting full provider sync"))
|
||||
self.logger.debug("Starting provider sync")
|
||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||
with allow_join_result(), lock:
|
||||
with allow_join_result(), provider.sync_lock as lock_acquired:
|
||||
if not lock_acquired:
|
||||
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
|
||||
return
|
||||
try:
|
||||
for page in users_paginator.page_range:
|
||||
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
|
||||
@ -83,7 +83,7 @@ class SyncTasks:
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
).get():
|
||||
messages.append(msg)
|
||||
messages.append(LogEvent(**msg))
|
||||
for page in groups_paginator.page_range:
|
||||
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
|
||||
for msg in sync_objects.apply_async(
|
||||
@ -91,7 +91,7 @@ class SyncTasks:
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
).get():
|
||||
messages.append(msg)
|
||||
messages.append(LogEvent(**msg))
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
raise task.retry(exc=exc) from exc
|
||||
@ -129,57 +129,63 @@ class SyncTasks:
|
||||
except BadRequestSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, obj=obj)
|
||||
messages.append(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
attributes={"arguments": exc.args[1:]},
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
attributes={"arguments": exc.args[1:]},
|
||||
)
|
||||
)
|
||||
)
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
||||
messages.append(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to transient error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to transient error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
)
|
||||
)
|
||||
)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc)
|
||||
messages.append(
|
||||
LogEvent(
|
||||
_(
|
||||
"Stopping sync due to error: {error}".format_map(
|
||||
{
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
"Stopping sync due to error: {error}".format_map(
|
||||
{
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger="",
|
||||
)
|
||||
)
|
||||
)
|
||||
break
|
||||
|
@ -169,3 +169,9 @@ class TestConfig(TestCase):
|
||||
self.assertEqual(config.get("cache.timeout_flows"), "32m")
|
||||
self.assertEqual(config.get("cache.timeout_policies"), "3920ns")
|
||||
self.assertEqual(config.get("cache.timeout_reputation"), "298382us")
|
||||
|
||||
def test_get_keys(self):
|
||||
"""Test get_keys"""
|
||||
config = ConfigLoader()
|
||||
config.set("foo.bar", "baz")
|
||||
self.assertEqual(list(config.get_keys("foo")), ["bar"])
|
||||
|
@ -12,7 +12,7 @@ from authentik.lib.config import CONFIG
|
||||
SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST"
|
||||
|
||||
|
||||
def all_subclasses(cls, sort=True):
|
||||
def all_subclasses[T](cls: T, sort=True) -> list[T] | set[T]:
|
||||
"""Recursively return all subclassess of cls"""
|
||||
classes = set(cls.__subclasses__()).union(
|
||||
[s for c in cls.__subclasses__() for s in all_subclasses(c, sort=sort)]
|
||||
|
@ -117,8 +117,12 @@ class OutpostHealthSerializer(PassiveSerializer):
|
||||
uid = CharField(read_only=True)
|
||||
last_seen = DateTimeField(read_only=True)
|
||||
version = CharField(read_only=True)
|
||||
version_should = CharField(read_only=True)
|
||||
golang_version = CharField(read_only=True)
|
||||
openssl_enabled = BooleanField(read_only=True)
|
||||
openssl_version = CharField(read_only=True)
|
||||
fips_enabled = BooleanField(read_only=True)
|
||||
|
||||
version_should = CharField(read_only=True)
|
||||
version_outdated = BooleanField(read_only=True)
|
||||
|
||||
build_hash = CharField(read_only=True, required=False)
|
||||
@ -173,6 +177,10 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
|
||||
"version_should": state.version_should,
|
||||
"version_outdated": state.version_outdated,
|
||||
"build_hash": state.build_hash,
|
||||
"golang_version": state.golang_version,
|
||||
"openssl_enabled": state.openssl_enabled,
|
||||
"openssl_version": state.openssl_version,
|
||||
"fips_enabled": state.fips_enabled,
|
||||
"hostname": state.hostname,
|
||||
"build_hash_should": get_build_hash(),
|
||||
}
|
||||
|
@ -15,9 +15,12 @@ from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
|
||||
from authentik.core.api.object_types import TypesMixin
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
from authentik.core.api.utils import (
|
||||
MetaNameSerializer,
|
||||
PassiveSerializer,
|
||||
)
|
||||
from authentik.outposts.models import (
|
||||
DockerServiceConnection,
|
||||
KubernetesServiceConnection,
|
||||
@ -57,6 +60,7 @@ class ServiceConnectionStateSerializer(PassiveSerializer):
|
||||
|
||||
|
||||
class ServiceConnectionViewSet(
|
||||
TypesMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
@ -70,23 +74,6 @@ class ServiceConnectionViewSet(
|
||||
search_fields = ["name"]
|
||||
filterset_fields = ["name"]
|
||||
|
||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def types(self, request: Request) -> Response:
|
||||
"""Get all creatable service connection types"""
|
||||
data = []
|
||||
for subclass in all_subclasses(self.queryset.model):
|
||||
subclass: OutpostServiceConnection
|
||||
data.append(
|
||||
{
|
||||
"name": subclass._meta.verbose_name,
|
||||
"description": subclass.__doc__,
|
||||
"component": subclass().component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
}
|
||||
)
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
||||
|
||||
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def state(self, request: Request, pk: str) -> Response:
|
||||
|
@ -121,6 +121,10 @@ class OutpostConsumer(JsonWebsocketConsumer):
|
||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||
state.version = msg.args.pop("version", None)
|
||||
state.build_hash = msg.args.pop("buildHash", "")
|
||||
state.golang_version = msg.args.pop("golangVersion", "")
|
||||
state.openssl_enabled = msg.args.pop("opensslEnabled", False)
|
||||
state.openssl_version = msg.args.pop("opensslVersion", "")
|
||||
state.fips_enabled = msg.args.pop("fipsEnabled", False)
|
||||
state.args.update(msg.args)
|
||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||
return
|
||||
|
@ -124,7 +124,6 @@ class KubernetesObjectReconciler(Generic[T]):
|
||||
self.update(current, reference)
|
||||
self.logger.debug("Updating")
|
||||
except (OpenApiException, HTTPError) as exc:
|
||||
|
||||
if isinstance(exc, ApiException) and exc.status == 422: # noqa: PLR2004
|
||||
self.logger.debug("Failed to update current, triggering re-create")
|
||||
self._recreate(current=current, reference=reference)
|
||||
|
@ -131,7 +131,7 @@ class OutpostServiceConnection(models.Model):
|
||||
verbose_name = _("Outpost Service-Connection")
|
||||
verbose_name_plural = _("Outpost Service-Connections")
|
||||
|
||||
def __str__(self) -> __version__:
|
||||
def __str__(self) -> str:
|
||||
return f"Outpost service connection {self.name}"
|
||||
|
||||
@property
|
||||
@ -434,6 +434,10 @@ class OutpostState:
|
||||
version: str | None = field(default=None)
|
||||
version_should: Version = field(default=OUR_VERSION)
|
||||
build_hash: str = field(default="")
|
||||
golang_version: str = field(default="")
|
||||
openssl_enabled: bool = field(default=False)
|
||||
openssl_version: str = field(default="")
|
||||
fips_enabled: bool = field(default=False)
|
||||
hostname: str = field(default="")
|
||||
args: dict = field(default_factory=dict)
|
||||
|
||||
|
@ -13,10 +13,13 @@ from rest_framework.viewsets import GenericViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.applications import user_app_cache_key
|
||||
from authentik.core.api.object_types import TypesMixin
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer
|
||||
from authentik.core.api.utils import (
|
||||
CacheSerializer,
|
||||
MetaNameSerializer,
|
||||
)
|
||||
from authentik.events.logs import LogEventSerializer, capture_logs
|
||||
from authentik.lib.utils.reflection import all_subclasses
|
||||
from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer
|
||||
from authentik.policies.models import Policy, PolicyBinding
|
||||
from authentik.policies.process import PolicyProcess
|
||||
@ -69,6 +72,7 @@ class PolicySerializer(ModelSerializer, MetaNameSerializer):
|
||||
|
||||
|
||||
class PolicyViewSet(
|
||||
TypesMixin,
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
@ -89,23 +93,6 @@ class PolicyViewSet(
|
||||
def get_queryset(self): # pragma: no cover
|
||||
return Policy.objects.select_subclasses().prefetch_related("bindings", "promptstage_set")
|
||||
|
||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def types(self, request: Request) -> Response:
|
||||
"""Get all creatable policy types"""
|
||||
data = []
|
||||
for subclass in all_subclasses(self.queryset.model):
|
||||
subclass: Policy
|
||||
data.append(
|
||||
{
|
||||
"name": subclass._meta.verbose_name,
|
||||
"description": subclass.__doc__,
|
||||
"component": subclass().component,
|
||||
"model_name": subclass._meta.model_name,
|
||||
}
|
||||
)
|
||||
return Response(TypeCreateSerializer(data, many=True).data)
|
||||
|
||||
@permission_required(None, ["authentik_policies.view_policy_cache"])
|
||||
@extend_schema(responses={200: CacheSerializer(many=False)})
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
|
@ -96,16 +96,42 @@ class TestEvaluator(TestCase):
|
||||
execution_logging=True,
|
||||
expression="ak_message(request.http_request.path)\nreturn True",
|
||||
)
|
||||
tmpl = f"""
|
||||
ak_message(request.http_request.path)
|
||||
res = ak_call_policy('{expr.name}')
|
||||
ak_message(request.http_request.path)
|
||||
for msg in res.messages:
|
||||
ak_message(msg)
|
||||
"""
|
||||
evaluator = PolicyEvaluator("test")
|
||||
evaluator.set_policy_request(self.request)
|
||||
res = evaluator.evaluate(tmpl)
|
||||
expr2 = ExpressionPolicy.objects.create(
|
||||
name=generate_id(),
|
||||
execution_logging=True,
|
||||
expression=f"""
|
||||
ak_message(request.http_request.path)
|
||||
res = ak_call_policy('{expr.name}')
|
||||
ak_message(request.http_request.path)
|
||||
for msg in res.messages:
|
||||
ak_message(msg)
|
||||
""",
|
||||
)
|
||||
proc = PolicyProcess(PolicyBinding(policy=expr2), request=self.request, connection=None)
|
||||
res = proc.profiling_wrapper()
|
||||
self.assertEqual(res.messages, ("/", "/", "/"))
|
||||
|
||||
def test_call_policy_test_like(self):
|
||||
"""test ak_call_policy without `obj` set, as if it was when testing policies"""
|
||||
expr = ExpressionPolicy.objects.create(
|
||||
name=generate_id(),
|
||||
execution_logging=True,
|
||||
expression="ak_message(request.http_request.path)\nreturn True",
|
||||
)
|
||||
expr2 = ExpressionPolicy.objects.create(
|
||||
name=generate_id(),
|
||||
execution_logging=True,
|
||||
expression=f"""
|
||||
ak_message(request.http_request.path)
|
||||
res = ak_call_policy('{expr.name}')
|
||||
ak_message(request.http_request.path)
|
||||
for msg in res.messages:
|
||||
ak_message(msg)
|
||||
""",
|
||||
)
|
||||
self.request.obj = None
|
||||
proc = PolicyProcess(PolicyBinding(policy=expr2), request=self.request, connection=None)
|
||||
res = proc.profiling_wrapper()
|
||||
self.assertEqual(res.messages, ("/", "/", "/"))
|
||||
|
||||
|
||||
|
@ -128,8 +128,8 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
binding_order=self.binding.order,
|
||||
binding_target_type=self.binding.target_type,
|
||||
binding_target_name=self.binding.target_name,
|
||||
object_pk=str(self.request.obj.pk),
|
||||
object_type=class_to_path(self.request.obj.__class__),
|
||||
object_pk=str(self.request.obj.pk) if self.request.obj else "",
|
||||
object_type=class_to_path(self.request.obj.__class__) if self.request.obj else "",
|
||||
mode="execute_process",
|
||||
).time(),
|
||||
):
|
||||
|
@ -3,6 +3,7 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
from django.db import models
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
@ -90,6 +91,10 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
||||
def component(self) -> str:
|
||||
return "ak-provider-ldap-form"
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/ldap.png")
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.providers.ldap.api import LDAPProviderSerializer
|
||||
|
@ -8,7 +8,7 @@ from rest_framework.fields import CharField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.providers.oauth2.models import ScopeMapping
|
||||
|
||||
|
@ -15,6 +15,7 @@ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
|
||||
from dacite.core import from_dict
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
from django.templatetags.static import static
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from jwt import encode
|
||||
@ -262,6 +263,10 @@ class OAuth2Provider(Provider):
|
||||
LOGGER.warning("Failed to format launch url", exc=exc)
|
||||
return None
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/openidconnect.svg")
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-provider-oauth2-form"
|
||||
|
@ -15,7 +15,6 @@ from authentik.core.expression.exceptions import PropertyMappingExpressionExcept
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.challenge import PermissionDict
|
||||
from authentik.providers.oauth2.constants import (
|
||||
SCOPE_AUTHENTIK_API,
|
||||
SCOPE_GITHUB_ORG_READ,
|
||||
SCOPE_GITHUB_USER,
|
||||
SCOPE_GITHUB_USER_EMAIL,
|
||||
@ -57,7 +56,6 @@ class UserInfoView(View):
|
||||
SCOPE_GITHUB_USER_READ: _("GitHub Compatibility: Access your User Information"),
|
||||
SCOPE_GITHUB_USER_EMAIL: _("GitHub Compatibility: Access you Email addresses"),
|
||||
SCOPE_GITHUB_ORG_READ: _("GitHub Compatibility: Access your Groups"),
|
||||
SCOPE_AUTHENTIK_API: _("authentik API Access on behalf of your user"),
|
||||
}
|
||||
for scope in scopes:
|
||||
if scope in special_scope_map:
|
||||
|
@ -6,6 +6,7 @@ from random import SystemRandom
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from django.db import models
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
@ -115,6 +116,10 @@ class ProxyProvider(OutpostModel, OAuth2Provider):
|
||||
def component(self) -> str:
|
||||
return "ak-provider-proxy-form"
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/proxy.svg")
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.providers.proxy.api import ProxyProviderSerializer
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Radius Provider"""
|
||||
|
||||
from django.db import models
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
@ -46,6 +47,10 @@ class RadiusProvider(OutpostModel, Provider):
|
||||
def component(self) -> str:
|
||||
return "ak-provider-radius-form"
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/radius.svg")
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.providers.radius.api import RadiusProviderSerializer
|
||||
|
@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.providers.saml.models import SAMLPropertyMapping
|
||||
|
@ -1,11 +1,13 @@
|
||||
"""authentik saml_idp Models"""
|
||||
"""authentik SAML Provider Models"""
|
||||
|
||||
from django.db import models
|
||||
from django.templatetags.static import static
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.object_types import CreatableType
|
||||
from authentik.core.models import PropertyMapping, Provider
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
@ -159,6 +161,10 @@ class SAMLProvider(Provider):
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return None
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/saml.png")
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.providers.saml.api.providers import SAMLProviderSerializer
|
||||
@ -189,7 +195,7 @@ class SAMLPropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.providers.saml.api.property_mapping import SAMLPropertyMappingSerializer
|
||||
from authentik.providers.saml.api.property_mappings import SAMLPropertyMappingSerializer
|
||||
|
||||
return SAMLPropertyMappingSerializer
|
||||
|
||||
@ -200,3 +206,20 @@ class SAMLPropertyMapping(PropertyMapping):
|
||||
class Meta:
|
||||
verbose_name = _("SAML Property Mapping")
|
||||
verbose_name_plural = _("SAML Property Mappings")
|
||||
|
||||
|
||||
class SAMLProviderImportModel(CreatableType, Provider):
|
||||
"""Create a SAML Provider by importing its Metadata."""
|
||||
|
||||
@property
|
||||
def component(self):
|
||||
return "ak-provider-saml-import-form"
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/saml.png")
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
verbose_name = _("SAML Provider from Metadata")
|
||||
verbose_name_plural = _("SAML Providers from Metadata")
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from authentik.providers.saml.api.property_mapping import SAMLPropertyMappingViewSet
|
||||
from authentik.providers.saml.api.property_mappings import SAMLPropertyMappingViewSet
|
||||
from authentik.providers.saml.api.providers import SAMLProviderViewSet
|
||||
from authentik.providers.saml.views import metadata, slo, sso
|
||||
|
||||
|
@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.providers.scim.models import SCIMMapping
|
||||
|
||||
|
@ -6,9 +6,18 @@ from django.http import HttpResponseBadRequest, HttpResponseNotFound
|
||||
from pydantic import ValidationError
|
||||
from requests import RequestException, Session
|
||||
|
||||
from authentik.lib.sync.outgoing import HTTP_CONFLICT
|
||||
from authentik.lib.sync.outgoing import (
|
||||
HTTP_CONFLICT,
|
||||
HTTP_NO_CONTENT,
|
||||
HTTP_SERVICE_UNAVAILABLE,
|
||||
HTTP_TOO_MANY_REQUESTS,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, ObjectExistsSyncException
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
NotFoundSyncException,
|
||||
ObjectExistsSyncException,
|
||||
TransientSyncException,
|
||||
)
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.providers.scim.clients.exceptions import SCIMRequestException
|
||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
|
||||
@ -61,13 +70,15 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"](
|
||||
if response.status_code >= HttpResponseBadRequest.status_code:
|
||||
if response.status_code == HttpResponseNotFound.status_code:
|
||||
raise NotFoundSyncException(response)
|
||||
if response.status_code in [HTTP_TOO_MANY_REQUESTS, HTTP_SERVICE_UNAVAILABLE]:
|
||||
raise TransientSyncException()
|
||||
if response.status_code == HTTP_CONFLICT:
|
||||
raise ObjectExistsSyncException(response)
|
||||
self.logger.warning(
|
||||
"Failed to send SCIM request", path=path, method=method, response=response.text
|
||||
)
|
||||
raise SCIMRequestException(response)
|
||||
if response.status_code == 204: # noqa: PLR2004
|
||||
if response.status_code == HTTP_NO_CONTENT:
|
||||
return {}
|
||||
return response.json()
|
||||
|
||||
|
@ -1,31 +1,25 @@
|
||||
"""Group client"""
|
||||
|
||||
from deepmerge import always_merger
|
||||
from pydantic import ValidationError
|
||||
from pydanticscim.group import GroupMember
|
||||
from pydanticscim.responses import PatchOp, PatchOperation
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import Group
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
NotFoundSyncException,
|
||||
ObjectExistsSyncException,
|
||||
StopSync,
|
||||
)
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.utils import delete_none_values
|
||||
from authentik.providers.scim.clients.base import SCIMClient
|
||||
from authentik.providers.scim.clients.exceptions import (
|
||||
SCIMRequestException,
|
||||
)
|
||||
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest
|
||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
|
||||
from authentik.providers.scim.clients.schema import PatchRequest
|
||||
from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMUser
|
||||
from authentik.providers.scim.models import SCIMGroup, SCIMMapping, SCIMProvider, SCIMUser
|
||||
|
||||
|
||||
class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
@ -33,41 +27,23 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
|
||||
connection_type = SCIMGroup
|
||||
connection_type_query = "group"
|
||||
mapper: PropertyMappingManager
|
||||
|
||||
def to_schema(self, obj: Group, creating: bool) -> SCIMGroupSchema:
|
||||
def __init__(self, provider: SCIMProvider):
|
||||
super().__init__(provider)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses(),
|
||||
SCIMMapping,
|
||||
["group", "provider", "connection"],
|
||||
)
|
||||
|
||||
def to_schema(self, obj: Group, connection: SCIMGroup) -> SCIMGroupSchema:
|
||||
"""Convert authentik user into SCIM"""
|
||||
raw_scim_group = {
|
||||
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:Group",),
|
||||
}
|
||||
for mapping in (
|
||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
|
||||
):
|
||||
if not isinstance(mapping, SCIMMapping):
|
||||
continue
|
||||
try:
|
||||
mapping: SCIMMapping
|
||||
value = mapping.evaluate(
|
||||
user=None,
|
||||
request=None,
|
||||
group=obj,
|
||||
provider=self.provider,
|
||||
creating=creating,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_scim_group, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_scim_group:
|
||||
raise StopSync(ValueError("No group mappings configured"), obj)
|
||||
raw_scim_group = super().to_schema(
|
||||
obj,
|
||||
connection,
|
||||
schemas=(SCIM_GROUP_SCHEMA,),
|
||||
)
|
||||
try:
|
||||
scim_group = SCIMGroupSchema.model_validate(delete_none_values(raw_scim_group))
|
||||
except ValidationError as exc:
|
||||
@ -100,7 +76,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
|
||||
def create(self, group: Group):
|
||||
"""Create group from scratch and create a connection object"""
|
||||
scim_group = self.to_schema(group, True)
|
||||
scim_group = self.to_schema(group, None)
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/Groups",
|
||||
@ -116,7 +92,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroup, SCIMGroupSchema]):
|
||||
|
||||
def update(self, group: Group, connection: SCIMGroup):
|
||||
"""Update existing group"""
|
||||
scim_group = self.to_schema(group, False)
|
||||
scim_group = self.to_schema(group, connection)
|
||||
scim_group.id = connection.scim_id
|
||||
try:
|
||||
return self._request(
|
||||
|
@ -1,20 +1,15 @@
|
||||
"""User client"""
|
||||
|
||||
from deepmerge import always_merger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from authentik.core.expression.exceptions import (
|
||||
PropertyMappingExpressionException,
|
||||
SkipObjectException,
|
||||
)
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.utils import delete_none_values
|
||||
from authentik.providers.scim.clients.base import SCIMClient
|
||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMUser
|
||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider, SCIMUser
|
||||
|
||||
|
||||
class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
@ -22,38 +17,23 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
|
||||
connection_type = SCIMUser
|
||||
connection_type_query = "user"
|
||||
mapper: PropertyMappingManager
|
||||
|
||||
def to_schema(self, obj: User, creating: bool) -> SCIMUserSchema:
|
||||
def __init__(self, provider: SCIMProvider):
|
||||
super().__init__(provider)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self.provider.property_mappings.all().order_by("name").select_subclasses(),
|
||||
SCIMMapping,
|
||||
["provider", "connection"],
|
||||
)
|
||||
|
||||
def to_schema(self, obj: User, connection: SCIMUser) -> SCIMUserSchema:
|
||||
"""Convert authentik user into SCIM"""
|
||||
raw_scim_user = {
|
||||
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:User",),
|
||||
}
|
||||
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
|
||||
if not isinstance(mapping, SCIMMapping):
|
||||
continue
|
||||
try:
|
||||
mapping: SCIMMapping
|
||||
value = mapping.evaluate(
|
||||
user=obj,
|
||||
request=None,
|
||||
provider=self.provider,
|
||||
creating=creating,
|
||||
)
|
||||
if value is None:
|
||||
continue
|
||||
always_merger.merge(raw_scim_user, value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=mapping,
|
||||
).save()
|
||||
raise StopSync(exc, obj, mapping) from exc
|
||||
if not raw_scim_user:
|
||||
raise StopSync(ValueError("No user mappings configured"), obj)
|
||||
raw_scim_user = super().to_schema(
|
||||
obj,
|
||||
connection,
|
||||
schemas=(SCIM_USER_SCHEMA,),
|
||||
)
|
||||
try:
|
||||
scim_user = SCIMUserSchema.model_validate(delete_none_values(raw_scim_user))
|
||||
except ValidationError as exc:
|
||||
@ -74,7 +54,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
|
||||
def create(self, user: User):
|
||||
"""Create user from scratch and create a connection object"""
|
||||
scim_user = self.to_schema(user, True)
|
||||
scim_user = self.to_schema(user, None)
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/Users",
|
||||
@ -90,7 +70,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUser, SCIMUserSchema]):
|
||||
|
||||
def update(self, user: User, connection: SCIMUser):
|
||||
"""Update existing user"""
|
||||
scim_user = self.to_schema(user, False)
|
||||
scim_user = self.to_schema(user, connection)
|
||||
scim_user.id = connection.scim_id
|
||||
self._request(
|
||||
"PUT",
|
||||
|
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
@ -32,6 +33,10 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
help_text=_("Property mappings used for group creation/updating."),
|
||||
)
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return static("authentik/sources/scim.png")
|
||||
|
||||
def client_for_model(
|
||||
self, model: type[User | Group]
|
||||
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
|
||||
|
@ -0,0 +1,28 @@
|
||||
# Generated by Django 5.0.6 on 2024-05-19 14:17
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_rbac", "0003_alter_systempermission_options"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="systempermission",
|
||||
options={
|
||||
"default_permissions": (),
|
||||
"managed": False,
|
||||
"permissions": [
|
||||
("view_system_info", "Can view system info"),
|
||||
("access_admin_interface", "Can access admin interface"),
|
||||
("view_system_settings", "Can view system settings"),
|
||||
("edit_system_settings", "Can edit system settings"),
|
||||
],
|
||||
"verbose_name": "System permission",
|
||||
"verbose_name_plural": "System permissions",
|
||||
},
|
||||
),
|
||||
]
|
@ -67,8 +67,6 @@ class SystemPermission(models.Model):
|
||||
verbose_name_plural = _("System permissions")
|
||||
permissions = [
|
||||
("view_system_info", _("Can view system info")),
|
||||
("view_system_tasks", _("Can view system tasks")),
|
||||
("run_system_tasks", _("Can run system tasks")),
|
||||
("access_admin_interface", _("Can access admin interface")),
|
||||
("view_system_settings", _("Can view system settings")),
|
||||
("edit_system_settings", _("Can edit system settings")),
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""rbac signals"""
|
||||
|
||||
from django.contrib.auth.models import Group as DjangoGroup
|
||||
from django.db.models.signals import m2m_changed, pre_save
|
||||
from django.db.models.signals import m2m_changed, pre_delete, pre_save
|
||||
from django.db.transaction import atomic
|
||||
from django.dispatch import receiver
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Group
|
||||
@ -21,23 +22,42 @@ def rbac_role_pre_save(sender: type[Role], instance: Role, **_):
|
||||
instance.group = group
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=Role)
|
||||
@receiver(pre_delete, sender=Group)
|
||||
def rbac_pre_delete_cleanup(sender: type[Group] | type[Role], instance: Group | Role, **_):
|
||||
"""RBAC: remove permissions from users when a group is deleted"""
|
||||
if sender == Group:
|
||||
for role in instance.roles.all():
|
||||
role.group.user_set.clear()
|
||||
if sender == Role:
|
||||
instance.group.user_set.clear()
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=Group.roles.through)
|
||||
def rbac_group_role_m2m(sender: type[Group], action: str, instance: Group, reverse: bool, **_):
|
||||
def rbac_group_role_m2m(
|
||||
sender: type[Group], action: str, instance: Group, reverse: bool, pk_set: set, **_
|
||||
):
|
||||
"""RBAC: Sync group members into roles when roles are assigned"""
|
||||
if action == "pre_add":
|
||||
# Validation: check that any of the added roles are not used in any other groups
|
||||
if Group.objects.filter(roles__in=pk_set).exclude(pk=instance.pk).exists():
|
||||
raise ValidationError("Roles can only be used with a single group.")
|
||||
if action not in ["post_add", "post_remove", "post_clear"]:
|
||||
return
|
||||
with atomic():
|
||||
group_users = list(
|
||||
instance.children_recursive()
|
||||
group_users = (
|
||||
Group.objects.filter(group_uuid=instance.group_uuid)
|
||||
.with_children_recursive()
|
||||
.exclude(users__isnull=True)
|
||||
.values_list("users", flat=True)
|
||||
)
|
||||
if not group_users:
|
||||
return
|
||||
for role in instance.roles.all():
|
||||
role: Role
|
||||
role.group.user_set.set(group_users)
|
||||
LOGGER.debug("Updated users in group", group=instance)
|
||||
for role in Role.objects.filter(pk__in=pk_set):
|
||||
if action == "post_add":
|
||||
role.group.user_set.add(*group_users)
|
||||
# Role(s) in pk_set were removed from group, so remove the users that we added
|
||||
if action == "post_remove":
|
||||
role.group.user_set.remove(*group_users)
|
||||
LOGGER.debug("Updated users in group", group=instance, direction=action, users=group_users)
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=Group.users.through)
|
||||
|
27
authentik/rbac/tests/test_api.py
Normal file
27
authentik/rbac/tests/test_api.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Test RBACPermissionViewSet api"""
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.rbac.models import Role
|
||||
|
||||
|
||||
class TestRBACAPI(APITestCase):
|
||||
"""Test RBACPermissionViewSet api"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.superuser = create_test_admin_user()
|
||||
|
||||
self.user = create_test_user()
|
||||
self.role = Role.objects.create(name=generate_id())
|
||||
self.group = Group.objects.create(name=generate_id())
|
||||
self.group.roles.add(self.role)
|
||||
self.group.users.add(self.user)
|
||||
|
||||
def test_list(self):
|
||||
self.client.force_login(self.superuser)
|
||||
res = self.client.get(reverse("authentik_api:permission-list"))
|
||||
self.assertEqual(res.status_code, 200)
|
75
authentik/rbac/tests/test_api_permissions_roles.py
Normal file
75
authentik/rbac/tests/test_api_permissions_roles.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Test RolePermissionViewSet api"""
|
||||
|
||||
from django.urls import reverse
|
||||
from guardian.models import GroupObjectPermission
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.rbac.models import Role
|
||||
from authentik.stages.invitation.models import Invitation
|
||||
|
||||
|
||||
class TestRBACPermissionRoles(APITestCase):
|
||||
"""Test RolePermissionViewSet api"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.superuser = create_test_admin_user()
|
||||
|
||||
self.user = create_test_user()
|
||||
self.role = Role.objects.create(name=generate_id())
|
||||
self.group = Group.objects.create(name=generate_id())
|
||||
self.group.roles.add(self.role)
|
||||
self.group.users.add(self.user)
|
||||
|
||||
def test_list(self):
|
||||
"""Test list of all permissions"""
|
||||
self.client.force_login(self.superuser)
|
||||
inv = Invitation.objects.create(
|
||||
name=generate_id(),
|
||||
created_by=self.superuser,
|
||||
)
|
||||
self.role.assign_permission("authentik_stages_invitation.view_invitation", obj=inv)
|
||||
res = self.client.get(reverse("authentik_api:permissions-roles-list"))
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_list_role(self):
|
||||
"""Test list of all permissions"""
|
||||
self.client.force_login(self.superuser)
|
||||
inv = Invitation.objects.create(
|
||||
name=generate_id(),
|
||||
created_by=self.superuser,
|
||||
)
|
||||
self.role.assign_permission("authentik_stages_invitation.view_invitation", obj=inv)
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:permissions-roles-list") + f"?uuid={self.role.pk}"
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(
|
||||
res.content,
|
||||
{
|
||||
"pagination": {
|
||||
"next": 0,
|
||||
"previous": 0,
|
||||
"count": 1,
|
||||
"current": 1,
|
||||
"total_pages": 1,
|
||||
"start_index": 1,
|
||||
"end_index": 1,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"id": GroupObjectPermission.objects.filter(object_pk=inv.pk).first().pk,
|
||||
"codename": "view_invitation",
|
||||
"model": "invitation",
|
||||
"app_label": "authentik_stages_invitation",
|
||||
"object_pk": str(inv.pk),
|
||||
"name": "Can view Invitation",
|
||||
"app_label_verbose": "authentik Stages.Invitation",
|
||||
"model_verbose": "Invitation",
|
||||
"object_description": str(inv),
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
76
authentik/rbac/tests/test_api_permissions_users.py
Normal file
76
authentik/rbac/tests/test_api_permissions_users.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""Test UserPermissionViewSet api"""
|
||||
|
||||
from django.urls import reverse
|
||||
from guardian.models import UserObjectPermission
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.rbac.models import Role
|
||||
from authentik.stages.invitation.models import Invitation
|
||||
|
||||
|
||||
class TestRBACPermissionUsers(APITestCase):
|
||||
"""Test UserPermissionViewSet api"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.superuser = create_test_admin_user()
|
||||
|
||||
self.user = create_test_user()
|
||||
self.role = Role.objects.create(name=generate_id())
|
||||
self.group = Group.objects.create(name=generate_id())
|
||||
self.group.roles.add(self.role)
|
||||
self.group.users.add(self.user)
|
||||
|
||||
def test_list(self):
|
||||
"""Test list of all permissions"""
|
||||
self.client.force_login(self.superuser)
|
||||
inv = Invitation.objects.create(
|
||||
name=generate_id(),
|
||||
created_by=self.superuser,
|
||||
)
|
||||
assign_perm("authentik_stages_invitation.view_invitation", self.user, inv)
|
||||
res = self.client.get(reverse("authentik_api:permissions-users-list"))
|
||||
self.assertEqual(res.status_code, 400)
|
||||
|
||||
def test_list_role(self):
|
||||
"""Test list of all permissions"""
|
||||
self.client.force_login(self.superuser)
|
||||
inv = Invitation.objects.create(
|
||||
name=generate_id(),
|
||||
created_by=self.superuser,
|
||||
)
|
||||
assign_perm("authentik_stages_invitation.view_invitation", self.user, inv)
|
||||
res = self.client.get(
|
||||
reverse("authentik_api:permissions-users-list") + f"?user_id={self.user.pk}"
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(
|
||||
res.content,
|
||||
{
|
||||
"pagination": {
|
||||
"next": 0,
|
||||
"previous": 0,
|
||||
"count": 1,
|
||||
"current": 1,
|
||||
"total_pages": 1,
|
||||
"start_index": 1,
|
||||
"end_index": 1,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"id": UserObjectPermission.objects.filter(object_pk=inv.pk).first().pk,
|
||||
"codename": "view_invitation",
|
||||
"model": "invitation",
|
||||
"app_label": "authentik_stages_invitation",
|
||||
"object_pk": str(inv.pk),
|
||||
"name": "Can view Invitation",
|
||||
"app_label_verbose": "authentik Stages.Invitation",
|
||||
"model_verbose": "Invitation",
|
||||
"object_description": str(inv),
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
@ -1,9 +1,10 @@
|
||||
"""RBAC role tests"""
|
||||
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.rbac.models import Role
|
||||
|
||||
@ -13,18 +14,30 @@ class TestRoles(APITestCase):
|
||||
|
||||
def test_role_create(self):
|
||||
"""Test creation"""
|
||||
user = create_test_admin_user()
|
||||
user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.save()
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
group.roles.add(role)
|
||||
group.users.add(user)
|
||||
self.assertEqual(list(role.group.user_set.all()), [user])
|
||||
self.assertTrue(user.has_perm("authentik_core.view_application"))
|
||||
|
||||
def test_role_create_remove(self):
|
||||
def test_role_create_add_reverse(self):
|
||||
"""Test creation (add user in reverse)"""
|
||||
user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
group.roles.add(role)
|
||||
user.ak_groups.add(group)
|
||||
self.assertEqual(list(role.group.user_set.all()), [user])
|
||||
self.assertTrue(user.has_perm("authentik_core.view_application"))
|
||||
|
||||
def test_remove_group_delete(self):
|
||||
"""Test creation and remove"""
|
||||
user = create_test_admin_user()
|
||||
user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
@ -32,5 +45,77 @@ class TestRoles(APITestCase):
|
||||
group.users.add(user)
|
||||
self.assertEqual(list(role.group.user_set.all()), [user])
|
||||
self.assertTrue(user.has_perm("authentik_core.view_application"))
|
||||
user.delete()
|
||||
group.delete()
|
||||
user = User.objects.get(username=user.username)
|
||||
self.assertFalse(user.has_perm("authentik_core.view_application"))
|
||||
self.assertEqual(list(role.group.user_set.all()), [])
|
||||
|
||||
def test_remove_roles_remove(self):
|
||||
"""Test assigning permission to role, then removing role from group"""
|
||||
user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
group.roles.add(role)
|
||||
group.users.add(user)
|
||||
self.assertEqual(list(role.group.user_set.all()), [user])
|
||||
self.assertTrue(user.has_perm("authentik_core.view_application"))
|
||||
group.roles.remove(role)
|
||||
user = User.objects.get(username=user.username)
|
||||
self.assertFalse(user.has_perm("authentik_core.view_application"))
|
||||
self.assertEqual(list(role.group.user_set.all()), [])
|
||||
|
||||
def test_remove_role_delete(self):
|
||||
"""Test assigning permissions to role, then removing role"""
|
||||
user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
group.roles.add(role)
|
||||
group.users.add(user)
|
||||
self.assertEqual(list(role.group.user_set.all()), [user])
|
||||
self.assertTrue(user.has_perm("authentik_core.view_application"))
|
||||
role.delete()
|
||||
user = User.objects.get(username=user.username)
|
||||
self.assertFalse(user.has_perm("authentik_core.view_application"))
|
||||
self.assertEqual(list(role.group.user_set.all()), [])
|
||||
|
||||
def test_role_assign_twice(self):
|
||||
"""Test assigning role to two groups"""
|
||||
group1 = Group.objects.create(name=generate_id())
|
||||
group2 = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
group1.roles.add(role)
|
||||
with self.assertRaises(ValidationError):
|
||||
group2.roles.add(role)
|
||||
|
||||
def test_remove_users_remove(self):
|
||||
"""Test assigning permission to role, then removing user from group"""
|
||||
user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
group.roles.add(role)
|
||||
group.users.add(user)
|
||||
self.assertEqual(list(role.group.user_set.all()), [user])
|
||||
self.assertTrue(user.has_perm("authentik_core.view_application"))
|
||||
group.users.remove(user)
|
||||
user = User.objects.get(username=user.username)
|
||||
self.assertFalse(user.has_perm("authentik_core.view_application"))
|
||||
self.assertEqual(list(role.group.user_set.all()), [])
|
||||
|
||||
def test_remove_users_remove_reverse(self):
|
||||
"""Test assigning permission to role, then removing user from group in reverse"""
|
||||
user = create_test_user()
|
||||
group = Group.objects.create(name=generate_id())
|
||||
role = Role.objects.create(name=generate_id())
|
||||
role.assign_permission("authentik_core.view_application")
|
||||
group.roles.add(role)
|
||||
group.users.add(user)
|
||||
self.assertEqual(list(role.group.user_set.all()), [user])
|
||||
self.assertTrue(user.has_perm("authentik_core.view_application"))
|
||||
user.ak_groups.remove(group)
|
||||
user = User.objects.get(username=user.username)
|
||||
self.assertFalse(user.has_perm("authentik_core.view_application"))
|
||||
self.assertEqual(list(role.group.user_set.all()), [])
|
||||
|
@ -10,8 +10,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
|
||||
def get_connection_params(self):
|
||||
"""Refresh DB credentials before getting connection params"""
|
||||
CONFIG.refresh("postgresql.password")
|
||||
conn_params = super().get_connection_params()
|
||||
conn_params["user"] = CONFIG.get("postgresql.user")
|
||||
conn_params["password"] = CONFIG.get("postgresql.password")
|
||||
|
||||
prefix = "postgresql"
|
||||
if self.alias.startswith("replica_"):
|
||||
prefix = f"postgresql.read_replicas.{self.alias.removeprefix('replica_')}"
|
||||
|
||||
for setting in ("host", "port", "user", "password"):
|
||||
conn_params[setting] = CONFIG.refresh(f"{prefix}.{setting}")
|
||||
if conn_params[setting] is None and self.alias.startswith("replica_"):
|
||||
conn_params[setting] = CONFIG.refresh(f"postgresql.{setting}")
|
||||
|
||||
return conn_params
|
||||
|
@ -47,8 +47,10 @@ class ReadyView(View):
|
||||
|
||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||
try:
|
||||
db_conn = connections["default"]
|
||||
_ = db_conn.cursor()
|
||||
for db_conn in connections.all():
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
except OperationalError: # pragma: no cover
|
||||
return HttpResponse(status=503)
|
||||
try:
|
||||
|
@ -60,6 +60,8 @@ SHARED_APPS = [
|
||||
"django_filters",
|
||||
"drf_spectacular",
|
||||
"django_prometheus",
|
||||
"pgactivity",
|
||||
"pglock",
|
||||
"channels",
|
||||
]
|
||||
TENANT_APPS = [
|
||||
@ -293,7 +295,7 @@ DATABASES = {
|
||||
"NAME": CONFIG.get("postgresql.name"),
|
||||
"USER": CONFIG.get("postgresql.user"),
|
||||
"PASSWORD": CONFIG.get("postgresql.password"),
|
||||
"PORT": CONFIG.get_int("postgresql.port"),
|
||||
"PORT": CONFIG.get("postgresql.port"),
|
||||
"SSLMODE": CONFIG.get("postgresql.sslmode"),
|
||||
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
|
||||
"SSLCERT": CONFIG.get("postgresql.sslcert"),
|
||||
@ -313,7 +315,23 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
|
||||
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
|
||||
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
|
||||
|
||||
DATABASE_ROUTERS = ("django_tenants.routers.TenantSyncRouter",)
|
||||
for replica in CONFIG.get_keys("postgresql.read_replicas"):
|
||||
_database = DATABASES["default"].copy()
|
||||
for setting in DATABASES["default"].keys():
|
||||
default = object()
|
||||
if setting in ("TEST",):
|
||||
continue
|
||||
override = CONFIG.get(
|
||||
f"postgresql.read_replicas.{replica}.{setting.lower()}", default=default
|
||||
)
|
||||
if override is not default:
|
||||
_database[setting] = override
|
||||
DATABASES[f"replica_{replica}"] = _database
|
||||
|
||||
DATABASE_ROUTERS = (
|
||||
"authentik.tenants.db.FailoverRouter",
|
||||
"django_tenants.routers.TenantSyncRouter",
|
||||
)
|
||||
|
||||
# Email
|
||||
# These values should never actually be used, emails are only sent from email stages, which
|
||||
|
@ -1,6 +1,8 @@
|
||||
from os import environ
|
||||
from ssl import OPENSSL_VERSION
|
||||
|
||||
import pytest
|
||||
from cryptography.hazmat.backends.openssl.backend import backend
|
||||
|
||||
from authentik import get_full_version
|
||||
|
||||
@ -18,4 +20,7 @@ def pytest_sessionstart(*_, **__):
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_report_header(*_, **__):
|
||||
"""Add authentik version to pytest output"""
|
||||
return [f"authentik version: {get_full_version()}"]
|
||||
return [
|
||||
f"authentik version: {get_full_version()}",
|
||||
f"OpenSSL version: {OPENSSL_VERSION}, FIPS: {backend._fips_enabled}",
|
||||
]
|
||||
|
@ -16,7 +16,7 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
@ -143,10 +143,12 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
uid__startswith=source.slug,
|
||||
)
|
||||
)
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
"is_running": source.sync_lock.locked(),
|
||||
}
|
||||
with source.sync_lock as lock_acquired:
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
# If we could not acquire the lock, it means a task is using it, and thus is running
|
||||
"is_running": not lock_acquired,
|
||||
}
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
|
||||
@extend_schema(
|
||||
|
@ -6,12 +6,12 @@ from shutil import rmtree
|
||||
from ssl import CERT_REQUIRED
|
||||
from tempfile import NamedTemporaryFile, mkdtemp
|
||||
|
||||
from django.core.cache import cache
|
||||
import pglock
|
||||
from django.db import connection, models
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from ldap3 import ALL, NONE, RANDOM, Connection, Server, ServerPool, Tls
|
||||
from ldap3.core.exceptions import LDAPException, LDAPInsufficientAccessRightsResult, LDAPSchemaError
|
||||
from redis.lock import Lock
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import Group, PropertyMapping, Source
|
||||
@ -126,6 +126,10 @@ class LDAPSource(Source):
|
||||
|
||||
return LDAPSourceSerializer
|
||||
|
||||
@property
|
||||
def icon_url(self) -> str:
|
||||
return static("authentik/sources/ldap.png")
|
||||
|
||||
def server(self, **kwargs) -> ServerPool:
|
||||
"""Get LDAP Server/ServerPool"""
|
||||
servers = []
|
||||
@ -204,15 +208,12 @@ class LDAPSource(Source):
|
||||
return RuntimeError("Failed to bind")
|
||||
|
||||
@property
|
||||
def sync_lock(self) -> Lock:
|
||||
"""Redis lock for syncing LDAP to prevent multiple parallel syncs happening"""
|
||||
return Lock(
|
||||
cache.client.get_client(),
|
||||
name=f"goauthentik.io/sources/ldap/sync/{connection.schema_name}-{self.slug}",
|
||||
# Convert task timeout hours to seconds, and multiply times 3
|
||||
# (see authentik/sources/ldap/tasks.py:54)
|
||||
# multiply by 3 to add even more leeway
|
||||
timeout=(60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) * 3,
|
||||
def sync_lock(self) -> pglock.advisory:
|
||||
"""Postgres lock for syncing LDAP to prevent multiple parallel syncs happening"""
|
||||
return pglock.advisory(
|
||||
lock_id=f"goauthentik.io/{connection.schema_name}/sources/ldap/sync/{self.slug}",
|
||||
timeout=0,
|
||||
side_effect=pglock.Return,
|
||||
)
|
||||
|
||||
def check_connection(self) -> dict[str, dict[str, str]]:
|
||||
|
@ -5,7 +5,6 @@ from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.query import QuerySet
|
||||
from ldap3 import DEREF_ALWAYS, SUBTREE, Connection
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
@ -16,8 +15,11 @@ from authentik.core.expression.exceptions import (
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG, set_path_in_dict
|
||||
from authentik.lib.merge import MERGE_LIST_UNIQUE
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME
|
||||
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
|
||||
from authentik.sources.ldap.models import LDAPSource
|
||||
|
||||
LDAP_UNIQUENESS = "ldap_uniq"
|
||||
|
||||
@ -38,6 +40,7 @@ class BaseLDAPSynchronizer:
|
||||
_logger: BoundLogger
|
||||
_connection: Connection
|
||||
_messages: list[str]
|
||||
mapper: PropertyMappingManager
|
||||
|
||||
def __init__(self, source: LDAPSource):
|
||||
self._source = source
|
||||
@ -139,52 +142,47 @@ class BaseLDAPSynchronizer:
|
||||
|
||||
def build_user_properties(self, user_dn: str, **kwargs) -> dict[str, Any]:
|
||||
"""Build attributes for User object based on property mappings."""
|
||||
props = self._build_object_properties(user_dn, self._source.property_mappings, **kwargs)
|
||||
props = self._build_object_properties(user_dn, **kwargs)
|
||||
props.setdefault("path", self._source.get_user_path())
|
||||
return props
|
||||
|
||||
def build_group_properties(self, group_dn: str, **kwargs) -> dict[str, Any]:
|
||||
"""Build attributes for Group object based on property mappings."""
|
||||
return self._build_object_properties(
|
||||
group_dn, self._source.property_mappings_group, **kwargs
|
||||
)
|
||||
return self._build_object_properties(group_dn, **kwargs)
|
||||
|
||||
def _build_object_properties(
|
||||
self, object_dn: str, mappings: QuerySet, **kwargs
|
||||
) -> dict[str, dict[Any, Any]]:
|
||||
def _build_object_properties(self, object_dn: str, **kwargs) -> dict[str, dict[Any, Any]]:
|
||||
properties = {"attributes": {}}
|
||||
for mapping in mappings.all().select_subclasses():
|
||||
if not isinstance(mapping, LDAPPropertyMapping):
|
||||
continue
|
||||
mapping: LDAPPropertyMapping
|
||||
try:
|
||||
value = mapping.evaluate(
|
||||
user=None, request=None, ldap=kwargs, dn=object_dn, source=self._source
|
||||
)
|
||||
if value is None:
|
||||
self._logger.warning("property mapping returned None", mapping=mapping)
|
||||
continue
|
||||
if isinstance(value, (bytes)):
|
||||
self._logger.warning("property mapping returned bytes", mapping=mapping)
|
||||
continue
|
||||
object_field = mapping.object_field
|
||||
if object_field.startswith("attributes."):
|
||||
# Because returning a list might desired, we can't
|
||||
# rely on flatten here. Instead, just save the result as-is
|
||||
set_path_in_dict(properties, object_field, value)
|
||||
else:
|
||||
properties[object_field] = flatten(value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except PropertyMappingExpressionException as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping: '{mapping.name}'",
|
||||
source=self._source,
|
||||
mapping=mapping,
|
||||
).save()
|
||||
self._logger.warning("Mapping failed to evaluate", exc=exc, mapping=mapping)
|
||||
continue
|
||||
try:
|
||||
for value, mapping in self.mapper.iter_eval(
|
||||
user=None,
|
||||
request=None,
|
||||
return_mapping=True,
|
||||
ldap=kwargs,
|
||||
dn=object_dn,
|
||||
source=self._source,
|
||||
):
|
||||
try:
|
||||
if isinstance(value, (bytes)):
|
||||
self._logger.warning("property mapping returned bytes", mapping=mapping)
|
||||
continue
|
||||
object_field = mapping.object_field
|
||||
if object_field.startswith("attributes."):
|
||||
# Because returning a list might desired, we can't
|
||||
# rely on flatten here. Instead, just save the result as-is
|
||||
set_path_in_dict(properties, object_field, value)
|
||||
else:
|
||||
properties[object_field] = flatten(value)
|
||||
except SkipObjectException as exc:
|
||||
raise exc from exc
|
||||
except PropertyMappingExpressionException as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=exc.mapping,
|
||||
).save()
|
||||
self._logger.warning("Mapping failed to evaluate", exc=exc, mapping=exc.mapping)
|
||||
raise StopSync(exc, None, exc.mapping) from exc
|
||||
if self._source.object_uniqueness_field in kwargs:
|
||||
properties["attributes"][LDAP_UNIQUENESS] = flatten(
|
||||
kwargs.get(self._source.object_uniqueness_field)
|
||||
|
@ -9,12 +9,22 @@ from ldap3 import ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, SUBTREE
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import Group
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
|
||||
from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchronizer, flatten
|
||||
|
||||
|
||||
class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
"""Sync LDAP Users and groups into authentik"""
|
||||
|
||||
def __init__(self, source: LDAPSource):
|
||||
super().__init__(source)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self._source.property_mappings_group.all().order_by("name").select_subclasses(),
|
||||
LDAPPropertyMapping,
|
||||
["ldap", "dn", "source"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "groups"
|
||||
|
@ -9,6 +9,8 @@ from ldap3 import ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, SUBTREE
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource
|
||||
from authentik.sources.ldap.sync.base import LDAP_UNIQUENESS, BaseLDAPSynchronizer, flatten
|
||||
from authentik.sources.ldap.sync.vendor.freeipa import FreeIPA
|
||||
from authentik.sources.ldap.sync.vendor.ms_ad import MicrosoftActiveDirectory
|
||||
@ -17,6 +19,14 @@ from authentik.sources.ldap.sync.vendor.ms_ad import MicrosoftActiveDirectory
|
||||
class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
"""Sync LDAP Users into authentik"""
|
||||
|
||||
def __init__(self, source: LDAPSource):
|
||||
super().__init__(source)
|
||||
self.mapper = PropertyMappingManager(
|
||||
self._source.property_mappings.all().order_by("name").select_subclasses(),
|
||||
LDAPPropertyMapping,
|
||||
["ldap", "dn", "source"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "users"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user