Compare commits
6 Commits
version/20
...
core/soft-
Author | SHA1 | Date | |
---|---|---|---|
a5379c35aa | |||
e4c11a5284 | |||
a4853a1e09 | |||
b65b72d910 | |||
cd7be6a1a4 | |||
e5cb8ef541 |
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 2024.6.0-rc1
|
current_version = 2024.4.2
|
||||||
tag = True
|
tag = True
|
||||||
commit = True
|
commit = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||||
|
2
.github/actions/setup/docker-compose.yml
vendored
2
.github/actions/setup/docker-compose.yml
vendored
@ -1,3 +1,5 @@
|
|||||||
|
version: "3.7"
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgresql:
|
postgresql:
|
||||||
image: docker.io/library/postgres:${PSQL_TAG:-16}
|
image: docker.io/library/postgres:${PSQL_TAG:-16}
|
||||||
|
1
.github/codespell-words.txt
vendored
1
.github/codespell-words.txt
vendored
@ -4,4 +4,3 @@ hass
|
|||||||
warmup
|
warmup
|
||||||
ontext
|
ontext
|
||||||
singed
|
singed
|
||||||
assertIn
|
|
||||||
|
6
.github/workflows/ci-main.yml
vendored
6
.github/workflows/ci-main.yml
vendored
@ -50,6 +50,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
psql:
|
psql:
|
||||||
|
- 12-alpine
|
||||||
- 15-alpine
|
- 15-alpine
|
||||||
- 16-alpine
|
- 16-alpine
|
||||||
steps:
|
steps:
|
||||||
@ -103,6 +104,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
psql:
|
psql:
|
||||||
|
- 12-alpine
|
||||||
- 15-alpine
|
- 15-alpine
|
||||||
- 16-alpine
|
- 16-alpine
|
||||||
steps:
|
steps:
|
||||||
@ -250,8 +252,8 @@ jobs:
|
|||||||
push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
|
push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
|
||||||
build-args: |
|
build-args: |
|
||||||
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
|
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
|
||||||
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache
|
cache-from: type=gha
|
||||||
cache-to: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache,mode=max
|
cache-to: type=gha,mode=max
|
||||||
platforms: linux/${{ matrix.arch }}
|
platforms: linux/${{ matrix.arch }}
|
||||||
pr-comment:
|
pr-comment:
|
||||||
needs:
|
needs:
|
||||||
|
4
.github/workflows/ci-outpost.yml
vendored
4
.github/workflows/ci-outpost.yml
vendored
@ -105,8 +105,8 @@ jobs:
|
|||||||
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
|
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
context: .
|
context: .
|
||||||
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache
|
cache-from: type=gha
|
||||||
cache-to: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache,mode=max
|
cache-to: type=gha,mode=max
|
||||||
build-binary:
|
build-binary:
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
needs:
|
needs:
|
||||||
|
30
Dockerfile
30
Dockerfile
@ -38,7 +38,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
|
|||||||
RUN npm run build
|
RUN npm run build
|
||||||
|
|
||||||
# Stage 3: Build go proxy
|
# Stage 3: Build go proxy
|
||||||
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
|
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.3-bookworm AS go-builder
|
||||||
|
|
||||||
ARG TARGETOS
|
ARG TARGETOS
|
||||||
ARG TARGETARCH
|
ARG TARGETARCH
|
||||||
@ -49,11 +49,6 @@ ARG GOARCH=$TARGETARCH
|
|||||||
|
|
||||||
WORKDIR /go/src/goauthentik.io
|
WORKDIR /go/src/goauthentik.io
|
||||||
|
|
||||||
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
|
|
||||||
dpkg --add-architecture arm64 && \
|
|
||||||
apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends crossbuild-essential-arm64 gcc-aarch64-linux-gnu
|
|
||||||
|
|
||||||
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
|
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
|
||||||
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
|
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
|
||||||
--mount=type=cache,target=/go/pkg/mod \
|
--mount=type=cache,target=/go/pkg/mod \
|
||||||
@ -68,11 +63,11 @@ COPY ./internal /go/src/goauthentik.io/internal
|
|||||||
COPY ./go.mod /go/src/goauthentik.io/go.mod
|
COPY ./go.mod /go/src/goauthentik.io/go.mod
|
||||||
COPY ./go.sum /go/src/goauthentik.io/go.sum
|
COPY ./go.sum /go/src/goauthentik.io/go.sum
|
||||||
|
|
||||||
|
ENV CGO_ENABLED=0
|
||||||
|
|
||||||
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||||
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
|
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
|
||||||
if [ "$TARGETARCH" = "arm64" ]; then export CC=aarch64-linux-gnu-gcc && export CC_FOR_TARGET=gcc-aarch64-linux-gnu; fi && \
|
GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server
|
||||||
CGO_ENABLED=1 GOEXPERIMENT="systemcrypto" GOFLAGS="-tags=requirefips" GOARM="${TARGETVARIANT#v}" \
|
|
||||||
go build -o /go/authentik ./cmd/server
|
|
||||||
|
|
||||||
# Stage 4: MaxMind GeoIP
|
# Stage 4: MaxMind GeoIP
|
||||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
|
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
|
||||||
@ -89,7 +84,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
|||||||
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||||
|
|
||||||
# Stage 5: Python dependencies
|
# Stage 5: Python dependencies
|
||||||
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps
|
FROM docker.io/python:3.12.3-slim-bookworm AS python-deps
|
||||||
|
|
||||||
WORKDIR /ak-root/poetry
|
WORKDIR /ak-root/poetry
|
||||||
|
|
||||||
@ -102,7 +97,7 @@ RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloa
|
|||||||
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
|
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
|
||||||
apt-get update && \
|
apt-get update && \
|
||||||
# Required for installing pip packages
|
# Required for installing pip packages
|
||||||
apt-get install -y --no-install-recommends build-essential pkg-config libpq-dev
|
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev
|
||||||
|
|
||||||
RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
|
RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
|
||||||
--mount=type=bind,target=./poetry.lock,src=./poetry.lock \
|
--mount=type=bind,target=./poetry.lock,src=./poetry.lock \
|
||||||
@ -110,13 +105,12 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
|
|||||||
--mount=type=cache,target=/root/.cache/pypoetry \
|
--mount=type=cache,target=/root/.cache/pypoetry \
|
||||||
python -m venv /ak-root/venv/ && \
|
python -m venv /ak-root/venv/ && \
|
||||||
bash -c "source ${VENV_PATH}/bin/activate && \
|
bash -c "source ${VENV_PATH}/bin/activate && \
|
||||||
pip3 install --upgrade pip && \
|
pip3 install --upgrade pip && \
|
||||||
pip3 install poetry && \
|
pip3 install poetry && \
|
||||||
poetry install --only=main --no-ansi --no-interaction --no-root && \
|
poetry install --only=main --no-ansi --no-interaction --no-root"
|
||||||
pip install --force-reinstall /wheels/*"
|
|
||||||
|
|
||||||
# Stage 6: Run
|
# Stage 6: Run
|
||||||
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image
|
FROM docker.io/python:3.12.3-slim-bookworm AS final-image
|
||||||
|
|
||||||
ARG GIT_BUILD_HASH
|
ARG GIT_BUILD_HASH
|
||||||
ARG VERSION
|
ARG VERSION
|
||||||
@ -133,7 +127,7 @@ WORKDIR /
|
|||||||
# We cannot cache this layer otherwise we'll end up with a bigger image
|
# We cannot cache this layer otherwise we'll end up with a bigger image
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
# Required for runtime
|
# Required for runtime
|
||||||
apt-get install -y --no-install-recommends libpq5 libmaxminddb0 ca-certificates && \
|
apt-get install -y --no-install-recommends libpq5 openssl libxmlsec1-openssl libmaxminddb0 ca-certificates && \
|
||||||
# Required for bootstrap & healtcheck
|
# Required for bootstrap & healtcheck
|
||||||
apt-get install -y --no-install-recommends runit && \
|
apt-get install -y --no-install-recommends runit && \
|
||||||
apt-get clean && \
|
apt-get clean && \
|
||||||
@ -169,8 +163,6 @@ ENV TMPDIR=/dev/shm/ \
|
|||||||
VENV_PATH="/ak-root/venv" \
|
VENV_PATH="/ak-root/venv" \
|
||||||
POETRY_VIRTUALENVS_CREATE=false
|
POETRY_VIRTUALENVS_CREATE=false
|
||||||
|
|
||||||
ENV GOFIPS=1
|
|
||||||
|
|
||||||
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 CMD [ "ak", "healthcheck" ]
|
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 CMD [ "ak", "healthcheck" ]
|
||||||
|
|
||||||
ENTRYPOINT [ "dumb-init", "--", "ak" ]
|
ENTRYPOINT [ "dumb-init", "--", "ak" ]
|
||||||
|
1
Makefile
1
Makefile
@ -253,7 +253,6 @@ website-watch: ## Build and watch the documentation website, updating automatic
|
|||||||
#########################
|
#########################
|
||||||
|
|
||||||
docker: ## Build a docker image of the current source tree
|
docker: ## Build a docker image of the current source tree
|
||||||
mkdir -p ${GEN_API_TS}
|
|
||||||
DOCKER_BUILDKIT=1 docker build . --progress plain --tag ${DOCKER_IMAGE}
|
DOCKER_BUILDKIT=1 docker build . --progress plain --tag ${DOCKER_IMAGE}
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from os import environ
|
from os import environ
|
||||||
|
|
||||||
__version__ = "2024.6.0"
|
__version__ = "2024.4.2"
|
||||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,21 +2,18 @@
|
|||||||
|
|
||||||
import platform
|
import platform
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ssl import OPENSSL_VERSION
|
|
||||||
from sys import version as python_version
|
from sys import version as python_version
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from drf_spectacular.utils import extend_schema
|
from drf_spectacular.utils import extend_schema
|
||||||
|
from gunicorn import version_info as gunicorn_version
|
||||||
from rest_framework.fields import SerializerMethodField
|
from rest_framework.fields import SerializerMethodField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
from authentik import get_full_version
|
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
from authentik.core.api.utils import PassiveSerializer
|
||||||
from authentik.enterprise.license import LicenseKey
|
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.utils.reflection import get_env
|
from authentik.lib.utils.reflection import get_env
|
||||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||||
@ -28,13 +25,11 @@ class RuntimeDict(TypedDict):
|
|||||||
"""Runtime information"""
|
"""Runtime information"""
|
||||||
|
|
||||||
python_version: str
|
python_version: str
|
||||||
|
gunicorn_version: str
|
||||||
environment: str
|
environment: str
|
||||||
architecture: str
|
architecture: str
|
||||||
platform: str
|
platform: str
|
||||||
uname: str
|
uname: str
|
||||||
openssl_version: str
|
|
||||||
openssl_fips_enabled: bool | None
|
|
||||||
authentik_version: str
|
|
||||||
|
|
||||||
|
|
||||||
class SystemInfoSerializer(PassiveSerializer):
|
class SystemInfoSerializer(PassiveSerializer):
|
||||||
@ -69,15 +64,11 @@ class SystemInfoSerializer(PassiveSerializer):
|
|||||||
def get_runtime(self, request: Request) -> RuntimeDict:
|
def get_runtime(self, request: Request) -> RuntimeDict:
|
||||||
"""Get versions"""
|
"""Get versions"""
|
||||||
return {
|
return {
|
||||||
"architecture": platform.machine(),
|
|
||||||
"authentik_version": get_full_version(),
|
|
||||||
"environment": get_env(),
|
|
||||||
"openssl_fips_enabled": (
|
|
||||||
backend._fips_enabled if LicenseKey.get_total().is_valid() else None
|
|
||||||
),
|
|
||||||
"openssl_version": OPENSSL_VERSION,
|
|
||||||
"platform": platform.platform(),
|
|
||||||
"python_version": python_version,
|
"python_version": python_version,
|
||||||
|
"gunicorn_version": ".".join(str(x) for x in gunicorn_version),
|
||||||
|
"environment": get_env(),
|
||||||
|
"architecture": platform.machine(),
|
||||||
|
"platform": platform.platform(),
|
||||||
"uname": " ".join(platform.uname()),
|
"uname": " ".join(platform.uname()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ class BlueprintEntry:
|
|||||||
_state: BlueprintEntryState = field(default_factory=BlueprintEntryState)
|
_state: BlueprintEntryState = field(default_factory=BlueprintEntryState)
|
||||||
|
|
||||||
def __post_init__(self, *args, **kwargs) -> None:
|
def __post_init__(self, *args, **kwargs) -> None:
|
||||||
self.__tag_contexts: list[YAMLTagContext] = []
|
self.__tag_contexts: list["YAMLTagContext"] = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry":
|
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry":
|
||||||
|
@ -4,7 +4,6 @@ from collections.abc import Iterable
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from django.apps import apps
|
from django.apps import apps
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.db.models import Model, Q, QuerySet
|
from django.db.models import Model, Q, QuerySet
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
@ -47,8 +46,6 @@ class Exporter:
|
|||||||
def get_model_instances(self, model: type[Model]) -> QuerySet:
|
def get_model_instances(self, model: type[Model]) -> QuerySet:
|
||||||
"""Return a queryset for `model`. Can be used to filter some
|
"""Return a queryset for `model`. Can be used to filter some
|
||||||
objects on some models"""
|
objects on some models"""
|
||||||
if model == get_user_model():
|
|
||||||
return model.objects.exclude_anonymous()
|
|
||||||
return model.objects.all()
|
return model.objects.all()
|
||||||
|
|
||||||
def _pre_export(self, blueprint: Blueprint):
|
def _pre_export(self, blueprint: Blueprint):
|
||||||
|
@ -58,7 +58,7 @@ from authentik.outposts.models import OutpostServiceConnection
|
|||||||
from authentik.policies.models import Policy, PolicyBindingModel
|
from authentik.policies.models import Policy, PolicyBindingModel
|
||||||
from authentik.policies.reputation.models import Reputation
|
from authentik.policies.reputation.models import Reputation
|
||||||
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
|
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
|
||||||
from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser
|
from authentik.providers.scim.models import SCIMGroup, SCIMUser
|
||||||
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
|
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
|
||||||
from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType
|
from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType
|
||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
@ -97,8 +97,8 @@ def excluded_models() -> list[type[Model]]:
|
|||||||
# FIXME: these shouldn't need to be explicitly listed, but rather based off of a mixin
|
# FIXME: these shouldn't need to be explicitly listed, but rather based off of a mixin
|
||||||
FlowToken,
|
FlowToken,
|
||||||
LicenseUsage,
|
LicenseUsage,
|
||||||
SCIMProviderGroup,
|
SCIMGroup,
|
||||||
SCIMProviderUser,
|
SCIMUser,
|
||||||
Tenant,
|
Tenant,
|
||||||
SystemTask,
|
SystemTask,
|
||||||
ConnectionToken,
|
ConnectionToken,
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from json import loads
|
from json import loads
|
||||||
|
|
||||||
from django.db.models import Prefetch
|
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
from django_filters.filters import CharFilter, ModelMultipleChoiceFilter
|
from django_filters.filters import CharFilter, ModelMultipleChoiceFilter
|
||||||
from django_filters.filterset import FilterSet
|
from django_filters.filterset import FilterSet
|
||||||
@ -167,14 +166,8 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
|||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
base_qs = Group.objects.all().select_related("parent").prefetch_related("roles")
|
base_qs = Group.objects.all().select_related("parent").prefetch_related("roles")
|
||||||
|
|
||||||
if self.serializer_class(context={"request": self.request})._should_include_users:
|
if self.serializer_class(context={"request": self.request})._should_include_users:
|
||||||
base_qs = base_qs.prefetch_related("users")
|
base_qs = base_qs.prefetch_related("users")
|
||||||
else:
|
|
||||||
base_qs = base_qs.prefetch_related(
|
|
||||||
Prefetch("users", queryset=User.objects.all().only("id"))
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_qs
|
return base_qs
|
||||||
|
|
||||||
@extend_schema(
|
@extend_schema(
|
||||||
@ -185,14 +178,6 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
|
|||||||
def list(self, request, *args, **kwargs):
|
def list(self, request, *args, **kwargs):
|
||||||
return super().list(request, *args, **kwargs)
|
return super().list(request, *args, **kwargs)
|
||||||
|
|
||||||
@extend_schema(
|
|
||||||
parameters=[
|
|
||||||
OpenApiParameter("include_users", bool, default=True),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
def retrieve(self, request, *args, **kwargs):
|
|
||||||
return super().retrieve(request, *args, **kwargs)
|
|
||||||
|
|
||||||
@permission_required("authentik_core.add_user_to_group")
|
@permission_required("authentik_core.add_user_to_group")
|
||||||
@extend_schema(
|
@extend_schema(
|
||||||
request=UserAccountSerializer,
|
request=UserAccountSerializer,
|
||||||
|
@ -1,79 +0,0 @@
|
|||||||
"""API Utilities"""
|
|
||||||
|
|
||||||
from drf_spectacular.utils import extend_schema
|
|
||||||
from rest_framework.decorators import action
|
|
||||||
from rest_framework.fields import (
|
|
||||||
BooleanField,
|
|
||||||
CharField,
|
|
||||||
)
|
|
||||||
from rest_framework.request import Request
|
|
||||||
from rest_framework.response import Response
|
|
||||||
|
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
|
||||||
from authentik.enterprise.apps import EnterpriseConfig
|
|
||||||
from authentik.lib.utils.reflection import all_subclasses
|
|
||||||
|
|
||||||
|
|
||||||
class TypeCreateSerializer(PassiveSerializer):
|
|
||||||
"""Types of an object that can be created"""
|
|
||||||
|
|
||||||
name = CharField(required=True)
|
|
||||||
description = CharField(required=True)
|
|
||||||
component = CharField(required=True)
|
|
||||||
model_name = CharField(required=True)
|
|
||||||
|
|
||||||
icon_url = CharField(required=False)
|
|
||||||
requires_enterprise = BooleanField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class CreatableType:
|
|
||||||
"""Class to inherit from to mark a model as creatable, even if the model itself is marked
|
|
||||||
as abstract"""
|
|
||||||
|
|
||||||
|
|
||||||
class NonCreatableType:
|
|
||||||
"""Class to inherit from to mark a model as non-creatable even if it is not abstract"""
|
|
||||||
|
|
||||||
|
|
||||||
class TypesMixin:
|
|
||||||
"""Mixin which adds an API endpoint to list all possible types that can be created"""
|
|
||||||
|
|
||||||
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
|
||||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
|
||||||
def types(self, request: Request, additional: list[dict] | None = None) -> Response:
|
|
||||||
"""Get all creatable types"""
|
|
||||||
data = []
|
|
||||||
for subclass in all_subclasses(self.queryset.model):
|
|
||||||
instance = None
|
|
||||||
if subclass._meta.abstract:
|
|
||||||
if not issubclass(subclass, CreatableType):
|
|
||||||
continue
|
|
||||||
# Circumvent the django protection for not being able to instantiate
|
|
||||||
# abstract models. We need a model instance to access .component
|
|
||||||
# and further down .icon_url
|
|
||||||
instance = subclass.__new__(subclass)
|
|
||||||
# Django re-sets abstract = False so we need to override that
|
|
||||||
instance.Meta.abstract = True
|
|
||||||
else:
|
|
||||||
if issubclass(subclass, NonCreatableType):
|
|
||||||
continue
|
|
||||||
instance = subclass()
|
|
||||||
try:
|
|
||||||
data.append(
|
|
||||||
{
|
|
||||||
"name": subclass._meta.verbose_name,
|
|
||||||
"description": subclass.__doc__,
|
|
||||||
"component": instance.component,
|
|
||||||
"model_name": subclass._meta.model_name,
|
|
||||||
"icon_url": getattr(instance, "icon_url", None),
|
|
||||||
"requires_enterprise": isinstance(
|
|
||||||
subclass._meta.app_config, EnterpriseConfig
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except NotImplementedError:
|
|
||||||
continue
|
|
||||||
if additional:
|
|
||||||
data.extend(additional)
|
|
||||||
data = sorted(data, key=lambda x: x["name"])
|
|
||||||
return Response(TypeCreateSerializer(data, many=True).data)
|
|
@ -9,22 +9,18 @@ from rest_framework import mixins
|
|||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.exceptions import PermissionDenied
|
from rest_framework.exceptions import PermissionDenied
|
||||||
from rest_framework.fields import BooleanField, CharField
|
from rest_framework.fields import BooleanField, CharField
|
||||||
from rest_framework.relations import PrimaryKeyRelatedField
|
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
||||||
from rest_framework.viewsets import GenericViewSet
|
from rest_framework.viewsets import GenericViewSet
|
||||||
|
|
||||||
from authentik.blueprints.api import ManagedSerializer
|
from authentik.blueprints.api import ManagedSerializer
|
||||||
from authentik.core.api.object_types import TypesMixin
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import (
|
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||||
MetaNameSerializer,
|
|
||||||
PassiveSerializer,
|
|
||||||
)
|
|
||||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||||
from authentik.core.models import Group, PropertyMapping, User
|
from authentik.core.models import PropertyMapping
|
||||||
from authentik.events.utils import sanitize_item
|
from authentik.events.utils import sanitize_item
|
||||||
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
from authentik.policies.api.exec import PolicyTestSerializer
|
from authentik.policies.api.exec import PolicyTestSerializer
|
||||||
from authentik.rbac.decorators import permission_required
|
from authentik.rbac.decorators import permission_required
|
||||||
|
|
||||||
@ -68,7 +64,6 @@ class PropertyMappingSerializer(ManagedSerializer, ModelSerializer, MetaNameSeri
|
|||||||
|
|
||||||
|
|
||||||
class PropertyMappingViewSet(
|
class PropertyMappingViewSet(
|
||||||
TypesMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
mixins.RetrieveModelMixin,
|
||||||
mixins.DestroyModelMixin,
|
mixins.DestroyModelMixin,
|
||||||
UsedByMixin,
|
UsedByMixin,
|
||||||
@ -77,15 +72,7 @@ class PropertyMappingViewSet(
|
|||||||
):
|
):
|
||||||
"""PropertyMapping Viewset"""
|
"""PropertyMapping Viewset"""
|
||||||
|
|
||||||
class PropertyMappingTestSerializer(PolicyTestSerializer):
|
queryset = PropertyMapping.objects.none()
|
||||||
"""Test property mapping execution for a user/group with context"""
|
|
||||||
|
|
||||||
user = PrimaryKeyRelatedField(queryset=User.objects.all(), required=False, allow_null=True)
|
|
||||||
group = PrimaryKeyRelatedField(
|
|
||||||
queryset=Group.objects.all(), required=False, allow_null=True
|
|
||||||
)
|
|
||||||
|
|
||||||
queryset = PropertyMapping.objects.select_subclasses()
|
|
||||||
serializer_class = PropertyMappingSerializer
|
serializer_class = PropertyMappingSerializer
|
||||||
search_fields = [
|
search_fields = [
|
||||||
"name",
|
"name",
|
||||||
@ -93,9 +80,29 @@ class PropertyMappingViewSet(
|
|||||||
filterset_fields = {"managed": ["isnull"]}
|
filterset_fields = {"managed": ["isnull"]}
|
||||||
ordering = ["name"]
|
ordering = ["name"]
|
||||||
|
|
||||||
|
def get_queryset(self): # pragma: no cover
|
||||||
|
return PropertyMapping.objects.select_subclasses()
|
||||||
|
|
||||||
|
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||||
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
|
def types(self, request: Request) -> Response:
|
||||||
|
"""Get all creatable property-mapping types"""
|
||||||
|
data = []
|
||||||
|
for subclass in all_subclasses(self.queryset.model):
|
||||||
|
subclass: PropertyMapping
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": subclass._meta.verbose_name,
|
||||||
|
"description": subclass.__doc__,
|
||||||
|
"component": subclass().component,
|
||||||
|
"model_name": subclass._meta.model_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
|
||||||
@permission_required("authentik_core.view_propertymapping")
|
@permission_required("authentik_core.view_propertymapping")
|
||||||
@extend_schema(
|
@extend_schema(
|
||||||
request=PropertyMappingTestSerializer(),
|
request=PolicyTestSerializer(),
|
||||||
responses={
|
responses={
|
||||||
200: PropertyMappingTestResultSerializer,
|
200: PropertyMappingTestResultSerializer,
|
||||||
400: OpenApiResponse(description="Invalid parameters"),
|
400: OpenApiResponse(description="Invalid parameters"),
|
||||||
@ -113,39 +120,29 @@ class PropertyMappingViewSet(
|
|||||||
"""Test Property Mapping"""
|
"""Test Property Mapping"""
|
||||||
_mapping: PropertyMapping = self.get_object()
|
_mapping: PropertyMapping = self.get_object()
|
||||||
# Use `get_subclass` to get correct class and correct `.evaluate` implementation
|
# Use `get_subclass` to get correct class and correct `.evaluate` implementation
|
||||||
mapping: PropertyMapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk)
|
mapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk)
|
||||||
# FIXME: when we separate policy mappings between ones for sources
|
# FIXME: when we separate policy mappings between ones for sources
|
||||||
# and ones for providers, we need to make the user field optional for the source mapping
|
# and ones for providers, we need to make the user field optional for the source mapping
|
||||||
test_params = self.PropertyMappingTestSerializer(data=request.data)
|
test_params = PolicyTestSerializer(data=request.data)
|
||||||
if not test_params.is_valid():
|
if not test_params.is_valid():
|
||||||
return Response(test_params.errors, status=400)
|
return Response(test_params.errors, status=400)
|
||||||
|
|
||||||
format_result = str(request.GET.get("format_result", "false")).lower() == "true"
|
format_result = str(request.GET.get("format_result", "false")).lower() == "true"
|
||||||
|
|
||||||
context: dict = test_params.validated_data.get("context", {})
|
# User permission check, only allow mapping testing for users that are readable
|
||||||
context.setdefault("user", None)
|
users = get_objects_for_user(request.user, "authentik_core.view_user").filter(
|
||||||
|
pk=test_params.validated_data["user"].pk
|
||||||
if user := test_params.validated_data.get("user"):
|
)
|
||||||
# User permission check, only allow mapping testing for users that are readable
|
if not users.exists():
|
||||||
users = get_objects_for_user(request.user, "authentik_core.view_user").filter(
|
raise PermissionDenied()
|
||||||
pk=user.pk
|
|
||||||
)
|
|
||||||
if not users.exists():
|
|
||||||
raise PermissionDenied()
|
|
||||||
context["user"] = user
|
|
||||||
if group := test_params.validated_data.get("group"):
|
|
||||||
# Group permission check, only allow mapping testing for groups that are readable
|
|
||||||
groups = get_objects_for_user(request.user, "authentik_core.view_group").filter(
|
|
||||||
pk=group.pk
|
|
||||||
)
|
|
||||||
if not groups.exists():
|
|
||||||
raise PermissionDenied()
|
|
||||||
context["group"] = group
|
|
||||||
context["request"] = self.request
|
|
||||||
|
|
||||||
response_data = {"successful": True, "result": ""}
|
response_data = {"successful": True, "result": ""}
|
||||||
try:
|
try:
|
||||||
result = mapping.evaluate(**context)
|
result = mapping.evaluate(
|
||||||
|
users.first(),
|
||||||
|
self.request,
|
||||||
|
**test_params.validated_data.get("context", {}),
|
||||||
|
)
|
||||||
response_data["result"] = dumps(
|
response_data["result"] = dumps(
|
||||||
sanitize_item(result), indent=(4 if format_result else None)
|
sanitize_item(result), indent=(4 if format_result else None)
|
||||||
)
|
)
|
@ -5,15 +5,20 @@ from django.db.models.query import Q
|
|||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django_filters.filters import BooleanFilter
|
from django_filters.filters import BooleanFilter
|
||||||
from django_filters.filterset import FilterSet
|
from django_filters.filterset import FilterSet
|
||||||
|
from drf_spectacular.utils import extend_schema
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import ReadOnlyField
|
from rest_framework.fields import ReadOnlyField
|
||||||
|
from rest_framework.request import Request
|
||||||
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
||||||
from rest_framework.viewsets import GenericViewSet
|
from rest_framework.viewsets import GenericViewSet
|
||||||
|
|
||||||
from authentik.core.api.object_types import TypesMixin
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import MetaNameSerializer
|
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||||
from authentik.core.models import Provider
|
from authentik.core.models import Provider
|
||||||
|
from authentik.enterprise.apps import EnterpriseConfig
|
||||||
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
|
|
||||||
|
|
||||||
class ProviderSerializer(ModelSerializer, MetaNameSerializer):
|
class ProviderSerializer(ModelSerializer, MetaNameSerializer):
|
||||||
@ -81,7 +86,6 @@ class ProviderFilter(FilterSet):
|
|||||||
|
|
||||||
|
|
||||||
class ProviderViewSet(
|
class ProviderViewSet(
|
||||||
TypesMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
mixins.RetrieveModelMixin,
|
||||||
mixins.DestroyModelMixin,
|
mixins.DestroyModelMixin,
|
||||||
UsedByMixin,
|
UsedByMixin,
|
||||||
@ -100,3 +104,31 @@ class ProviderViewSet(
|
|||||||
|
|
||||||
def get_queryset(self): # pragma: no cover
|
def get_queryset(self): # pragma: no cover
|
||||||
return Provider.objects.select_subclasses()
|
return Provider.objects.select_subclasses()
|
||||||
|
|
||||||
|
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||||
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
|
def types(self, request: Request) -> Response:
|
||||||
|
"""Get all creatable provider types"""
|
||||||
|
data = []
|
||||||
|
for subclass in all_subclasses(self.queryset.model):
|
||||||
|
subclass: Provider
|
||||||
|
if subclass._meta.abstract:
|
||||||
|
continue
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": subclass._meta.verbose_name,
|
||||||
|
"description": subclass.__doc__,
|
||||||
|
"component": subclass().component,
|
||||||
|
"model_name": subclass._meta.model_name,
|
||||||
|
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": _("SAML Provider from Metadata"),
|
||||||
|
"description": _("Create a SAML Provider by importing its Metadata."),
|
||||||
|
"component": "ak-provider-saml-import-form",
|
||||||
|
"model_name": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
@ -17,9 +17,8 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
|
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
|
||||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||||
from authentik.core.api.object_types import TypesMixin
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import MetaNameSerializer
|
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||||
from authentik.core.models import Source, UserSourceConnection
|
from authentik.core.models import Source, UserSourceConnection
|
||||||
from authentik.core.types import UserSettingSerializer
|
from authentik.core.types import UserSettingSerializer
|
||||||
from authentik.lib.utils.file import (
|
from authentik.lib.utils.file import (
|
||||||
@ -28,6 +27,7 @@ from authentik.lib.utils.file import (
|
|||||||
set_file,
|
set_file,
|
||||||
set_file_url,
|
set_file_url,
|
||||||
)
|
)
|
||||||
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
from authentik.rbac.decorators import permission_required
|
from authentik.rbac.decorators import permission_required
|
||||||
|
|
||||||
@ -74,7 +74,6 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
|
|||||||
|
|
||||||
|
|
||||||
class SourceViewSet(
|
class SourceViewSet(
|
||||||
TypesMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
mixins.RetrieveModelMixin,
|
||||||
mixins.DestroyModelMixin,
|
mixins.DestroyModelMixin,
|
||||||
UsedByMixin,
|
UsedByMixin,
|
||||||
@ -133,6 +132,30 @@ class SourceViewSet(
|
|||||||
source: Source = self.get_object()
|
source: Source = self.get_object()
|
||||||
return set_file_url(request, source, "icon")
|
return set_file_url(request, source, "icon")
|
||||||
|
|
||||||
|
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||||
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
|
def types(self, request: Request) -> Response:
|
||||||
|
"""Get all creatable source types"""
|
||||||
|
data = []
|
||||||
|
for subclass in all_subclasses(self.queryset.model):
|
||||||
|
subclass: Source
|
||||||
|
component = ""
|
||||||
|
if len(subclass.__subclasses__()) > 0:
|
||||||
|
continue
|
||||||
|
if subclass._meta.abstract:
|
||||||
|
component = subclass.__bases__[0]().component
|
||||||
|
else:
|
||||||
|
component = subclass().component
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": subclass._meta.verbose_name,
|
||||||
|
"description": subclass.__doc__,
|
||||||
|
"component": component,
|
||||||
|
"model_name": subclass._meta.model_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
|
||||||
@extend_schema(responses={200: UserSettingSerializer(many=True)})
|
@extend_schema(responses={200: UserSettingSerializer(many=True)})
|
||||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
def user_settings(self, request: Request) -> Response:
|
def user_settings(self, request: Request) -> Response:
|
||||||
|
@ -39,12 +39,12 @@ def get_delete_action(manager: Manager) -> str:
|
|||||||
"""Get the delete action from the Foreign key, falls back to cascade"""
|
"""Get the delete action from the Foreign key, falls back to cascade"""
|
||||||
if hasattr(manager, "field"):
|
if hasattr(manager, "field"):
|
||||||
if manager.field.remote_field.on_delete.__name__ == SET_NULL.__name__:
|
if manager.field.remote_field.on_delete.__name__ == SET_NULL.__name__:
|
||||||
return DeleteAction.SET_NULL.value
|
return DeleteAction.SET_NULL.name
|
||||||
if manager.field.remote_field.on_delete.__name__ == SET_DEFAULT.__name__:
|
if manager.field.remote_field.on_delete.__name__ == SET_DEFAULT.__name__:
|
||||||
return DeleteAction.SET_DEFAULT.value
|
return DeleteAction.SET_DEFAULT.name
|
||||||
if hasattr(manager, "source_field"):
|
if hasattr(manager, "source_field"):
|
||||||
return DeleteAction.CASCADE_MANY.value
|
return DeleteAction.CASCADE_MANY.name
|
||||||
return DeleteAction.CASCADE.value
|
return DeleteAction.CASCADE.name
|
||||||
|
|
||||||
|
|
||||||
class UsedByMixin:
|
class UsedByMixin:
|
||||||
|
@ -408,7 +408,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
|||||||
filterset_class = UsersFilter
|
filterset_class = UsersFilter
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
base_qs = User.objects.all().exclude_anonymous()
|
base_qs = User.objects.all()
|
||||||
if self.serializer_class(context={"request": self.request})._should_include_groups:
|
if self.serializer_class(context={"request": self.request})._should_include_groups:
|
||||||
base_qs = base_qs.prefetch_related("ak_groups")
|
base_qs = base_qs.prefetch_related("ak_groups")
|
||||||
return base_qs
|
return base_qs
|
||||||
|
@ -6,16 +6,8 @@ from django.db.models import Model
|
|||||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||||
from drf_spectacular.plumbing import build_basic_type
|
from drf_spectacular.plumbing import build_basic_type
|
||||||
from drf_spectacular.types import OpenApiTypes
|
from drf_spectacular.types import OpenApiTypes
|
||||||
from rest_framework.fields import (
|
from rest_framework.fields import BooleanField, CharField, IntegerField, JSONField
|
||||||
CharField,
|
from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
|
||||||
IntegerField,
|
|
||||||
JSONField,
|
|
||||||
SerializerMethodField,
|
|
||||||
)
|
|
||||||
from rest_framework.serializers import (
|
|
||||||
Serializer,
|
|
||||||
ValidationError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_dict(value: Any):
|
def is_dict(value: Any):
|
||||||
@ -76,6 +68,16 @@ class MetaNameSerializer(PassiveSerializer):
|
|||||||
return f"{obj._meta.app_label}.{obj._meta.model_name}"
|
return f"{obj._meta.app_label}.{obj._meta.model_name}"
|
||||||
|
|
||||||
|
|
||||||
|
class TypeCreateSerializer(PassiveSerializer):
|
||||||
|
"""Types of an object that can be created"""
|
||||||
|
|
||||||
|
name = CharField(required=True)
|
||||||
|
description = CharField(required=True)
|
||||||
|
component = CharField(required=True)
|
||||||
|
model_name = CharField(required=True)
|
||||||
|
requires_enterprise = BooleanField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class CacheSerializer(PassiveSerializer):
|
class CacheSerializer(PassiveSerializer):
|
||||||
"""Generic cache stats for an object"""
|
"""Generic cache stats for an object"""
|
||||||
|
|
||||||
|
@ -31,9 +31,8 @@ class InbuiltBackend(ModelBackend):
|
|||||||
# Since we can't directly pass other variables to signals, and we want to log the method
|
# Since we can't directly pass other variables to signals, and we want to log the method
|
||||||
# and the token used, we assume we're running in a flow and set a variable in the context
|
# and the token used, we assume we're running in a flow and set a variable in the context
|
||||||
flow_plan: FlowPlan = request.session.get(SESSION_KEY_PLAN, FlowPlan(""))
|
flow_plan: FlowPlan = request.session.get(SESSION_KEY_PLAN, FlowPlan(""))
|
||||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, method)
|
flow_plan.context[PLAN_CONTEXT_METHOD] = method
|
||||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS] = cleanse_dict(sanitize_dict(kwargs))
|
||||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].update(cleanse_dict(sanitize_dict(kwargs)))
|
|
||||||
request.session[SESSION_KEY_PLAN] = flow_plan
|
request.session[SESSION_KEY_PLAN] = flow_plan
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""Property Mapping Evaluator"""
|
"""Property Mapping Evaluator"""
|
||||||
|
|
||||||
from types import CodeType
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
@ -25,8 +24,6 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
|||||||
"""Custom Evaluator that adds some different context variables."""
|
"""Custom Evaluator that adds some different context variables."""
|
||||||
|
|
||||||
dry_run: bool
|
dry_run: bool
|
||||||
model: Model
|
|
||||||
_compiled: CodeType | None = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -36,32 +33,23 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
|||||||
dry_run: bool | None = False,
|
dry_run: bool | None = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
|
||||||
if hasattr(model, "name"):
|
if hasattr(model, "name"):
|
||||||
_filename = model.name
|
_filename = model.name
|
||||||
else:
|
else:
|
||||||
_filename = str(model)
|
_filename = str(model)
|
||||||
super().__init__(filename=_filename)
|
super().__init__(filename=_filename)
|
||||||
self.dry_run = dry_run
|
|
||||||
self.set_context(user, request, **kwargs)
|
|
||||||
|
|
||||||
def set_context(
|
|
||||||
self,
|
|
||||||
user: User | None = None,
|
|
||||||
request: HttpRequest | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
req = PolicyRequest(user=User())
|
req = PolicyRequest(user=User())
|
||||||
req.obj = self.model
|
req.obj = model
|
||||||
if user:
|
if user:
|
||||||
req.user = user
|
req.user = user
|
||||||
self._context["user"] = user
|
self._context["user"] = user
|
||||||
if request:
|
if request:
|
||||||
req.http_request = request
|
req.http_request = request
|
||||||
req.context.update(**kwargs)
|
|
||||||
self._context["request"] = req
|
self._context["request"] = req
|
||||||
|
req.context.update(**kwargs)
|
||||||
self._context.update(**kwargs)
|
self._context.update(**kwargs)
|
||||||
self._globals["SkipObject"] = SkipObjectException
|
self._globals["SkipObject"] = SkipObjectException
|
||||||
|
self.dry_run = dry_run
|
||||||
|
|
||||||
def handle_error(self, exc: Exception, expression_source: str):
|
def handle_error(self, exc: Exception, expression_source: str):
|
||||||
"""Exception Handler"""
|
"""Exception Handler"""
|
||||||
@ -83,9 +71,3 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
|||||||
def evaluate(self, *args, **kwargs) -> Any:
|
def evaluate(self, *args, **kwargs) -> Any:
|
||||||
with PROPERTY_MAPPING_TIME.labels(mapping_name=self._filename).time():
|
with PROPERTY_MAPPING_TIME.labels(mapping_name=self._filename).time():
|
||||||
return super().evaluate(*args, **kwargs)
|
return super().evaluate(*args, **kwargs)
|
||||||
|
|
||||||
def compile(self, expression: str | None = None) -> Any:
|
|
||||||
if not self._compiled:
|
|
||||||
compiled = super().compile(expression or self.model.expression)
|
|
||||||
self._compiled = compiled
|
|
||||||
return self._compiled
|
|
||||||
|
@ -6,11 +6,6 @@ from authentik.lib.sentry import SentryIgnoredException
|
|||||||
class PropertyMappingExpressionException(SentryIgnoredException):
|
class PropertyMappingExpressionException(SentryIgnoredException):
|
||||||
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated."""
|
"""Error when a PropertyMapping Exception expression could not be parsed or evaluated."""
|
||||||
|
|
||||||
def __init__(self, exc: Exception, mapping) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.exc = exc
|
|
||||||
self.mapping = mapping
|
|
||||||
|
|
||||||
|
|
||||||
class SkipObjectException(PropertyMappingExpressionException):
|
class SkipObjectException(PropertyMappingExpressionException):
|
||||||
"""Exception which can be raised in a property mapping to skip syncing an object.
|
"""Exception which can be raised in a property mapping to skip syncing an object.
|
||||||
|
@ -10,7 +10,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
|||||||
from django.db.models import Count
|
from django.db.models import Count
|
||||||
|
|
||||||
import authentik.core.models
|
import authentik.core.models
|
||||||
import authentik.lib.models
|
import authentik.lib.validators
|
||||||
|
|
||||||
|
|
||||||
def migrate_sessions(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
def migrate_sessions(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
@ -160,7 +160,7 @@ class Migration(migrations.Migration):
|
|||||||
field=models.TextField(
|
field=models.TextField(
|
||||||
blank=True,
|
blank=True,
|
||||||
default="",
|
default="",
|
||||||
validators=[authentik.lib.models.DomainlessFormattedURLValidator()],
|
validators=[authentik.lib.validators.DomainlessFormattedURLValidator()],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
migrations.RunPython(
|
migrations.RunPython(
|
||||||
|
23
authentik/core/migrations/0036_user_group_soft_delete.py
Normal file
23
authentik/core/migrations/0036_user_group_soft_delete.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# Generated by Django 5.0.4 on 2024-04-23 16:59
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_core", "0035_alter_group_options_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="group",
|
||||||
|
name="deleted_at",
|
||||||
|
field=models.DateTimeField(blank=True, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="user",
|
||||||
|
name="deleted_at",
|
||||||
|
field=models.DateTimeField(blank=True, null=True),
|
||||||
|
),
|
||||||
|
]
|
@ -15,7 +15,6 @@ from django.http import HttpRequest
|
|||||||
from django.utils.functional import SimpleLazyObject, cached_property
|
from django.utils.functional import SimpleLazyObject, cached_property
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django_cte import CTEQuerySet, With
|
|
||||||
from guardian.conf import settings
|
from guardian.conf import settings
|
||||||
from guardian.mixins import GuardianUserMixin
|
from guardian.mixins import GuardianUserMixin
|
||||||
from model_utils.managers import InheritanceManager
|
from model_utils.managers import InheritanceManager
|
||||||
@ -29,10 +28,12 @@ from authentik.lib.avatars import get_avatar
|
|||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.models import (
|
from authentik.lib.models import (
|
||||||
CreatedUpdatedModel,
|
CreatedUpdatedModel,
|
||||||
DomainlessFormattedURLValidator,
|
|
||||||
SerializerModel,
|
SerializerModel,
|
||||||
|
SoftDeleteModel,
|
||||||
|
SoftDeleteQuerySet,
|
||||||
)
|
)
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
|
from authentik.lib.validators import DomainlessFormattedURLValidator
|
||||||
from authentik.policies.models import PolicyBindingModel
|
from authentik.policies.models import PolicyBindingModel
|
||||||
from authentik.tenants.models import DEFAULT_TOKEN_DURATION, DEFAULT_TOKEN_LENGTH
|
from authentik.tenants.models import DEFAULT_TOKEN_DURATION, DEFAULT_TOKEN_LENGTH
|
||||||
from authentik.tenants.utils import get_current_tenant, get_unique_identifier
|
from authentik.tenants.utils import get_current_tenant, get_unique_identifier
|
||||||
@ -57,8 +58,6 @@ options.DEFAULT_NAMES = options.DEFAULT_NAMES + (
|
|||||||
"authentik_used_by_shadows",
|
"authentik_used_by_shadows",
|
||||||
)
|
)
|
||||||
|
|
||||||
GROUP_RECURSION_LIMIT = 20
|
|
||||||
|
|
||||||
|
|
||||||
def default_token_duration() -> datetime:
|
def default_token_duration() -> datetime:
|
||||||
"""Default duration a Token is valid"""
|
"""Default duration a Token is valid"""
|
||||||
@ -99,41 +98,7 @@ class UserTypes(models.TextChoices):
|
|||||||
INTERNAL_SERVICE_ACCOUNT = "internal_service_account"
|
INTERNAL_SERVICE_ACCOUNT = "internal_service_account"
|
||||||
|
|
||||||
|
|
||||||
class GroupQuerySet(CTEQuerySet):
|
class Group(SoftDeleteModel, SerializerModel):
|
||||||
def with_children_recursive(self):
|
|
||||||
"""Recursively get all groups that have the current queryset as parents
|
|
||||||
or are indirectly related."""
|
|
||||||
|
|
||||||
def make_cte(cte):
|
|
||||||
"""Build the query that ends up in WITH RECURSIVE"""
|
|
||||||
# Start from self, aka the current query
|
|
||||||
# Add a depth attribute to limit the recursion
|
|
||||||
return self.annotate(
|
|
||||||
relative_depth=models.Value(0, output_field=models.IntegerField())
|
|
||||||
).union(
|
|
||||||
# Here is the recursive part of the query. cte refers to the previous iteration
|
|
||||||
# Only select groups for which the parent is part of the previous iteration
|
|
||||||
# and increase the depth
|
|
||||||
# Finally, limit the depth
|
|
||||||
cte.join(Group, group_uuid=cte.col.parent_id)
|
|
||||||
.annotate(
|
|
||||||
relative_depth=models.ExpressionWrapper(
|
|
||||||
cte.col.relative_depth
|
|
||||||
+ models.Value(1, output_field=models.IntegerField()),
|
|
||||||
output_field=models.IntegerField(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.filter(relative_depth__lt=GROUP_RECURSION_LIMIT),
|
|
||||||
all=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build the recursive query, see above
|
|
||||||
cte = With.recursive(make_cte)
|
|
||||||
# Return the result, as a usable queryset for Group.
|
|
||||||
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
|
|
||||||
|
|
||||||
|
|
||||||
class Group(SerializerModel):
|
|
||||||
"""Group model which supports a basic hierarchy and has attributes"""
|
"""Group model which supports a basic hierarchy and has attributes"""
|
||||||
|
|
||||||
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
@ -155,8 +120,6 @@ class Group(SerializerModel):
|
|||||||
)
|
)
|
||||||
attributes = models.JSONField(default=dict, blank=True)
|
attributes = models.JSONField(default=dict, blank=True)
|
||||||
|
|
||||||
objects = GroupQuerySet.as_manager()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serializer(self) -> Serializer:
|
def serializer(self) -> Serializer:
|
||||||
from authentik.core.api.groups import GroupSerializer
|
from authentik.core.api.groups import GroupSerializer
|
||||||
@ -175,11 +138,36 @@ class Group(SerializerModel):
|
|||||||
return user.all_groups().filter(group_uuid=self.group_uuid).exists()
|
return user.all_groups().filter(group_uuid=self.group_uuid).exists()
|
||||||
|
|
||||||
def children_recursive(self: Self | QuerySet["Group"]) -> QuerySet["Group"]:
|
def children_recursive(self: Self | QuerySet["Group"]) -> QuerySet["Group"]:
|
||||||
"""Compatibility layer for Group.objects.with_children_recursive()"""
|
"""Recursively get all groups that have this as parent or are indirectly related"""
|
||||||
qs = self
|
direct_groups = []
|
||||||
if not isinstance(self, QuerySet):
|
if isinstance(self, QuerySet):
|
||||||
qs = Group.objects.filter(group_uuid=self.group_uuid)
|
direct_groups = list(x for x in self.all().values_list("pk", flat=True).iterator())
|
||||||
return qs.with_children_recursive()
|
else:
|
||||||
|
direct_groups = [self.pk]
|
||||||
|
if len(direct_groups) < 1:
|
||||||
|
return Group.objects.none()
|
||||||
|
query = """
|
||||||
|
WITH RECURSIVE parents AS (
|
||||||
|
SELECT authentik_core_group.*, 0 AS relative_depth
|
||||||
|
FROM authentik_core_group
|
||||||
|
WHERE authentik_core_group.group_uuid = ANY(%s)
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
SELECT authentik_core_group.*, parents.relative_depth + 1
|
||||||
|
FROM authentik_core_group, parents
|
||||||
|
WHERE (
|
||||||
|
authentik_core_group.group_uuid = parents.parent_id and
|
||||||
|
parents.relative_depth < 20
|
||||||
|
)
|
||||||
|
)
|
||||||
|
SELECT group_uuid
|
||||||
|
FROM parents
|
||||||
|
GROUP BY group_uuid, name
|
||||||
|
ORDER BY name;
|
||||||
|
"""
|
||||||
|
group_pks = [group.pk for group in Group.objects.raw(query, [direct_groups]).iterator()]
|
||||||
|
return Group.objects.filter(pk__in=group_pks)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Group {self.name}"
|
return f"Group {self.name}"
|
||||||
@ -200,31 +188,21 @@ class Group(SerializerModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class UserQuerySet(models.QuerySet):
|
|
||||||
"""User queryset"""
|
|
||||||
|
|
||||||
def exclude_anonymous(self):
|
|
||||||
"""Exclude anonymous user"""
|
|
||||||
return self.exclude(**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME})
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager(DjangoUserManager):
|
class UserManager(DjangoUserManager):
|
||||||
"""User manager that doesn't assign is_superuser and is_staff"""
|
"""User manager that doesn't assign is_superuser and is_staff"""
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
"""Create special user queryset"""
|
"""Create special user queryset"""
|
||||||
return UserQuerySet(self.model, using=self._db)
|
return SoftDeleteQuerySet(self.model, using=self._db).exclude(
|
||||||
|
**{User.USERNAME_FIELD: settings.ANONYMOUS_USER_NAME}
|
||||||
|
)
|
||||||
|
|
||||||
def create_user(self, username, email=None, password=None, **extra_fields):
|
def create_user(self, username, email=None, password=None, **extra_fields):
|
||||||
"""User manager that doesn't assign is_superuser and is_staff"""
|
"""User manager that doesn't assign is_superuser and is_staff"""
|
||||||
return self._create_user(username, email, password, **extra_fields)
|
return self._create_user(username, email, password, **extra_fields)
|
||||||
|
|
||||||
def exclude_anonymous(self) -> QuerySet:
|
|
||||||
"""Exclude anonymous user"""
|
|
||||||
return self.get_queryset().exclude_anonymous()
|
|
||||||
|
|
||||||
|
class User(SoftDeleteModel, SerializerModel, GuardianUserMixin, AbstractUser):
|
||||||
class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
|
||||||
"""authentik User model, based on django's contrib auth user model."""
|
"""authentik User model, based on django's contrib auth user model."""
|
||||||
|
|
||||||
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
||||||
@ -246,8 +224,10 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):
|
|||||||
return User._meta.get_field("path").default
|
return User._meta.get_field("path").default
|
||||||
|
|
||||||
def all_groups(self) -> QuerySet[Group]:
|
def all_groups(self) -> QuerySet[Group]:
|
||||||
"""Recursively get all groups this user is a member of."""
|
"""Recursively get all groups this user is a member of.
|
||||||
return self.ak_groups.all().with_children_recursive()
|
At least one query is done to get the direct groups of the user, with groups
|
||||||
|
there are at most 3 queries done"""
|
||||||
|
return Group.children_recursive(self.ak_groups.all())
|
||||||
|
|
||||||
def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]:
|
def group_attributes(self, request: HttpRequest | None = None) -> dict[str, Any]:
|
||||||
"""Get a dictionary containing the attributes from all groups the user belongs to,
|
"""Get a dictionary containing the attributes from all groups the user belongs to,
|
||||||
@ -389,10 +369,6 @@ class Provider(SerializerModel):
|
|||||||
Can return None for providers that are not URL-based"""
|
Can return None for providers that are not URL-based"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
|
||||||
def icon_url(self) -> str | None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
"""Return component used to edit this object"""
|
"""Return component used to edit this object"""
|
||||||
@ -784,7 +760,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
|||||||
try:
|
try:
|
||||||
return evaluator.evaluate(self.expression)
|
return evaluator.evaluate(self.expression)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise PropertyMappingExpressionException(self, exc) from exc
|
raise PropertyMappingExpressionException(exc) from exc
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Property Mapping {self.name}"
|
return f"Property Mapping {self.name}"
|
||||||
|
@ -23,17 +23,6 @@ class TestGroupsAPI(APITestCase):
|
|||||||
response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
|
response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
def test_retrieve_with_users(self):
|
|
||||||
"""Test retrieve with users"""
|
|
||||||
admin = create_test_admin_user()
|
|
||||||
group = Group.objects.create(name=generate_id())
|
|
||||||
self.client.force_login(admin)
|
|
||||||
response = self.client.get(
|
|
||||||
reverse("authentik_api:group-detail", kwargs={"pk": group.pk}),
|
|
||||||
{"include_users": "true"},
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
def test_add_user(self):
|
def test_add_user(self):
|
||||||
"""Test add_user"""
|
"""Test add_user"""
|
||||||
group = Group.objects.create(name=generate_id())
|
group = Group.objects.create(name=generate_id())
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
"""authentik core models tests"""
|
"""authentik core models tests"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import timedelta
|
from time import sleep
|
||||||
|
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory, TestCase
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from freezegun import freeze_time
|
|
||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
|
|
||||||
from authentik.core.models import Provider, Source, Token
|
from authentik.core.models import Provider, Source, Token
|
||||||
|
from authentik.flows.models import Stage
|
||||||
from authentik.lib.utils.reflection import all_subclasses
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
|
|
||||||
|
|
||||||
@ -17,20 +17,18 @@ class TestModels(TestCase):
|
|||||||
|
|
||||||
def test_token_expire(self):
|
def test_token_expire(self):
|
||||||
"""Test token expiring"""
|
"""Test token expiring"""
|
||||||
with freeze_time() as freeze:
|
token = Token.objects.create(expires=now(), user=get_anonymous_user())
|
||||||
token = Token.objects.create(expires=now(), user=get_anonymous_user())
|
sleep(0.5)
|
||||||
freeze.tick(timedelta(seconds=1))
|
self.assertTrue(token.is_expired)
|
||||||
self.assertTrue(token.is_expired)
|
|
||||||
|
|
||||||
def test_token_expire_no_expire(self):
|
def test_token_expire_no_expire(self):
|
||||||
"""Test token expiring with "expiring" set"""
|
"""Test token expiring with "expiring" set"""
|
||||||
with freeze_time() as freeze:
|
token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
|
||||||
token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
|
sleep(0.5)
|
||||||
freeze.tick(timedelta(seconds=1))
|
self.assertFalse(token.is_expired)
|
||||||
self.assertFalse(token.is_expired)
|
|
||||||
|
|
||||||
|
|
||||||
def source_tester_factory(test_model: type[Source]) -> Callable:
|
def source_tester_factory(test_model: type[Stage]) -> Callable:
|
||||||
"""Test source"""
|
"""Test source"""
|
||||||
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
@ -38,19 +36,19 @@ def source_tester_factory(test_model: type[Source]) -> Callable:
|
|||||||
|
|
||||||
def tester(self: TestModels):
|
def tester(self: TestModels):
|
||||||
model_class = None
|
model_class = None
|
||||||
if test_model._meta.abstract:
|
if test_model._meta.abstract: # pragma: no cover
|
||||||
model_class = [x for x in test_model.__bases__ if issubclass(x, Source)][0]()
|
model_class = test_model.__bases__[0]()
|
||||||
else:
|
else:
|
||||||
model_class = test_model()
|
model_class = test_model()
|
||||||
model_class.slug = "test"
|
model_class.slug = "test"
|
||||||
self.assertIsNotNone(model_class.component)
|
self.assertIsNotNone(model_class.component)
|
||||||
model_class.ui_login_button(request)
|
_ = model_class.ui_login_button(request)
|
||||||
model_class.ui_user_settings()
|
_ = model_class.ui_user_settings()
|
||||||
|
|
||||||
return tester
|
return tester
|
||||||
|
|
||||||
|
|
||||||
def provider_tester_factory(test_model: type[Provider]) -> Callable:
|
def provider_tester_factory(test_model: type[Stage]) -> Callable:
|
||||||
"""Test provider"""
|
"""Test provider"""
|
||||||
|
|
||||||
def tester(self: TestModels):
|
def tester(self: TestModels):
|
||||||
|
@ -6,10 +6,9 @@ from django.urls import reverse
|
|||||||
from rest_framework.serializers import ValidationError
|
from rest_framework.serializers import ValidationError
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||||
from authentik.core.models import Group, PropertyMapping
|
from authentik.core.models import PropertyMapping
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
from authentik.lib.generators import generate_id
|
|
||||||
|
|
||||||
|
|
||||||
class TestPropertyMappingAPI(APITestCase):
|
class TestPropertyMappingAPI(APITestCase):
|
||||||
@ -17,40 +16,23 @@ class TestPropertyMappingAPI(APITestCase):
|
|||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.mapping = PropertyMapping.objects.create(
|
||||||
|
name="dummy", expression="""return {'foo': 'bar'}"""
|
||||||
|
)
|
||||||
self.user = create_test_admin_user()
|
self.user = create_test_admin_user()
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
|
|
||||||
def test_test_call(self):
|
def test_test_call(self):
|
||||||
"""Test PropertyMappings's test endpoint"""
|
"""Test PropertMappings's test endpoint"""
|
||||||
mapping = PropertyMapping.objects.create(
|
|
||||||
name="dummy", expression="""return {'foo': 'bar', 'baz': user.username}"""
|
|
||||||
)
|
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse("authentik_api:propertymapping-test", kwargs={"pk": mapping.pk}),
|
reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}),
|
||||||
data={
|
data={
|
||||||
"user": self.user.pk,
|
"user": self.user.pk,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
response.content.decode(),
|
response.content.decode(),
|
||||||
{"result": dumps({"foo": "bar", "baz": self.user.username}), "successful": True},
|
{"result": dumps({"foo": "bar"}), "successful": True},
|
||||||
)
|
|
||||||
|
|
||||||
def test_test_call_group(self):
|
|
||||||
"""Test PropertyMappings's test endpoint"""
|
|
||||||
mapping = PropertyMapping.objects.create(
|
|
||||||
name="dummy", expression="""return {'foo': 'bar', 'baz': group.name}"""
|
|
||||||
)
|
|
||||||
group = Group.objects.create(name=generate_id())
|
|
||||||
response = self.client.post(
|
|
||||||
reverse("authentik_api:propertymapping-test", kwargs={"pk": mapping.pk}),
|
|
||||||
data={
|
|
||||||
"group": group.pk,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertJSONEqual(
|
|
||||||
response.content.decode(),
|
|
||||||
{"result": dumps({"foo": "bar", "baz": group.name}), "successful": True},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_validate(self):
|
def test_validate(self):
|
||||||
|
@ -42,8 +42,8 @@ class TestUsersAvatars(APITestCase):
|
|||||||
with Mocker() as mocker:
|
with Mocker() as mocker:
|
||||||
mocker.head(
|
mocker.head(
|
||||||
(
|
(
|
||||||
"https://www.gravatar.com/avatar/76eb3c74c8beb6faa037f1b6e2ecb3e252bdac"
|
"https://secure.gravatar.com/avatar/84730f9c1851d1ea03f1a"
|
||||||
"6cf71fb567ae36025a9d4ea86b?size=158&rating=g&default=404"
|
"a9ed85bd1ea?size=158&rating=g&default=404"
|
||||||
),
|
),
|
||||||
text="foo",
|
text="foo",
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,7 @@ from authentik.core.api.applications import ApplicationViewSet
|
|||||||
from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet
|
from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet
|
||||||
from authentik.core.api.devices import AdminDeviceViewSet, DeviceViewSet
|
from authentik.core.api.devices import AdminDeviceViewSet, DeviceViewSet
|
||||||
from authentik.core.api.groups import GroupViewSet
|
from authentik.core.api.groups import GroupViewSet
|
||||||
from authentik.core.api.property_mappings import PropertyMappingViewSet
|
from authentik.core.api.propertymappings import PropertyMappingViewSet
|
||||||
from authentik.core.api.providers import ProviderViewSet
|
from authentik.core.api.providers import ProviderViewSet
|
||||||
from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSet
|
from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSet
|
||||||
from authentik.core.api.tokens import TokenViewSet
|
from authentik.core.api.tokens import TokenViewSet
|
||||||
|
@ -92,11 +92,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
|||||||
@property
|
@property
|
||||||
def kid(self):
|
def kid(self):
|
||||||
"""Get Key ID used for JWKS"""
|
"""Get Key ID used for JWKS"""
|
||||||
return (
|
return md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else "" # nosec
|
||||||
md5(self.key_data.encode("utf-8"), usedforsecurity=False).hexdigest()
|
|
||||||
if self.key_data
|
|
||||||
else ""
|
|
||||||
) # nosec
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"Certificate-Key Pair {self.name}"
|
return f"Certificate-Key Pair {self.name}"
|
||||||
|
@ -241,7 +241,7 @@ class TestCrypto(APITestCase):
|
|||||||
"model_name": "oauth2provider",
|
"model_name": "oauth2provider",
|
||||||
"pk": str(provider.pk),
|
"pk": str(provider.pk),
|
||||||
"name": str(provider),
|
"name": str(provider),
|
||||||
"action": DeleteAction.SET_NULL.value,
|
"action": DeleteAction.SET_NULL.name,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -132,7 +132,7 @@ class LicenseKey:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def base_user_qs() -> QuerySet:
|
def base_user_qs() -> QuerySet:
|
||||||
"""Base query set for all users"""
|
"""Base query set for all users"""
|
||||||
return User.objects.all().exclude_anonymous().exclude(is_active=False)
|
return User.objects.all().exclude(is_active=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_user_count():
|
def get_default_user_count():
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
"""GoogleWorkspaceProviderGroup API Views"""
|
"""GoogleWorkspaceProviderGroup API Views"""
|
||||||
|
|
||||||
from rest_framework import mixins
|
from rest_framework.viewsets import ModelViewSet
|
||||||
from rest_framework.serializers import ModelSerializer
|
|
||||||
from rest_framework.viewsets import GenericViewSet
|
|
||||||
|
|
||||||
|
from authentik.core.api.sources import SourceSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.users import UserGroupSerializer
|
from authentik.core.api.users import UserGroupSerializer
|
||||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderGroup
|
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderGroup
|
||||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
|
class GoogleWorkspaceProviderGroupSerializer(SourceSerializer):
|
||||||
"""GoogleWorkspaceProviderGroup Serializer"""
|
"""GoogleWorkspaceProviderGroup Serializer"""
|
||||||
|
|
||||||
group_obj = UserGroupSerializer(source="group", read_only=True)
|
group_obj = UserGroupSerializer(source="group", read_only=True)
|
||||||
@ -20,24 +18,12 @@ class GoogleWorkspaceProviderGroupSerializer(ModelSerializer):
|
|||||||
model = GoogleWorkspaceProviderGroup
|
model = GoogleWorkspaceProviderGroup
|
||||||
fields = [
|
fields = [
|
||||||
"id",
|
"id",
|
||||||
"google_id",
|
|
||||||
"group",
|
"group",
|
||||||
"group_obj",
|
"group_obj",
|
||||||
"provider",
|
|
||||||
"attributes",
|
|
||||||
]
|
]
|
||||||
extra_kwargs = {"attributes": {"read_only": True}}
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceProviderGroupViewSet(
|
class GoogleWorkspaceProviderGroupViewSet(UsedByMixin, ModelViewSet):
|
||||||
mixins.CreateModelMixin,
|
|
||||||
OutgoingSyncConnectionCreateMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
|
||||||
mixins.DestroyModelMixin,
|
|
||||||
UsedByMixin,
|
|
||||||
mixins.ListModelMixin,
|
|
||||||
GenericViewSet,
|
|
||||||
):
|
|
||||||
"""GoogleWorkspaceProviderGroup Viewset"""
|
"""GoogleWorkspaceProviderGroup Viewset"""
|
||||||
|
|
||||||
queryset = GoogleWorkspaceProviderGroup.objects.all().select_related("group")
|
queryset = GoogleWorkspaceProviderGroup.objects.all().select_related("group")
|
||||||
|
@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
|
|||||||
from drf_spectacular.utils import extend_schema_field
|
from drf_spectacular.utils import extend_schema_field
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderMapping
|
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderMapping
|
||||||
|
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
"""GoogleWorkspaceProviderUser API Views"""
|
"""GoogleWorkspaceProviderUser API Views"""
|
||||||
|
|
||||||
from rest_framework import mixins
|
from rest_framework.viewsets import ModelViewSet
|
||||||
from rest_framework.serializers import ModelSerializer
|
|
||||||
from rest_framework.viewsets import GenericViewSet
|
|
||||||
|
|
||||||
from authentik.core.api.groups import GroupMemberSerializer
|
from authentik.core.api.groups import GroupMemberSerializer
|
||||||
|
from authentik.core.api.sources import SourceSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderUser
|
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProviderUser
|
||||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
|
class GoogleWorkspaceProviderUserSerializer(SourceSerializer):
|
||||||
"""GoogleWorkspaceProviderUser Serializer"""
|
"""GoogleWorkspaceProviderUser Serializer"""
|
||||||
|
|
||||||
user_obj = GroupMemberSerializer(source="user", read_only=True)
|
user_obj = GroupMemberSerializer(source="user", read_only=True)
|
||||||
@ -20,24 +18,12 @@ class GoogleWorkspaceProviderUserSerializer(ModelSerializer):
|
|||||||
model = GoogleWorkspaceProviderUser
|
model = GoogleWorkspaceProviderUser
|
||||||
fields = [
|
fields = [
|
||||||
"id",
|
"id",
|
||||||
"google_id",
|
|
||||||
"user",
|
"user",
|
||||||
"user_obj",
|
"user_obj",
|
||||||
"provider",
|
|
||||||
"attributes",
|
|
||||||
]
|
]
|
||||||
extra_kwargs = {"attributes": {"read_only": True}}
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceProviderUserViewSet(
|
class GoogleWorkspaceProviderUserViewSet(UsedByMixin, ModelViewSet):
|
||||||
mixins.CreateModelMixin,
|
|
||||||
OutgoingSyncConnectionCreateMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
|
||||||
mixins.DestroyModelMixin,
|
|
||||||
UsedByMixin,
|
|
||||||
mixins.ListModelMixin,
|
|
||||||
GenericViewSet,
|
|
||||||
):
|
|
||||||
"""GoogleWorkspaceProviderUser Viewset"""
|
"""GoogleWorkspaceProviderUser Viewset"""
|
||||||
|
|
||||||
queryset = GoogleWorkspaceProviderUser.objects.all().select_related("user")
|
queryset = GoogleWorkspaceProviderUser.objects.all().select_related("user")
|
||||||
|
@ -1,22 +1,28 @@
|
|||||||
|
from deepmerge import always_merger
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
|
|
||||||
|
from authentik.core.expression.exceptions import (
|
||||||
|
PropertyMappingExpressionException,
|
||||||
|
SkipObjectException,
|
||||||
|
)
|
||||||
from authentik.core.models import Group
|
from authentik.core.models import Group
|
||||||
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
||||||
from authentik.enterprise.providers.google_workspace.models import (
|
from authentik.enterprise.providers.google_workspace.models import (
|
||||||
GoogleWorkspaceProvider,
|
|
||||||
GoogleWorkspaceProviderGroup,
|
GoogleWorkspaceProviderGroup,
|
||||||
GoogleWorkspaceProviderMapping,
|
GoogleWorkspaceProviderMapping,
|
||||||
GoogleWorkspaceProviderUser,
|
GoogleWorkspaceProviderUser,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.sync.outgoing.base import Direction
|
from authentik.lib.sync.outgoing.base import Direction
|
||||||
from authentik.lib.sync.outgoing.exceptions import (
|
from authentik.lib.sync.outgoing.exceptions import (
|
||||||
NotFoundSyncException,
|
NotFoundSyncException,
|
||||||
ObjectExistsSyncException,
|
ObjectExistsSyncException,
|
||||||
|
StopSync,
|
||||||
TransientSyncException,
|
TransientSyncException,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||||
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceGroupClient(
|
class GoogleWorkspaceGroupClient(
|
||||||
@ -28,21 +34,41 @@ class GoogleWorkspaceGroupClient(
|
|||||||
connection_type_query = "group"
|
connection_type_query = "group"
|
||||||
can_discover = True
|
can_discover = True
|
||||||
|
|
||||||
def __init__(self, provider: GoogleWorkspaceProvider) -> None:
|
def to_schema(self, obj: Group, creating: bool) -> dict:
|
||||||
super().__init__(provider)
|
|
||||||
self.mapper = PropertyMappingManager(
|
|
||||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses(),
|
|
||||||
GoogleWorkspaceProviderMapping,
|
|
||||||
["group", "provider", "connection"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_schema(self, obj: Group, connection: GoogleWorkspaceProviderGroup) -> dict:
|
|
||||||
"""Convert authentik group"""
|
"""Convert authentik group"""
|
||||||
return super().to_schema(
|
raw_google_group = {
|
||||||
obj,
|
"email": f"{slugify(obj.name)}@{self.provider.default_group_email_domain}"
|
||||||
connection=connection,
|
}
|
||||||
email=f"{slugify(obj.name)}@{self.provider.default_group_email_domain}",
|
for mapping in (
|
||||||
)
|
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
|
||||||
|
):
|
||||||
|
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = mapping.evaluate(
|
||||||
|
user=None,
|
||||||
|
request=None,
|
||||||
|
group=obj,
|
||||||
|
provider=self.provider,
|
||||||
|
creating=creating,
|
||||||
|
)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
always_merger.merge(raw_google_group, value)
|
||||||
|
except SkipObjectException as exc:
|
||||||
|
raise exc from exc
|
||||||
|
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||||
|
# Value error can be raised when assigning invalid data to an attribute
|
||||||
|
Event.new(
|
||||||
|
EventAction.CONFIGURATION_ERROR,
|
||||||
|
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||||
|
mapping=mapping,
|
||||||
|
).save()
|
||||||
|
raise StopSync(exc, obj, mapping) from exc
|
||||||
|
if not raw_google_group:
|
||||||
|
raise StopSync(ValueError("No group mappings configured"), obj)
|
||||||
|
|
||||||
|
return raw_google_group
|
||||||
|
|
||||||
def delete(self, obj: Group):
|
def delete(self, obj: Group):
|
||||||
"""Delete group"""
|
"""Delete group"""
|
||||||
@ -61,7 +87,7 @@ class GoogleWorkspaceGroupClient(
|
|||||||
|
|
||||||
def create(self, group: Group):
|
def create(self, group: Group):
|
||||||
"""Create group from scratch and create a connection object"""
|
"""Create group from scratch and create a connection object"""
|
||||||
google_group = self.to_schema(group, None)
|
google_group = self.to_schema(group, True)
|
||||||
self.check_email_valid(google_group["email"])
|
self.check_email_valid(google_group["email"])
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
try:
|
try:
|
||||||
@ -74,32 +100,24 @@ class GoogleWorkspaceGroupClient(
|
|||||||
self.directory_service.groups().get(groupKey=google_group["email"])
|
self.directory_service.groups().get(groupKey=google_group["email"])
|
||||||
)
|
)
|
||||||
return GoogleWorkspaceProviderGroup.objects.create(
|
return GoogleWorkspaceProviderGroup.objects.create(
|
||||||
provider=self.provider,
|
provider=self.provider, group=group, google_id=group_data["id"]
|
||||||
group=group,
|
|
||||||
google_id=group_data["id"],
|
|
||||||
attributes=group_data,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return GoogleWorkspaceProviderGroup.objects.create(
|
return GoogleWorkspaceProviderGroup.objects.create(
|
||||||
provider=self.provider,
|
provider=self.provider, group=group, google_id=response["id"]
|
||||||
group=group,
|
|
||||||
google_id=response["id"],
|
|
||||||
attributes=response,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, group: Group, connection: GoogleWorkspaceProviderGroup):
|
def update(self, group: Group, connection: GoogleWorkspaceProviderGroup):
|
||||||
"""Update existing group"""
|
"""Update existing group"""
|
||||||
google_group = self.to_schema(group, connection)
|
google_group = self.to_schema(group, False)
|
||||||
self.check_email_valid(google_group["email"])
|
self.check_email_valid(google_group["email"])
|
||||||
try:
|
try:
|
||||||
response = self._request(
|
return self._request(
|
||||||
self.directory_service.groups().update(
|
self.directory_service.groups().update(
|
||||||
groupKey=connection.google_id,
|
groupKey=connection.google_id,
|
||||||
body=google_group,
|
body=google_group,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
connection.attributes = response
|
|
||||||
connection.save()
|
|
||||||
except NotFoundSyncException:
|
except NotFoundSyncException:
|
||||||
# Resource missing is handled by self.write, which will re-create the group
|
# Resource missing is handled by self.write, which will re-create the group
|
||||||
raise
|
raise
|
||||||
@ -212,9 +230,4 @@ class GoogleWorkspaceGroupClient(
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
group=matching_authentik_group,
|
group=matching_authentik_group,
|
||||||
google_id=google_id,
|
google_id=google_id,
|
||||||
attributes=group,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_single_attribute(self, connection: GoogleWorkspaceProviderUser):
|
|
||||||
group = self.directory_service.groups().get(connection.google_id)
|
|
||||||
connection.attributes = group
|
|
||||||
|
@ -1,18 +1,24 @@
|
|||||||
|
from deepmerge import always_merger
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
|
|
||||||
|
from authentik.core.expression.exceptions import (
|
||||||
|
PropertyMappingExpressionException,
|
||||||
|
SkipObjectException,
|
||||||
|
)
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
from authentik.enterprise.providers.google_workspace.clients.base import GoogleWorkspaceSyncClient
|
||||||
from authentik.enterprise.providers.google_workspace.models import (
|
from authentik.enterprise.providers.google_workspace.models import (
|
||||||
GoogleWorkspaceProvider,
|
|
||||||
GoogleWorkspaceProviderMapping,
|
GoogleWorkspaceProviderMapping,
|
||||||
GoogleWorkspaceProviderUser,
|
GoogleWorkspaceProviderUser,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.sync.outgoing.exceptions import (
|
from authentik.lib.sync.outgoing.exceptions import (
|
||||||
ObjectExistsSyncException,
|
ObjectExistsSyncException,
|
||||||
|
StopSync,
|
||||||
TransientSyncException,
|
TransientSyncException,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||||
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
from authentik.policies.utils import delete_none_values
|
from authentik.policies.utils import delete_none_values
|
||||||
|
|
||||||
|
|
||||||
@ -23,17 +29,37 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
|||||||
connection_type_query = "user"
|
connection_type_query = "user"
|
||||||
can_discover = True
|
can_discover = True
|
||||||
|
|
||||||
def __init__(self, provider: GoogleWorkspaceProvider) -> None:
|
def to_schema(self, obj: User, creating: bool) -> dict:
|
||||||
super().__init__(provider)
|
|
||||||
self.mapper = PropertyMappingManager(
|
|
||||||
self.provider.property_mappings.all().order_by("name").select_subclasses(),
|
|
||||||
GoogleWorkspaceProviderMapping,
|
|
||||||
["provider", "connection"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_schema(self, obj: User, connection: GoogleWorkspaceProviderUser) -> dict:
|
|
||||||
"""Convert authentik user"""
|
"""Convert authentik user"""
|
||||||
return delete_none_values(super().to_schema(obj, connection, primaryEmail=obj.email))
|
raw_google_user = {}
|
||||||
|
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
|
||||||
|
if not isinstance(mapping, GoogleWorkspaceProviderMapping):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = mapping.evaluate(
|
||||||
|
user=obj,
|
||||||
|
request=None,
|
||||||
|
provider=self.provider,
|
||||||
|
creating=creating,
|
||||||
|
)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
always_merger.merge(raw_google_user, value)
|
||||||
|
except SkipObjectException as exc:
|
||||||
|
raise exc from exc
|
||||||
|
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||||
|
# Value error can be raised when assigning invalid data to an attribute
|
||||||
|
Event.new(
|
||||||
|
EventAction.CONFIGURATION_ERROR,
|
||||||
|
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||||
|
mapping=mapping,
|
||||||
|
).save()
|
||||||
|
raise StopSync(exc, obj, mapping) from exc
|
||||||
|
if not raw_google_user:
|
||||||
|
raise StopSync(ValueError("No user mappings configured"), obj)
|
||||||
|
if "primaryEmail" not in raw_google_user:
|
||||||
|
raw_google_user["primaryEmail"] = str(obj.email)
|
||||||
|
return delete_none_values(raw_google_user)
|
||||||
|
|
||||||
def delete(self, obj: User):
|
def delete(self, obj: User):
|
||||||
"""Delete user"""
|
"""Delete user"""
|
||||||
@ -60,7 +86,7 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
|||||||
|
|
||||||
def create(self, user: User):
|
def create(self, user: User):
|
||||||
"""Create user from scratch and create a connection object"""
|
"""Create user from scratch and create a connection object"""
|
||||||
google_user = self.to_schema(user, None)
|
google_user = self.to_schema(user, True)
|
||||||
self.check_email_valid(
|
self.check_email_valid(
|
||||||
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
||||||
)
|
)
|
||||||
@ -70,29 +96,24 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
|||||||
except ObjectExistsSyncException:
|
except ObjectExistsSyncException:
|
||||||
# user already exists in google workspace, so we can connect them manually
|
# user already exists in google workspace, so we can connect them manually
|
||||||
return GoogleWorkspaceProviderUser.objects.create(
|
return GoogleWorkspaceProviderUser.objects.create(
|
||||||
provider=self.provider, user=user, google_id=user.email, attributes={}
|
provider=self.provider, user=user, google_id=user.email
|
||||||
)
|
)
|
||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
return GoogleWorkspaceProviderUser.objects.create(
|
return GoogleWorkspaceProviderUser.objects.create(
|
||||||
provider=self.provider,
|
provider=self.provider, user=user, google_id=response["primaryEmail"]
|
||||||
user=user,
|
|
||||||
google_id=response["primaryEmail"],
|
|
||||||
attributes=response,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, user: User, connection: GoogleWorkspaceProviderUser):
|
def update(self, user: User, connection: GoogleWorkspaceProviderUser):
|
||||||
"""Update existing user"""
|
"""Update existing user"""
|
||||||
google_user = self.to_schema(user, connection)
|
google_user = self.to_schema(user, False)
|
||||||
self.check_email_valid(
|
self.check_email_valid(
|
||||||
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
google_user["primaryEmail"], *[x["address"] for x in google_user.get("emails", [])]
|
||||||
)
|
)
|
||||||
response = self._request(
|
self._request(
|
||||||
self.directory_service.users().update(userKey=connection.google_id, body=google_user)
|
self.directory_service.users().update(userKey=connection.google_id, body=google_user)
|
||||||
)
|
)
|
||||||
connection.attributes = response
|
|
||||||
connection.save()
|
|
||||||
|
|
||||||
def discover(self):
|
def discover(self):
|
||||||
"""Iterate through all users and connect them with authentik users if possible"""
|
"""Iterate through all users and connect them with authentik users if possible"""
|
||||||
@ -117,9 +138,4 @@ class GoogleWorkspaceUserClient(GoogleWorkspaceSyncClient[User, GoogleWorkspaceP
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=matching_authentik_user,
|
user=matching_authentik_user,
|
||||||
google_id=email,
|
google_id=email,
|
||||||
attributes=user,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_single_attribute(self, connection: GoogleWorkspaceProviderUser):
|
|
||||||
user = self.directory_service.users().get(connection.google_id)
|
|
||||||
connection.attributes = user
|
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
# Generated by Django 5.0.6 on 2024-05-23 20:48
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
(
|
|
||||||
"authentik_providers_google_workspace",
|
|
||||||
"0001_squashed_0002_alter_googleworkspaceprovidergroup_options_and_more",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="googleworkspaceprovidergroup",
|
|
||||||
name="attributes",
|
|
||||||
field=models.JSONField(default=dict),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="googleworkspaceprovideruser",
|
|
||||||
name="attributes",
|
|
||||||
field=models.JSONField(default=dict),
|
|
||||||
),
|
|
||||||
]
|
|
@ -5,7 +5,6 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.templatetags.static import static
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from google.oauth2.service_account import Credentials
|
from google.oauth2.service_account import Credentials
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
@ -31,58 +30,6 @@ def default_scopes() -> list[str]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceProviderUser(SerializerModel):
|
|
||||||
"""Mapping of a user and provider to a Google user ID"""
|
|
||||||
|
|
||||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
|
||||||
google_id = models.TextField()
|
|
||||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
|
||||||
provider = models.ForeignKey("GoogleWorkspaceProvider", on_delete=models.CASCADE)
|
|
||||||
attributes = models.JSONField(default=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def serializer(self) -> type[Serializer]:
|
|
||||||
from authentik.enterprise.providers.google_workspace.api.users import (
|
|
||||||
GoogleWorkspaceProviderUserSerializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return GoogleWorkspaceProviderUserSerializer
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
verbose_name = _("Google Workspace Provider User")
|
|
||||||
verbose_name_plural = _("Google Workspace Provider Users")
|
|
||||||
unique_together = (("google_id", "user", "provider"),)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"Google Workspace Provider User {self.user_id} to {self.provider_id}"
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceProviderGroup(SerializerModel):
|
|
||||||
"""Mapping of a group and provider to a Google group ID"""
|
|
||||||
|
|
||||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
|
||||||
google_id = models.TextField()
|
|
||||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
|
||||||
provider = models.ForeignKey("GoogleWorkspaceProvider", on_delete=models.CASCADE)
|
|
||||||
attributes = models.JSONField(default=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def serializer(self) -> type[Serializer]:
|
|
||||||
from authentik.enterprise.providers.google_workspace.api.groups import (
|
|
||||||
GoogleWorkspaceProviderGroupSerializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return GoogleWorkspaceProviderGroupSerializer
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
verbose_name = _("Google Workspace Provider Group")
|
|
||||||
verbose_name_plural = _("Google Workspace Provider Groups")
|
|
||||||
unique_together = (("google_id", "group", "provider"),)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"Google Workspace Provider Group {self.group_id} to {self.provider_id}"
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||||
"""Sync users from authentik into Google Workspace."""
|
"""Sync users from authentik into Google Workspace."""
|
||||||
|
|
||||||
@ -111,16 +58,15 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def client_for_model(
|
def client_for_model(
|
||||||
self,
|
self, model: type[User | Group]
|
||||||
model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup],
|
|
||||||
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
|
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
|
||||||
if issubclass(model, User | GoogleWorkspaceProviderUser):
|
if issubclass(model, User):
|
||||||
from authentik.enterprise.providers.google_workspace.clients.users import (
|
from authentik.enterprise.providers.google_workspace.clients.users import (
|
||||||
GoogleWorkspaceUserClient,
|
GoogleWorkspaceUserClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
return GoogleWorkspaceUserClient(self)
|
return GoogleWorkspaceUserClient(self)
|
||||||
if issubclass(model, Group | GoogleWorkspaceProviderGroup):
|
if issubclass(model, Group):
|
||||||
from authentik.enterprise.providers.google_workspace.clients.groups import (
|
from authentik.enterprise.providers.google_workspace.clients.groups import (
|
||||||
GoogleWorkspaceGroupClient,
|
GoogleWorkspaceGroupClient,
|
||||||
)
|
)
|
||||||
@ -152,10 +98,6 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
|||||||
).with_subject(self.delegated_subject),
|
).with_subject(self.delegated_subject),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
|
||||||
def icon_url(self) -> str | None:
|
|
||||||
return static("authentik/sources/google.svg")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-provider-google-workspace-form"
|
return "ak-provider-google-workspace-form"
|
||||||
@ -197,3 +139,53 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
|
|||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = _("Google Workspace Provider Mapping")
|
verbose_name = _("Google Workspace Provider Mapping")
|
||||||
verbose_name_plural = _("Google Workspace Provider Mappings")
|
verbose_name_plural = _("Google Workspace Provider Mappings")
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleWorkspaceProviderUser(SerializerModel):
|
||||||
|
"""Mapping of a user and provider to a Google user ID"""
|
||||||
|
|
||||||
|
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
|
google_id = models.TextField()
|
||||||
|
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||||
|
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> type[Serializer]:
|
||||||
|
from authentik.enterprise.providers.google_workspace.api.users import (
|
||||||
|
GoogleWorkspaceProviderUserSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return GoogleWorkspaceProviderUserSerializer
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = _("Google Workspace Provider User")
|
||||||
|
verbose_name_plural = _("Google Workspace Provider Users")
|
||||||
|
unique_together = (("google_id", "user", "provider"),)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Google Workspace Provider User {self.user_id} to {self.provider_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleWorkspaceProviderGroup(SerializerModel):
|
||||||
|
"""Mapping of a group and provider to a Google group ID"""
|
||||||
|
|
||||||
|
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
|
google_id = models.TextField()
|
||||||
|
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||||
|
provider = models.ForeignKey(GoogleWorkspaceProvider, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> type[Serializer]:
|
||||||
|
from authentik.enterprise.providers.google_workspace.api.groups import (
|
||||||
|
GoogleWorkspaceProviderGroupSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return GoogleWorkspaceProviderGroupSerializer
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = _("Google Workspace Provider Group")
|
||||||
|
verbose_name_plural = _("Google Workspace Provider Groups")
|
||||||
|
unique_together = (("google_id", "group", "provider"),)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Google Workspace Provider Group {self.group_id} to {self.provider_id}"
|
||||||
|
@ -82,27 +82,6 @@ class GoogleWorkspaceGroupTests(TestCase):
|
|||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||||
self.assertEqual(len(http.requests()), 2)
|
self.assertEqual(len(http.requests()), 2)
|
||||||
|
|
||||||
def test_group_not_created(self):
|
|
||||||
"""Test without group property mappings, no group is created"""
|
|
||||||
self.provider.property_mappings_group.clear()
|
|
||||||
uid = generate_id()
|
|
||||||
http = MockHTTP()
|
|
||||||
http.add_response(
|
|
||||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
|
||||||
domains_list_v1_mock,
|
|
||||||
)
|
|
||||||
with patch(
|
|
||||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
|
||||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
|
||||||
):
|
|
||||||
group = Group.objects.create(name=uid)
|
|
||||||
google_group = GoogleWorkspaceProviderGroup.objects.filter(
|
|
||||||
provider=self.provider, group=group
|
|
||||||
).first()
|
|
||||||
self.assertIsNone(google_group)
|
|
||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
|
||||||
self.assertEqual(len(http.requests()), 1)
|
|
||||||
|
|
||||||
def test_group_create_update(self):
|
def test_group_create_update(self):
|
||||||
"""Test group updating"""
|
"""Test group updating"""
|
||||||
uid = generate_id()
|
uid = generate_id()
|
||||||
|
@ -86,31 +86,6 @@ class GoogleWorkspaceUserTests(TestCase):
|
|||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||||
self.assertEqual(len(http.requests()), 2)
|
self.assertEqual(len(http.requests()), 2)
|
||||||
|
|
||||||
def test_user_not_created(self):
|
|
||||||
"""Test without property mappings, no group is created"""
|
|
||||||
self.provider.property_mappings.clear()
|
|
||||||
uid = generate_id()
|
|
||||||
http = MockHTTP()
|
|
||||||
http.add_response(
|
|
||||||
f"https://admin.googleapis.com/admin/directory/v1/customer/my_customer/domains?key={self.api_key}&alt=json",
|
|
||||||
domains_list_v1_mock,
|
|
||||||
)
|
|
||||||
with patch(
|
|
||||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
|
||||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
|
||||||
):
|
|
||||||
user = User.objects.create(
|
|
||||||
username=uid,
|
|
||||||
name=f"{uid} {uid}",
|
|
||||||
email=f"{uid}@goauthentik.io",
|
|
||||||
)
|
|
||||||
google_user = GoogleWorkspaceProviderUser.objects.filter(
|
|
||||||
provider=self.provider, user=user
|
|
||||||
).first()
|
|
||||||
self.assertIsNone(google_user)
|
|
||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
|
||||||
self.assertEqual(len(http.requests()), 1)
|
|
||||||
|
|
||||||
def test_user_create_update(self):
|
def test_user_create_update(self):
|
||||||
"""Test user updating"""
|
"""Test user updating"""
|
||||||
uid = generate_id()
|
uid = generate_id()
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
"""MicrosoftEntraProviderGroup API Views"""
|
"""MicrosoftEntraProviderGroup API Views"""
|
||||||
|
|
||||||
from rest_framework import mixins
|
from rest_framework.viewsets import ModelViewSet
|
||||||
from rest_framework.serializers import ModelSerializer
|
|
||||||
from rest_framework.viewsets import GenericViewSet
|
|
||||||
|
|
||||||
|
from authentik.core.api.sources import SourceSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.users import UserGroupSerializer
|
from authentik.core.api.users import UserGroupSerializer
|
||||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderGroup
|
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderGroup
|
||||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
|
class MicrosoftEntraProviderGroupSerializer(SourceSerializer):
|
||||||
"""MicrosoftEntraProviderGroup Serializer"""
|
"""MicrosoftEntraProviderGroup Serializer"""
|
||||||
|
|
||||||
group_obj = UserGroupSerializer(source="group", read_only=True)
|
group_obj = UserGroupSerializer(source="group", read_only=True)
|
||||||
@ -20,24 +18,12 @@ class MicrosoftEntraProviderGroupSerializer(ModelSerializer):
|
|||||||
model = MicrosoftEntraProviderGroup
|
model = MicrosoftEntraProviderGroup
|
||||||
fields = [
|
fields = [
|
||||||
"id",
|
"id",
|
||||||
"microsoft_id",
|
|
||||||
"group",
|
"group",
|
||||||
"group_obj",
|
"group_obj",
|
||||||
"provider",
|
|
||||||
"attributes",
|
|
||||||
]
|
]
|
||||||
extra_kwargs = {"attributes": {"read_only": True}}
|
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraProviderGroupViewSet(
|
class MicrosoftEntraProviderGroupViewSet(UsedByMixin, ModelViewSet):
|
||||||
mixins.CreateModelMixin,
|
|
||||||
OutgoingSyncConnectionCreateMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
|
||||||
mixins.DestroyModelMixin,
|
|
||||||
UsedByMixin,
|
|
||||||
mixins.ListModelMixin,
|
|
||||||
GenericViewSet,
|
|
||||||
):
|
|
||||||
"""MicrosoftEntraProviderGroup Viewset"""
|
"""MicrosoftEntraProviderGroup Viewset"""
|
||||||
|
|
||||||
queryset = MicrosoftEntraProviderGroup.objects.all().select_related("group")
|
queryset = MicrosoftEntraProviderGroup.objects.all().select_related("group")
|
||||||
|
@ -6,7 +6,7 @@ from drf_spectacular.types import OpenApiTypes
|
|||||||
from drf_spectacular.utils import extend_schema_field
|
from drf_spectacular.utils import extend_schema_field
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderMapping
|
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderMapping
|
||||||
|
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
"""MicrosoftEntraProviderUser API Views"""
|
"""MicrosoftEntraProviderUser API Views"""
|
||||||
|
|
||||||
from rest_framework import mixins
|
from rest_framework.viewsets import ModelViewSet
|
||||||
from rest_framework.serializers import ModelSerializer
|
|
||||||
from rest_framework.viewsets import GenericViewSet
|
|
||||||
|
|
||||||
from authentik.core.api.groups import GroupMemberSerializer
|
from authentik.core.api.groups import GroupMemberSerializer
|
||||||
|
from authentik.core.api.sources import SourceSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderUser
|
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProviderUser
|
||||||
from authentik.lib.sync.outgoing.api import OutgoingSyncConnectionCreateMixin
|
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraProviderUserSerializer(ModelSerializer):
|
class MicrosoftEntraProviderUserSerializer(SourceSerializer):
|
||||||
"""MicrosoftEntraProviderUser Serializer"""
|
"""MicrosoftEntraProviderUser Serializer"""
|
||||||
|
|
||||||
user_obj = GroupMemberSerializer(source="user", read_only=True)
|
user_obj = GroupMemberSerializer(source="user", read_only=True)
|
||||||
@ -20,24 +18,12 @@ class MicrosoftEntraProviderUserSerializer(ModelSerializer):
|
|||||||
model = MicrosoftEntraProviderUser
|
model = MicrosoftEntraProviderUser
|
||||||
fields = [
|
fields = [
|
||||||
"id",
|
"id",
|
||||||
"microsoft_id",
|
|
||||||
"user",
|
"user",
|
||||||
"user_obj",
|
"user_obj",
|
||||||
"provider",
|
|
||||||
"attributes",
|
|
||||||
]
|
]
|
||||||
extra_kwargs = {"attributes": {"read_only": True}}
|
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraProviderUserViewSet(
|
class MicrosoftEntraProviderUserViewSet(UsedByMixin, ModelViewSet):
|
||||||
OutgoingSyncConnectionCreateMixin,
|
|
||||||
mixins.CreateModelMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
|
||||||
mixins.DestroyModelMixin,
|
|
||||||
UsedByMixin,
|
|
||||||
mixins.ListModelMixin,
|
|
||||||
GenericViewSet,
|
|
||||||
):
|
|
||||||
"""MicrosoftEntraProviderUser Viewset"""
|
"""MicrosoftEntraProviderUser Viewset"""
|
||||||
|
|
||||||
queryset = MicrosoftEntraProviderUser.objects.all().select_related("user")
|
queryset = MicrosoftEntraProviderUser.objects.all().select_related("user")
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from asyncio import run
|
from asyncio import run
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from dataclasses import asdict
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from azure.core.exceptions import (
|
from azure.core.exceptions import (
|
||||||
@ -16,14 +15,12 @@ from kiota_authentication_azure.azure_identity_authentication_provider import (
|
|||||||
AzureIdentityAuthenticationProvider,
|
AzureIdentityAuthenticationProvider,
|
||||||
)
|
)
|
||||||
from kiota_http.kiota_client_factory import KiotaClientFactory
|
from kiota_http.kiota_client_factory import KiotaClientFactory
|
||||||
from msgraph.generated.models.entity import Entity
|
|
||||||
from msgraph.generated.models.o_data_errors.o_data_error import ODataError
|
from msgraph.generated.models.o_data_errors.o_data_error import ODataError
|
||||||
from msgraph.graph_request_adapter import GraphRequestAdapter, options
|
from msgraph.graph_request_adapter import GraphRequestAdapter, options
|
||||||
from msgraph.graph_service_client import GraphServiceClient
|
from msgraph.graph_service_client import GraphServiceClient
|
||||||
from msgraph_core import GraphClientFactory
|
from msgraph_core import GraphClientFactory
|
||||||
|
|
||||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||||
from authentik.events.utils import sanitize_item
|
|
||||||
from authentik.lib.sync.outgoing import HTTP_CONFLICT
|
from authentik.lib.sync.outgoing import HTTP_CONFLICT
|
||||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||||
from authentik.lib.sync.outgoing.exceptions import (
|
from authentik.lib.sync.outgoing.exceptions import (
|
||||||
@ -101,10 +98,3 @@ class MicrosoftEntraSyncClient[TModel: Model, TConnection: Model, TSchema: dict]
|
|||||||
for email in emails:
|
for email in emails:
|
||||||
if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
|
if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
|
||||||
raise BadRequestSyncException(f"Invalid email domain: {email}")
|
raise BadRequestSyncException(f"Invalid email domain: {email}")
|
||||||
|
|
||||||
def entity_as_dict(self, entity: Entity) -> dict:
|
|
||||||
"""Create a dictionary of a model instance, making sure to remove (known) things
|
|
||||||
we can't JSON serialize"""
|
|
||||||
raw_data = asdict(entity)
|
|
||||||
raw_data.pop("backing_store", None)
|
|
||||||
return sanitize_item(raw_data)
|
|
||||||
|
@ -4,15 +4,18 @@ from msgraph.generated.groups.groups_request_builder import GroupsRequestBuilder
|
|||||||
from msgraph.generated.models.group import Group as MSGroup
|
from msgraph.generated.models.group import Group as MSGroup
|
||||||
from msgraph.generated.models.reference_create import ReferenceCreate
|
from msgraph.generated.models.reference_create import ReferenceCreate
|
||||||
|
|
||||||
|
from authentik.core.expression.exceptions import (
|
||||||
|
PropertyMappingExpressionException,
|
||||||
|
SkipObjectException,
|
||||||
|
)
|
||||||
from authentik.core.models import Group
|
from authentik.core.models import Group
|
||||||
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
|
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
|
||||||
from authentik.enterprise.providers.microsoft_entra.models import (
|
from authentik.enterprise.providers.microsoft_entra.models import (
|
||||||
MicrosoftEntraProvider,
|
|
||||||
MicrosoftEntraProviderGroup,
|
MicrosoftEntraProviderGroup,
|
||||||
MicrosoftEntraProviderMapping,
|
MicrosoftEntraProviderMapping,
|
||||||
MicrosoftEntraProviderUser,
|
MicrosoftEntraProviderUser,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.sync.outgoing.base import Direction
|
from authentik.lib.sync.outgoing.base import Direction
|
||||||
from authentik.lib.sync.outgoing.exceptions import (
|
from authentik.lib.sync.outgoing.exceptions import (
|
||||||
NotFoundSyncException,
|
NotFoundSyncException,
|
||||||
@ -21,6 +24,7 @@ from authentik.lib.sync.outgoing.exceptions import (
|
|||||||
TransientSyncException,
|
TransientSyncException,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||||
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraGroupClient(
|
class MicrosoftEntraGroupClient(
|
||||||
@ -32,17 +36,37 @@ class MicrosoftEntraGroupClient(
|
|||||||
connection_type_query = "group"
|
connection_type_query = "group"
|
||||||
can_discover = True
|
can_discover = True
|
||||||
|
|
||||||
def __init__(self, provider: MicrosoftEntraProvider) -> None:
|
def to_schema(self, obj: Group, creating: bool) -> MSGroup:
|
||||||
super().__init__(provider)
|
|
||||||
self.mapper = PropertyMappingManager(
|
|
||||||
self.provider.property_mappings_group.all().order_by("name").select_subclasses(),
|
|
||||||
MicrosoftEntraProviderMapping,
|
|
||||||
["group", "provider", "connection"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_schema(self, obj: Group, connection: MicrosoftEntraProviderGroup) -> MSGroup:
|
|
||||||
"""Convert authentik group"""
|
"""Convert authentik group"""
|
||||||
raw_microsoft_group = super().to_schema(obj, connection)
|
raw_microsoft_group = {}
|
||||||
|
for mapping in (
|
||||||
|
self.provider.property_mappings_group.all().order_by("name").select_subclasses()
|
||||||
|
):
|
||||||
|
if not isinstance(mapping, MicrosoftEntraProviderMapping):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = mapping.evaluate(
|
||||||
|
user=None,
|
||||||
|
request=None,
|
||||||
|
group=obj,
|
||||||
|
provider=self.provider,
|
||||||
|
creating=creating,
|
||||||
|
)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
always_merger.merge(raw_microsoft_group, value)
|
||||||
|
except SkipObjectException as exc:
|
||||||
|
raise exc from exc
|
||||||
|
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||||
|
# Value error can be raised when assigning invalid data to an attribute
|
||||||
|
Event.new(
|
||||||
|
EventAction.CONFIGURATION_ERROR,
|
||||||
|
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||||
|
mapping=mapping,
|
||||||
|
).save()
|
||||||
|
raise StopSync(exc, obj, mapping) from exc
|
||||||
|
if not raw_microsoft_group:
|
||||||
|
raise StopSync(ValueError("No group mappings configured"), obj)
|
||||||
try:
|
try:
|
||||||
return MSGroup(**raw_microsoft_group)
|
return MSGroup(**raw_microsoft_group)
|
||||||
except TypeError as exc:
|
except TypeError as exc:
|
||||||
@ -63,7 +87,7 @@ class MicrosoftEntraGroupClient(
|
|||||||
|
|
||||||
def create(self, group: Group):
|
def create(self, group: Group):
|
||||||
"""Create group from scratch and create a connection object"""
|
"""Create group from scratch and create a connection object"""
|
||||||
microsoft_group = self.to_schema(group, None)
|
microsoft_group = self.to_schema(group, True)
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
try:
|
try:
|
||||||
response = self._request(self.client.groups.post(microsoft_group))
|
response = self._request(self.client.groups.post(microsoft_group))
|
||||||
@ -80,37 +104,27 @@ class MicrosoftEntraGroupClient(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
group_data = self._request(self.client.groups.get(request_configuration))
|
group_data = self._request(self.client.groups.get(request_configuration))
|
||||||
if group_data.odata_count < 1 or len(group_data.value) < 1:
|
if group_data.odata_count < 1:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Group which could not be created also does not exist", group=group
|
"Group which could not be created also does not exist", group=group
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
ms_group = group_data.value[0]
|
|
||||||
return MicrosoftEntraProviderGroup.objects.create(
|
return MicrosoftEntraProviderGroup.objects.create(
|
||||||
provider=self.provider,
|
provider=self.provider, group=group, microsoft_id=group_data.value[0].id
|
||||||
group=group,
|
|
||||||
microsoft_id=ms_group.id,
|
|
||||||
attributes=self.entity_as_dict(ms_group),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return MicrosoftEntraProviderGroup.objects.create(
|
return MicrosoftEntraProviderGroup.objects.create(
|
||||||
provider=self.provider,
|
provider=self.provider, group=group, microsoft_id=response.id
|
||||||
group=group,
|
|
||||||
microsoft_id=response.id,
|
|
||||||
attributes=self.entity_as_dict(response),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, group: Group, connection: MicrosoftEntraProviderGroup):
|
def update(self, group: Group, connection: MicrosoftEntraProviderGroup):
|
||||||
"""Update existing group"""
|
"""Update existing group"""
|
||||||
microsoft_group = self.to_schema(group, connection)
|
microsoft_group = self.to_schema(group, False)
|
||||||
microsoft_group.id = connection.microsoft_id
|
microsoft_group.id = connection.microsoft_id
|
||||||
try:
|
try:
|
||||||
response = self._request(
|
return self._request(
|
||||||
self.client.groups.by_group_id(connection.microsoft_id).patch(microsoft_group)
|
self.client.groups.by_group_id(connection.microsoft_id).patch(microsoft_group)
|
||||||
)
|
)
|
||||||
if response:
|
|
||||||
always_merger.merge(connection.attributes, self.entity_as_dict(response))
|
|
||||||
connection.save()
|
|
||||||
except NotFoundSyncException:
|
except NotFoundSyncException:
|
||||||
# Resource missing is handled by self.write, which will re-create the group
|
# Resource missing is handled by self.write, which will re-create the group
|
||||||
raise
|
raise
|
||||||
@ -224,9 +238,4 @@ class MicrosoftEntraGroupClient(
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
group=matching_authentik_group,
|
group=matching_authentik_group,
|
||||||
microsoft_id=group.id,
|
microsoft_id=group.id,
|
||||||
attributes=self.entity_as_dict(group),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_single_attribute(self, connection: MicrosoftEntraProviderGroup):
|
|
||||||
data = self._request(self.client.groups.by_group_id(connection.microsoft_id).get())
|
|
||||||
connection.attributes = self.entity_as_dict(data)
|
|
||||||
|
@ -3,20 +3,24 @@ from django.db import transaction
|
|||||||
from msgraph.generated.models.user import User as MSUser
|
from msgraph.generated.models.user import User as MSUser
|
||||||
from msgraph.generated.users.users_request_builder import UsersRequestBuilder
|
from msgraph.generated.users.users_request_builder import UsersRequestBuilder
|
||||||
|
|
||||||
|
from authentik.core.expression.exceptions import (
|
||||||
|
PropertyMappingExpressionException,
|
||||||
|
SkipObjectException,
|
||||||
|
)
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
|
from authentik.enterprise.providers.microsoft_entra.clients.base import MicrosoftEntraSyncClient
|
||||||
from authentik.enterprise.providers.microsoft_entra.models import (
|
from authentik.enterprise.providers.microsoft_entra.models import (
|
||||||
MicrosoftEntraProvider,
|
|
||||||
MicrosoftEntraProviderMapping,
|
MicrosoftEntraProviderMapping,
|
||||||
MicrosoftEntraProviderUser,
|
MicrosoftEntraProviderUser,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.sync.outgoing.exceptions import (
|
from authentik.lib.sync.outgoing.exceptions import (
|
||||||
ObjectExistsSyncException,
|
ObjectExistsSyncException,
|
||||||
StopSync,
|
StopSync,
|
||||||
TransientSyncException,
|
TransientSyncException,
|
||||||
)
|
)
|
||||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
||||||
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
from authentik.policies.utils import delete_none_values
|
from authentik.policies.utils import delete_none_values
|
||||||
|
|
||||||
|
|
||||||
@ -27,17 +31,34 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
|||||||
connection_type_query = "user"
|
connection_type_query = "user"
|
||||||
can_discover = True
|
can_discover = True
|
||||||
|
|
||||||
def __init__(self, provider: MicrosoftEntraProvider) -> None:
|
def to_schema(self, obj: User, creating: bool) -> MSUser:
|
||||||
super().__init__(provider)
|
|
||||||
self.mapper = PropertyMappingManager(
|
|
||||||
self.provider.property_mappings.all().order_by("name").select_subclasses(),
|
|
||||||
MicrosoftEntraProviderMapping,
|
|
||||||
["provider", "connection"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_schema(self, obj: User, connection: MicrosoftEntraProviderUser) -> MSUser:
|
|
||||||
"""Convert authentik user"""
|
"""Convert authentik user"""
|
||||||
raw_microsoft_user = super().to_schema(obj, connection)
|
raw_microsoft_user = {}
|
||||||
|
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
|
||||||
|
if not isinstance(mapping, MicrosoftEntraProviderMapping):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = mapping.evaluate(
|
||||||
|
user=obj,
|
||||||
|
request=None,
|
||||||
|
provider=self.provider,
|
||||||
|
creating=creating,
|
||||||
|
)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
always_merger.merge(raw_microsoft_user, value)
|
||||||
|
except SkipObjectException as exc:
|
||||||
|
raise exc from exc
|
||||||
|
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||||
|
# Value error can be raised when assigning invalid data to an attribute
|
||||||
|
Event.new(
|
||||||
|
EventAction.CONFIGURATION_ERROR,
|
||||||
|
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||||
|
mapping=mapping,
|
||||||
|
).save()
|
||||||
|
raise StopSync(exc, obj, mapping) from exc
|
||||||
|
if not raw_microsoft_user:
|
||||||
|
raise StopSync(ValueError("No user mappings configured"), obj)
|
||||||
try:
|
try:
|
||||||
return MSUser(**delete_none_values(raw_microsoft_user))
|
return MSUser(**delete_none_values(raw_microsoft_user))
|
||||||
except TypeError as exc:
|
except TypeError as exc:
|
||||||
@ -66,85 +87,48 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
|||||||
microsoft_user.delete()
|
microsoft_user.delete()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_select_fields(self) -> list[str]:
|
|
||||||
"""All fields that should be selected when we fetch user data."""
|
|
||||||
# TODO: Make this customizable in the future
|
|
||||||
return [
|
|
||||||
# Default fields
|
|
||||||
"businessPhones",
|
|
||||||
"displayName",
|
|
||||||
"givenName",
|
|
||||||
"jobTitle",
|
|
||||||
"mail",
|
|
||||||
"mobilePhone",
|
|
||||||
"officeLocation",
|
|
||||||
"preferredLanguage",
|
|
||||||
"surname",
|
|
||||||
"userPrincipalName",
|
|
||||||
"id",
|
|
||||||
# Required for logging into M365 using authentik
|
|
||||||
"onPremisesImmutableId",
|
|
||||||
]
|
|
||||||
|
|
||||||
def create(self, user: User):
|
def create(self, user: User):
|
||||||
"""Create user from scratch and create a connection object"""
|
"""Create user from scratch and create a connection object"""
|
||||||
microsoft_user = self.to_schema(user, None)
|
microsoft_user = self.to_schema(user, True)
|
||||||
self.check_email_valid(microsoft_user.user_principal_name)
|
self.check_email_valid(microsoft_user.user_principal_name)
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
try:
|
try:
|
||||||
response = self._request(self.client.users.post(microsoft_user))
|
response = self._request(self.client.users.post(microsoft_user))
|
||||||
except ObjectExistsSyncException:
|
except ObjectExistsSyncException:
|
||||||
# user already exists in microsoft entra, so we can connect them manually
|
# user already exists in microsoft entra, so we can connect them manually
|
||||||
|
query_params = UsersRequestBuilder.UsersRequestBuilderGetQueryParameters()(
|
||||||
|
filter=f"mail eq '{microsoft_user.mail}'",
|
||||||
|
)
|
||||||
request_configuration = (
|
request_configuration = (
|
||||||
UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration(
|
UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration(
|
||||||
query_parameters=UsersRequestBuilder.UsersRequestBuilderGetQueryParameters(
|
query_parameters=query_params,
|
||||||
filter=f"mail eq '{microsoft_user.mail}'",
|
|
||||||
select=self.get_select_fields(),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
user_data = self._request(self.client.users.get(request_configuration))
|
user_data = self._request(self.client.users.get(request_configuration))
|
||||||
if user_data.odata_count < 1 or len(user_data.value) < 1:
|
if user_data.odata_count < 1:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"User which could not be created also does not exist", user=user
|
"User which could not be created also does not exist", user=user
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
ms_user = user_data.value[0]
|
|
||||||
return MicrosoftEntraProviderUser.objects.create(
|
return MicrosoftEntraProviderUser.objects.create(
|
||||||
provider=self.provider,
|
provider=self.provider, user=user, microsoft_id=user_data.value[0].id
|
||||||
user=user,
|
|
||||||
microsoft_id=ms_user.id,
|
|
||||||
attributes=self.entity_as_dict(ms_user),
|
|
||||||
)
|
)
|
||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
return MicrosoftEntraProviderUser.objects.create(
|
return MicrosoftEntraProviderUser.objects.create(
|
||||||
provider=self.provider,
|
provider=self.provider, user=user, microsoft_id=response.id
|
||||||
user=user,
|
|
||||||
microsoft_id=response.id,
|
|
||||||
attributes=self.entity_as_dict(response),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, user: User, connection: MicrosoftEntraProviderUser):
|
def update(self, user: User, connection: MicrosoftEntraProviderUser):
|
||||||
"""Update existing user"""
|
"""Update existing user"""
|
||||||
microsoft_user = self.to_schema(user, connection)
|
microsoft_user = self.to_schema(user, False)
|
||||||
self.check_email_valid(microsoft_user.user_principal_name)
|
self.check_email_valid(microsoft_user.user_principal_name)
|
||||||
response = self._request(
|
self._request(self.client.users.by_user_id(connection.microsoft_id).patch(microsoft_user))
|
||||||
self.client.users.by_user_id(connection.microsoft_id).patch(microsoft_user)
|
|
||||||
)
|
|
||||||
if response:
|
|
||||||
always_merger.merge(connection.attributes, self.entity_as_dict(response))
|
|
||||||
connection.save()
|
|
||||||
|
|
||||||
def discover(self):
|
def discover(self):
|
||||||
"""Iterate through all users and connect them with authentik users if possible"""
|
"""Iterate through all users and connect them with authentik users if possible"""
|
||||||
request_configuration = UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration(
|
users = self._request(self.client.users.get())
|
||||||
query_parameters=UsersRequestBuilder.UsersRequestBuilderGetQueryParameters(
|
|
||||||
select=self.get_select_fields(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
users = self._request(self.client.users.get(request_configuration))
|
|
||||||
next_link = True
|
next_link = True
|
||||||
while next_link:
|
while next_link:
|
||||||
for user in users.value:
|
for user in users.value:
|
||||||
@ -163,16 +147,4 @@ class MicrosoftEntraUserClient(MicrosoftEntraSyncClient[User, MicrosoftEntraProv
|
|||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
user=matching_authentik_user,
|
user=matching_authentik_user,
|
||||||
microsoft_id=user.id,
|
microsoft_id=user.id,
|
||||||
attributes=self.entity_as_dict(user),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_single_attribute(self, connection: MicrosoftEntraProviderUser):
|
|
||||||
request_configuration = UsersRequestBuilder.UsersRequestBuilderGetRequestConfiguration(
|
|
||||||
query_parameters=UsersRequestBuilder.UsersRequestBuilderGetQueryParameters(
|
|
||||||
select=self.get_select_fields(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
data = self._request(
|
|
||||||
self.client.users.by_user_id(connection.microsoft_id).get(request_configuration)
|
|
||||||
)
|
|
||||||
connection.attributes = self.entity_as_dict(data)
|
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
# Generated by Django 5.0.6 on 2024-05-23 20:48
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("authentik_providers_microsoft_entra", "0001_initial"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="microsoftentraprovidergroup",
|
|
||||||
name="attributes",
|
|
||||||
field=models.JSONField(default=dict),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="microsoftentraprovideruser",
|
|
||||||
name="attributes",
|
|
||||||
field=models.JSONField(default=dict),
|
|
||||||
),
|
|
||||||
]
|
|
@ -6,7 +6,6 @@ from uuid import uuid4
|
|||||||
from azure.identity.aio import ClientSecretCredential
|
from azure.identity.aio import ClientSecretCredential
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.templatetags.static import static
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
|
|
||||||
@ -22,58 +21,6 @@ from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
|||||||
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction, OutgoingSyncProvider
|
from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction, OutgoingSyncProvider
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraProviderUser(SerializerModel):
|
|
||||||
"""Mapping of a user and provider to a Microsoft user ID"""
|
|
||||||
|
|
||||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
|
||||||
microsoft_id = models.TextField()
|
|
||||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
|
||||||
provider = models.ForeignKey("MicrosoftEntraProvider", on_delete=models.CASCADE)
|
|
||||||
attributes = models.JSONField(default=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def serializer(self) -> type[Serializer]:
|
|
||||||
from authentik.enterprise.providers.microsoft_entra.api.users import (
|
|
||||||
MicrosoftEntraProviderUserSerializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return MicrosoftEntraProviderUserSerializer
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
verbose_name = _("Microsoft Entra Provider User")
|
|
||||||
verbose_name_plural = _("Microsoft Entra Provider User")
|
|
||||||
unique_together = (("microsoft_id", "user", "provider"),)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"Microsoft Entra Provider User {self.user_id} to {self.provider_id}"
|
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraProviderGroup(SerializerModel):
|
|
||||||
"""Mapping of a group and provider to a Microsoft group ID"""
|
|
||||||
|
|
||||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
|
||||||
microsoft_id = models.TextField()
|
|
||||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
|
||||||
provider = models.ForeignKey("MicrosoftEntraProvider", on_delete=models.CASCADE)
|
|
||||||
attributes = models.JSONField(default=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def serializer(self) -> type[Serializer]:
|
|
||||||
from authentik.enterprise.providers.microsoft_entra.api.groups import (
|
|
||||||
MicrosoftEntraProviderGroupSerializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return MicrosoftEntraProviderGroupSerializer
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
verbose_name = _("Microsoft Entra Provider Group")
|
|
||||||
verbose_name_plural = _("Microsoft Entra Provider Groups")
|
|
||||||
unique_together = (("microsoft_id", "group", "provider"),)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"Microsoft Entra Provider Group {self.group_id} to {self.provider_id}"
|
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||||
"""Sync users from authentik into Microsoft Entra."""
|
"""Sync users from authentik into Microsoft Entra."""
|
||||||
|
|
||||||
@ -100,16 +47,15 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def client_for_model(
|
def client_for_model(
|
||||||
self,
|
self, model: type[User | Group]
|
||||||
model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup],
|
|
||||||
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
|
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:
|
||||||
if issubclass(model, User | MicrosoftEntraProviderUser):
|
if issubclass(model, User):
|
||||||
from authentik.enterprise.providers.microsoft_entra.clients.users import (
|
from authentik.enterprise.providers.microsoft_entra.clients.users import (
|
||||||
MicrosoftEntraUserClient,
|
MicrosoftEntraUserClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MicrosoftEntraUserClient(self)
|
return MicrosoftEntraUserClient(self)
|
||||||
if issubclass(model, Group | MicrosoftEntraProviderGroup):
|
if issubclass(model, Group):
|
||||||
from authentik.enterprise.providers.microsoft_entra.clients.groups import (
|
from authentik.enterprise.providers.microsoft_entra.clients.groups import (
|
||||||
MicrosoftEntraGroupClient,
|
MicrosoftEntraGroupClient,
|
||||||
)
|
)
|
||||||
@ -141,10 +87,6 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
|
||||||
def icon_url(self) -> str | None:
|
|
||||||
return static("authentik/sources/azuread.svg")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-provider-microsoft-entra-form"
|
return "ak-provider-microsoft-entra-form"
|
||||||
@ -186,3 +128,53 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
|
|||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = _("Microsoft Entra Provider Mapping")
|
verbose_name = _("Microsoft Entra Provider Mapping")
|
||||||
verbose_name_plural = _("Microsoft Entra Provider Mappings")
|
verbose_name_plural = _("Microsoft Entra Provider Mappings")
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftEntraProviderUser(SerializerModel):
|
||||||
|
"""Mapping of a user and provider to a Microsoft user ID"""
|
||||||
|
|
||||||
|
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
|
microsoft_id = models.TextField()
|
||||||
|
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||||
|
provider = models.ForeignKey(MicrosoftEntraProvider, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> type[Serializer]:
|
||||||
|
from authentik.enterprise.providers.microsoft_entra.api.users import (
|
||||||
|
MicrosoftEntraProviderUserSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return MicrosoftEntraProviderUserSerializer
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = _("Microsoft Entra Provider User")
|
||||||
|
verbose_name_plural = _("Microsoft Entra Provider User")
|
||||||
|
unique_together = (("microsoft_id", "user", "provider"),)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Microsoft Entra Provider User {self.user_id} to {self.provider_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftEntraProviderGroup(SerializerModel):
|
||||||
|
"""Mapping of a group and provider to a Microsoft group ID"""
|
||||||
|
|
||||||
|
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||||
|
microsoft_id = models.TextField()
|
||||||
|
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||||
|
provider = models.ForeignKey(MicrosoftEntraProvider, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> type[Serializer]:
|
||||||
|
from authentik.enterprise.providers.microsoft_entra.api.groups import (
|
||||||
|
MicrosoftEntraProviderGroupSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return MicrosoftEntraProviderGroupSerializer
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = _("Microsoft Entra Provider Group")
|
||||||
|
verbose_name_plural = _("Microsoft Entra Provider Groups")
|
||||||
|
unique_together = (("microsoft_id", "group", "provider"),)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Microsoft Entra Provider Group {self.group_id} to {self.provider_id}"
|
||||||
|
@ -93,38 +93,6 @@ class MicrosoftEntraGroupTests(TestCase):
|
|||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||||
group_create.assert_called_once()
|
group_create.assert_called_once()
|
||||||
|
|
||||||
def test_group_not_created(self):
|
|
||||||
"""Test without group property mappings, no group is created"""
|
|
||||||
self.provider.property_mappings_group.clear()
|
|
||||||
uid = generate_id()
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
|
|
||||||
MagicMock(return_value={"credentials": self.creds}),
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
|
|
||||||
AsyncMock(
|
|
||||||
return_value=OrganizationCollectionResponse(
|
|
||||||
value=[
|
|
||||||
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"msgraph.generated.groups.groups_request_builder.GroupsRequestBuilder.post",
|
|
||||||
AsyncMock(return_value=MSGroup(id=generate_id())),
|
|
||||||
) as group_create,
|
|
||||||
):
|
|
||||||
group = Group.objects.create(name=uid)
|
|
||||||
microsoft_group = MicrosoftEntraProviderGroup.objects.filter(
|
|
||||||
provider=self.provider, group=group
|
|
||||||
).first()
|
|
||||||
self.assertIsNone(microsoft_group)
|
|
||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
|
||||||
group_create.assert_not_called()
|
|
||||||
|
|
||||||
def test_group_create_update(self):
|
def test_group_create_update(self):
|
||||||
"""Test group updating"""
|
"""Test group updating"""
|
||||||
uid = generate_id()
|
uid = generate_id()
|
||||||
|
@ -3,18 +3,16 @@
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from azure.identity.aio import ClientSecretCredential
|
from azure.identity.aio import ClientSecretCredential
|
||||||
from django.urls import reverse
|
from django.test import TestCase
|
||||||
from msgraph.generated.models.group_collection_response import GroupCollectionResponse
|
from msgraph.generated.models.group_collection_response import GroupCollectionResponse
|
||||||
from msgraph.generated.models.organization import Organization
|
from msgraph.generated.models.organization import Organization
|
||||||
from msgraph.generated.models.organization_collection_response import OrganizationCollectionResponse
|
from msgraph.generated.models.organization_collection_response import OrganizationCollectionResponse
|
||||||
from msgraph.generated.models.user import User as MSUser
|
from msgraph.generated.models.user import User as MSUser
|
||||||
from msgraph.generated.models.user_collection_response import UserCollectionResponse
|
from msgraph.generated.models.user_collection_response import UserCollectionResponse
|
||||||
from msgraph.generated.models.verified_domain import VerifiedDomain
|
from msgraph.generated.models.verified_domain import VerifiedDomain
|
||||||
from rest_framework.test import APITestCase
|
|
||||||
|
|
||||||
from authentik.blueprints.tests import apply_blueprint
|
from authentik.blueprints.tests import apply_blueprint
|
||||||
from authentik.core.models import Application, Group, User
|
from authentik.core.models import Application, Group, User
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
|
||||||
from authentik.enterprise.providers.microsoft_entra.models import (
|
from authentik.enterprise.providers.microsoft_entra.models import (
|
||||||
MicrosoftEntraProvider,
|
MicrosoftEntraProvider,
|
||||||
MicrosoftEntraProviderMapping,
|
MicrosoftEntraProviderMapping,
|
||||||
@ -27,12 +25,11 @@ from authentik.lib.sync.outgoing.models import OutgoingSyncDeleteAction
|
|||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
|
|
||||||
|
|
||||||
class MicrosoftEntraUserTests(APITestCase):
|
class MicrosoftEntraUserTests(TestCase):
|
||||||
"""Microsoft Entra User tests"""
|
"""Microsoft Entra User tests"""
|
||||||
|
|
||||||
@apply_blueprint("system/providers-microsoft-entra.yaml")
|
@apply_blueprint("system/providers-microsoft-entra.yaml")
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
|
||||||
# Delete all users and groups as the mocked HTTP responses only return one ID
|
# Delete all users and groups as the mocked HTTP responses only return one ID
|
||||||
# which will cause errors with multiple users
|
# which will cause errors with multiple users
|
||||||
Tenant.objects.update(avatars="none")
|
Tenant.objects.update(avatars="none")
|
||||||
@ -97,42 +94,6 @@ class MicrosoftEntraUserTests(APITestCase):
|
|||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||||
user_create.assert_called_once()
|
user_create.assert_called_once()
|
||||||
|
|
||||||
def test_user_not_created(self):
|
|
||||||
"""Test without property mappings, no group is created"""
|
|
||||||
self.provider.property_mappings.clear()
|
|
||||||
uid = generate_id()
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
|
|
||||||
MagicMock(return_value={"credentials": self.creds}),
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
|
|
||||||
AsyncMock(
|
|
||||||
return_value=OrganizationCollectionResponse(
|
|
||||||
value=[
|
|
||||||
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"msgraph.generated.users.users_request_builder.UsersRequestBuilder.post",
|
|
||||||
AsyncMock(return_value=MSUser(id=generate_id())),
|
|
||||||
) as user_create,
|
|
||||||
):
|
|
||||||
user = User.objects.create(
|
|
||||||
username=uid,
|
|
||||||
name=f"{uid} {uid}",
|
|
||||||
email=f"{uid}@goauthentik.io",
|
|
||||||
)
|
|
||||||
microsoft_user = MicrosoftEntraProviderUser.objects.filter(
|
|
||||||
provider=self.provider, user=user
|
|
||||||
).first()
|
|
||||||
self.assertIsNone(microsoft_user)
|
|
||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
|
||||||
user_create.assert_not_called()
|
|
||||||
|
|
||||||
def test_user_create_update(self):
|
def test_user_create_update(self):
|
||||||
"""Test user updating"""
|
"""Test user updating"""
|
||||||
uid = generate_id()
|
uid = generate_id()
|
||||||
@ -374,45 +335,3 @@ class MicrosoftEntraUserTests(APITestCase):
|
|||||||
)
|
)
|
||||||
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
self.assertFalse(Event.objects.filter(action=EventAction.SYSTEM_EXCEPTION).exists())
|
||||||
user_list.assert_called_once()
|
user_list.assert_called_once()
|
||||||
|
|
||||||
def test_connect_manual(self):
|
|
||||||
"""test manual user connection"""
|
|
||||||
uid = generate_id()
|
|
||||||
self.app.backchannel_providers.remove(self.provider)
|
|
||||||
admin = create_test_admin_user()
|
|
||||||
different_user = User.objects.create(
|
|
||||||
username=uid,
|
|
||||||
email=f"{uid}@goauthentik.io",
|
|
||||||
)
|
|
||||||
self.app.backchannel_providers.add(self.provider)
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider.microsoft_credentials",
|
|
||||||
MagicMock(return_value={"credentials": self.creds}),
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"msgraph.generated.organization.organization_request_builder.OrganizationRequestBuilder.get",
|
|
||||||
AsyncMock(
|
|
||||||
return_value=OrganizationCollectionResponse(
|
|
||||||
value=[
|
|
||||||
Organization(verified_domains=[VerifiedDomain(name="goauthentik.io")])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"authentik.enterprise.providers.microsoft_entra.clients.users.MicrosoftEntraUserClient.update_single_attribute",
|
|
||||||
MagicMock(),
|
|
||||||
) as user_get,
|
|
||||||
):
|
|
||||||
self.client.force_login(admin)
|
|
||||||
response = self.client.post(
|
|
||||||
reverse("authentik_api:microsoftentraprovideruser-list"),
|
|
||||||
data={
|
|
||||||
"microsoft_id": generate_id(),
|
|
||||||
"user": different_user.pk,
|
|
||||||
"provider": self.provider.pk,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 201)
|
|
||||||
user_get.assert_called_once()
|
|
||||||
|
@ -7,7 +7,7 @@ from drf_spectacular.utils import extend_schema_field
|
|||||||
from rest_framework.fields import CharField
|
from rest_framework.fields import CharField
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import JSONDictField
|
from authentik.core.api.utils import JSONDictField
|
||||||
from authentik.enterprise.providers.rac.models import RACPropertyMapping
|
from authentik.enterprise.providers.rac.models import RACPropertyMapping
|
||||||
|
@ -7,7 +7,6 @@ from deepmerge import always_merger
|
|||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.templatetags.static import static
|
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
@ -64,10 +63,6 @@ class RACProvider(Provider):
|
|||||||
Can return None for providers that are not URL-based"""
|
Can return None for providers that are not URL-based"""
|
||||||
return "goauthentik.io://providers/rac/launch"
|
return "goauthentik.io://providers/rac/launch"
|
||||||
|
|
||||||
@property
|
|
||||||
def icon_url(self) -> str | None:
|
|
||||||
return static("authentik/sources/rac.svg")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-provider-rac-form"
|
return "ak-provider-rac-form"
|
||||||
|
@ -18,12 +18,9 @@ class SourceStageSerializer(EnterpriseRequiredMixin, StageSerializer):
|
|||||||
source = Source.objects.filter(pk=_source.pk).select_subclasses().first()
|
source = Source.objects.filter(pk=_source.pk).select_subclasses().first()
|
||||||
if not source:
|
if not source:
|
||||||
raise ValidationError("Invalid source")
|
raise ValidationError("Invalid source")
|
||||||
if "request" in self.context:
|
login_button = source.ui_login_button(self.context["request"])
|
||||||
login_button = source.ui_login_button(self.context["request"])
|
if not login_button:
|
||||||
if not login_button:
|
raise ValidationError("Invalid source selected, only web-based sources are supported.")
|
||||||
raise ValidationError(
|
|
||||||
"Invalid source selected, only web-based sources are supported."
|
|
||||||
)
|
|
||||||
return source
|
return source
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -54,7 +54,7 @@ class SourceStageView(ChallengeStageView):
|
|||||||
def create_flow_token(self) -> FlowToken:
|
def create_flow_token(self) -> FlowToken:
|
||||||
"""Save the current flow state in a token that can be used to resume this flow"""
|
"""Save the current flow state in a token that can be used to resume this flow"""
|
||||||
pending_user: User = self.get_pending_user()
|
pending_user: User = self.get_pending_user()
|
||||||
if pending_user.is_anonymous or not pending_user.pk:
|
if pending_user.is_anonymous:
|
||||||
pending_user = get_anonymous_user()
|
pending_user = get_anonymous_user()
|
||||||
current_stage: SourceStage = self.executor.current_stage
|
current_stage: SourceStage = self.executor.current_stage
|
||||||
identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}")
|
identifier = slugify(f"ak-source-stage-{current_stage.name}-{str(uuid4())}")
|
||||||
|
@ -19,8 +19,7 @@ from rest_framework.serializers import ModelSerializer
|
|||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.admin.api.metrics import CoordinateSerializer
|
from authentik.admin.api.metrics import CoordinateSerializer
|
||||||
from authentik.core.api.object_types import TypeCreateSerializer
|
from authentik.core.api.utils import PassiveSerializer, TypeCreateSerializer
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ class GeoIPContextProcessor(MMDBContextProcessor):
|
|||||||
|
|
||||||
def enrich_context(self, request: HttpRequest) -> dict:
|
def enrich_context(self, request: HttpRequest) -> dict:
|
||||||
# Different key `geoip` vs `geo` for legacy reasons
|
# Different key `geoip` vs `geo` for legacy reasons
|
||||||
return {"geoip": self.city_dict(ClientIPMiddleware.get_client_ip(request))}
|
return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))}
|
||||||
|
|
||||||
def city(self, ip_address: str) -> City | None:
|
def city(self, ip_address: str) -> City | None:
|
||||||
"""Wrapper for Reader.city"""
|
"""Wrapper for Reader.city"""
|
||||||
|
@ -10,7 +10,7 @@ from django.db import migrations, models
|
|||||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||||
|
|
||||||
import authentik.events.models
|
import authentik.events.models
|
||||||
import authentik.lib.models
|
import authentik.lib.validators
|
||||||
from authentik.lib.migrations import progress_bar
|
from authentik.lib.migrations import progress_bar
|
||||||
|
|
||||||
|
|
||||||
@ -377,7 +377,7 @@ class Migration(migrations.Migration):
|
|||||||
model_name="notificationtransport",
|
model_name="notificationtransport",
|
||||||
name="webhook_url",
|
name="webhook_url",
|
||||||
field=models.TextField(
|
field=models.TextField(
|
||||||
blank=True, validators=[authentik.lib.models.DomainlessURLValidator()]
|
blank=True, validators=[authentik.lib.validators.DomainlessURLValidator()]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@ -41,10 +41,11 @@ from authentik.events.utils import (
|
|||||||
sanitize_dict,
|
sanitize_dict,
|
||||||
sanitize_item,
|
sanitize_item,
|
||||||
)
|
)
|
||||||
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
|
from authentik.lib.validators import DomainlessURLValidator
|
||||||
from authentik.policies.models import PolicyBindingModel
|
from authentik.policies.models import PolicyBindingModel
|
||||||
from authentik.root.middleware import ClientIPMiddleware
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
from authentik.stages.email.utils import TemplateEmailMessage
|
from authentik.stages.email.utils import TemplateEmailMessage
|
||||||
|
@ -10,10 +10,10 @@ from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
|||||||
from rest_framework.viewsets import GenericViewSet
|
from rest_framework.viewsets import GenericViewSet
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.api.object_types import TypesMixin
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import MetaNameSerializer
|
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||||
from authentik.core.types import UserSettingSerializer
|
from authentik.core.types import UserSettingSerializer
|
||||||
|
from authentik.enterprise.apps import EnterpriseConfig
|
||||||
from authentik.flows.api.flows import FlowSetSerializer
|
from authentik.flows.api.flows import FlowSetSerializer
|
||||||
from authentik.flows.models import ConfigurableStage, Stage
|
from authentik.flows.models import ConfigurableStage, Stage
|
||||||
from authentik.lib.utils.reflection import all_subclasses
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
@ -47,7 +47,6 @@ class StageSerializer(ModelSerializer, MetaNameSerializer):
|
|||||||
|
|
||||||
|
|
||||||
class StageViewSet(
|
class StageViewSet(
|
||||||
TypesMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
mixins.RetrieveModelMixin,
|
||||||
mixins.DestroyModelMixin,
|
mixins.DestroyModelMixin,
|
||||||
UsedByMixin,
|
UsedByMixin,
|
||||||
@ -64,6 +63,25 @@ class StageViewSet(
|
|||||||
def get_queryset(self): # pragma: no cover
|
def get_queryset(self): # pragma: no cover
|
||||||
return Stage.objects.select_subclasses().prefetch_related("flow_set")
|
return Stage.objects.select_subclasses().prefetch_related("flow_set")
|
||||||
|
|
||||||
|
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||||
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
|
def types(self, request: Request) -> Response:
|
||||||
|
"""Get all creatable stage types"""
|
||||||
|
data = []
|
||||||
|
for subclass in all_subclasses(self.queryset.model, False):
|
||||||
|
subclass: Stage
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": subclass._meta.verbose_name,
|
||||||
|
"description": subclass.__doc__,
|
||||||
|
"component": subclass().component,
|
||||||
|
"model_name": subclass._meta.model_name,
|
||||||
|
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = sorted(data, key=lambda x: x["name"])
|
||||||
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
|
||||||
@extend_schema(responses={200: UserSettingSerializer(many=True)})
|
@extend_schema(responses={200: UserSettingSerializer(many=True)})
|
||||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
def user_settings(self, request: Request) -> Response:
|
def user_settings(self, request: Request) -> Response:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from functools import cache as funccache
|
from functools import cache as funccache
|
||||||
from hashlib import md5, sha256
|
from hashlib import md5
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ from authentik.tenants.utils import get_current_tenant
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
|
|
||||||
GRAVATAR_URL = "https://www.gravatar.com"
|
GRAVATAR_URL = "https://secure.gravatar.com"
|
||||||
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
DEFAULT_AVATAR = static("dist/assets/images/user_default.png")
|
||||||
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
|
CACHE_KEY_GRAVATAR = "goauthentik.io/lib/avatars/"
|
||||||
CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available"
|
CACHE_KEY_GRAVATAR_AVAILABLE = "goauthentik.io/lib/avatars/gravatar_available"
|
||||||
@ -55,9 +55,10 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
|
|||||||
if not cache.get(CACHE_KEY_GRAVATAR_AVAILABLE, True):
|
if not cache.get(CACHE_KEY_GRAVATAR_AVAILABLE, True):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mail_hash = sha256(user.email.lower().encode("utf-8")).hexdigest() # nosec
|
# gravatar uses md5 for their URLs, so md5 can't be avoided
|
||||||
parameters = {"size": "158", "rating": "g", "default": "404"}
|
mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||||
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters)}"
|
parameters = [("size", "158"), ("rating", "g"), ("default", "404")]
|
||||||
|
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
|
||||||
|
|
||||||
full_key = CACHE_KEY_GRAVATAR + mail_hash
|
full_key = CACHE_KEY_GRAVATAR + mail_hash
|
||||||
if cache.has_key(full_key):
|
if cache.has_key(full_key):
|
||||||
@ -83,9 +84,7 @@ def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
|
|||||||
|
|
||||||
def generate_colors(text: str) -> tuple[str, str]:
|
def generate_colors(text: str) -> tuple[str, str]:
|
||||||
"""Generate colours based on `text`"""
|
"""Generate colours based on `text`"""
|
||||||
color = (
|
color = int(md5(text.lower().encode("utf-8")).hexdigest(), 16) % 0xFFFFFF # nosec
|
||||||
int(md5(text.lower().encode("utf-8"), usedforsecurity=False).hexdigest(), 16) % 0xFFFFFF
|
|
||||||
) # nosec
|
|
||||||
|
|
||||||
# Get a (somewhat arbitrarily) reduced scope of colors
|
# Get a (somewhat arbitrarily) reduced scope of colors
|
||||||
# to avoid too dark or light backgrounds
|
# to avoid too dark or light backgrounds
|
||||||
@ -180,7 +179,7 @@ def avatar_mode_generated(user: "User", mode: str) -> str | None:
|
|||||||
|
|
||||||
def avatar_mode_url(user: "User", mode: str) -> str | None:
|
def avatar_mode_url(user: "User", mode: str) -> str | None:
|
||||||
"""Format url"""
|
"""Format url"""
|
||||||
mail_hash = md5(user.email.lower().encode("utf-8"), usedforsecurity=False).hexdigest() # nosec
|
mail_hash = md5(user.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||||
return mode % {
|
return mode % {
|
||||||
"username": user.username,
|
"username": user.username,
|
||||||
"mail_hash": mail_hash,
|
"mail_hash": mail_hash,
|
||||||
|
@ -304,12 +304,6 @@ class ConfigLoader:
|
|||||||
"""Wrapper for get that converts value into boolean"""
|
"""Wrapper for get that converts value into boolean"""
|
||||||
return str(self.get(path, default)).lower() == "true"
|
return str(self.get(path, default)).lower() == "true"
|
||||||
|
|
||||||
def get_keys(self, path: str, sep=".") -> list[str]:
|
|
||||||
"""List attribute keys by using yaml path"""
|
|
||||||
root = self.raw
|
|
||||||
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr({}))
|
|
||||||
return attr.keys()
|
|
||||||
|
|
||||||
def get_dict_from_b64_json(self, path: str, default=None) -> dict:
|
def get_dict_from_b64_json(self, path: str, default=None) -> dict:
|
||||||
"""Wrapper for get that converts value from Base64 encoded string into dictionary"""
|
"""Wrapper for get that converts value from Base64 encoded string into dictionary"""
|
||||||
config_value = self.get(path)
|
config_value = self.get(path)
|
||||||
|
@ -10,10 +10,6 @@ postgresql:
|
|||||||
use_pgpool: false
|
use_pgpool: false
|
||||||
test:
|
test:
|
||||||
name: test_authentik
|
name: test_authentik
|
||||||
read_replicas: {}
|
|
||||||
# For example
|
|
||||||
# 0:
|
|
||||||
# host: replica1.example.com
|
|
||||||
|
|
||||||
listen:
|
listen:
|
||||||
listen_http: 0.0.0.0:9000
|
listen_http: 0.0.0.0:9000
|
||||||
@ -50,6 +46,7 @@ cache:
|
|||||||
timeout: 300
|
timeout: 300
|
||||||
timeout_flows: 300
|
timeout_flows: 300
|
||||||
timeout_policies: 300
|
timeout_policies: 300
|
||||||
|
timeout_reputation: 300
|
||||||
|
|
||||||
# channel:
|
# channel:
|
||||||
# url: ""
|
# url: ""
|
||||||
@ -115,9 +112,6 @@ events:
|
|||||||
context_processors:
|
context_processors:
|
||||||
geoip: "/geoip/GeoLite2-City.mmdb"
|
geoip: "/geoip/GeoLite2-City.mmdb"
|
||||||
asn: "/geoip/GeoLite2-ASN.mmdb"
|
asn: "/geoip/GeoLite2-ASN.mmdb"
|
||||||
compliance:
|
|
||||||
fips:
|
|
||||||
enabled: false
|
|
||||||
|
|
||||||
cert_discovery_dir: /certs
|
cert_discovery_dir: /certs
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ import socket
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_address, ip_network
|
||||||
from textwrap import indent
|
from textwrap import indent
|
||||||
from types import CodeType
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cachetools import TLRUCache, cached
|
from cachetools import TLRUCache, cached
|
||||||
@ -185,7 +184,7 @@ class BaseEvaluator:
|
|||||||
full_expression += f"\nresult = handler({handler_signature})"
|
full_expression += f"\nresult = handler({handler_signature})"
|
||||||
return full_expression
|
return full_expression
|
||||||
|
|
||||||
def compile(self, expression: str) -> CodeType:
|
def compile(self, expression: str) -> Any:
|
||||||
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
|
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
|
||||||
param_keys = self._context.keys()
|
param_keys = self._context.keys()
|
||||||
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
|
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
|
||||||
|
@ -102,8 +102,6 @@ def get_logger_config():
|
|||||||
"gunicorn": "INFO",
|
"gunicorn": "INFO",
|
||||||
"requests_mock": "WARNING",
|
"requests_mock": "WARNING",
|
||||||
"hpack": "WARNING",
|
"hpack": "WARNING",
|
||||||
"httpx": "WARNING",
|
|
||||||
"azure": "WARNING",
|
|
||||||
}
|
}
|
||||||
for handler_name, level in handler_level_map.items():
|
for handler_name, level in handler_level_map.items():
|
||||||
base_config["loggers"][handler_name] = {
|
base_config["loggers"][handler_name] = {
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
"""Generic models"""
|
"""Generic models"""
|
||||||
|
|
||||||
import re
|
from typing import Any
|
||||||
|
|
||||||
from django.core.validators import URLValidator
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.regex_helper import _lazy_re_compile
|
from django.dispatch import Signal
|
||||||
|
from django.utils import timezone
|
||||||
from model_utils.managers import InheritanceManager
|
from model_utils.managers import InheritanceManager
|
||||||
from rest_framework.serializers import BaseSerializer
|
from rest_framework.serializers import BaseSerializer
|
||||||
|
|
||||||
|
pre_soft_delete = Signal()
|
||||||
|
post_soft_delete = Signal()
|
||||||
|
|
||||||
|
|
||||||
class SerializerModel(models.Model):
|
class SerializerModel(models.Model):
|
||||||
"""Base Abstract Model which has a serializer"""
|
"""Base Abstract Model which has a serializer"""
|
||||||
@ -51,46 +54,57 @@ class InheritanceForeignKey(models.ForeignKey):
|
|||||||
forward_related_accessor_class = InheritanceForwardManyToOneDescriptor
|
forward_related_accessor_class = InheritanceForwardManyToOneDescriptor
|
||||||
|
|
||||||
|
|
||||||
class DomainlessURLValidator(URLValidator):
|
class SoftDeleteQuerySet(models.QuerySet):
|
||||||
"""Subclass of URLValidator which doesn't check the domain
|
|
||||||
(to allow hostnames without domain)"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def delete(self):
|
||||||
super().__init__(*args, **kwargs)
|
for obj in self.all():
|
||||||
self.host_re = "(" + self.hostname_re + self.domain_re + "|localhost)"
|
obj.delete()
|
||||||
self.regex = _lazy_re_compile(
|
|
||||||
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
|
def hard_delete(self):
|
||||||
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
|
return super().delete()
|
||||||
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
|
|
||||||
r"(?::\d{2,5})?" # port
|
|
||||||
r"(?:[/?#][^\s]*)?" # resource path
|
class SoftDeleteManager(models.Manager):
|
||||||
r"\Z",
|
|
||||||
re.IGNORECASE,
|
def get_queryset(self):
|
||||||
|
return SoftDeleteQuerySet(self.model, using=self._db).filter(deleted_at__isnull=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DeletedSoftDeleteManager(models.Manager):
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
return super().get_queryset().exclude(deleted_at__isnull=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftDeleteModel(models.Model):
|
||||||
|
"""Model which doesn't fully delete itself, but rather saved the delete status
|
||||||
|
so cleanup events can run."""
|
||||||
|
|
||||||
|
deleted_at = models.DateTimeField(blank=True, null=True)
|
||||||
|
|
||||||
|
objects = SoftDeleteManager()
|
||||||
|
deleted = DeletedSoftDeleteManager()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_deleted(self):
|
||||||
|
return self.deleted_at is not None
|
||||||
|
|
||||||
|
def delete(self, using: Any = ..., keep_parents: bool = ...) -> tuple[int, dict[str, int]]:
|
||||||
|
pre_soft_delete.send(sender=self.__class__, instance=self)
|
||||||
|
now = timezone.now()
|
||||||
|
self.deleted_at = now
|
||||||
|
self.save(
|
||||||
|
update_fields=[
|
||||||
|
"deleted_at",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.schemes = ["http", "https", "blank"] + list(self.schemes)
|
post_soft_delete.send(sender=self.__class__, instance=self)
|
||||||
|
return tuple()
|
||||||
|
|
||||||
def __call__(self, value: str):
|
def force_delete(self, using: Any = ...):
|
||||||
# Check if the scheme is valid.
|
if not self.deleted_at:
|
||||||
scheme = value.split("://")[0].lower()
|
raise models.ProtectedError("Refusing to force delete non-deleted model", {self})
|
||||||
if scheme not in self.schemes:
|
return super().delete(using=using)
|
||||||
value = "default" + value
|
|
||||||
super().__call__(value)
|
|
||||||
|
|
||||||
|
|
||||||
class DomainlessFormattedURLValidator(DomainlessURLValidator):
|
|
||||||
"""URL validator which allows for python format strings"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.formatter_re = r"([%\(\)a-zA-Z])*"
|
|
||||||
self.host_re = "(" + self.formatter_re + self.hostname_re + self.domain_re + "|localhost)"
|
|
||||||
self.regex = _lazy_re_compile(
|
|
||||||
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
|
|
||||||
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
|
|
||||||
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
|
|
||||||
r"(?::\d{2,5})?" # port
|
|
||||||
r"(?:[/?#][^\s]*)?" # resource path
|
|
||||||
r"\Z",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
self.schemes = ["http", "https", "blank"] + list(self.schemes)
|
|
||||||
|
@ -1,69 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from django.http import HttpRequest
|
|
||||||
|
|
||||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
|
||||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
|
||||||
from authentik.core.models import PropertyMapping, User
|
|
||||||
|
|
||||||
|
|
||||||
class PropertyMappingManager:
|
|
||||||
"""Pre-compile and cache property mappings when an identical
|
|
||||||
set is used multiple times"""
|
|
||||||
|
|
||||||
query_set: QuerySet[PropertyMapping]
|
|
||||||
mapping_subclass: type[PropertyMapping]
|
|
||||||
|
|
||||||
_evaluators: list[PropertyMappingEvaluator]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
qs: QuerySet[PropertyMapping],
|
|
||||||
# Expected subclass of PropertyMappings, any objects in the queryset
|
|
||||||
# that are not an instance of this class will be discarded
|
|
||||||
mapping_subclass: type[PropertyMapping],
|
|
||||||
# As they keys of parameters are part of the compilation,
|
|
||||||
# we need a list of all parameter names that will be used during evaluation
|
|
||||||
context_keys: list[str],
|
|
||||||
) -> None:
|
|
||||||
self.query_set = qs
|
|
||||||
self.mapping_subclass = mapping_subclass
|
|
||||||
self.context_keys = context_keys
|
|
||||||
self.compile()
|
|
||||||
|
|
||||||
def compile(self):
|
|
||||||
self._evaluators = []
|
|
||||||
for mapping in self.query_set:
|
|
||||||
if not isinstance(mapping, self.mapping_subclass):
|
|
||||||
continue
|
|
||||||
evaluator = PropertyMappingEvaluator(
|
|
||||||
mapping, **{key: None for key in self.context_keys}
|
|
||||||
)
|
|
||||||
# Compile and cache expression
|
|
||||||
evaluator.compile()
|
|
||||||
self._evaluators.append(evaluator)
|
|
||||||
|
|
||||||
def iter_eval(
|
|
||||||
self,
|
|
||||||
user: User | None,
|
|
||||||
request: HttpRequest | None,
|
|
||||||
return_mapping: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Generator[tuple[dict, PropertyMapping], None]:
|
|
||||||
"""Iterate over all mappings that were pre-compiled and
|
|
||||||
execute all of them with the given context"""
|
|
||||||
for mapping in self._evaluators:
|
|
||||||
mapping.set_context(user, request, **kwargs)
|
|
||||||
try:
|
|
||||||
value = mapping.evaluate(mapping.model.expression)
|
|
||||||
except PropertyMappingExpressionException as exc:
|
|
||||||
raise exc from exc
|
|
||||||
except Exception as exc:
|
|
||||||
raise PropertyMappingExpressionException(exc, mapping.model) from exc
|
|
||||||
if value is None:
|
|
||||||
continue
|
|
||||||
if return_mapping:
|
|
||||||
yield value, mapping.model
|
|
||||||
else:
|
|
||||||
yield value
|
|
@ -3,6 +3,3 @@
|
|||||||
PAGE_SIZE = 100
|
PAGE_SIZE = 100
|
||||||
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
|
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
|
||||||
HTTP_CONFLICT = 409
|
HTTP_CONFLICT = 409
|
||||||
HTTP_NO_CONTENT = 204
|
|
||||||
HTTP_SERVICE_UNAVAILABLE = 503
|
|
||||||
HTTP_TOO_MANY_REQUESTS = 429
|
|
||||||
|
@ -7,7 +7,6 @@ from rest_framework.decorators import action
|
|||||||
from rest_framework.fields import BooleanField
|
from rest_framework.fields import BooleanField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import ModelSerializer
|
|
||||||
|
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
from authentik.core.api.utils import PassiveSerializer
|
||||||
from authentik.events.api.tasks import SystemTaskSerializer
|
from authentik.events.api.tasks import SystemTaskSerializer
|
||||||
@ -48,24 +47,8 @@ class OutgoingSyncProviderStatusMixin:
|
|||||||
uid=slugify(provider.name),
|
uid=slugify(provider.name),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with provider.sync_lock as lock_acquired:
|
status = {
|
||||||
status = {
|
"tasks": tasks,
|
||||||
"tasks": tasks,
|
"is_running": provider.sync_lock.locked(),
|
||||||
# If we could not acquire the lock, it means a task is using it, and thus is running
|
}
|
||||||
"is_running": not lock_acquired,
|
|
||||||
}
|
|
||||||
return Response(SyncStatusSerializer(status).data)
|
return Response(SyncStatusSerializer(status).data)
|
||||||
|
|
||||||
|
|
||||||
class OutgoingSyncConnectionCreateMixin:
|
|
||||||
"""Mixin for connection objects that fetches remote data upon creation"""
|
|
||||||
|
|
||||||
def perform_create(self, serializer: ModelSerializer):
|
|
||||||
super().perform_create(serializer)
|
|
||||||
try:
|
|
||||||
instance = serializer.instance
|
|
||||||
client = instance.provider.client_for_model(instance.__class__)
|
|
||||||
client.update_single_attribute(instance)
|
|
||||||
instance.save()
|
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
|
@ -3,18 +3,10 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from deepmerge import always_merger
|
|
||||||
from django.db import DatabaseError
|
from django.db import DatabaseError
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.expression.exceptions import (
|
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException
|
||||||
PropertyMappingExpressionException,
|
|
||||||
SkipObjectException,
|
|
||||||
)
|
|
||||||
from authentik.events.models import Event, EventAction
|
|
||||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
|
||||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
|
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
@ -36,7 +28,6 @@ class BaseOutgoingSyncClient[
|
|||||||
provider: TProvider
|
provider: TProvider
|
||||||
connection_type: type[TConnection]
|
connection_type: type[TConnection]
|
||||||
connection_type_query: str
|
connection_type_query: str
|
||||||
mapper: PropertyMappingManager
|
|
||||||
|
|
||||||
can_discover = False
|
can_discover = False
|
||||||
|
|
||||||
@ -79,34 +70,9 @@ class BaseOutgoingSyncClient[
|
|||||||
"""Delete object from destination"""
|
"""Delete object from destination"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def to_schema(self, obj: TModel, connection: TConnection | None, **defaults) -> TSchema:
|
def to_schema(self, obj: TModel, creating: bool) -> TSchema:
|
||||||
"""Convert object to destination schema"""
|
"""Convert object to destination schema"""
|
||||||
raw_final_object = {}
|
raise NotImplementedError()
|
||||||
try:
|
|
||||||
eval_kwargs = {
|
|
||||||
"request": None,
|
|
||||||
"provider": self.provider,
|
|
||||||
"connection": connection,
|
|
||||||
obj._meta.model_name: obj,
|
|
||||||
}
|
|
||||||
eval_kwargs.setdefault("user", None)
|
|
||||||
for value in self.mapper.iter_eval(**eval_kwargs):
|
|
||||||
always_merger.merge(raw_final_object, value)
|
|
||||||
except SkipObjectException as exc:
|
|
||||||
raise exc from exc
|
|
||||||
except PropertyMappingExpressionException as exc:
|
|
||||||
# Value error can be raised when assigning invalid data to an attribute
|
|
||||||
Event.new(
|
|
||||||
EventAction.CONFIGURATION_ERROR,
|
|
||||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
|
||||||
mapping=exc.mapping,
|
|
||||||
).save()
|
|
||||||
raise StopSync(exc, obj, exc.mapping) from exc
|
|
||||||
if not raw_final_object:
|
|
||||||
raise StopSync(ValueError("No mappings configured"), obj)
|
|
||||||
for key, value in defaults.items():
|
|
||||||
raw_final_object.setdefault(key, value)
|
|
||||||
return raw_final_object
|
|
||||||
|
|
||||||
def discover(self):
|
def discover(self):
|
||||||
"""Optional method. Can be used to implement a "discovery" where
|
"""Optional method. Can be used to implement a "discovery" where
|
||||||
@ -114,8 +80,3 @@ class BaseOutgoingSyncClient[
|
|||||||
pre-link any users/groups in the remote system with the respective
|
pre-link any users/groups in the remote system with the respective
|
||||||
object in authentik based on a common identifier"""
|
object in authentik based on a common identifier"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def update_single_attribute(self, connection: TConnection):
|
|
||||||
"""Update connection attributes on a connection object, when the connection
|
|
||||||
is manually created"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from typing import Any, Self
|
from typing import Any, Self
|
||||||
|
|
||||||
import pglock
|
from django.core.cache import cache
|
||||||
from django.db import connection
|
|
||||||
from django.db.models import Model, QuerySet, TextChoices
|
from django.db.models import Model, QuerySet, TextChoices
|
||||||
|
from redis.lock import Lock
|
||||||
|
|
||||||
from authentik.core.models import Group, User
|
from authentik.core.models import Group, User
|
||||||
|
from authentik.lib.sync.outgoing import PAGE_TIMEOUT
|
||||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||||
|
|
||||||
|
|
||||||
@ -31,10 +32,10 @@ class OutgoingSyncProvider(Model):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sync_lock(self) -> pglock.advisory:
|
def sync_lock(self) -> Lock:
|
||||||
"""Postgres lock for syncing SCIM to prevent multiple parallel syncs happening"""
|
"""Redis lock to prevent multiple parallel syncs happening"""
|
||||||
return pglock.advisory(
|
return Lock(
|
||||||
lock_id=f"goauthentik.io/{connection.schema_name}/providers/outgoing-sync/{str(self.pk)}",
|
cache.client.get_client(),
|
||||||
timeout=0,
|
name=f"goauthentik.io/providers/outgoing-sync/{str(self.pk)}",
|
||||||
side_effect=pglock.Return,
|
timeout=(60 * 60 * PAGE_TIMEOUT) * 3,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import asdict
|
|
||||||
|
|
||||||
from celery.exceptions import Retry
|
from celery.exceptions import Retry
|
||||||
from celery.result import allow_join_result
|
from celery.result import allow_join_result
|
||||||
@ -14,7 +13,6 @@ from authentik.core.models import Group, User
|
|||||||
from authentik.events.logs import LogEvent
|
from authentik.events.logs import LogEvent
|
||||||
from authentik.events.models import TaskStatus
|
from authentik.events.models import TaskStatus
|
||||||
from authentik.events.system_tasks import SystemTask
|
from authentik.events.system_tasks import SystemTask
|
||||||
from authentik.events.utils import sanitize_item
|
|
||||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
||||||
from authentik.lib.sync.outgoing.base import Direction
|
from authentik.lib.sync.outgoing.base import Direction
|
||||||
from authentik.lib.sync.outgoing.exceptions import (
|
from authentik.lib.sync.outgoing.exceptions import (
|
||||||
@ -66,16 +64,17 @@ class SyncTasks:
|
|||||||
).first()
|
).first()
|
||||||
if not provider:
|
if not provider:
|
||||||
return
|
return
|
||||||
|
lock = provider.sync_lock
|
||||||
|
if lock.locked():
|
||||||
|
self.logger.debug("Sync locked, skipping task", source=provider.name)
|
||||||
|
return
|
||||||
task.set_uid(slugify(provider.name))
|
task.set_uid(slugify(provider.name))
|
||||||
messages = []
|
messages = []
|
||||||
messages.append(_("Starting full provider sync"))
|
messages.append(_("Starting full provider sync"))
|
||||||
self.logger.debug("Starting provider sync")
|
self.logger.debug("Starting provider sync")
|
||||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||||
with allow_join_result(), provider.sync_lock as lock_acquired:
|
with allow_join_result(), lock:
|
||||||
if not lock_acquired:
|
|
||||||
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
for page in users_paginator.page_range:
|
for page in users_paginator.page_range:
|
||||||
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
|
messages.append(_("Syncing page %(page)d of users" % {"page": page}))
|
||||||
@ -84,7 +83,7 @@ class SyncTasks:
|
|||||||
time_limit=PAGE_TIMEOUT,
|
time_limit=PAGE_TIMEOUT,
|
||||||
soft_time_limit=PAGE_TIMEOUT,
|
soft_time_limit=PAGE_TIMEOUT,
|
||||||
).get():
|
).get():
|
||||||
messages.append(LogEvent(**msg))
|
messages.append(msg)
|
||||||
for page in groups_paginator.page_range:
|
for page in groups_paginator.page_range:
|
||||||
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
|
messages.append(_("Syncing page %(page)d of groups" % {"page": page}))
|
||||||
for msg in sync_objects.apply_async(
|
for msg in sync_objects.apply_async(
|
||||||
@ -92,7 +91,7 @@ class SyncTasks:
|
|||||||
time_limit=PAGE_TIMEOUT,
|
time_limit=PAGE_TIMEOUT,
|
||||||
soft_time_limit=PAGE_TIMEOUT,
|
soft_time_limit=PAGE_TIMEOUT,
|
||||||
).get():
|
).get():
|
||||||
messages.append(LogEvent(**msg))
|
messages.append(msg)
|
||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
self.logger.warning("transient sync exception", exc=exc)
|
self.logger.warning("transient sync exception", exc=exc)
|
||||||
raise task.retry(exc=exc) from exc
|
raise task.retry(exc=exc) from exc
|
||||||
@ -126,70 +125,61 @@ class SyncTasks:
|
|||||||
try:
|
try:
|
||||||
client.write(obj)
|
client.write(obj)
|
||||||
except SkipObjectException:
|
except SkipObjectException:
|
||||||
self.logger.debug("skipping object due to SkipObject", obj=obj)
|
|
||||||
continue
|
continue
|
||||||
except BadRequestSyncException as exc:
|
except BadRequestSyncException as exc:
|
||||||
self.logger.warning("failed to sync object", exc=exc, obj=obj)
|
self.logger.warning("failed to sync object", exc=exc, obj=obj)
|
||||||
messages.append(
|
messages.append(
|
||||||
asdict(
|
LogEvent(
|
||||||
LogEvent(
|
_(
|
||||||
_(
|
(
|
||||||
(
|
"Failed to sync {object_type} {object_name} "
|
||||||
"Failed to sync {object_type} {object_name} "
|
"due to error: {error}"
|
||||||
"due to error: {error}"
|
).format_map(
|
||||||
).format_map(
|
{
|
||||||
{
|
"object_type": obj._meta.verbose_name,
|
||||||
"object_type": obj._meta.verbose_name,
|
"object_name": str(obj),
|
||||||
"object_name": str(obj),
|
"error": str(exc),
|
||||||
"error": str(exc),
|
}
|
||||||
}
|
)
|
||||||
)
|
),
|
||||||
),
|
log_level="warning",
|
||||||
log_level="warning",
|
logger="",
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
attributes={"arguments": exc.args[1:]},
|
||||||
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
||||||
messages.append(
|
messages.append(
|
||||||
asdict(
|
LogEvent(
|
||||||
LogEvent(
|
_(
|
||||||
_(
|
(
|
||||||
(
|
"Failed to sync {object_type} {object_name} "
|
||||||
"Failed to sync {object_type} {object_name} "
|
"due to transient error: {error}"
|
||||||
"due to transient error: {error}"
|
).format_map(
|
||||||
).format_map(
|
{
|
||||||
{
|
"object_type": obj._meta.verbose_name,
|
||||||
"object_type": obj._meta.verbose_name,
|
"object_name": str(obj),
|
||||||
"object_name": str(obj),
|
"error": str(exc),
|
||||||
"error": str(exc),
|
}
|
||||||
}
|
)
|
||||||
)
|
),
|
||||||
),
|
log_level="warning",
|
||||||
log_level="warning",
|
logger="",
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
|
||||||
attributes={"obj": sanitize_item(obj)},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except StopSync as exc:
|
except StopSync as exc:
|
||||||
self.logger.warning("Stopping sync", exc=exc)
|
self.logger.warning("Stopping sync", exc=exc)
|
||||||
messages.append(
|
messages.append(
|
||||||
asdict(
|
LogEvent(
|
||||||
LogEvent(
|
_(
|
||||||
_(
|
"Stopping sync due to error: {error}".format_map(
|
||||||
"Stopping sync due to error: {error}".format_map(
|
{
|
||||||
{
|
"error": exc.detail(),
|
||||||
"error": exc.detail(),
|
}
|
||||||
}
|
)
|
||||||
)
|
),
|
||||||
),
|
log_level="warning",
|
||||||
log_level="warning",
|
logger="",
|
||||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
|
||||||
attributes={"obj": sanitize_item(obj)},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
@ -169,9 +169,3 @@ class TestConfig(TestCase):
|
|||||||
self.assertEqual(config.get("cache.timeout_flows"), "32m")
|
self.assertEqual(config.get("cache.timeout_flows"), "32m")
|
||||||
self.assertEqual(config.get("cache.timeout_policies"), "3920ns")
|
self.assertEqual(config.get("cache.timeout_policies"), "3920ns")
|
||||||
self.assertEqual(config.get("cache.timeout_reputation"), "298382us")
|
self.assertEqual(config.get("cache.timeout_reputation"), "298382us")
|
||||||
|
|
||||||
def test_get_keys(self):
|
|
||||||
"""Test get_keys"""
|
|
||||||
config = ConfigLoader()
|
|
||||||
config.set("foo.bar", "baz")
|
|
||||||
self.assertEqual(list(config.get_keys("foo")), ["bar"])
|
|
||||||
|
@ -12,7 +12,7 @@ from authentik.lib.config import CONFIG
|
|||||||
SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST"
|
SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST"
|
||||||
|
|
||||||
|
|
||||||
def all_subclasses[T](cls: T, sort=True) -> list[T] | set[T]:
|
def all_subclasses(cls, sort=True):
|
||||||
"""Recursively return all subclassess of cls"""
|
"""Recursively return all subclassess of cls"""
|
||||||
classes = set(cls.__subclasses__()).union(
|
classes = set(cls.__subclasses__()).union(
|
||||||
[s for c in cls.__subclasses__() for s in all_subclasses(c, sort=sort)]
|
[s for c in cls.__subclasses__() for s in all_subclasses(c, sort=sort)]
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
"""Serializer validators"""
|
"""Serializer validators"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from django.core.validators import URLValidator
|
||||||
|
from django.utils.regex_helper import _lazy_re_compile
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework.exceptions import ValidationError
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
@ -29,3 +33,48 @@ class RequiredTogetherValidator:
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.__class__.__name__}(fields={smart_repr(self.fields)})>"
|
return f"<{self.__class__.__name__}(fields={smart_repr(self.fields)})>"
|
||||||
|
|
||||||
|
|
||||||
|
class DomainlessURLValidator(URLValidator):
|
||||||
|
"""Subclass of URLValidator which doesn't check the domain
|
||||||
|
(to allow hostnames without domain)"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.host_re = "(" + self.hostname_re + self.domain_re + "|localhost)"
|
||||||
|
self.regex = _lazy_re_compile(
|
||||||
|
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
|
||||||
|
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
|
||||||
|
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
|
||||||
|
r"(?::\d{2,5})?" # port
|
||||||
|
r"(?:[/?#][^\s]*)?" # resource path
|
||||||
|
r"\Z",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
self.schemes = ["http", "https", "blank"] + list(self.schemes)
|
||||||
|
|
||||||
|
def __call__(self, value: str):
|
||||||
|
# Check if the scheme is valid.
|
||||||
|
scheme = value.split("://")[0].lower()
|
||||||
|
if scheme not in self.schemes:
|
||||||
|
value = "default" + value
|
||||||
|
super().__call__(value)
|
||||||
|
|
||||||
|
|
||||||
|
class DomainlessFormattedURLValidator(DomainlessURLValidator):
|
||||||
|
"""URL validator which allows for python format strings"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.formatter_re = r"([%\(\)a-zA-Z])*"
|
||||||
|
self.host_re = "(" + self.formatter_re + self.hostname_re + self.domain_re + "|localhost)"
|
||||||
|
self.regex = _lazy_re_compile(
|
||||||
|
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
|
||||||
|
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
|
||||||
|
r"(?:" + self.ipv4_re + "|" + self.ipv6_re + "|" + self.host_re + ")"
|
||||||
|
r"(?::\d{2,5})?" # port
|
||||||
|
r"(?:[/?#][^\s]*)?" # resource path
|
||||||
|
r"\Z",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
self.schemes = ["http", "https", "blank"] + list(self.schemes)
|
||||||
|
@ -6,7 +6,7 @@ from django_filters.filters import ModelMultipleChoiceFilter
|
|||||||
from django_filters.filterset import FilterSet
|
from django_filters.filterset import FilterSet
|
||||||
from drf_spectacular.utils import extend_schema
|
from drf_spectacular.utils import extend_schema
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import BooleanField, CharField, DateTimeField, SerializerMethodField
|
from rest_framework.fields import BooleanField, CharField, DateTimeField
|
||||||
from rest_framework.relations import PrimaryKeyRelatedField
|
from rest_framework.relations import PrimaryKeyRelatedField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
@ -18,7 +18,6 @@ from authentik.core.api.providers import ProviderSerializer
|
|||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||||
from authentik.core.models import Provider
|
from authentik.core.models import Provider
|
||||||
from authentik.enterprise.license import LicenseKey
|
|
||||||
from authentik.enterprise.providers.rac.models import RACProvider
|
from authentik.enterprise.providers.rac.models import RACProvider
|
||||||
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
||||||
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
|
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
|
||||||
@ -118,12 +117,8 @@ class OutpostHealthSerializer(PassiveSerializer):
|
|||||||
uid = CharField(read_only=True)
|
uid = CharField(read_only=True)
|
||||||
last_seen = DateTimeField(read_only=True)
|
last_seen = DateTimeField(read_only=True)
|
||||||
version = CharField(read_only=True)
|
version = CharField(read_only=True)
|
||||||
golang_version = CharField(read_only=True)
|
|
||||||
openssl_enabled = BooleanField(read_only=True)
|
|
||||||
openssl_version = CharField(read_only=True)
|
|
||||||
fips_enabled = SerializerMethodField()
|
|
||||||
|
|
||||||
version_should = CharField(read_only=True)
|
version_should = CharField(read_only=True)
|
||||||
|
|
||||||
version_outdated = BooleanField(read_only=True)
|
version_outdated = BooleanField(read_only=True)
|
||||||
|
|
||||||
build_hash = CharField(read_only=True, required=False)
|
build_hash = CharField(read_only=True, required=False)
|
||||||
@ -131,12 +126,6 @@ class OutpostHealthSerializer(PassiveSerializer):
|
|||||||
|
|
||||||
hostname = CharField(read_only=True, required=False)
|
hostname = CharField(read_only=True, required=False)
|
||||||
|
|
||||||
def get_fips_enabled(self, obj: dict) -> bool | None:
|
|
||||||
"""Get FIPS enabled"""
|
|
||||||
if not LicenseKey.get_total().is_valid():
|
|
||||||
return None
|
|
||||||
return obj["fips_enabled"]
|
|
||||||
|
|
||||||
|
|
||||||
class OutpostFilter(FilterSet):
|
class OutpostFilter(FilterSet):
|
||||||
"""Filter for Outposts"""
|
"""Filter for Outposts"""
|
||||||
@ -184,10 +173,6 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
|
|||||||
"version_should": state.version_should,
|
"version_should": state.version_should,
|
||||||
"version_outdated": state.version_outdated,
|
"version_outdated": state.version_outdated,
|
||||||
"build_hash": state.build_hash,
|
"build_hash": state.build_hash,
|
||||||
"golang_version": state.golang_version,
|
|
||||||
"openssl_enabled": state.openssl_enabled,
|
|
||||||
"openssl_version": state.openssl_version,
|
|
||||||
"fips_enabled": state.fips_enabled,
|
|
||||||
"hostname": state.hostname,
|
"hostname": state.hostname,
|
||||||
"build_hash_should": get_build_hash(),
|
"build_hash_should": get_build_hash(),
|
||||||
}
|
}
|
||||||
|
@ -15,12 +15,9 @@ from rest_framework.response import Response
|
|||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.object_types import TypesMixin
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import (
|
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||||
MetaNameSerializer,
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
PassiveSerializer,
|
|
||||||
)
|
|
||||||
from authentik.outposts.models import (
|
from authentik.outposts.models import (
|
||||||
DockerServiceConnection,
|
DockerServiceConnection,
|
||||||
KubernetesServiceConnection,
|
KubernetesServiceConnection,
|
||||||
@ -60,7 +57,6 @@ class ServiceConnectionStateSerializer(PassiveSerializer):
|
|||||||
|
|
||||||
|
|
||||||
class ServiceConnectionViewSet(
|
class ServiceConnectionViewSet(
|
||||||
TypesMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
mixins.RetrieveModelMixin,
|
||||||
mixins.DestroyModelMixin,
|
mixins.DestroyModelMixin,
|
||||||
UsedByMixin,
|
UsedByMixin,
|
||||||
@ -74,6 +70,23 @@ class ServiceConnectionViewSet(
|
|||||||
search_fields = ["name"]
|
search_fields = ["name"]
|
||||||
filterset_fields = ["name"]
|
filterset_fields = ["name"]
|
||||||
|
|
||||||
|
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||||
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
|
def types(self, request: Request) -> Response:
|
||||||
|
"""Get all creatable service connection types"""
|
||||||
|
data = []
|
||||||
|
for subclass in all_subclasses(self.queryset.model):
|
||||||
|
subclass: OutpostServiceConnection
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": subclass._meta.verbose_name,
|
||||||
|
"description": subclass.__doc__,
|
||||||
|
"component": subclass().component,
|
||||||
|
"model_name": subclass._meta.model_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
|
||||||
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
|
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
|
||||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||||
def state(self, request: Request, pk: str) -> Response:
|
def state(self, request: Request, pk: str) -> Response:
|
||||||
|
@ -121,10 +121,6 @@ class OutpostConsumer(JsonWebsocketConsumer):
|
|||||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||||
state.version = msg.args.pop("version", None)
|
state.version = msg.args.pop("version", None)
|
||||||
state.build_hash = msg.args.pop("buildHash", "")
|
state.build_hash = msg.args.pop("buildHash", "")
|
||||||
state.golang_version = msg.args.pop("golangVersion", "")
|
|
||||||
state.openssl_enabled = msg.args.pop("opensslEnabled", False)
|
|
||||||
state.openssl_version = msg.args.pop("opensslVersion", "")
|
|
||||||
state.fips_enabled = msg.args.pop("fipsEnabled", False)
|
|
||||||
state.args.update(msg.args)
|
state.args.update(msg.args)
|
||||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||||
return
|
return
|
||||||
|
@ -124,6 +124,7 @@ class KubernetesObjectReconciler(Generic[T]):
|
|||||||
self.update(current, reference)
|
self.update(current, reference)
|
||||||
self.logger.debug("Updating")
|
self.logger.debug("Updating")
|
||||||
except (OpenApiException, HTTPError) as exc:
|
except (OpenApiException, HTTPError) as exc:
|
||||||
|
|
||||||
if isinstance(exc, ApiException) and exc.status == 422: # noqa: PLR2004
|
if isinstance(exc, ApiException) and exc.status == 422: # noqa: PLR2004
|
||||||
self.logger.debug("Failed to update current, triggering re-create")
|
self.logger.debug("Failed to update current, triggering re-create")
|
||||||
self._recreate(current=current, reference=reference)
|
self._recreate(current=current, reference=reference)
|
||||||
|
18
authentik/outposts/migrations/0022_outpost_deleted_at.py
Normal file
18
authentik/outposts/migrations/0022_outpost_deleted_at.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Generated by Django 5.0.4 on 2024-04-23 21:00
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_outposts", "0021_alter_outpost_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="outpost",
|
||||||
|
name="deleted_at",
|
||||||
|
field=models.DateTimeField(blank=True, null=True),
|
||||||
|
),
|
||||||
|
]
|
@ -33,7 +33,7 @@ from authentik.core.models import (
|
|||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.models import InheritanceForeignKey, SerializerModel
|
from authentik.lib.models import InheritanceForeignKey, SerializerModel, SoftDeleteModel
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
from authentik.lib.utils.errors import exception_to_string
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
from authentik.outposts.controllers.k8s.utils import get_namespace
|
from authentik.outposts.controllers.k8s.utils import get_namespace
|
||||||
@ -131,7 +131,7 @@ class OutpostServiceConnection(models.Model):
|
|||||||
verbose_name = _("Outpost Service-Connection")
|
verbose_name = _("Outpost Service-Connection")
|
||||||
verbose_name_plural = _("Outpost Service-Connections")
|
verbose_name_plural = _("Outpost Service-Connections")
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self):
|
||||||
return f"Outpost service connection {self.name}"
|
return f"Outpost service connection {self.name}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -241,7 +241,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection):
|
|||||||
return "ak-service-connection-kubernetes-form"
|
return "ak-service-connection-kubernetes-form"
|
||||||
|
|
||||||
|
|
||||||
class Outpost(SerializerModel, ManagedModel):
|
class Outpost(SoftDeleteModel, SerializerModel, ManagedModel):
|
||||||
"""Outpost instance which manages a service user and token"""
|
"""Outpost instance which manages a service user and token"""
|
||||||
|
|
||||||
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
||||||
@ -434,10 +434,6 @@ class OutpostState:
|
|||||||
version: str | None = field(default=None)
|
version: str | None = field(default=None)
|
||||||
version_should: Version = field(default=OUR_VERSION)
|
version_should: Version = field(default=OUR_VERSION)
|
||||||
build_hash: str = field(default="")
|
build_hash: str = field(default="")
|
||||||
golang_version: str = field(default="")
|
|
||||||
openssl_enabled: bool = field(default=False)
|
|
||||||
openssl_version: str = field(default="")
|
|
||||||
fips_enabled: bool = field(default=False)
|
|
||||||
hostname: str = field(default="")
|
hostname: str = field(default="")
|
||||||
args: dict = field(default_factory=dict)
|
args: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
|
from django.db.models.signals import m2m_changed, post_save, pre_save
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.brands.models import Brand
|
from authentik.brands.models import Brand
|
||||||
from authentik.core.models import Provider
|
from authentik.core.models import Provider
|
||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
|
from authentik.lib.models import post_soft_delete
|
||||||
from authentik.lib.utils.reflection import class_to_path
|
from authentik.lib.utils.reflection import class_to_path
|
||||||
from authentik.outposts.models import Outpost, OutpostServiceConnection
|
from authentik.outposts.models import Outpost, OutpostServiceConnection
|
||||||
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
|
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
|
||||||
@ -67,9 +68,7 @@ def post_save_update(sender, instance: Model, created: bool, **_):
|
|||||||
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
|
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
|
||||||
|
|
||||||
|
|
||||||
@receiver(pre_delete, sender=Outpost)
|
@receiver(post_soft_delete, sender=Outpost)
|
||||||
def pre_delete_cleanup(sender, instance: Outpost, **_):
|
def outpost_cleanup(sender, instance: Outpost, **_):
|
||||||
"""Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
|
"""Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
|
||||||
instance.user.delete()
|
outpost_controller.delay(instance.pk.hex, action="down")
|
||||||
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
|
|
||||||
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
|
|
||||||
|
@ -129,17 +129,14 @@ def outpost_controller_all():
|
|||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||||
def outpost_controller(
|
def outpost_controller(self: SystemTask, outpost_pk: str, action: str = "up"):
|
||||||
self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
|
|
||||||
):
|
|
||||||
"""Create/update/monitor/delete the deployment of an Outpost"""
|
"""Create/update/monitor/delete the deployment of an Outpost"""
|
||||||
logs = []
|
logs = []
|
||||||
if from_cache:
|
outpost: Outpost = None
|
||||||
outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
if action == "up":
|
||||||
LOGGER.debug("Getting outpost from cache to delete")
|
outpost = Outpost.objects.filter(pk=outpost_pk).first()
|
||||||
else:
|
elif action == "down":
|
||||||
outpost: Outpost = Outpost.objects.filter(pk=outpost_pk).first()
|
outpost = Outpost.deleted.filter(pk=outpost_pk).first()
|
||||||
LOGGER.debug("Getting outpost from DB")
|
|
||||||
if not outpost:
|
if not outpost:
|
||||||
LOGGER.warning("No outpost")
|
LOGGER.warning("No outpost")
|
||||||
return
|
return
|
||||||
@ -155,9 +152,10 @@ def outpost_controller(
|
|||||||
except (ControllerException, ServiceConnectionInvalid) as exc:
|
except (ControllerException, ServiceConnectionInvalid) as exc:
|
||||||
self.set_error(exc)
|
self.set_error(exc)
|
||||||
else:
|
else:
|
||||||
if from_cache:
|
|
||||||
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
|
||||||
self.set_status(TaskStatus.SUCCESSFUL, *logs)
|
self.set_status(TaskStatus.SUCCESSFUL, *logs)
|
||||||
|
finally:
|
||||||
|
if outpost.deleted_at:
|
||||||
|
outpost.force_delete()
|
||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||||
|
@ -13,13 +13,10 @@ from rest_framework.viewsets import GenericViewSet
|
|||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.api.applications import user_app_cache_key
|
from authentik.core.api.applications import user_app_cache_key
|
||||||
from authentik.core.api.object_types import TypesMixin
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import (
|
from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer
|
||||||
CacheSerializer,
|
|
||||||
MetaNameSerializer,
|
|
||||||
)
|
|
||||||
from authentik.events.logs import LogEventSerializer, capture_logs
|
from authentik.events.logs import LogEventSerializer, capture_logs
|
||||||
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer
|
from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer
|
||||||
from authentik.policies.models import Policy, PolicyBinding
|
from authentik.policies.models import Policy, PolicyBinding
|
||||||
from authentik.policies.process import PolicyProcess
|
from authentik.policies.process import PolicyProcess
|
||||||
@ -72,7 +69,6 @@ class PolicySerializer(ModelSerializer, MetaNameSerializer):
|
|||||||
|
|
||||||
|
|
||||||
class PolicyViewSet(
|
class PolicyViewSet(
|
||||||
TypesMixin,
|
|
||||||
mixins.RetrieveModelMixin,
|
mixins.RetrieveModelMixin,
|
||||||
mixins.DestroyModelMixin,
|
mixins.DestroyModelMixin,
|
||||||
UsedByMixin,
|
UsedByMixin,
|
||||||
@ -93,6 +89,23 @@ class PolicyViewSet(
|
|||||||
def get_queryset(self): # pragma: no cover
|
def get_queryset(self): # pragma: no cover
|
||||||
return Policy.objects.select_subclasses().prefetch_related("bindings", "promptstage_set")
|
return Policy.objects.select_subclasses().prefetch_related("bindings", "promptstage_set")
|
||||||
|
|
||||||
|
@extend_schema(responses={200: TypeCreateSerializer(many=True)})
|
||||||
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
|
def types(self, request: Request) -> Response:
|
||||||
|
"""Get all creatable policy types"""
|
||||||
|
data = []
|
||||||
|
for subclass in all_subclasses(self.queryset.model):
|
||||||
|
subclass: Policy
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": subclass._meta.verbose_name,
|
||||||
|
"description": subclass.__doc__,
|
||||||
|
"component": subclass().component,
|
||||||
|
"model_name": subclass._meta.model_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
|
||||||
@permission_required(None, ["authentik_policies.view_policy_cache"])
|
@permission_required(None, ["authentik_policies.view_policy_cache"])
|
||||||
@extend_schema(responses={200: CacheSerializer(many=False)})
|
@extend_schema(responses={200: CacheSerializer(many=False)})
|
||||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
|
@ -102,7 +102,7 @@ class EventMatcherPolicy(Policy):
|
|||||||
result = checker(request, event)
|
result = checker(request, event)
|
||||||
if result is None:
|
if result is None:
|
||||||
continue
|
continue
|
||||||
LOGGER.debug(
|
LOGGER.info(
|
||||||
"Event matcher check result",
|
"Event matcher check result",
|
||||||
checker=checker.__name__,
|
checker=checker.__name__,
|
||||||
result=result,
|
result=result,
|
||||||
|
@ -96,42 +96,16 @@ class TestEvaluator(TestCase):
|
|||||||
execution_logging=True,
|
execution_logging=True,
|
||||||
expression="ak_message(request.http_request.path)\nreturn True",
|
expression="ak_message(request.http_request.path)\nreturn True",
|
||||||
)
|
)
|
||||||
expr2 = ExpressionPolicy.objects.create(
|
tmpl = f"""
|
||||||
name=generate_id(),
|
ak_message(request.http_request.path)
|
||||||
execution_logging=True,
|
res = ak_call_policy('{expr.name}')
|
||||||
expression=f"""
|
ak_message(request.http_request.path)
|
||||||
ak_message(request.http_request.path)
|
for msg in res.messages:
|
||||||
res = ak_call_policy('{expr.name}')
|
ak_message(msg)
|
||||||
ak_message(request.http_request.path)
|
"""
|
||||||
for msg in res.messages:
|
evaluator = PolicyEvaluator("test")
|
||||||
ak_message(msg)
|
evaluator.set_policy_request(self.request)
|
||||||
""",
|
res = evaluator.evaluate(tmpl)
|
||||||
)
|
|
||||||
proc = PolicyProcess(PolicyBinding(policy=expr2), request=self.request, connection=None)
|
|
||||||
res = proc.profiling_wrapper()
|
|
||||||
self.assertEqual(res.messages, ("/", "/", "/"))
|
|
||||||
|
|
||||||
def test_call_policy_test_like(self):
|
|
||||||
"""test ak_call_policy without `obj` set, as if it was when testing policies"""
|
|
||||||
expr = ExpressionPolicy.objects.create(
|
|
||||||
name=generate_id(),
|
|
||||||
execution_logging=True,
|
|
||||||
expression="ak_message(request.http_request.path)\nreturn True",
|
|
||||||
)
|
|
||||||
expr2 = ExpressionPolicy.objects.create(
|
|
||||||
name=generate_id(),
|
|
||||||
execution_logging=True,
|
|
||||||
expression=f"""
|
|
||||||
ak_message(request.http_request.path)
|
|
||||||
res = ak_call_policy('{expr.name}')
|
|
||||||
ak_message(request.http_request.path)
|
|
||||||
for msg in res.messages:
|
|
||||||
ak_message(msg)
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
self.request.obj = None
|
|
||||||
proc = PolicyProcess(PolicyBinding(policy=expr2), request=self.request, connection=None)
|
|
||||||
res = proc.profiling_wrapper()
|
|
||||||
self.assertEqual(res.messages, ("/", "/", "/"))
|
self.assertEqual(res.messages, ("/", "/", "/"))
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,8 +128,8 @@ class PolicyProcess(PROCESS_CLASS):
|
|||||||
binding_order=self.binding.order,
|
binding_order=self.binding.order,
|
||||||
binding_target_type=self.binding.target_type,
|
binding_target_type=self.binding.target_type,
|
||||||
binding_target_name=self.binding.target_name,
|
binding_target_name=self.binding.target_name,
|
||||||
object_pk=str(self.request.obj.pk) if self.request.obj else "",
|
object_pk=str(self.request.obj.pk),
|
||||||
object_type=class_to_path(self.request.obj.__class__) if self.request.obj else "",
|
object_type=class_to_path(self.request.obj.__class__),
|
||||||
mode="execute_process",
|
mode="execute_process",
|
||||||
).time(),
|
).time(),
|
||||||
):
|
):
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from authentik.blueprints.apps import ManagedAppConfig
|
from authentik.blueprints.apps import ManagedAppConfig
|
||||||
|
|
||||||
|
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
|
||||||
|
|
||||||
|
|
||||||
class AuthentikPolicyReputationConfig(ManagedAppConfig):
|
class AuthentikPolicyReputationConfig(ManagedAppConfig):
|
||||||
"""Authentik reputation app config"""
|
"""Authentik reputation app config"""
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
# Generated by Django 5.0.6 on 2024-06-11 08:50
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("authentik_policies_reputation", "0006_reputation_ip_asn_data"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddIndex(
|
|
||||||
model_name="reputation",
|
|
||||||
index=models.Index(fields=["identifier"], name="authentik_p_identif_9434d7_idx"),
|
|
||||||
),
|
|
||||||
migrations.AddIndex(
|
|
||||||
model_name="reputation",
|
|
||||||
index=models.Index(fields=["ip"], name="authentik_p_ip_7ad0df_idx"),
|
|
||||||
),
|
|
||||||
migrations.AddIndex(
|
|
||||||
model_name="reputation",
|
|
||||||
index=models.Index(fields=["ip", "identifier"], name="authentik_p_ip_d779aa_idx"),
|
|
||||||
),
|
|
||||||
]
|
|
@ -96,8 +96,3 @@ class Reputation(ExpiringModel, SerializerModel):
|
|||||||
verbose_name = _("Reputation Score")
|
verbose_name = _("Reputation Score")
|
||||||
verbose_name_plural = _("Reputation Scores")
|
verbose_name_plural = _("Reputation Scores")
|
||||||
unique_together = ("identifier", "ip")
|
unique_together = ("identifier", "ip")
|
||||||
indexes = [
|
|
||||||
models.Index(fields=["identifier"]),
|
|
||||||
models.Index(fields=["ip"]),
|
|
||||||
models.Index(fields=["ip", "identifier"]),
|
|
||||||
]
|
|
||||||
|
11
authentik/policies/reputation/settings.py
Normal file
11
authentik/policies/reputation/settings.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
"""Reputation Settings"""
|
||||||
|
|
||||||
|
from celery.schedules import crontab
|
||||||
|
|
||||||
|
CELERY_BEAT_SCHEDULE = {
|
||||||
|
"policies_reputation_save": {
|
||||||
|
"task": "authentik.policies.reputation.tasks.save_reputation",
|
||||||
|
"schedule": crontab(minute="1-59/5"),
|
||||||
|
"options": {"queue": "authentik_scheduled"},
|
||||||
|
},
|
||||||
|
}
|
@ -1,35 +1,40 @@
|
|||||||
"""authentik reputation request signals"""
|
"""authentik reputation request signals"""
|
||||||
|
|
||||||
from django.contrib.auth.signals import user_logged_in
|
from django.contrib.auth.signals import user_logged_in
|
||||||
|
from django.core.cache import cache
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.signals import login_failed
|
from authentik.core.signals import login_failed
|
||||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
|
||||||
from authentik.policies.reputation.models import Reputation, reputation_expiry
|
from authentik.policies.reputation.tasks import save_reputation
|
||||||
from authentik.root.middleware import ClientIPMiddleware
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
from authentik.stages.identification.signals import identification_failed
|
from authentik.stages.identification.signals import identification_failed
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_reputation")
|
||||||
|
|
||||||
|
|
||||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||||
"""Update score for IP and User"""
|
"""Update score for IP and User"""
|
||||||
remote_ip = ClientIPMiddleware.get_client_ip(request)
|
remote_ip = ClientIPMiddleware.get_client_ip(request)
|
||||||
|
|
||||||
Reputation.objects.update_or_create(
|
try:
|
||||||
ip=remote_ip,
|
# We only update the cache here, as its faster than writing to the DB
|
||||||
identifier=identifier,
|
score = cache.get_or_set(
|
||||||
defaults={
|
CACHE_KEY_PREFIX + remote_ip + "/" + identifier,
|
||||||
"score": amount,
|
{"ip": remote_ip, "identifier": identifier, "score": 0},
|
||||||
"ip_geo_data": GEOIP_CONTEXT_PROCESSOR.city_dict(remote_ip) or {},
|
CACHE_TIMEOUT,
|
||||||
"ip_asn_data": ASN_CONTEXT_PROCESSOR.asn_dict(remote_ip) or {},
|
)
|
||||||
"expires": reputation_expiry(),
|
score["score"] += amount
|
||||||
},
|
cache.set(CACHE_KEY_PREFIX + remote_ip + "/" + identifier, score)
|
||||||
)
|
except ValueError as exc:
|
||||||
|
LOGGER.warning("failed to set reputation", exc=exc)
|
||||||
|
|
||||||
LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
|
LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
|
||||||
|
save_reputation.delay()
|
||||||
|
|
||||||
|
|
||||||
@receiver(login_failed)
|
@receiver(login_failed)
|
||||||
|
32
authentik/policies/reputation/tasks.py
Normal file
32
authentik/policies/reputation/tasks.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""Reputation tasks"""
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||||
|
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||||
|
from authentik.events.models import TaskStatus
|
||||||
|
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||||
|
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
|
||||||
|
from authentik.policies.reputation.models import Reputation
|
||||||
|
from authentik.root.celery import CELERY_APP
|
||||||
|
|
||||||
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||||
|
@prefill_task
|
||||||
|
def save_reputation(self: SystemTask):
|
||||||
|
"""Save currently cached reputation to database"""
|
||||||
|
objects_to_update = []
|
||||||
|
for _, score in cache.get_many(cache.keys(CACHE_KEY_PREFIX + "*")).items():
|
||||||
|
rep, _ = Reputation.objects.get_or_create(
|
||||||
|
ip=score["ip"],
|
||||||
|
identifier=score["identifier"],
|
||||||
|
)
|
||||||
|
rep.ip_geo_data = GEOIP_CONTEXT_PROCESSOR.city_dict(score["ip"]) or {}
|
||||||
|
rep.ip_asn_data = ASN_CONTEXT_PROCESSOR.asn_dict(score["ip"]) or {}
|
||||||
|
rep.score = score["score"]
|
||||||
|
objects_to_update.append(rep)
|
||||||
|
Reputation.objects.bulk_update(objects_to_update, ["score", "ip_geo_data"])
|
||||||
|
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated Reputation")
|
@ -1,11 +1,14 @@
|
|||||||
"""test reputation signals and policy"""
|
"""test reputation signals and policy"""
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory, TestCase
|
||||||
|
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.policies.reputation.api import ReputationPolicySerializer
|
from authentik.policies.reputation.api import ReputationPolicySerializer
|
||||||
|
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
|
||||||
from authentik.policies.reputation.models import Reputation, ReputationPolicy
|
from authentik.policies.reputation.models import Reputation, ReputationPolicy
|
||||||
|
from authentik.policies.reputation.tasks import save_reputation
|
||||||
from authentik.policies.types import PolicyRequest
|
from authentik.policies.types import PolicyRequest
|
||||||
from authentik.stages.password import BACKEND_INBUILT
|
from authentik.stages.password import BACKEND_INBUILT
|
||||||
from authentik.stages.password.stage import authenticate
|
from authentik.stages.password.stage import authenticate
|
||||||
@ -19,6 +22,8 @@ class TestReputationPolicy(TestCase):
|
|||||||
self.request = self.request_factory.get("/")
|
self.request = self.request_factory.get("/")
|
||||||
self.test_ip = "127.0.0.1"
|
self.test_ip = "127.0.0.1"
|
||||||
self.test_username = "test"
|
self.test_username = "test"
|
||||||
|
keys = cache.keys(CACHE_KEY_PREFIX + "*")
|
||||||
|
cache.delete_many(keys)
|
||||||
# We need a user for the one-to-one in userreputation
|
# We need a user for the one-to-one in userreputation
|
||||||
self.user = User.objects.create(username=self.test_username)
|
self.user = User.objects.create(username=self.test_username)
|
||||||
self.backends = [BACKEND_INBUILT]
|
self.backends = [BACKEND_INBUILT]
|
||||||
@ -29,6 +34,13 @@ class TestReputationPolicy(TestCase):
|
|||||||
authenticate(
|
authenticate(
|
||||||
self.request, self.backends, username=self.test_username, password=self.test_username
|
self.request, self.backends, username=self.test_username, password=self.test_username
|
||||||
)
|
)
|
||||||
|
# Test value in cache
|
||||||
|
self.assertEqual(
|
||||||
|
cache.get(CACHE_KEY_PREFIX + self.test_ip + "/" + self.test_username),
|
||||||
|
{"ip": "127.0.0.1", "identifier": "test", "score": -1},
|
||||||
|
)
|
||||||
|
# Save cache and check db values
|
||||||
|
save_reputation.delay().get()
|
||||||
self.assertEqual(Reputation.objects.get(ip=self.test_ip).score, -1)
|
self.assertEqual(Reputation.objects.get(ip=self.test_ip).score, -1)
|
||||||
|
|
||||||
def test_user_reputation(self):
|
def test_user_reputation(self):
|
||||||
@ -37,6 +49,13 @@ class TestReputationPolicy(TestCase):
|
|||||||
authenticate(
|
authenticate(
|
||||||
self.request, self.backends, username=self.test_username, password=self.test_username
|
self.request, self.backends, username=self.test_username, password=self.test_username
|
||||||
)
|
)
|
||||||
|
# Test value in cache
|
||||||
|
self.assertEqual(
|
||||||
|
cache.get(CACHE_KEY_PREFIX + self.test_ip + "/" + self.test_username),
|
||||||
|
{"ip": "127.0.0.1", "identifier": "test", "score": -1},
|
||||||
|
)
|
||||||
|
# Save cache and check db values
|
||||||
|
save_reputation.delay().get()
|
||||||
self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, -1)
|
self.assertEqual(Reputation.objects.get(identifier=self.test_username).score, -1)
|
||||||
|
|
||||||
def test_policy(self):
|
def test_policy(self):
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.templatetags.static import static
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
|
|
||||||
@ -91,10 +90,6 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
|||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-provider-ldap-form"
|
return "ak-provider-ldap-form"
|
||||||
|
|
||||||
@property
|
|
||||||
def icon_url(self) -> str | None:
|
|
||||||
return static("authentik/sources/ldap.png")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serializer(self) -> type[Serializer]:
|
def serializer(self) -> type[Serializer]:
|
||||||
from authentik.providers.ldap.api import LDAPProviderSerializer
|
from authentik.providers.ldap.api import LDAPProviderSerializer
|
||||||
|
@ -8,7 +8,7 @@ from rest_framework.fields import CharField
|
|||||||
from rest_framework.serializers import ValidationError
|
from rest_framework.serializers import ValidationError
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.property_mappings import PropertyMappingSerializer
|
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.providers.oauth2.models import ScopeMapping
|
from authentik.providers.oauth2.models import ScopeMapping
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
|
|||||||
from dacite.core import from_dict
|
from dacite.core import from_dict
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.templatetags.static import static
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from jwt import encode
|
from jwt import encode
|
||||||
@ -263,10 +262,6 @@ class OAuth2Provider(Provider):
|
|||||||
LOGGER.warning("Failed to format launch url", exc=exc)
|
LOGGER.warning("Failed to format launch url", exc=exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
|
||||||
def icon_url(self) -> str | None:
|
|
||||||
return static("authentik/sources/openidconnect.svg")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
return "ak-provider-oauth2-form"
|
return "ak-provider-oauth2-form"
|
||||||
|
@ -15,6 +15,7 @@ from authentik.core.expression.exceptions import PropertyMappingExpressionExcept
|
|||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.flows.challenge import PermissionDict
|
from authentik.flows.challenge import PermissionDict
|
||||||
from authentik.providers.oauth2.constants import (
|
from authentik.providers.oauth2.constants import (
|
||||||
|
SCOPE_AUTHENTIK_API,
|
||||||
SCOPE_GITHUB_ORG_READ,
|
SCOPE_GITHUB_ORG_READ,
|
||||||
SCOPE_GITHUB_USER,
|
SCOPE_GITHUB_USER,
|
||||||
SCOPE_GITHUB_USER_EMAIL,
|
SCOPE_GITHUB_USER_EMAIL,
|
||||||
@ -56,6 +57,7 @@ class UserInfoView(View):
|
|||||||
SCOPE_GITHUB_USER_READ: _("GitHub Compatibility: Access your User Information"),
|
SCOPE_GITHUB_USER_READ: _("GitHub Compatibility: Access your User Information"),
|
||||||
SCOPE_GITHUB_USER_EMAIL: _("GitHub Compatibility: Access you Email addresses"),
|
SCOPE_GITHUB_USER_EMAIL: _("GitHub Compatibility: Access you Email addresses"),
|
||||||
SCOPE_GITHUB_ORG_READ: _("GitHub Compatibility: Access your Groups"),
|
SCOPE_GITHUB_ORG_READ: _("GitHub Compatibility: Access your Groups"),
|
||||||
|
SCOPE_AUTHENTIK_API: _("authentik API Access on behalf of your user"),
|
||||||
}
|
}
|
||||||
for scope in scopes:
|
for scope in scopes:
|
||||||
if scope in special_scope_map:
|
if scope in special_scope_map:
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user