Compare commits
3 Commits
version/20
...
docs-certs
Author | SHA1 | Date | |
---|---|---|---|
d660a392b9 | |||
f530ce5e02 | |||
d4012df59d |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2024.8.3
|
||||
current_version = 2024.6.3
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||
|
@ -29,9 +29,9 @@ outputs:
|
||||
imageTags:
|
||||
description: "Docker image tags"
|
||||
value: ${{ steps.ev.outputs.imageTags }}
|
||||
attestImageNames:
|
||||
description: "Docker image names used for attestation"
|
||||
value: ${{ steps.ev.outputs.attestImageNames }}
|
||||
imageNames:
|
||||
description: "Docker image names"
|
||||
value: ${{ steps.ev.outputs.imageNames }}
|
||||
imageMainTag:
|
||||
description: "Docker image main tag"
|
||||
value: ${{ steps.ev.outputs.imageMainTag }}
|
||||
|
@ -51,24 +51,15 @@ else:
|
||||
]
|
||||
|
||||
image_main_tag = image_tags[0].split(":")[-1]
|
||||
|
||||
|
||||
def get_attest_image_names(image_with_tags: list[str]):
|
||||
"""Attestation only for GHCR"""
|
||||
image_tags = []
|
||||
for image_name in set(name.split(":")[0] for name in image_with_tags):
|
||||
if not image_name.startswith("ghcr.io"):
|
||||
continue
|
||||
image_tags.append(image_name)
|
||||
return ",".join(set(image_tags))
|
||||
|
||||
image_tags_rendered = ",".join(image_tags)
|
||||
image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags))
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
|
||||
print(f"shouldBuild={should_build}", file=_output)
|
||||
print(f"sha={sha}", file=_output)
|
||||
print(f"version={version}", file=_output)
|
||||
print(f"prerelease={prerelease}", file=_output)
|
||||
print(f"imageTags={','.join(image_tags)}", file=_output)
|
||||
print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output)
|
||||
print(f"imageTags={image_tags_rendered}", file=_output)
|
||||
print(f"imageNames={image_names_rendered}", file=_output)
|
||||
print(f"imageMainTag={image_main_tag}", file=_output)
|
||||
print(f"imageMainName={image_tags[0]}", file=_output)
|
||||
|
4
.github/dependabot.yml
vendored
4
.github/dependabot.yml
vendored
@ -58,10 +58,6 @@ updates:
|
||||
patterns:
|
||||
- "@rollup/*"
|
||||
- "rollup-*"
|
||||
swc:
|
||||
patterns:
|
||||
- "@swc/*"
|
||||
- "swc-*"
|
||||
wdio:
|
||||
patterns:
|
||||
- "@wdio/*"
|
||||
|
2
.github/workflows/ci-main.yml
vendored
2
.github/workflows/ci-main.yml
vendored
@ -261,7 +261,7 @@ jobs:
|
||||
id: attest
|
||||
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
pr-comment:
|
||||
|
4
.github/workflows/ci-outpost.yml
vendored
4
.github/workflows/ci-outpost.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: latest
|
||||
version: v1.54.2
|
||||
args: --timeout 5000s --verbose
|
||||
skip-cache: true
|
||||
test-unittest:
|
||||
@ -115,7 +115,7 @@ jobs:
|
||||
id: attest
|
||||
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
build-binary:
|
||||
|
2
.github/workflows/ci-web.yml
vendored
2
.github/workflows/ci-web.yml
vendored
@ -92,4 +92,4 @@ jobs:
|
||||
run: make gen-client-ts
|
||||
- name: test
|
||||
working-directory: web/
|
||||
run: npm run test || exit 0
|
||||
run: npm run test
|
||||
|
8
.github/workflows/release-publish.yml
vendored
8
.github/workflows/release-publish.yml
vendored
@ -51,14 +51,12 @@ jobs:
|
||||
secrets: |
|
||||
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
|
||||
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
|
||||
build-args: |
|
||||
VERSION=${{ github.ref }}
|
||||
tags: ${{ steps.ev.outputs.imageTags }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- uses: actions/attest-build-provenance@v1
|
||||
id: attest
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
build-outpost:
|
||||
@ -113,8 +111,6 @@ jobs:
|
||||
id: push
|
||||
with:
|
||||
push: true
|
||||
build-args: |
|
||||
VERSION=${{ github.ref }}
|
||||
tags: ${{ steps.ev.outputs.imageTags }}
|
||||
file: ${{ matrix.type }}.Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
@ -122,7 +118,7 @@ jobs:
|
||||
- uses: actions/attest-build-provenance@v1
|
||||
id: attest
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
build-outpost-binary:
|
||||
|
23
Dockerfile
23
Dockerfile
@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build website
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS website-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as website-builder
|
||||
|
||||
ENV NODE_ENV=production
|
||||
|
||||
@ -20,7 +20,7 @@ COPY ./SECURITY.md /work/
|
||||
RUN npm run build-bundled
|
||||
|
||||
# Stage 2: Build webui
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS web-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as web-builder
|
||||
|
||||
ARG GIT_BUILD_HASH
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
@ -43,7 +43,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
|
||||
RUN npm run build
|
||||
|
||||
# Stage 3: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.23-fips-bookworm AS go-builder
|
||||
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@ -80,7 +80,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
go build -o /go/authentik ./cmd/server
|
||||
|
||||
# Stage 4: MaxMind GeoIP
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 AS geoip
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
|
||||
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||
ENV GEOIPUPDATE_VERBOSE="1"
|
||||
@ -96,9 +96,6 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
# Stage 5: Python dependencies
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.5-slim-bookworm-fips-full AS python-deps
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
|
||||
WORKDIR /ak-root/poetry
|
||||
|
||||
ENV VENV_PATH="/ak-root/venv" \
|
||||
@ -126,15 +123,15 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
|
||||
# Stage 6: Run
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.5-slim-bookworm-fips-full AS final-image
|
||||
|
||||
ARG VERSION
|
||||
ARG GIT_BUILD_HASH
|
||||
ARG VERSION
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
|
||||
LABEL org.opencontainers.image.url=https://goauthentik.io
|
||||
LABEL org.opencontainers.image.description="goauthentik.io Main server image, see https://goauthentik.io for more info."
|
||||
LABEL org.opencontainers.image.source=https://github.com/goauthentik/authentik
|
||||
LABEL org.opencontainers.image.version=${VERSION}
|
||||
LABEL org.opencontainers.image.revision=${GIT_BUILD_HASH}
|
||||
LABEL org.opencontainers.image.url https://goauthentik.io
|
||||
LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info.
|
||||
LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
|
||||
LABEL org.opencontainers.image.version ${VERSION}
|
||||
LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
|
||||
|
||||
WORKDIR /
|
||||
|
||||
|
5
Makefile
5
Makefile
@ -43,7 +43,7 @@ help: ## Show this help
|
||||
sort
|
||||
@echo ""
|
||||
|
||||
go-test:
|
||||
test-go:
|
||||
go test -timeout 0 -v -race -cover ./...
|
||||
|
||||
test-docker: ## Run all tests in a docker-compose
|
||||
@ -210,9 +210,6 @@ web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting is
|
||||
web-install: ## Install the necessary libraries to build the Authentik UI
|
||||
cd web && npm ci
|
||||
|
||||
web-test: ## Run tests for the Authentik UI
|
||||
cd web && npm run test
|
||||
|
||||
web-watch: ## Build and watch the Authentik UI for changes, updating automatically
|
||||
rm -rf web/dist/
|
||||
mkdir web/dist/
|
||||
|
@ -15,9 +15,7 @@
|
||||
|
||||
## What is authentik?
|
||||
|
||||
authentik is an open-source Identity Provider that emphasizes flexibility and versatility, with support for a wide set of protocols.
|
||||
|
||||
Our [enterprise offer](https://goauthentik.io/pricing) can also be used as a self-hosted replacement for large-scale deployments of Okta/Auth0, Entra ID, Ping Identity, or other legacy IdPs for employees and B2B2C use.
|
||||
authentik is an open-source Identity Provider that emphasizes flexibility and versatility. It can be seamlessly integrated into existing environments to support new protocols. authentik is also a great solution for implementing sign-up, recovery, and other similar features in your application, saving you the hassle of dealing with them.
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from os import environ
|
||||
|
||||
__version__ = "2024.8.3"
|
||||
__version__ = "2024.6.3"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
@ -12,7 +12,6 @@ from rest_framework.views import APIView
|
||||
from authentik import __version__, get_build_hash
|
||||
from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.outposts.models import Outpost
|
||||
|
||||
|
||||
class VersionSerializer(PassiveSerializer):
|
||||
@ -23,7 +22,6 @@ class VersionSerializer(PassiveSerializer):
|
||||
version_latest_valid = SerializerMethodField()
|
||||
build_hash = SerializerMethodField()
|
||||
outdated = SerializerMethodField()
|
||||
outpost_outdated = SerializerMethodField()
|
||||
|
||||
def get_build_hash(self, _) -> str:
|
||||
"""Get build hash, if version is not latest or released"""
|
||||
@ -49,15 +47,6 @@ class VersionSerializer(PassiveSerializer):
|
||||
"""Check if we're running the latest version"""
|
||||
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
|
||||
|
||||
def get_outpost_outdated(self, _) -> bool:
|
||||
"""Check if any outpost is outdated/has a version mismatch"""
|
||||
any_outdated = False
|
||||
for outpost in Outpost.objects.all():
|
||||
for state in outpost.state:
|
||||
if state.version_outdated:
|
||||
any_outdated = True
|
||||
return any_outdated
|
||||
|
||||
|
||||
class VersionView(APIView):
|
||||
"""Get running and latest version."""
|
||||
|
@ -30,10 +30,8 @@ from authentik.core.api.utils import (
|
||||
PassiveSerializer,
|
||||
)
|
||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import Group, PropertyMapping, User
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.api.exec import PolicyTestSerializer
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
@ -164,15 +162,12 @@ class PropertyMappingViewSet(
|
||||
|
||||
response_data = {"successful": True, "result": ""}
|
||||
try:
|
||||
result = mapping.evaluate(dry_run=True, **context)
|
||||
result = mapping.evaluate(**context)
|
||||
response_data["result"] = dumps(
|
||||
sanitize_item(result), indent=(4 if format_result else None)
|
||||
)
|
||||
except PropertyMappingExpressionException as exc:
|
||||
response_data["result"] = exception_to_string(exc.exc)
|
||||
response_data["successful"] = False
|
||||
except Exception as exc:
|
||||
response_data["result"] = exception_to_string(exc)
|
||||
response_data["result"] = str(exc)
|
||||
response_data["successful"] = False
|
||||
response = PropertyMappingTestResultSerializer(response_data)
|
||||
return Response(response.data)
|
||||
|
@ -14,7 +14,6 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
|
||||
class DeleteAction(Enum):
|
||||
@ -54,7 +53,7 @@ class UsedByMixin:
|
||||
@extend_schema(
|
||||
responses={200: UsedBySerializer(many=True)},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def used_by(self, request: Request, *args, **kwargs) -> Response:
|
||||
"""Get a list of all objects that use this object"""
|
||||
model: Model = self.get_object()
|
||||
|
@ -678,10 +678,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
if not request.tenant.impersonation:
|
||||
LOGGER.debug("User attempted to impersonate", user=request.user)
|
||||
return Response(status=401)
|
||||
user_to_be = self.get_object()
|
||||
if not request.user.has_perm("impersonate", user_to_be):
|
||||
if not request.user.has_perm("impersonate"):
|
||||
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
|
||||
return Response(status=401)
|
||||
user_to_be = self.get_object()
|
||||
if user_to_be.pk == self.request.user.pk:
|
||||
LOGGER.debug("User attempted to impersonate themselves", user=request.user)
|
||||
return Response(status=401)
|
||||
|
@ -9,11 +9,10 @@ class Command(TenantCommand):
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("--type", type=str, required=True)
|
||||
parser.add_argument("--all", action="store_true", default=False)
|
||||
parser.add_argument("usernames", nargs="*", type=str)
|
||||
parser.add_argument("--all", action="store_true")
|
||||
parser.add_argument("usernames", nargs="+", type=str)
|
||||
|
||||
def handle_per_tenant(self, **options):
|
||||
print(options)
|
||||
new_type = UserTypes(options["type"])
|
||||
qs = (
|
||||
User.objects.exclude_anonymous()
|
||||
@ -23,9 +22,6 @@ class Command(TenantCommand):
|
||||
if options["usernames"] and options["all"]:
|
||||
self.stderr.write("--all and usernames specified, only one can be specified")
|
||||
return
|
||||
if not options["usernames"] and not options["all"]:
|
||||
self.stderr.write("--all or usernames must be specified")
|
||||
return
|
||||
if options["usernames"] and not options["all"]:
|
||||
qs = qs.filter(username__in=options["usernames"])
|
||||
updated = qs.update(type=new_type)
|
||||
|
@ -466,6 +466,8 @@ class ApplicationQuerySet(QuerySet):
|
||||
def with_provider(self) -> "QuerySet[Application]":
|
||||
qs = self.select_related("provider")
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
if LOOKUP_SEP in subclass:
|
||||
continue
|
||||
qs = qs.select_related(f"provider__{subclass}")
|
||||
return qs
|
||||
|
||||
@ -543,24 +545,15 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
if not self.provider:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
base_class = Provider
|
||||
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
|
||||
parent = self.provider
|
||||
for level in subclass.split(LOOKUP_SEP):
|
||||
try:
|
||||
parent = getattr(parent, level)
|
||||
except AttributeError:
|
||||
break
|
||||
if parent in candidates:
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
# We don't care about recursion, skip nested models
|
||||
if LOOKUP_SEP in subclass:
|
||||
continue
|
||||
idx = subclass.count(LOOKUP_SEP)
|
||||
if type(parent) is not base_class:
|
||||
idx += 1
|
||||
candidates.insert(idx, parent)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[-1]
|
||||
try:
|
||||
return getattr(self.provider, subclass)
|
||||
except AttributeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name)
|
||||
@ -908,7 +901,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
||||
except ControlFlowException as exc:
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
raise PropertyMappingExpressionException(exc, self) from exc
|
||||
raise PropertyMappingExpressionException(self, exc) from exc
|
||||
|
||||
def __str__(self):
|
||||
return f"Property Mapping {self.name}"
|
||||
|
@ -9,12 +9,9 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
from authentik.providers.saml.models import SAMLProvider
|
||||
|
||||
|
||||
class TestApplicationsAPI(APITestCase):
|
||||
@ -225,31 +222,3 @@ class TestApplicationsAPI(APITestCase):
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_get_provider(self):
|
||||
"""Ensure that proxy providers (at the time of writing that is the only provider
|
||||
that inherits from another proxy type (OAuth) instead of inheriting from the root
|
||||
provider class) is correctly looked up and selected from the database"""
|
||||
slug = generate_id()
|
||||
provider = ProxyProvider.objects.create(name=generate_id())
|
||||
Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=slug,
|
||||
provider=provider,
|
||||
)
|
||||
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
|
||||
self.assertEqual(
|
||||
Application.objects.with_provider().get(slug=slug).get_provider(), provider
|
||||
)
|
||||
|
||||
slug = generate_id()
|
||||
provider = SAMLProvider.objects.create(name=generate_id())
|
||||
Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=slug,
|
||||
provider=provider,
|
||||
)
|
||||
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
|
||||
self.assertEqual(
|
||||
Application.objects.with_provider().get(slug=slug).get_provider(), provider
|
||||
)
|
||||
|
@ -3,10 +3,10 @@
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
|
||||
from authentik.core.models import User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.other_user = create_test_user()
|
||||
self.other_user = User.objects.create(username="to-impersonate")
|
||||
self.user = create_test_admin_user()
|
||||
|
||||
def test_impersonate_simple(self):
|
||||
@ -44,26 +44,6 @@ class TestImpersonation(APITestCase):
|
||||
self.assertEqual(response_body["user"]["username"], self.user.username)
|
||||
self.assertNotIn("original", response_body)
|
||||
|
||||
def test_impersonate_scoped(self):
|
||||
"""Test impersonation with scoped permissions"""
|
||||
new_user = create_test_user()
|
||||
assign_perm("authentik_core.impersonate", new_user, self.other_user)
|
||||
assign_perm("authentik_core.view_user", new_user, self.other_user)
|
||||
self.client.force_login(new_user)
|
||||
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:user-impersonate",
|
||||
kwargs={"pk": self.other_user.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
|
||||
response = self.client.get(reverse("authentik_api:user-me"))
|
||||
response_body = loads(response.content.decode())
|
||||
self.assertEqual(response_body["user"]["username"], self.other_user.username)
|
||||
self.assertEqual(response_body["original"]["username"], new_user.username)
|
||||
|
||||
def test_impersonate_denied(self):
|
||||
"""test impersonation without permissions"""
|
||||
self.client.force_login(self.other_user)
|
||||
|
@ -35,7 +35,6 @@ from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.rbac.decorators import permission_required
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -266,7 +265,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
],
|
||||
responses={200: CertificateDataSerializer(many=False)},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def view_certificate(self, request: Request, pk: str) -> Response:
|
||||
"""Return certificate-key pairs certificate and log access"""
|
||||
certificate: CertificateKeyPair = self.get_object()
|
||||
@ -296,7 +295,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
],
|
||||
responses={200: CertificateDataSerializer(many=False)},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def view_private_key(self, request: Request, pk: str) -> Response:
|
||||
"""Return certificate-key pairs private key and log access"""
|
||||
certificate: CertificateKeyPair = self.get_object()
|
||||
|
@ -214,46 +214,6 @@ class TestCrypto(APITestCase):
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertIn("Content-Disposition", response)
|
||||
|
||||
def test_certificate_download_denied(self):
|
||||
"""Test certificate export (download)"""
|
||||
self.client.logout()
|
||||
keypair = create_test_cert()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-certificate",
|
||||
kwargs={"pk": keypair.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-certificate",
|
||||
kwargs={"pk": keypair.pk},
|
||||
),
|
||||
data={"download": True},
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
|
||||
def test_private_key_download_denied(self):
|
||||
"""Test private_key export (download)"""
|
||||
self.client.logout()
|
||||
keypair = create_test_cert()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-private-key",
|
||||
kwargs={"pk": keypair.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-private-key",
|
||||
kwargs={"pk": keypair.pk},
|
||||
),
|
||||
data={"download": True},
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
|
||||
def test_used_by(self):
|
||||
"""Test used_by endpoint"""
|
||||
self.client.force_login(create_test_admin_user())
|
||||
@ -286,26 +246,6 @@ class TestCrypto(APITestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_used_by_denied(self):
|
||||
"""Test used_by endpoint"""
|
||||
self.client.logout()
|
||||
keypair = create_test_cert()
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://localhost",
|
||||
signing_key=keypair,
|
||||
)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-used-by",
|
||||
kwargs={"pk": keypair.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
|
||||
def test_discovery(self):
|
||||
"""Test certificate discovery"""
|
||||
name = generate_id()
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Enterprise API Views"""
|
||||
|
||||
from dataclasses import asdict
|
||||
from datetime import timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, inline_serializer
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
@ -18,7 +19,7 @@ from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import User, UserTypes
|
||||
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
|
||||
from authentik.enterprise.models import License
|
||||
from authentik.enterprise.models import License, LicenseUsageStatus
|
||||
from authentik.rbac.decorators import permission_required
|
||||
from authentik.tenants.utils import get_unique_identifier
|
||||
|
||||
@ -29,7 +30,7 @@ class EnterpriseRequiredMixin:
|
||||
|
||||
def validate(self, attrs: dict) -> dict:
|
||||
"""Check that a valid license exists"""
|
||||
if not LicenseKey.cached_summary().status.is_valid:
|
||||
if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED:
|
||||
raise ValidationError(_("Enterprise is required to create/update this object."))
|
||||
return super().validate(attrs)
|
||||
|
||||
@ -86,7 +87,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
|
||||
},
|
||||
)
|
||||
@action(detail=False, methods=["GET"])
|
||||
def install_id(self, request: Request) -> Response:
|
||||
def get_install_id(self, request: Request) -> Response:
|
||||
"""Get install_id"""
|
||||
return Response(
|
||||
data={
|
||||
@ -99,22 +100,12 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
|
||||
responses={
|
||||
200: LicenseSummarySerializer(),
|
||||
},
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="cached",
|
||||
location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.BOOL,
|
||||
default=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
|
||||
def summary(self, request: Request) -> Response:
|
||||
"""Get the total license status"""
|
||||
summary = LicenseKey.cached_summary()
|
||||
if request.query_params.get("cached", "true").lower() == "false":
|
||||
summary = LicenseKey.get_total().summary()
|
||||
response = LicenseSummarySerializer(instance=summary)
|
||||
response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
|
||||
response.is_valid(raise_exception=True)
|
||||
return Response(response.data)
|
||||
|
||||
@permission_required(None, ["authentik_enterprise.view_license"])
|
||||
|
@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
|
||||
"""Actual enterprise check, cached"""
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
return LicenseKey.cached_summary().status.is_valid
|
||||
return LicenseKey.cached_summary().status
|
||||
|
@ -20,7 +20,6 @@ from rest_framework.fields import (
|
||||
ChoiceField,
|
||||
DateTimeField,
|
||||
IntegerField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
@ -56,7 +55,6 @@ class LicenseFlags(Enum):
|
||||
"""License flags"""
|
||||
|
||||
TRIAL = "trial"
|
||||
NON_PRODUCTION = "non_production"
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -67,7 +65,6 @@ class LicenseSummary:
|
||||
external_users: int
|
||||
status: LicenseUsageStatus
|
||||
latest_valid: datetime
|
||||
license_flags: list[LicenseFlags]
|
||||
|
||||
|
||||
class LicenseSummarySerializer(PassiveSerializer):
|
||||
@ -77,7 +74,6 @@ class LicenseSummarySerializer(PassiveSerializer):
|
||||
external_users = IntegerField(required=True)
|
||||
status = ChoiceField(choices=LicenseUsageStatus.choices)
|
||||
latest_valid = DateTimeField()
|
||||
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags)))
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -90,7 +86,7 @@ class LicenseKey:
|
||||
name: str
|
||||
internal_users: int = 0
|
||||
external_users: int = 0
|
||||
license_flags: list[LicenseFlags] = field(default_factory=list)
|
||||
flags: list[LicenseFlags] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
|
||||
@ -117,13 +113,10 @@ class LicenseKey:
|
||||
our_cert.public_key(),
|
||||
algorithms=["ES512"],
|
||||
audience=get_license_aud(),
|
||||
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
|
||||
options={"verify_exp": check_expiry},
|
||||
),
|
||||
)
|
||||
except PyJWTError:
|
||||
unverified = decode(jwt, options={"verify_signature": False})
|
||||
if unverified["aud"] != get_license_aud():
|
||||
raise ValidationError("Invalid Install ID in license") from None
|
||||
raise ValidationError("Unable to verify license") from None
|
||||
return body
|
||||
|
||||
@ -137,8 +130,9 @@ class LicenseKey:
|
||||
exp_ts = int(mktime(lic.expiry.timetuple()))
|
||||
if total.exp == 0:
|
||||
total.exp = exp_ts
|
||||
total.exp = max(total.exp, exp_ts)
|
||||
total.license_flags.extend(lic.status.license_flags)
|
||||
if exp_ts <= total.exp:
|
||||
total.exp = exp_ts
|
||||
total.flags.extend(lic.status.flags)
|
||||
return total
|
||||
|
||||
@staticmethod
|
||||
@ -222,7 +216,6 @@ class LicenseKey:
|
||||
internal_users=self.internal_users,
|
||||
external_users=self.external_users,
|
||||
status=status,
|
||||
license_flags=self.license_flags,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import (
|
||||
google_workspace_sync,
|
||||
google_workspace_sync_objects,
|
||||
)
|
||||
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
|
||||
|
||||
@ -55,4 +52,3 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_single_task = google_workspace_sync
|
||||
sync_objects_task = google_workspace_sync_objects
|
||||
|
@ -181,7 +181,7 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-google-workspace-form"
|
||||
return "ak-property-mapping-google-workspace-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import (
|
||||
microsoft_entra_sync,
|
||||
microsoft_entra_sync_objects,
|
||||
)
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
|
||||
|
||||
@ -53,4 +50,3 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_single_task = microsoft_entra_sync
|
||||
sync_objects_task = microsoft_entra_sync_objects
|
||||
|
@ -170,7 +170,7 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-microsoft-entra-form"
|
||||
return "ak-property-mapping-microsoft-entra-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-12 12:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_rac", "0004_alter_connectiontoken_expires"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="racpropertymapping",
|
||||
options={
|
||||
"verbose_name": "RAC Provider Property Mapping",
|
||||
"verbose_name_plural": "RAC Provider Property Mappings",
|
||||
},
|
||||
),
|
||||
]
|
@ -125,7 +125,7 @@ class RACPropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-rac-form"
|
||||
return "ak-property-mapping-rac-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -136,8 +136,8 @@ class RACPropertyMapping(PropertyMapping):
|
||||
return RACPropertyMappingSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Provider Property Mapping")
|
||||
verbose_name_plural = _("RAC Provider Property Mappings")
|
||||
verbose_name = _("RAC Property Mapping")
|
||||
verbose_name_plural = _("RAC Property Mappings")
|
||||
|
||||
|
||||
class ConnectionToken(ExpiringModel):
|
||||
|
@ -44,7 +44,7 @@ websocket_urlpatterns = [
|
||||
|
||||
api_urlpatterns = [
|
||||
("providers/rac", RACProviderViewSet),
|
||||
("propertymappings/provider/rac", RACPropertyMappingViewSet),
|
||||
("propertymappings/rac", RACPropertyMappingViewSet),
|
||||
("rac/endpoints", EndpointViewSet),
|
||||
("rac/connection_tokens", ConnectionTokenViewSet),
|
||||
]
|
||||
|
@ -3,7 +3,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models.signals import post_delete, post_save, pre_save
|
||||
from django.db.models.signals import post_save, pre_save
|
||||
from django.dispatch import receiver
|
||||
from django.utils.timezone import get_current_timezone
|
||||
|
||||
@ -27,9 +27,3 @@ def post_save_license(sender: type[License], instance: License, **_):
|
||||
"""Trigger license usage calculation when license is saved"""
|
||||
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
|
||||
enterprise_update_usage.delay()
|
||||
|
||||
|
||||
@receiver(post_delete, sender=License)
|
||||
def post_delete_license(sender: type[License], instance: License, **_):
|
||||
"""Clear license cache when license is deleted"""
|
||||
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
|
||||
|
@ -69,5 +69,8 @@ class NotificationViewSet(
|
||||
@action(detail=False, methods=["post"])
|
||||
def mark_all_seen(self, request: Request) -> Response:
|
||||
"""Mark all the user's notifications as seen"""
|
||||
Notification.objects.filter(user=request.user, seen=False).update(seen=True)
|
||||
notifications = Notification.objects.filter(user=request.user)
|
||||
for notification in notifications:
|
||||
notification.seen = True
|
||||
Notification.objects.bulk_update(notifications, ["seen"])
|
||||
return Response({}, status=204)
|
||||
|
@ -49,7 +49,6 @@ from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
DISCORD_FIELD_LIMIT = 25
|
||||
@ -59,11 +58,7 @@ NOTIFICATION_SUMMARY_LENGTH = 75
|
||||
def default_event_duration():
|
||||
"""Default duration an Event is saved.
|
||||
This is used as a fallback when no brand is available"""
|
||||
try:
|
||||
tenant = get_current_tenant()
|
||||
return now() + timedelta_from_string(tenant.event_retention)
|
||||
except Tenant.DoesNotExist:
|
||||
return now() + timedelta(days=365)
|
||||
return now() + timedelta(days=365)
|
||||
|
||||
|
||||
def default_brand():
|
||||
@ -250,6 +245,12 @@ class Event(SerializerModel, ExpiringModel):
|
||||
if QS_QUERY in self.context["http_request"]["args"]:
|
||||
wrapped = self.context["http_request"]["args"][QS_QUERY]
|
||||
self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
|
||||
if hasattr(request, "tenant"):
|
||||
tenant: Tenant = request.tenant
|
||||
# Because self.created only gets set on save, we can't use it's value here
|
||||
# hence we set self.created to now and then use it
|
||||
self.created = now()
|
||||
self.expires = self.created + timedelta_from_string(tenant.event_retention)
|
||||
if hasattr(request, "brand"):
|
||||
brand: Brand = request.brand
|
||||
self.brand = sanitize_dict(model_to_dict(brand))
|
||||
|
@ -6,7 +6,6 @@ from django.db.models import Model
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.core.models import default_token_key
|
||||
from authentik.events.models import default_event_duration
|
||||
from authentik.lib.utils.reflection import get_apps
|
||||
|
||||
|
||||
@ -21,7 +20,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable:
|
||||
allowed = 0
|
||||
# Token-like objects need to lookup the current tenant to get the default token length
|
||||
for field in test_model._meta.fields:
|
||||
if field.default in [default_token_key, default_event_duration]:
|
||||
if field.default == default_token_key:
|
||||
allowed += 1
|
||||
with self.assertNumQueries(allowed):
|
||||
str(test_model())
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.models import (
|
||||
@ -11,7 +10,6 @@ from authentik.events.models import (
|
||||
EventAction,
|
||||
Notification,
|
||||
NotificationRule,
|
||||
NotificationSeverity,
|
||||
NotificationTransport,
|
||||
NotificationWebhookMapping,
|
||||
TransportMode,
|
||||
@ -22,7 +20,7 @@ from authentik.policies.exceptions import PolicyException
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
||||
|
||||
class TestEventsNotifications(APITestCase):
|
||||
class TestEventsNotifications(TestCase):
|
||||
"""Test Event Notifications"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
@ -133,15 +131,3 @@ class TestEventsNotifications(APITestCase):
|
||||
Notification.objects.all().delete()
|
||||
Event.new(EventAction.CUSTOM_PREFIX).save()
|
||||
self.assertEqual(Notification.objects.first().body, "foo")
|
||||
|
||||
def test_api_mark_all_seen(self):
|
||||
"""Test mark_all_seen"""
|
||||
self.client.force_login(self.user)
|
||||
|
||||
Notification.objects.create(
|
||||
severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
|
||||
)
|
||||
|
||||
response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
|
||||
self.assertEqual(response.status_code, 204)
|
||||
self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())
|
||||
|
@ -37,7 +37,6 @@ from authentik.lib.utils.file import (
|
||||
)
|
||||
from authentik.lib.views import bad_request_message
|
||||
from authentik.rbac.decorators import permission_required
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -282,7 +281,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
|
||||
400: OpenApiResponse(description="Flow not applicable"),
|
||||
},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def execute(self, request: Request, slug: str):
|
||||
"""Execute flow for current user"""
|
||||
# Because we pre-plan the flow here, and not in the planner, we need to manually clear
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import re
|
||||
import socket
|
||||
from collections.abc import Iterable
|
||||
from ipaddress import ip_address, ip_network
|
||||
from textwrap import indent
|
||||
from types import CodeType
|
||||
@ -27,12 +28,6 @@ from authentik.stages.authenticator import devices_for_user
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
ARG_SANITIZE = re.compile(r"[:.-]")
|
||||
|
||||
|
||||
def sanitize_arg(arg_name: str) -> str:
|
||||
return re.sub(ARG_SANITIZE, "_", arg_name)
|
||||
|
||||
|
||||
class BaseEvaluator:
|
||||
"""Validate and evaluate python-based expressions"""
|
||||
@ -182,9 +177,9 @@ class BaseEvaluator:
|
||||
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
|
||||
return proc.profiling_wrapper()
|
||||
|
||||
def wrap_expression(self, expression: str) -> str:
|
||||
def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
|
||||
"""Wrap expression in a function, call it, and save the result as `result`"""
|
||||
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
|
||||
handler_signature = ",".join(params)
|
||||
full_expression = ""
|
||||
full_expression += f"def handler({handler_signature}):\n"
|
||||
full_expression += indent(expression, " ")
|
||||
@ -193,8 +188,8 @@ class BaseEvaluator:
|
||||
|
||||
def compile(self, expression: str) -> CodeType:
|
||||
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
|
||||
expression = self.wrap_expression(expression)
|
||||
return compile(expression, self._filename, "exec")
|
||||
param_keys = self._context.keys()
|
||||
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
|
||||
|
||||
def evaluate(self, expression_source: str) -> Any:
|
||||
"""Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
|
||||
@ -210,7 +205,7 @@ class BaseEvaluator:
|
||||
self.handle_error(exc, expression_source)
|
||||
raise exc
|
||||
try:
|
||||
_locals = {sanitize_arg(x): y for x, y in self._context.items()}
|
||||
_locals = self._context
|
||||
# Yes this is an exec, yes it is potentially bad. Since we limit what variables are
|
||||
# available here, and these policies can only be edited by admins, this is a risk
|
||||
# we're willing to take.
|
||||
|
@ -1,19 +1,16 @@
|
||||
from celery import Task
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.utils.text import slugify
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, ChoiceField
|
||||
from rest_framework.fields import BooleanField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.events.logs import LogEvent, LogEventSerializer
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
|
||||
class SyncStatusSerializer(PassiveSerializer):
|
||||
@ -23,29 +20,10 @@ class SyncStatusSerializer(PassiveSerializer):
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class SyncObjectSerializer(PassiveSerializer):
|
||||
"""Sync object serializer"""
|
||||
|
||||
sync_object_model = ChoiceField(
|
||||
choices=(
|
||||
(class_to_path(User), "user"),
|
||||
(class_to_path(Group), "group"),
|
||||
)
|
||||
)
|
||||
sync_object_id = CharField()
|
||||
|
||||
|
||||
class SyncObjectResultSerializer(PassiveSerializer):
|
||||
"""Result of a single object sync"""
|
||||
|
||||
messages = LogEventSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class OutgoingSyncProviderStatusMixin:
|
||||
"""Common API Endpoints for Outgoing sync providers"""
|
||||
|
||||
sync_single_task: type[Task] = None
|
||||
sync_objects_task: type[Task] = None
|
||||
sync_single_task: Callable = None
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
@ -58,7 +36,7 @@ class OutgoingSyncProviderStatusMixin:
|
||||
detail=True,
|
||||
pagination_class=None,
|
||||
url_path="sync/status",
|
||||
filter_backends=[ObjectFilter],
|
||||
filter_backends=[],
|
||||
)
|
||||
def sync_status(self, request: Request, pk: int) -> Response:
|
||||
"""Get provider's sync status"""
|
||||
@ -77,30 +55,6 @@ class OutgoingSyncProviderStatusMixin:
|
||||
}
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
|
||||
@extend_schema(
|
||||
request=SyncObjectSerializer,
|
||||
responses={200: SyncObjectResultSerializer()},
|
||||
)
|
||||
@action(
|
||||
methods=["POST"],
|
||||
detail=True,
|
||||
pagination_class=None,
|
||||
url_path="sync/object",
|
||||
filter_backends=[ObjectFilter],
|
||||
)
|
||||
def sync_object(self, request: Request, pk: int) -> Response:
|
||||
"""Sync/Re-sync a single user/group object"""
|
||||
provider: OutgoingSyncProvider = self.get_object()
|
||||
params = SyncObjectSerializer(data=request.data)
|
||||
params.is_valid(raise_exception=True)
|
||||
res: list[LogEvent] = self.sync_objects_task.delay(
|
||||
params.validated_data["sync_object_model"],
|
||||
page=1,
|
||||
provider_pk=provider.pk,
|
||||
pk=params.validated_data["sync_object_id"],
|
||||
).get()
|
||||
return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
|
||||
|
||||
|
||||
class OutgoingSyncConnectionCreateMixin:
|
||||
"""Mixin for connection objects that fetches remote data upon creation"""
|
||||
|
@ -105,7 +105,7 @@ class SyncTasks:
|
||||
return
|
||||
task.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
def sync_objects(self, object_type: str, page: int, provider_pk: int, **filter):
|
||||
def sync_objects(self, object_type: str, page: int, provider_pk: int):
|
||||
_object_type = path_to_class(object_type)
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
@ -120,7 +120,7 @@ class SyncTasks:
|
||||
client = provider.client_for_model(_object_type)
|
||||
except TransientSyncException:
|
||||
return messages
|
||||
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
|
||||
paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
|
||||
if client.can_discover:
|
||||
self.logger.debug("starting discover")
|
||||
client.discover()
|
||||
|
@ -30,11 +30,6 @@ class TestHTTP(TestCase):
|
||||
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
|
||||
|
||||
def test_forward_for_invalid(self):
|
||||
"""Test invalid forward for"""
|
||||
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar")
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip)
|
||||
|
||||
def test_fake_outpost(self):
|
||||
"""Test faked IP which is overridden by an outpost"""
|
||||
token = Token.objects.create(
|
||||
@ -58,17 +53,6 @@ class TestHTTP(TestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||
# Invalid, not a real IP
|
||||
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
self.user.save()
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
**{
|
||||
ClientIPMiddleware.outpost_remote_ip_header: "foobar",
|
||||
ClientIPMiddleware.outpost_token_header: token.key,
|
||||
},
|
||||
)
|
||||
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||
# Valid
|
||||
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||
self.user.save()
|
||||
|
@ -26,6 +26,7 @@ from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
|
||||
from authentik.outposts.models import (
|
||||
Outpost,
|
||||
OutpostConfig,
|
||||
OutpostState,
|
||||
OutpostType,
|
||||
default_outpost_config,
|
||||
)
|
||||
@ -181,6 +182,7 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
|
||||
outpost: Outpost = self.get_object()
|
||||
states = []
|
||||
for state in outpost.state:
|
||||
state: OutpostState
|
||||
states.append(
|
||||
{
|
||||
"uid": state.uid,
|
||||
|
@ -26,7 +26,6 @@ from authentik.outposts.models import (
|
||||
KubernetesServiceConnection,
|
||||
OutpostServiceConnection,
|
||||
)
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
|
||||
class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer):
|
||||
@ -76,7 +75,7 @@ class ServiceConnectionViewSet(
|
||||
filterset_fields = ["name"]
|
||||
|
||||
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def state(self, request: Request, pk: str) -> Response:
|
||||
"""Get the service connection's state"""
|
||||
connection = self.get_object()
|
||||
|
@ -451,7 +451,7 @@ class OutpostState:
|
||||
return False
|
||||
if self.build_hash != get_build_hash():
|
||||
return False
|
||||
return parse(self.version) != OUR_VERSION
|
||||
return parse(self.version) < OUR_VERSION
|
||||
|
||||
@staticmethod
|
||||
def for_outpost(outpost: Outpost) -> list["OutpostState"]:
|
||||
|
@ -214,7 +214,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
|
||||
if not hasattr(instance, field_name):
|
||||
continue
|
||||
|
||||
LOGGER.debug("triggering outpost update from field", field=field.name)
|
||||
LOGGER.debug("triggering outpost update from from field", field=field.name)
|
||||
# Because the Outpost Model has an M2M to Provider,
|
||||
# we have to iterate over the entire QS
|
||||
for reverse in getattr(instance, field_name).all():
|
||||
|
@ -36,7 +36,7 @@ def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
if not created:
|
||||
reputation.score = F("score") + amount
|
||||
reputation.save()
|
||||
LOGGER.info("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
|
||||
LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
|
||||
|
||||
|
||||
@receiver(login_failed)
|
||||
|
@ -2,25 +2,15 @@
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.query import Q
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django_filters.filters import BooleanFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, ListField, SerializerMethodField
|
||||
from rest_framework.fields import CharField, ListField, SerializerMethodField
|
||||
from rest_framework.mixins import ListModelMixin
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import Application
|
||||
from authentik.policies.api.exec import PolicyTestResultSerializer
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.types import PolicyResult
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.providers.ldap.models import LDAPProvider
|
||||
|
||||
|
||||
@ -33,6 +23,7 @@ class LDAPProviderSerializer(ProviderSerializer):
|
||||
model = LDAPProvider
|
||||
fields = ProviderSerializer.Meta.fields + [
|
||||
"base_dn",
|
||||
"search_group",
|
||||
"certificate",
|
||||
"tls_server_name",
|
||||
"uid_start_number",
|
||||
@ -64,6 +55,8 @@ class LDAPProviderFilter(FilterSet):
|
||||
"name": ["iexact"],
|
||||
"authorization_flow__slug": ["iexact"],
|
||||
"base_dn": ["iexact"],
|
||||
"search_group__group_uuid": ["iexact"],
|
||||
"search_group__name": ["iexact"],
|
||||
"certificate__kp_uuid": ["iexact"],
|
||||
"certificate__name": ["iexact"],
|
||||
"tls_server_name": ["iexact"],
|
||||
@ -102,6 +95,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
|
||||
"base_dn",
|
||||
"bind_flow_slug",
|
||||
"application_slug",
|
||||
"search_group",
|
||||
"certificate",
|
||||
"tls_server_name",
|
||||
"uid_start_number",
|
||||
@ -122,33 +116,3 @@ class LDAPOutpostConfigViewSet(ListModelMixin, GenericViewSet):
|
||||
ordering = ["name"]
|
||||
search_fields = ["name"]
|
||||
filterset_fields = ["name"]
|
||||
|
||||
class LDAPCheckAccessSerializer(PassiveSerializer):
|
||||
has_search_permission = BooleanField(required=False)
|
||||
access = PolicyTestResultSerializer()
|
||||
|
||||
@extend_schema(
|
||||
request=None,
|
||||
parameters=[OpenApiParameter("app_slug", OpenApiTypes.STR)],
|
||||
responses={
|
||||
200: LDAPCheckAccessSerializer(),
|
||||
},
|
||||
operation_id="outposts_ldap_access_check",
|
||||
)
|
||||
@action(detail=True)
|
||||
def check_access(self, request: Request, pk) -> Response:
|
||||
"""Check access to a single application by slug"""
|
||||
provider = get_object_or_404(LDAPProvider, pk=pk)
|
||||
application = get_object_or_404(Application, slug=request.query_params["app_slug"])
|
||||
engine = PolicyEngine(application, request.user, request)
|
||||
engine.use_cache = False
|
||||
engine.build()
|
||||
result = engine.result
|
||||
access_response = PolicyResult(result.passing)
|
||||
response = self.LDAPCheckAccessSerializer(
|
||||
instance={
|
||||
"has_search_permission": request.user.has_perm("search_full_directory", provider),
|
||||
"access": access_response,
|
||||
}
|
||||
)
|
||||
return Response(response.data)
|
||||
|
@ -1,66 +0,0 @@
|
||||
# Generated by Django 5.0.7 on 2024-07-25 14:59
|
||||
from django.apps.registry import Apps
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
from authentik.core.models import User
|
||||
from django.apps import apps as real_apps
|
||||
from django.contrib.auth.management import create_permissions
|
||||
from guardian.shortcuts import UserObjectPermission
|
||||
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
# Permissions are only created _after_ migrations are run
|
||||
# - https://github.com/django/django/blob/43cdfa8b20e567a801b7d0a09ec67ddd062d5ea4/django/contrib/auth/apps.py#L19
|
||||
# - https://stackoverflow.com/a/72029063/1870445
|
||||
create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
|
||||
|
||||
LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
|
||||
Permission = apps.get_model("auth", "Permission")
|
||||
UserObjectPermission = apps.get_model("guardian", "UserObjectPermission")
|
||||
ContentType = apps.get_model("contenttypes", "ContentType")
|
||||
|
||||
new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory")
|
||||
ct = ContentType.objects.using(db_alias).get(
|
||||
app_label="authentik_providers_ldap",
|
||||
model="ldapprovider",
|
||||
)
|
||||
|
||||
for provider in LDAPProvider.objects.using(db_alias).all():
|
||||
if not provider.search_group:
|
||||
continue
|
||||
for user in provider.search_group.users.using(db_alias).all():
|
||||
UserObjectPermission.objects.using(db_alias).create(
|
||||
user=user,
|
||||
permission=new_prem,
|
||||
object_pk=provider.pk,
|
||||
content_type=ct,
|
||||
)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
|
||||
("guardian", "0002_generic_permissions_index"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="ldapprovider",
|
||||
options={
|
||||
"permissions": [("search_full_directory", "Search full LDAP directory")],
|
||||
"verbose_name": "LDAP Provider",
|
||||
"verbose_name_plural": "LDAP Providers",
|
||||
},
|
||||
),
|
||||
migrations.RunPython(migrate_search_group),
|
||||
migrations.RemoveField(
|
||||
model_name="ldapprovider",
|
||||
name="search_group",
|
||||
),
|
||||
]
|
@ -7,7 +7,7 @@ from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import BackchannelProvider
|
||||
from authentik.core.models import BackchannelProvider, Group
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.outposts.models import OutpostModel
|
||||
|
||||
@ -27,6 +27,17 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
||||
help_text=_("DN under which objects are accessible."),
|
||||
)
|
||||
|
||||
search_group = models.ForeignKey(
|
||||
Group,
|
||||
null=True,
|
||||
default=None,
|
||||
on_delete=models.SET_DEFAULT,
|
||||
help_text=_(
|
||||
"Users in this group can do search queries. "
|
||||
"If not set, every user can execute search queries."
|
||||
),
|
||||
)
|
||||
|
||||
tls_server_name = models.TextField(
|
||||
default="",
|
||||
blank=True,
|
||||
@ -102,6 +113,3 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
||||
class Meta:
|
||||
verbose_name = _("LDAP Provider")
|
||||
verbose_name_plural = _("LDAP Providers")
|
||||
permissions = [
|
||||
("search_full_directory", _("Search full LDAP directory")),
|
||||
]
|
||||
|
@ -105,7 +105,7 @@ class ScopeMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-scope-form"
|
||||
return "ak-property-mapping-scope-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -29,6 +29,7 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
self.app = Application.objects.create(
|
||||
name=generate_id(), slug=generate_id(), provider=self.provider
|
||||
)
|
||||
self.app.save()
|
||||
self.user = create_test_admin_user()
|
||||
self.auth = b64encode(
|
||||
f"{self.provider.client_id}:{self.provider.client_secret}".encode()
|
||||
@ -113,41 +114,6 @@ class TesOAuth2Introspection(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_introspect_invalid_provider(self):
|
||||
"""Test introspection (mismatched provider and token)"""
|
||||
provider: OAuth2Provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="",
|
||||
signing_key=create_test_cert(),
|
||||
)
|
||||
auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
|
||||
|
||||
token: AccessToken = AccessToken.objects.create(
|
||||
provider=self.provider,
|
||||
user=self.user,
|
||||
token=generate_id(),
|
||||
auth_time=timezone.now(),
|
||||
_scope="openid user profile",
|
||||
_id_token=json.dumps(
|
||||
asdict(
|
||||
IDToken("foo", "bar"),
|
||||
)
|
||||
),
|
||||
)
|
||||
res = self.client.post(
|
||||
reverse("authentik_providers_oauth2:token-introspection"),
|
||||
HTTP_AUTHORIZATION=f"Basic {auth}",
|
||||
data={"token": token.token},
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertJSONEqual(
|
||||
res.content.decode(),
|
||||
{
|
||||
"active": False,
|
||||
},
|
||||
)
|
||||
|
||||
def test_introspect_invalid_auth(self):
|
||||
"""Test introspect (invalid auth)"""
|
||||
res = self.client.post(
|
||||
|
@ -62,7 +62,7 @@ urlpatterns = [
|
||||
|
||||
api_urlpatterns = [
|
||||
("providers/oauth2", OAuth2ProviderViewSet),
|
||||
("propertymappings/provider/scope", ScopeMappingViewSet),
|
||||
("propertymappings/scope", ScopeMappingViewSet),
|
||||
("oauth2/authorization_codes", AuthorizationCodeViewSet),
|
||||
("oauth2/refresh_tokens", RefreshTokenViewSet),
|
||||
("oauth2/access_tokens", AccessTokenViewSet),
|
||||
|
@ -46,10 +46,10 @@ class TokenIntrospectionParams:
|
||||
if not provider:
|
||||
raise TokenIntrospectionError
|
||||
|
||||
access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first()
|
||||
access_token = AccessToken.objects.filter(token=raw_token).first()
|
||||
if access_token:
|
||||
return TokenIntrospectionParams(access_token, provider)
|
||||
refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first()
|
||||
refresh_token = RefreshToken.objects.filter(token=raw_token).first()
|
||||
if refresh_token:
|
||||
return TokenIntrospectionParams(refresh_token, provider)
|
||||
LOGGER.debug("Token does not exist", token=raw_token)
|
||||
|
@ -433,21 +433,20 @@ class TokenParams:
|
||||
app = Application.objects.filter(provider=self.provider).first()
|
||||
if not app or not app.provider:
|
||||
raise TokenError("invalid_grant")
|
||||
with audit_ignore():
|
||||
self.user, _ = User.objects.update_or_create(
|
||||
# trim username to ensure the entire username is max 150 chars
|
||||
# (22 chars being the length of the "template")
|
||||
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
|
||||
defaults={
|
||||
"attributes": {
|
||||
USER_ATTRIBUTE_GENERATED: True,
|
||||
},
|
||||
"last_login": timezone.now(),
|
||||
"name": f"Autogenerated user from application {app.name} (client credentials)",
|
||||
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
|
||||
"type": UserTypes.SERVICE_ACCOUNT,
|
||||
self.user, _ = User.objects.update_or_create(
|
||||
# trim username to ensure the entire username is max 150 chars
|
||||
# (22 chars being the length of the "template")
|
||||
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
|
||||
defaults={
|
||||
"attributes": {
|
||||
USER_ATTRIBUTE_GENERATED: True,
|
||||
},
|
||||
)
|
||||
"last_login": timezone.now(),
|
||||
"name": f"Autogenerated user from application {app.name} (client credentials)",
|
||||
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
|
||||
"type": UserTypes.SERVICE_ACCOUNT,
|
||||
},
|
||||
)
|
||||
self.__check_policy_access(app, request)
|
||||
|
||||
Event.new(
|
||||
|
@ -28,7 +28,7 @@ class ProxyDockerController(DockerController):
|
||||
labels = super()._get_labels()
|
||||
labels["traefik.enable"] = "true"
|
||||
labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
|
||||
f"({' || '.join([f'Host({host})' for host in hosts])})"
|
||||
f"({' || '.join([f'Host(`{host}`)' for host in hosts])})"
|
||||
f" && PathPrefix(`/outpost.goauthentik.io`)"
|
||||
)
|
||||
labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"
|
||||
|
@ -154,7 +154,6 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
|
||||
responses={
|
||||
200: RadiusCheckAccessSerializer(),
|
||||
},
|
||||
operation_id="outposts_radius_access_check",
|
||||
)
|
||||
@action(detail=True)
|
||||
def check_access(self, request: Request, pk) -> Response:
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-12 12:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_radius", "0003_radiusproviderpropertymapping"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="radiusproviderpropertymapping",
|
||||
options={
|
||||
"verbose_name": "Radius Provider Property Mapping",
|
||||
"verbose_name_plural": "Radius Provider Property Mappings",
|
||||
},
|
||||
),
|
||||
]
|
@ -70,7 +70,7 @@ class RadiusProviderPropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-radius-form"
|
||||
return "ak-property-mapping-radius-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -81,8 +81,8 @@ class RadiusProviderPropertyMapping(PropertyMapping):
|
||||
return RadiusProviderPropertyMappingSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"Radius Provider Property Mapping {self.name}"
|
||||
return f"Radius Property Mapping {self.name}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Radius Provider Property Mapping")
|
||||
verbose_name_plural = _("Radius Provider Property Mappings")
|
||||
verbose_name = _("Radius Property Mapping")
|
||||
verbose_name_plural = _("Radius Property Mappings")
|
||||
|
@ -7,7 +7,7 @@ from authentik.providers.radius.api.providers import (
|
||||
)
|
||||
|
||||
api_urlpatterns = [
|
||||
("propertymappings/provider/radius", RadiusProviderPropertyMappingViewSet),
|
||||
("propertymappings/radius", RadiusProviderPropertyMappingViewSet),
|
||||
("outposts/radius", RadiusOutpostConfigViewSet, "radiusprovideroutpost"),
|
||||
("providers/radius", RadiusProviderViewSet),
|
||||
]
|
||||
|
@ -133,17 +133,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return "-"
|
||||
|
||||
def validate(self, attrs: dict):
|
||||
if attrs.get("signing_kp"):
|
||||
if not attrs.get("sign_assertion") and not attrs.get("sign_response"):
|
||||
raise ValidationError(
|
||||
_(
|
||||
"With a signing keypair selected, at least one of 'Sign assertion' "
|
||||
"and 'Sign Response' must be selected."
|
||||
)
|
||||
)
|
||||
return super().validate(attrs)
|
||||
|
||||
class Meta:
|
||||
model = SAMLProvider
|
||||
fields = ProviderSerializer.Meta.fields + [
|
||||
@ -159,9 +148,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"signature_algorithm",
|
||||
"signing_kp",
|
||||
"verification_kp",
|
||||
"encryption_kp",
|
||||
"sign_assertion",
|
||||
"sign_response",
|
||||
"sp_binding",
|
||||
"default_relay_state",
|
||||
"url_download_metadata",
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-12 12:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_saml", "0014_alter_samlprovider_digest_algorithm_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="samlpropertymapping",
|
||||
options={
|
||||
"verbose_name": "SAML Provider Property Mapping",
|
||||
"verbose_name_plural": "SAML Provider Property Mappings",
|
||||
},
|
||||
),
|
||||
]
|
@ -1,39 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-15 14:52
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_crypto", "0004_alter_certificatekeypair_name"),
|
||||
("authentik_providers_saml", "0015_alter_samlpropertymapping_options"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="samlprovider",
|
||||
name="encryption_kp",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
help_text="When selected, incoming assertions are encrypted by the IdP using the public key of the encryption keypair. The assertion is decrypted by the SP using the the private key.",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="+",
|
||||
to="authentik_crypto.certificatekeypair",
|
||||
verbose_name="Encryption Keypair",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="samlprovider",
|
||||
name="sign_assertion",
|
||||
field=models.BooleanField(default=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="samlprovider",
|
||||
name="sign_response",
|
||||
field=models.BooleanField(default=False),
|
||||
),
|
||||
]
|
@ -144,28 +144,11 @@ class SAMLProvider(Provider):
|
||||
on_delete=models.SET_NULL,
|
||||
verbose_name=_("Signing Keypair"),
|
||||
)
|
||||
encryption_kp = models.ForeignKey(
|
||||
CertificateKeyPair,
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text=_(
|
||||
"When selected, incoming assertions are encrypted by the IdP using the public "
|
||||
"key of the encryption keypair. The assertion is decrypted by the SP using the "
|
||||
"the private key."
|
||||
),
|
||||
on_delete=models.SET_NULL,
|
||||
verbose_name=_("Encryption Keypair"),
|
||||
related_name="+",
|
||||
)
|
||||
|
||||
default_relay_state = models.TextField(
|
||||
default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins")
|
||||
)
|
||||
|
||||
sign_assertion = models.BooleanField(default=True)
|
||||
sign_response = models.BooleanField(default=False)
|
||||
|
||||
@property
|
||||
def launch_url(self) -> str | None:
|
||||
"""Use IDP-Initiated SAML flow as launch URL"""
|
||||
@ -208,7 +191,7 @@ class SAMLPropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-saml-form"
|
||||
return "ak-property-mapping-saml-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -221,8 +204,8 @@ class SAMLPropertyMapping(PropertyMapping):
|
||||
return f"{self.name} ({name})"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("SAML Provider Property Mapping")
|
||||
verbose_name_plural = _("SAML Provider Property Mappings")
|
||||
verbose_name = _("SAML Property Mapping")
|
||||
verbose_name_plural = _("SAML Property Mappings")
|
||||
|
||||
|
||||
class SAMLProviderImportModel(CreatableType, Provider):
|
||||
|
@ -18,11 +18,7 @@ from authentik.providers.saml.processors.authn_request_parser import AuthNReques
|
||||
from authentik.providers.saml.utils import get_random_id
|
||||
from authentik.providers.saml.utils.time import get_time_string
|
||||
from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME
|
||||
from authentik.sources.saml.exceptions import (
|
||||
InvalidEncryption,
|
||||
InvalidSignature,
|
||||
UnsupportedNameIDFormat,
|
||||
)
|
||||
from authentik.sources.saml.exceptions import InvalidSignature, UnsupportedNameIDFormat
|
||||
from authentik.sources.saml.processors.constants import (
|
||||
DIGEST_ALGORITHM_TRANSLATION_MAP,
|
||||
NS_MAP,
|
||||
@ -260,17 +256,9 @@ class AssertionProcessor:
|
||||
assertion,
|
||||
xmlsec.constants.TransformExclC14N,
|
||||
sign_algorithm_transform,
|
||||
ns=xmlsec.constants.DSigNs,
|
||||
ns="ds", # type: ignore
|
||||
)
|
||||
assertion.append(signature)
|
||||
if self.provider.encryption_kp:
|
||||
encryption = xmlsec.template.encrypted_data_create(
|
||||
assertion,
|
||||
xmlsec.constants.TransformAes128Cbc,
|
||||
self._assertion_id,
|
||||
ns=xmlsec.constants.DSigNs,
|
||||
)
|
||||
assertion.append(encryption)
|
||||
|
||||
assertion.append(self.get_assertion_subject())
|
||||
assertion.append(self.get_assertion_conditions())
|
||||
@ -298,86 +286,41 @@ class AssertionProcessor:
|
||||
response.append(self.get_assertion())
|
||||
return response
|
||||
|
||||
def _sign(self, element: Element):
|
||||
"""Sign an XML element based on the providers' configured signing settings"""
|
||||
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
|
||||
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
|
||||
)
|
||||
xmlsec.tree.add_ids(element, ["ID"])
|
||||
signature_node = xmlsec.tree.find_node(element, xmlsec.constants.NodeSignature)
|
||||
ref = xmlsec.template.add_reference(
|
||||
signature_node,
|
||||
digest_algorithm_transform,
|
||||
uri="#" + self._assertion_id,
|
||||
)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
|
||||
key_info = xmlsec.template.ensure_key_info(signature_node)
|
||||
xmlsec.template.add_x509_data(key_info)
|
||||
|
||||
ctx = xmlsec.SignatureContext()
|
||||
|
||||
key = xmlsec.Key.from_memory(
|
||||
self.provider.signing_kp.key_data,
|
||||
xmlsec.constants.KeyDataFormatPem,
|
||||
None,
|
||||
)
|
||||
key.load_cert_from_memory(
|
||||
self.provider.signing_kp.certificate_data,
|
||||
xmlsec.constants.KeyDataFormatCertPem,
|
||||
)
|
||||
ctx.key = key
|
||||
try:
|
||||
ctx.sign(signature_node)
|
||||
except xmlsec.Error as exc:
|
||||
raise InvalidSignature() from exc
|
||||
|
||||
def _encrypt(self, element: Element, parent: Element):
|
||||
"""Encrypt SAMLResponse EncryptedAssertion Element"""
|
||||
manager = xmlsec.KeysManager()
|
||||
key = xmlsec.Key.from_memory(
|
||||
self.provider.encryption_kp.key_data,
|
||||
xmlsec.constants.KeyDataFormatPem,
|
||||
)
|
||||
key.load_cert_from_memory(
|
||||
self.provider.encryption_kp.certificate_data,
|
||||
xmlsec.constants.KeyDataFormatCertPem,
|
||||
)
|
||||
|
||||
manager.add_key(key)
|
||||
encryption_context = xmlsec.EncryptionContext(manager)
|
||||
encryption_context.key = xmlsec.Key.generate(
|
||||
xmlsec.constants.KeyDataAes, 128, xmlsec.constants.KeyDataTypeSession
|
||||
)
|
||||
|
||||
container = SubElement(parent, f"{{{NS_SAML_ASSERTION}}}EncryptedAssertion")
|
||||
enc_data = xmlsec.template.encrypted_data_create(
|
||||
container, xmlsec.Transform.AES128, type=xmlsec.EncryptionType.ELEMENT, ns="xenc"
|
||||
)
|
||||
xmlsec.template.encrypted_data_ensure_cipher_value(enc_data)
|
||||
key_info = xmlsec.template.encrypted_data_ensure_key_info(enc_data, ns="ds")
|
||||
enc_key = xmlsec.template.add_encrypted_key(key_info, xmlsec.Transform.RSA_OAEP)
|
||||
xmlsec.template.encrypted_data_ensure_cipher_value(enc_key)
|
||||
|
||||
try:
|
||||
enc_data = encryption_context.encrypt_xml(enc_data, element)
|
||||
except xmlsec.Error as exc:
|
||||
raise InvalidEncryption() from exc
|
||||
|
||||
parent.remove(enc_data)
|
||||
container.append(enc_data)
|
||||
|
||||
def build_response(self) -> str:
|
||||
"""Build string XML Response and sign if signing is enabled."""
|
||||
root_response = self.get_response()
|
||||
if self.provider.signing_kp:
|
||||
if self.provider.sign_assertion:
|
||||
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
|
||||
self._sign(assertion)
|
||||
if self.provider.sign_response:
|
||||
response = root_response.xpath("//samlp:Response", namespaces=NS_MAP)[0]
|
||||
self._sign(response)
|
||||
if self.provider.encryption_kp:
|
||||
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
|
||||
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
|
||||
)
|
||||
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
|
||||
self._encrypt(assertion, root_response)
|
||||
xmlsec.tree.add_ids(assertion, ["ID"])
|
||||
signature_node = xmlsec.tree.find_node(assertion, xmlsec.constants.NodeSignature)
|
||||
ref = xmlsec.template.add_reference(
|
||||
signature_node,
|
||||
digest_algorithm_transform,
|
||||
uri="#" + self._assertion_id,
|
||||
)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
|
||||
key_info = xmlsec.template.ensure_key_info(signature_node)
|
||||
xmlsec.template.add_x509_data(key_info)
|
||||
|
||||
ctx = xmlsec.SignatureContext()
|
||||
|
||||
key = xmlsec.Key.from_memory(
|
||||
self.provider.signing_kp.key_data,
|
||||
xmlsec.constants.KeyDataFormatPem,
|
||||
None,
|
||||
)
|
||||
key.load_cert_from_memory(
|
||||
self.provider.signing_kp.certificate_data,
|
||||
xmlsec.constants.KeyDataFormatCertPem,
|
||||
)
|
||||
ctx.key = key
|
||||
try:
|
||||
ctx.sign(signature_node)
|
||||
except xmlsec.Error as exc:
|
||||
raise InvalidSignature() from exc
|
||||
|
||||
return etree.tostring(root_response).decode("utf-8") # nosec
|
||||
|
@ -126,7 +126,7 @@ class MetadataProcessor:
|
||||
entity_descriptor,
|
||||
xmlsec.constants.TransformExclC14N,
|
||||
sign_algorithm_transform,
|
||||
ns=xmlsec.constants.DSigNs,
|
||||
ns="ds", # type: ignore
|
||||
)
|
||||
entity_descriptor.append(signature)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.flows.models import FlowDesignation
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
@ -29,52 +29,12 @@ class TestSAMLProviderAPI(APITestCase):
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
Application.objects.create(name=generate_id(), provider=provider, slug=generate_id())
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
def test_create_validate_signing_kp(self):
|
||||
"""Test create"""
|
||||
cert = create_test_cert()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:samlprovider-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"authorization_flow": create_test_flow().pk,
|
||||
"acs_url": "http://localhost",
|
||||
"signing_kp": cert.pk,
|
||||
},
|
||||
)
|
||||
self.assertEqual(400, response.status_code)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{
|
||||
"non_field_errors": [
|
||||
(
|
||||
"With a signing keypair selected, at least one "
|
||||
"of 'Sign assertion' and 'Sign Response' must be selected."
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:samlprovider-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"authorization_flow": create_test_flow().pk,
|
||||
"acs_url": "http://localhost",
|
||||
"signing_kp": cert.pk,
|
||||
"sign_assertion": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(201, response.status_code)
|
||||
|
||||
def test_metadata(self):
|
||||
"""Test metadata export (normal)"""
|
||||
self.client.logout()
|
||||
|
@ -78,12 +78,12 @@ class TestAuthNRequest(TestCase):
|
||||
|
||||
@apply_blueprint("system/providers-saml.yaml")
|
||||
def setUp(self):
|
||||
self.cert = create_test_cert()
|
||||
cert = create_test_cert()
|
||||
self.provider: SAMLProvider = SAMLProvider.objects.create(
|
||||
authorization_flow=create_test_flow(),
|
||||
acs_url="http://testserver/source/saml/provider/acs/",
|
||||
signing_kp=self.cert,
|
||||
verification_kp=self.cert,
|
||||
signing_kp=cert,
|
||||
verification_kp=cert,
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
@ -91,8 +91,8 @@ class TestAuthNRequest(TestCase):
|
||||
slug="provider",
|
||||
issuer="authentik",
|
||||
pre_authentication_flow=create_test_flow(),
|
||||
signing_kp=self.cert,
|
||||
verification_kp=self.cert,
|
||||
signing_kp=cert,
|
||||
verification_kp=cert,
|
||||
)
|
||||
|
||||
def test_signed_valid(self):
|
||||
@ -112,34 +112,7 @@ class TestAuthNRequest(TestCase):
|
||||
self.assertEqual(parsed_request.id, request_proc.request_id)
|
||||
self.assertEqual(parsed_request.relay_state, "test_state")
|
||||
|
||||
def test_request_encrypt(self):
|
||||
"""Test full SAML Request/Response flow, fully encrypted"""
|
||||
self.provider.encryption_kp = self.cert
|
||||
self.provider.save()
|
||||
self.source.encryption_kp = self.cert
|
||||
self.source.save()
|
||||
http_request = get_request("/")
|
||||
|
||||
# First create an AuthNRequest
|
||||
request_proc = RequestProcessor(self.source, http_request, "test_state")
|
||||
request = request_proc.build_auth_n()
|
||||
|
||||
# To get an assertion we need a parsed request (parsed by provider)
|
||||
parsed_request = AuthNRequestParser(self.provider).parse(
|
||||
b64encode(request.encode()).decode(), "test_state"
|
||||
)
|
||||
# Now create a response and convert it to string (provider)
|
||||
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
|
||||
response = response_proc.build_response()
|
||||
|
||||
# Now parse the response (source)
|
||||
http_request.POST = QueryDict(mutable=True)
|
||||
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
|
||||
|
||||
response_parser = ResponseProcessor(self.source, http_request)
|
||||
response_parser.parse()
|
||||
|
||||
def test_request_signed(self):
|
||||
def test_request_full_signed(self):
|
||||
"""Test full SAML Request/Response flow, fully signed"""
|
||||
http_request = get_request("/")
|
||||
|
||||
@ -162,32 +135,6 @@ class TestAuthNRequest(TestCase):
|
||||
response_parser = ResponseProcessor(self.source, http_request)
|
||||
response_parser.parse()
|
||||
|
||||
def test_request_signed_both(self):
|
||||
"""Test full SAML Request/Response flow, fully signed"""
|
||||
self.provider.sign_assertion = True
|
||||
self.provider.sign_response = True
|
||||
self.provider.save()
|
||||
http_request = get_request("/")
|
||||
|
||||
# First create an AuthNRequest
|
||||
request_proc = RequestProcessor(self.source, http_request, "test_state")
|
||||
request = request_proc.build_auth_n()
|
||||
|
||||
# To get an assertion we need a parsed request (parsed by provider)
|
||||
parsed_request = AuthNRequestParser(self.provider).parse(
|
||||
b64encode(request.encode()).decode(), "test_state"
|
||||
)
|
||||
# Now create a response and convert it to string (provider)
|
||||
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
|
||||
response = response_proc.build_response()
|
||||
|
||||
# Now parse the response (source)
|
||||
http_request.POST = QueryDict(mutable=True)
|
||||
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
|
||||
|
||||
response_parser = ResponseProcessor(self.source, http_request)
|
||||
response_parser.parse()
|
||||
|
||||
def test_request_id_invalid(self):
|
||||
"""Test generated AuthNRequest with invalid request ID"""
|
||||
http_request = get_request("/")
|
||||
|
@ -54,11 +54,7 @@ class TestServiceProviderMetadataParser(TestCase):
|
||||
request = self.factory.get("/")
|
||||
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
|
||||
|
||||
schema = etree.XMLSchema(
|
||||
etree.parse(
|
||||
source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()
|
||||
) # nosec
|
||||
)
|
||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
|
||||
self.assertTrue(schema.validate(metadata))
|
||||
|
||||
def test_schema_want_authn_requests_signed(self):
|
||||
|
@ -47,9 +47,7 @@ class TestSchema(TestCase):
|
||||
|
||||
metadata = lxml_from_string(request)
|
||||
|
||||
schema = etree.XMLSchema(
|
||||
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
|
||||
)
|
||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
|
||||
self.assertTrue(schema.validate(metadata))
|
||||
|
||||
def test_response_schema(self):
|
||||
@ -70,7 +68,5 @@ class TestSchema(TestCase):
|
||||
|
||||
metadata = lxml_from_string(response)
|
||||
|
||||
schema = etree.XMLSchema(
|
||||
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
|
||||
)
|
||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
|
||||
self.assertTrue(schema.validate(metadata))
|
||||
|
@ -44,6 +44,6 @@ urlpatterns = [
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
("propertymappings/provider/saml", SAMLPropertyMappingViewSet),
|
||||
("propertymappings/saml", SAMLPropertyMappingViewSet),
|
||||
("providers/saml", SAMLProviderViewSet),
|
||||
]
|
||||
|
@ -6,7 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
|
||||
|
||||
class SCIMProviderSerializer(ProviderSerializer):
|
||||
@ -42,4 +42,3 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
|
||||
search_fields = ["name", "url"]
|
||||
ordering = ["name", "url"]
|
||||
sync_single_task = scim_sync
|
||||
sync_objects_task = scim_sync_objects
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Group client"""
|
||||
|
||||
from itertools import batched
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydanticscim.group import GroupMember
|
||||
from pydanticscim.responses import PatchOp, PatchOperation
|
||||
@ -58,22 +56,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
if not scim_group.externalId:
|
||||
scim_group.externalId = str(obj.pk)
|
||||
|
||||
if not self._config.patch.supported:
|
||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||
connections = SCIMProviderUser.objects.filter(
|
||||
provider=self.provider, user__pk__in=users
|
||||
)
|
||||
members = []
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.scim_id,
|
||||
)
|
||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||
connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
|
||||
members = []
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.scim_id,
|
||||
)
|
||||
if members:
|
||||
scim_group.members = members
|
||||
else:
|
||||
del scim_group.members
|
||||
)
|
||||
if members:
|
||||
scim_group.members = members
|
||||
return scim_group
|
||||
|
||||
def delete(self, obj: Group):
|
||||
@ -100,19 +93,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
scim_id = response.get("id")
|
||||
if not scim_id or scim_id == "":
|
||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||
connection = SCIMProviderGroup.objects.create(
|
||||
return SCIMProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, scim_id=scim_id
|
||||
)
|
||||
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||
self._patch_add_users(group, users)
|
||||
return connection
|
||||
|
||||
def update(self, group: Group, connection: SCIMProviderGroup):
|
||||
"""Update existing group"""
|
||||
scim_group = self.to_schema(group, connection)
|
||||
scim_group.id = connection.scim_id
|
||||
try:
|
||||
self._request(
|
||||
return self._request(
|
||||
"PUT",
|
||||
f"/Groups/{connection.scim_id}",
|
||||
json=scim_group.model_dump(
|
||||
@ -120,8 +110,6 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||
return self._patch_add_users(group, users)
|
||||
except NotFoundSyncException:
|
||||
# Resource missing is handled by self.write, which will re-create the group
|
||||
raise
|
||||
@ -164,18 +152,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
group_id: str,
|
||||
*ops: PatchOperation,
|
||||
):
|
||||
chunk_size = self._config.bulk.maxOperations
|
||||
if chunk_size < 1:
|
||||
chunk_size = len(ops)
|
||||
for chunk in batched(ops, chunk_size):
|
||||
req = PatchRequest(Operations=list(chunk))
|
||||
self._request(
|
||||
"PATCH",
|
||||
f"/Groups/{group_id}",
|
||||
json=req.model_dump(
|
||||
mode="json",
|
||||
),
|
||||
)
|
||||
req = PatchRequest(Operations=ops)
|
||||
self._request(
|
||||
"PATCH",
|
||||
f"/Groups/{group_id}",
|
||||
json=req.model_dump(
|
||||
mode="json",
|
||||
),
|
||||
)
|
||||
|
||||
def _patch_add_users(self, group: Group, users_set: set[int]):
|
||||
"""Add users in users_set to group"""
|
||||
@ -196,14 +180,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
*[
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
value=[{"value": x}],
|
||||
)
|
||||
for x in user_ids
|
||||
],
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
value=[{"value": x} for x in user_ids],
|
||||
),
|
||||
)
|
||||
|
||||
def _patch_remove_users(self, group: Group, users_set: set[int]):
|
||||
@ -225,12 +206,9 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
*[
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
||||
value=[{"value": x}],
|
||||
)
|
||||
for x in user_ids
|
||||
],
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
||||
value=[{"value": x} for x in user_ids],
|
||||
),
|
||||
)
|
||||
|
@ -1,11 +1,9 @@
|
||||
"""Custom SCIM schemas"""
|
||||
|
||||
from pydantic import Field
|
||||
from pydanticscim.group import Group as BaseGroup
|
||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
||||
from pydanticscim.responses import SCIMError as BaseSCIMError
|
||||
from pydanticscim.service_provider import Bulk as BaseBulk
|
||||
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import (
|
||||
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
||||
)
|
||||
@ -31,16 +29,10 @@ class Group(BaseGroup):
|
||||
meta: dict | None = None
|
||||
|
||||
|
||||
class Bulk(BaseBulk):
|
||||
|
||||
maxOperations: int = Field()
|
||||
|
||||
|
||||
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||
"""ServiceProviderConfig with fallback"""
|
||||
|
||||
_is_fallback: bool | None = False
|
||||
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
|
||||
|
||||
@property
|
||||
def is_fallback(self) -> bool:
|
||||
@ -53,7 +45,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||
"""Get default configuration, which doesn't support any optional features as fallback"""
|
||||
return ServiceProviderConfiguration(
|
||||
patch=Patch(supported=False),
|
||||
bulk=Bulk(supported=False, maxOperations=0),
|
||||
bulk=Bulk(supported=False),
|
||||
filter=Filter(supported=False),
|
||||
changePassword=ChangePassword(supported=False),
|
||||
sort=Sort(supported=False),
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-12 12:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_scim", "0008_rename_scimgroup_scimprovidergroup_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="scimmapping",
|
||||
options={
|
||||
"verbose_name": "SCIM Provider Mapping",
|
||||
"verbose_name_plural": "SCIM Provider Mappings",
|
||||
},
|
||||
),
|
||||
]
|
@ -133,7 +133,7 @@ class SCIMMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-scim-form"
|
||||
return "ak-property-mapping-scim-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -142,8 +142,8 @@ class SCIMMapping(PropertyMapping):
|
||||
return SCIMMappingSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"SCIM Provider Mapping {self.name}"
|
||||
return f"SCIM Mapping {self.name}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("SCIM Provider Mapping")
|
||||
verbose_name_plural = _("SCIM Provider Mappings")
|
||||
verbose_name = _("SCIM Mapping")
|
||||
verbose_name_plural = _("SCIM Mappings")
|
||||
|
@ -13,5 +13,5 @@ api_urlpatterns = [
|
||||
("providers/scim", SCIMProviderViewSet),
|
||||
("providers/scim_users", SCIMProviderUserViewSet),
|
||||
("providers/scim_groups", SCIMProviderGroupViewSet),
|
||||
("propertymappings/provider/scim", SCIMMappingViewSet),
|
||||
("propertymappings/scim", SCIMMappingViewSet),
|
||||
]
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.auth.management import _get_all_permissions
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.db import models
|
||||
from django.db.transaction import atomic
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@ -10,26 +10,28 @@ from guardian.shortcuts import assign_perm
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.reflection import get_apps
|
||||
|
||||
|
||||
def get_permission_choices():
|
||||
all_perms = []
|
||||
for app in get_apps():
|
||||
for model in app.get_models():
|
||||
for perm, _desc in _get_all_permissions(model._meta):
|
||||
all_perms.append((model, perm))
|
||||
return sorted(
|
||||
[
|
||||
(
|
||||
f"{model._meta.app_label}.{perm}",
|
||||
f"{model._meta.app_label}.{perm}",
|
||||
)
|
||||
for model, perm in all_perms
|
||||
]
|
||||
def get_permissions():
|
||||
return (
|
||||
Permission.objects.all()
|
||||
.select_related("content_type")
|
||||
.filter(
|
||||
content_type__app_label__startswith="authentik",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_permission_choices() -> list[tuple[str, str]]:
|
||||
return [
|
||||
(
|
||||
f"{x.content_type.app_label}.{x.codename}",
|
||||
f"{x.content_type.app_label}.{x.codename}",
|
||||
)
|
||||
for x in get_permissions()
|
||||
]
|
||||
|
||||
|
||||
class Role(SerializerModel):
|
||||
"""RBAC role, which can have different permissions (both global and per-object) attached
|
||||
to it."""
|
||||
|
@ -87,11 +87,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
|
||||
|
||||
def _get_startup_tasks_default_tenant() -> list[Callable]:
|
||||
"""Get all tasks to be run on startup for the default tenant"""
|
||||
from authentik.outposts.tasks import outpost_connection_discovery
|
||||
|
||||
return [
|
||||
outpost_connection_discovery,
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _get_startup_tasks_all_tenants() -> list[Callable]:
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from hashlib import sha512
|
||||
from ipaddress import ip_address
|
||||
from time import perf_counter, time
|
||||
from typing import Any
|
||||
|
||||
@ -175,7 +174,6 @@ class ClientIPMiddleware:
|
||||
|
||||
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
|
||||
self.get_response = get_response
|
||||
self.logger = get_logger().bind()
|
||||
|
||||
def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str:
|
||||
"""Attempt to get the client's IP by checking common HTTP Headers.
|
||||
@ -187,16 +185,11 @@ class ClientIPMiddleware:
|
||||
"HTTP_X_FORWARDED_FOR",
|
||||
"REMOTE_ADDR",
|
||||
)
|
||||
try:
|
||||
for _header in headers:
|
||||
if _header in meta:
|
||||
ips: list[str] = meta.get(_header).split(",")
|
||||
# Ensure the IP parses as a valid IP
|
||||
return str(ip_address(ips[0].strip()))
|
||||
return self.default_ip
|
||||
except ValueError as exc:
|
||||
self.logger.debug("Invalid remote IP", exc=exc)
|
||||
return self.default_ip
|
||||
for _header in headers:
|
||||
if _header in meta:
|
||||
ips: list[str] = meta.get(_header).split(",")
|
||||
return ips[0].strip()
|
||||
return self.default_ip
|
||||
|
||||
# FIXME: this should probably not be in `root` but rather in a middleware in `outposts`
|
||||
# but for now it's fine
|
||||
@ -233,11 +226,7 @@ class ClientIPMiddleware:
|
||||
Scope.get_isolation_scope().set_user(user)
|
||||
# Set the outpost service account on the request
|
||||
setattr(request, self.request_attr_outpost_user, user)
|
||||
try:
|
||||
return str(ip_address(delegated_ip))
|
||||
except ValueError as exc:
|
||||
self.logger.debug("Invalid remote IP from Outpost", exc=exc)
|
||||
return None
|
||||
return delegated_ip
|
||||
|
||||
def _get_client_ip(self, request: HttpRequest | None) -> str:
|
||||
"""Attempt to get the client's IP by checking common HTTP Headers.
|
||||
|
@ -9,7 +9,6 @@ import orjson
|
||||
from celery.schedules import crontab
|
||||
from django.conf import ImproperlyConfigured
|
||||
from sentry_sdk import set_tag
|
||||
from xmlsec import enable_debug_trace
|
||||
|
||||
from authentik import __version__
|
||||
from authentik.lib.config import CONFIG, redis_url
|
||||
@ -521,7 +520,6 @@ if DEBUG:
|
||||
"rest_framework.renderers.BrowsableAPIRenderer"
|
||||
)
|
||||
SHARED_APPS.insert(SHARED_APPS.index("django.contrib.staticfiles"), "daphne")
|
||||
enable_debug_trace(True)
|
||||
|
||||
TENANT_APPS.append("authentik.core")
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""authentik storage backends"""
|
||||
|
||||
import os
|
||||
from urllib.parse import parse_qsl, urlsplit
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import SuspiciousOperation
|
||||
@ -111,34 +110,3 @@ class S3Storage(BaseS3Storage):
|
||||
if self.querystring_auth:
|
||||
return url
|
||||
return self._strip_signing_parameters(url)
|
||||
|
||||
def _strip_signing_parameters(self, url):
|
||||
# Boto3 does not currently support generating URLs that are unsigned. Instead
|
||||
# we take the signed URLs and strip any querystring params related to signing
|
||||
# and expiration.
|
||||
# Note that this may end up with URLs that are still invalid, especially if
|
||||
# params are passed in that only work with signed URLs, e.g. response header
|
||||
# params.
|
||||
# The code attempts to strip all query parameters that match names of known
|
||||
# parameters from v2 and v4 signatures, regardless of the actual signature
|
||||
# version used.
|
||||
split_url = urlsplit(url)
|
||||
qs = parse_qsl(split_url.query, keep_blank_values=True)
|
||||
blacklist = {
|
||||
"x-amz-algorithm",
|
||||
"x-amz-credential",
|
||||
"x-amz-date",
|
||||
"x-amz-expires",
|
||||
"x-amz-signedheaders",
|
||||
"x-amz-signature",
|
||||
"x-amz-security-token",
|
||||
"awsaccesskeyid",
|
||||
"expires",
|
||||
"signature",
|
||||
}
|
||||
filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist)
|
||||
# Note: Parameters that did not have a value in the original query string will
|
||||
# have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar=
|
||||
joined_qs = ("=".join(keyval) for keyval in filtered_qs)
|
||||
split_url = split_url._replace(query="&".join(joined_qs))
|
||||
return split_url.geturl()
|
||||
|
@ -3,7 +3,6 @@
|
||||
from typing import Any
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
@ -40,8 +39,9 @@ class LDAPSourceSerializer(SourceSerializer):
|
||||
"""Get cached source connectivity"""
|
||||
return cache.get(CACHE_KEY_STATUS + source.slug, None)
|
||||
|
||||
def validate_sync_users_password(self, sync_users_password: bool) -> bool:
|
||||
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Check that only a single source has password_sync on"""
|
||||
sync_users_password = attrs.get("sync_users_password", True)
|
||||
if sync_users_password:
|
||||
sources = LDAPSource.objects.filter(sync_users_password=True)
|
||||
if self.instance:
|
||||
@ -49,31 +49,11 @@ class LDAPSourceSerializer(SourceSerializer):
|
||||
if sources.exists():
|
||||
raise ValidationError(
|
||||
{
|
||||
"sync_users_password": _(
|
||||
"sync_users_password": (
|
||||
"Only a single LDAP Source with password synchronization is allowed"
|
||||
)
|
||||
}
|
||||
)
|
||||
return sync_users_password
|
||||
|
||||
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate property mappings with sync_ flags"""
|
||||
types = ["user", "group"]
|
||||
for type in types:
|
||||
toggle_value = attrs.get(f"sync_{type}s", False)
|
||||
mappings_field = f"{type}_property_mappings"
|
||||
mappings_value = attrs.get(mappings_field, [])
|
||||
if toggle_value and len(mappings_value) == 0:
|
||||
raise ValidationError(
|
||||
{
|
||||
mappings_field: _(
|
||||
(
|
||||
"When 'Sync {type}s' is enabled, '{type}s property "
|
||||
"mappings' cannot be empty."
|
||||
).format(type=type)
|
||||
)
|
||||
}
|
||||
)
|
||||
return super().validate(attrs)
|
||||
|
||||
class Meta:
|
||||
@ -186,12 +166,11 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
for sync_class in SYNC_CLASSES:
|
||||
class_name = sync_class.name()
|
||||
all_objects.setdefault(class_name, [])
|
||||
for page in sync_class(source).get_objects(size_limit=10):
|
||||
for obj in page:
|
||||
obj: dict
|
||||
obj.pop("raw_attributes", None)
|
||||
obj.pop("raw_dn", None)
|
||||
all_objects[class_name].append(obj)
|
||||
for obj in sync_class(source).get_objects(size_limit=10):
|
||||
obj: dict
|
||||
obj.pop("raw_attributes", None)
|
||||
obj.pop("raw_dn", None)
|
||||
all_objects[class_name].append(obj)
|
||||
return Response(data=all_objects)
|
||||
|
||||
|
||||
|
@ -290,7 +290,7 @@ class LDAPSourcePropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-source-ldap-form"
|
||||
return "ak-property-mapping-ldap-source-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -26,16 +26,17 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
|
||||
"""Ensure that source is synced on save (if enabled)"""
|
||||
if not instance.enabled:
|
||||
return
|
||||
ldap_connectivity_check.delay(instance.pk)
|
||||
# Don't sync sources when they don't have any property mappings. This will only happen if:
|
||||
# - the user forgets to set them or
|
||||
# - the source is newly created, this is the first save event
|
||||
# and the mappings are created with an m2m event
|
||||
if instance.sync_users and not instance.user_property_mappings.exists():
|
||||
return
|
||||
if instance.sync_groups and not instance.group_property_mappings.exists():
|
||||
if (
|
||||
not instance.user_property_mappings.exists()
|
||||
or not instance.group_property_mappings.exists()
|
||||
):
|
||||
return
|
||||
ldap_sync_single.delay(instance.pk)
|
||||
ldap_connectivity_check.delay(instance.pk)
|
||||
|
||||
|
||||
@receiver(password_validate)
|
||||
|
@ -38,11 +38,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
search_base=self.base_dn_groups,
|
||||
search_filter=self._source.group_object_filter,
|
||||
search_scope=SUBTREE,
|
||||
attributes=[
|
||||
ALL_ATTRIBUTES,
|
||||
ALL_OPERATIONAL_ATTRIBUTES,
|
||||
self._source.object_uniqueness_field,
|
||||
],
|
||||
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -57,9 +53,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
continue
|
||||
attributes = group.get("attributes", {})
|
||||
group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
|
||||
if not attributes.get(self._source.object_uniqueness_field):
|
||||
if self._source.object_uniqueness_field not in attributes:
|
||||
self.message(
|
||||
f"Uniqueness field not found/not set in attributes: '{group_dn}'",
|
||||
f"Cannot find uniqueness field in attributes: '{group_dn}'",
|
||||
attributes=attributes.keys(),
|
||||
dn=group_dn,
|
||||
)
|
||||
|
@ -40,11 +40,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
search_base=self.base_dn_users,
|
||||
search_filter=self._source.user_object_filter,
|
||||
search_scope=SUBTREE,
|
||||
attributes=[
|
||||
ALL_ATTRIBUTES,
|
||||
ALL_OPERATIONAL_ATTRIBUTES,
|
||||
self._source.object_uniqueness_field,
|
||||
],
|
||||
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -59,9 +55,9 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
continue
|
||||
attributes = user.get("attributes", {})
|
||||
user_dn = flatten(user.get("entryDN", user.get("dn")))
|
||||
if not attributes.get(self._source.object_uniqueness_field):
|
||||
if self._source.object_uniqueness_field not in attributes:
|
||||
self.message(
|
||||
f"Uniqueness field not found/not set in attributes: '{user_dn}'",
|
||||
f"Cannot find uniqueness field in attributes: '{user_dn}'",
|
||||
attributes=attributes.keys(),
|
||||
dn=user_dn,
|
||||
)
|
||||
|
4
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
4
authentik/sources/ldap/sync/vendor/ms_ad.py
vendored
@ -78,9 +78,7 @@ class MicrosoftActiveDirectory(BaseLDAPSynchronizer):
|
||||
# /useraccountcontrol-manipulate-account-properties
|
||||
uac_bit = attributes.get("userAccountControl", 512)
|
||||
uac = UserAccountControl(uac_bit)
|
||||
is_active = (
|
||||
UserAccountControl.ACCOUNTDISABLE not in uac and UserAccountControl.LOCKOUT not in uac
|
||||
)
|
||||
is_active = UserAccountControl.ACCOUNTDISABLE not in uac
|
||||
if is_active != user.is_active:
|
||||
user.is_active = is_active
|
||||
user.save()
|
||||
|
@ -50,35 +50,3 @@ class LDAPAPITests(APITestCase):
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
|
||||
def test_sync_users_mapping_empty(self):
|
||||
"""Check that when sync_users is enabled, property mappings must be set"""
|
||||
serializer = LDAPSourceSerializer(
|
||||
data={
|
||||
"name": "foo",
|
||||
"slug": " foo",
|
||||
"server_uri": "ldaps://1.2.3.4",
|
||||
"bind_cn": "",
|
||||
"bind_password": LDAP_PASSWORD,
|
||||
"base_dn": "dc=foo",
|
||||
"sync_users": True,
|
||||
"user_property_mappings": [],
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
|
||||
def test_sync_groups_mapping_empty(self):
|
||||
"""Check that when sync_groups is enabled, property mappings must be set"""
|
||||
serializer = LDAPSourceSerializer(
|
||||
data={
|
||||
"name": "foo",
|
||||
"slug": " foo",
|
||||
"server_uri": "ldaps://1.2.3.4",
|
||||
"bind_cn": "",
|
||||
"bind_password": LDAP_PASSWORD,
|
||||
"base_dn": "dc=foo",
|
||||
"sync_groups": True,
|
||||
"group_property_mappings": [],
|
||||
}
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
|
@ -268,7 +268,7 @@ class OAuthSourcePropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-source-oauth-form"
|
||||
return "ak-property-mapping-oauth-source-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -123,7 +123,7 @@ class PlexSourcePropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-source-plex-form"
|
||||
return "ak-property-mapping-plex-source-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -299,7 +299,7 @@ class SAMLSourcePropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-source-saml-form"
|
||||
return "ak-property-mapping-saml-source-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -6,14 +6,12 @@ NS_SAML_PROTOCOL = "urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
NS_SAML_ASSERTION = "urn:oasis:names:tc:SAML:2.0:assertion"
|
||||
NS_SAML_METADATA = "urn:oasis:names:tc:SAML:2.0:metadata"
|
||||
NS_SIGNATURE = "http://www.w3.org/2000/09/xmldsig#"
|
||||
NS_ENC = "http://www.w3.org/2001/04/xmlenc#"
|
||||
|
||||
NS_MAP = {
|
||||
"samlp": NS_SAML_PROTOCOL,
|
||||
"saml": NS_SAML_ASSERTION,
|
||||
"ds": NS_SIGNATURE,
|
||||
"md": NS_SAML_METADATA,
|
||||
"xenc": NS_ENC,
|
||||
}
|
||||
|
||||
SAML_NAME_ID_FORMAT_EMAIL = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
|
||||
|
@ -76,7 +76,7 @@ class RequestProcessor:
|
||||
auth_n_request,
|
||||
xmlsec.constants.TransformExclC14N,
|
||||
sign_algorithm_transform,
|
||||
ns=xmlsec.constants.DSigNs,
|
||||
ns="ds", # type: ignore
|
||||
)
|
||||
auth_n_request.append(signature)
|
||||
|
||||
|
@ -30,9 +30,7 @@ class TestMetadataProcessor(TestCase):
|
||||
xml = MetadataProcessor(self.source, request).build_entity_descriptor()
|
||||
metadata = lxml_from_string(xml)
|
||||
|
||||
schema = etree.XMLSchema(
|
||||
etree.parse("schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()) # nosec
|
||||
)
|
||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
|
||||
self.assertTrue(schema.validate(metadata))
|
||||
|
||||
def test_metadata_consistent(self):
|
||||
|
@ -85,7 +85,7 @@ class SCIMSourcePropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-source-scim-form"
|
||||
return "ak-property-mapping-scim-source-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -14,9 +14,7 @@ class Migration(migrations.Migration):
|
||||
migrations.AddField(
|
||||
model_name="duodevice",
|
||||
name="created",
|
||||
field=models.DateTimeField(
|
||||
auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0, tzinfo=datetime.UTC)
|
||||
),
|
||||
field=models.DateTimeField(auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0)),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
|
@ -14,9 +14,7 @@ class Migration(migrations.Migration):
|
||||
migrations.AddField(
|
||||
model_name="smsdevice",
|
||||
name="created",
|
||||
field=models.DateTimeField(
|
||||
auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0, tzinfo=datetime.UTC)
|
||||
),
|
||||
field=models.DateTimeField(auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0)),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
|
@ -14,9 +14,7 @@ class Migration(migrations.Migration):
|
||||
migrations.AddField(
|
||||
model_name="staticdevice",
|
||||
name="created",
|
||||
field=models.DateTimeField(
|
||||
auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0, tzinfo=datetime.UTC)
|
||||
),
|
||||
field=models.DateTimeField(auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0)),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
|
@ -14,9 +14,7 @@ class Migration(migrations.Migration):
|
||||
migrations.AddField(
|
||||
model_name="totpdevice",
|
||||
name="created",
|
||||
field=models.DateTimeField(
|
||||
auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0, tzinfo=datetime.UTC)
|
||||
),
|
||||
field=models.DateTimeField(auto_now_add=True, default=datetime.datetime(1, 1, 1, 0, 0)),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user