Compare commits

..

17 Commits

Author SHA1 Message Date
1dc4fbbb2b Attempting coercion to ESM with Vite & Wdio 2024-08-09 10:55:44 -07:00
3332de267d Testing needs to be able to import from dependent packages. 2024-08-09 10:53:20 -07:00
ab366d0ec2 Testing needs to be able to import from dependent packages. 2024-08-09 10:50:51 -07:00
ac162582aa Modernization continues. 2024-08-09 10:43:28 -07:00
7d82e029d5 wdio does not need the storybook cssimport hack. 2024-08-09 10:37:17 -07:00
9b40ecb023 web: restricted all eslints to local checks only and blocked them from cache analysis 2024-08-09 09:56:55 -07:00
0cc0fdaae3 Not ready for primetime. 2024-08-09 09:35:46 -07:00
b55b168718 web: move common into its own package.
```
$ mkdir ./packages/common
$ git mv ./src/common ./packages/common/src
```

... and then added all of the boilerplate needed to drive with Wireit, build with ESlint, typecheck
with TSC, and then spell check documentation and comments, security checks of package.json and
package-lock.json, format.

... and _then_ fix all of the minor, nitpicky things ESLint 9 found in the package.

... and _then_ wire the whole thing into our build so that we can find it as a package, removing
it as an alias from the base package definition and turning it into a workspace.  Although it is
a workspace package, it's currently configured to build completely independently.

It could be published as an independent NPM package, although I don't recommended that at this time.

I've wanted to break the UI up into smaller, more digestible chunks for awhile, but was always
reluctant to, since I didn't want to mess with other teams' mental models of the code layout.
@Beryju, seeing the success of the Simple Flow Executor as an independent package, thought it might
be worthwhile to see what effort it took to break the graph of our independent apps (User, Flow, and
Admin) and their dependencies (Common <- Elements <- Components, Common <- Locales) into packages.

Turns out, it's not too bad.  It's going to be fiddly for awhile until things settle down, but
overall the experiment has been a success.

The `tsconfig.json` doesn't refer to the base because we want this to build independently; tooling
will be needed to ensure all of our `tsconfig` files in the future will be consistent across all
packages.

- We can use the ESLint boilerplate as-is.
- We have to run TSC as a separate (but fortunately parallel) build step, as client code will need
  the built types. Final builds will be fractionally slower, but Wireit can detect when a monorepo
  package is unchanged and can skip rebuilding `common` if it's not needed, so the development loop
  will be faster.
- The ESBuild boilerplate is different for libraries with UI, libraries without UI (like this one),
  and apps, and we'll have to have three different routines for them. Once we are building
  independent _apps_, getting them into the `dist` folder will be an interesting challenge; we may
  end up with two different builds, one to bundle it in *in the app*, and another to bundle it *for
  Django*. That's mostly an issue of targeting and integration, and shouldn't take too much time.
- Spelling, formatting, and package checking aren't affected.
- `Locales` is our biggest challenge, as usual. I have found only [one article on it
  anywhere](https://medium.com/tech-at-zet/streamlining-localization-in-a-monorepo-using-i18n-js-e7c521ff69d4),
  and it recommends creating a single package in which to keep all of the localizations and the
  localization machinery. That seems like a sound approach, but we haven't (yet) gotten there.

`common` is a bit of a junk drawer: there are global utilities in there, there are app-specific
helpers, there are plug-in specific helpers, and so on. Figuring out exactly what does what and
making more specific packages may be in our future.
2024-08-09 08:38:30 -07:00
c46dc8f290 Not sure how that happened. 2024-08-08 16:10:07 -07:00
e48da3520c Fix type checking issues at the TSC level. 2024-08-08 16:03:37 -07:00
1ec4652c60 Fix dependent types needed before attempting typecheck. 2024-08-08 15:51:09 -07:00
e375646705 Made linting the subpackages a requirement of success. 2024-08-08 15:47:54 -07:00
b84652d9d3 Fix eslint so it only lints the local package. Other packages have their own responsibilities. 2024-08-08 15:44:43 -07:00
74b8da28ca Added common:build" to the list of dependencies for building, which is what you want, right? 2024-08-08 15:32:11 -07:00
9084c7c6b4 web: all the basic commands are working: build, build types, lint source, lint lockfile, lint packagefile, lint types, lint spelling, format source, format package. 2024-08-08 15:08:05 -07:00
7a0b227b46 Interim commit 2024-08-08 14:25:14 -07:00
cc9128fd46 Move begun; sfe cleanup completed. 2024-08-08 11:14:50 -07:00
639 changed files with 20036 additions and 42911 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2024.8.4 current_version = 2024.6.3
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -29,9 +29,9 @@ outputs:
imageTags: imageTags:
description: "Docker image tags" description: "Docker image tags"
value: ${{ steps.ev.outputs.imageTags }} value: ${{ steps.ev.outputs.imageTags }}
attestImageNames: imageNames:
description: "Docker image names used for attestation" description: "Docker image names"
value: ${{ steps.ev.outputs.attestImageNames }} value: ${{ steps.ev.outputs.imageNames }}
imageMainTag: imageMainTag:
description: "Docker image main tag" description: "Docker image main tag"
value: ${{ steps.ev.outputs.imageMainTag }} value: ${{ steps.ev.outputs.imageMainTag }}

View File

@ -51,24 +51,15 @@ else:
] ]
image_main_tag = image_tags[0].split(":")[-1] image_main_tag = image_tags[0].split(":")[-1]
image_tags_rendered = ",".join(image_tags)
image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags))
def get_attest_image_names(image_with_tags: list[str]):
"""Attestation only for GHCR"""
image_tags = []
for image_name in set(name.split(":")[0] for name in image_with_tags):
if not image_name.startswith("ghcr.io"):
continue
image_tags.append(image_name)
return ",".join(set(image_tags))
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print(f"shouldBuild={should_build}", file=_output) print(f"shouldBuild={should_build}", file=_output)
print(f"sha={sha}", file=_output) print(f"sha={sha}", file=_output)
print(f"version={version}", file=_output) print(f"version={version}", file=_output)
print(f"prerelease={prerelease}", file=_output) print(f"prerelease={prerelease}", file=_output)
print(f"imageTags={','.join(image_tags)}", file=_output) print(f"imageTags={image_tags_rendered}", file=_output)
print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output) print(f"imageNames={image_names_rendered}", file=_output)
print(f"imageMainTag={image_main_tag}", file=_output) print(f"imageMainTag={image_main_tag}", file=_output)
print(f"imageMainName={image_tags[0]}", file=_output) print(f"imageMainName={image_tags[0]}", file=_output)

View File

@ -58,10 +58,6 @@ updates:
patterns: patterns:
- "@rollup/*" - "@rollup/*"
- "rollup-*" - "rollup-*"
swc:
patterns:
- "@swc/*"
- "swc-*"
wdio: wdio:
patterns: patterns:
- "@wdio/*" - "@wdio/*"

View File

@ -261,7 +261,7 @@ jobs:
id: attest id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }} if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
pr-comment: pr-comment:

View File

@ -31,7 +31,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v6 uses: golangci/golangci-lint-action@v6
with: with:
version: latest version: v1.54.2
args: --timeout 5000s --verbose args: --timeout 5000s --verbose
skip-cache: true skip-cache: true
test-unittest: test-unittest:
@ -115,7 +115,7 @@ jobs:
id: attest id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }} if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
build-binary: build-binary:

View File

@ -92,4 +92,4 @@ jobs:
run: make gen-client-ts run: make gen-client-ts
- name: test - name: test
working-directory: web/ working-directory: web/
run: npm run test || exit 0 run: npm run test

View File

@ -51,14 +51,12 @@ jobs:
secrets: | secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }} tags: ${{ steps.ev.outputs.imageTags }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
- uses: actions/attest-build-provenance@v1 - uses: actions/attest-build-provenance@v1
id: attest id: attest
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
build-outpost: build-outpost:
@ -113,8 +111,6 @@ jobs:
id: push id: push
with: with:
push: true push: true
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }} tags: ${{ steps.ev.outputs.imageTags }}
file: ${{ matrix.type }}.Dockerfile file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
@ -122,7 +118,7 @@ jobs:
- uses: actions/attest-build-provenance@v1 - uses: actions/attest-build-provenance@v1
id: attest id: attest
with: with:
subject-name: ${{ steps.ev.outputs.attestImageNames }} subject-name: ${{ steps.ev.outputs.imageNames }}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
build-outpost-binary: build-outpost-binary:

View File

@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1 # syntax=docker/dockerfile:1
# Stage 1: Build website # Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS website-builder FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as website-builder
ENV NODE_ENV=production ENV NODE_ENV=production
@ -20,7 +20,7 @@ COPY ./SECURITY.md /work/
RUN npm run build-bundled RUN npm run build-bundled
# Stage 2: Build webui # Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS web-builder FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as web-builder
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
@ -43,7 +43,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build RUN npm run build
# Stage 3: Build go proxy # Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.23-fips-bookworm AS go-builder FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
ARG TARGETOS ARG TARGETOS
ARG TARGETARCH ARG TARGETARCH
@ -80,7 +80,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
go build -o /go/authentik ./cmd/server go build -o /go/authentik ./cmd/server
# Stage 4: MaxMind GeoIP # Stage 4: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 AS geoip FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN" ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
ENV GEOIPUPDATE_VERBOSE="1" ENV GEOIPUPDATE_VERBOSE="1"
@ -94,10 +94,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 5: Python dependencies # Stage 5: Python dependencies
FROM ghcr.io/goauthentik/fips-python:3.12.5-slim-bookworm-fips-full AS python-deps FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps
ARG TARGETARCH
ARG TARGETVARIANT
WORKDIR /ak-root/poetry WORKDIR /ak-root/poetry
@ -124,17 +121,17 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
pip install --force-reinstall /wheels/*" pip install --force-reinstall /wheels/*"
# Stage 6: Run # Stage 6: Run
FROM ghcr.io/goauthentik/fips-python:3.12.5-slim-bookworm-fips-full AS final-image FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image
ARG VERSION
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ARG VERSION
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
LABEL org.opencontainers.image.url=https://goauthentik.io LABEL org.opencontainers.image.url https://goauthentik.io
LABEL org.opencontainers.image.description="goauthentik.io Main server image, see https://goauthentik.io for more info." LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info.
LABEL org.opencontainers.image.source=https://github.com/goauthentik/authentik LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version=${VERSION} LABEL org.opencontainers.image.version ${VERSION}
LABEL org.opencontainers.image.revision=${GIT_BUILD_HASH} LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
WORKDIR / WORKDIR /

View File

@ -43,7 +43,7 @@ help: ## Show this help
sort sort
@echo "" @echo ""
go-test: test-go:
go test -timeout 0 -v -race -cover ./... go test -timeout 0 -v -race -cover ./...
test-docker: ## Run all tests in a docker-compose test-docker: ## Run all tests in a docker-compose
@ -210,9 +210,6 @@ web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting is
web-install: ## Install the necessary libraries to build the Authentik UI web-install: ## Install the necessary libraries to build the Authentik UI
cd web && npm ci cd web && npm ci
web-test: ## Run tests for the Authentik UI
cd web && npm run test
web-watch: ## Build and watch the Authentik UI for changes, updating automatically web-watch: ## Build and watch the Authentik UI for changes, updating automatically
rm -rf web/dist/ rm -rf web/dist/
mkdir web/dist/ mkdir web/dist/

View File

@ -15,9 +15,7 @@
## What is authentik? ## What is authentik?
authentik is an open-source Identity Provider that emphasizes flexibility and versatility, with support for a wide set of protocols. authentik is an open-source Identity Provider that emphasizes flexibility and versatility. It can be seamlessly integrated into existing environments to support new protocols. authentik is also a great solution for implementing sign-up, recovery, and other similar features in your application, saving you the hassle of dealing with them.
Our [enterprise offer](https://goauthentik.io/pricing) can also be used as a self-hosted replacement for large-scale deployments of Okta/Auth0, Entra ID, Ping Identity, or other legacy IdPs for employees and B2B2C use.
## Installation ## Installation

View File

@ -2,7 +2,7 @@
from os import environ from os import environ
__version__ = "2024.8.4" __version__ = "2024.6.3"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -73,7 +73,7 @@ class SystemInfoSerializer(PassiveSerializer):
"authentik_version": get_full_version(), "authentik_version": get_full_version(),
"environment": get_env(), "environment": get_env(),
"openssl_fips_enabled": ( "openssl_fips_enabled": (
backend._fips_enabled if LicenseKey.get_total().status().is_valid else None backend._fips_enabled if LicenseKey.get_total().is_valid() else None
), ),
"openssl_version": OPENSSL_VERSION, "openssl_version": OPENSSL_VERSION,
"platform": platform.platform(), "platform": platform.platform(),

View File

@ -12,7 +12,6 @@ from rest_framework.views import APIView
from authentik import __version__, get_build_hash from authentik import __version__, get_build_hash
from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.outposts.models import Outpost
class VersionSerializer(PassiveSerializer): class VersionSerializer(PassiveSerializer):
@ -23,7 +22,6 @@ class VersionSerializer(PassiveSerializer):
version_latest_valid = SerializerMethodField() version_latest_valid = SerializerMethodField()
build_hash = SerializerMethodField() build_hash = SerializerMethodField()
outdated = SerializerMethodField() outdated = SerializerMethodField()
outpost_outdated = SerializerMethodField()
def get_build_hash(self, _) -> str: def get_build_hash(self, _) -> str:
"""Get build hash, if version is not latest or released""" """Get build hash, if version is not latest or released"""
@ -49,15 +47,6 @@ class VersionSerializer(PassiveSerializer):
"""Check if we're running the latest version""" """Check if we're running the latest version"""
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance)) return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
def get_outpost_outdated(self, _) -> bool:
"""Check if any outpost is outdated/has a version mismatch"""
any_outdated = False
for outpost in Outpost.objects.all():
for state in outpost.state:
if state.version_outdated:
any_outdated = True
return any_outdated
class VersionView(APIView): class VersionView(APIView):
"""Get running and latest version.""" """Get running and latest version."""

View File

@ -51,11 +51,9 @@ class BlueprintInstanceSerializer(ModelSerializer):
context = self.instance.context if self.instance else {} context = self.instance.context if self.instance else {}
valid, logs = Importer.from_string(content, context).validate() valid, logs = Importer.from_string(content, context).validate()
if not valid: if not valid:
text_logs = "\n".join([x["event"] for x in logs])
raise ValidationError( raise ValidationError(
[ _("Failed to validate blueprint: {logs}".format_map({"logs": text_logs}))
_("Failed to validate blueprint"),
*[f"- {x.event}" for x in logs],
]
) )
return content return content

View File

@ -78,5 +78,5 @@ class TestBlueprintsV1API(APITestCase):
self.assertEqual(res.status_code, 400) self.assertEqual(res.status_code, 400)
self.assertJSONEqual( self.assertJSONEqual(
res.content.decode(), res.content.decode(),
{"content": ["Failed to validate blueprint", "- Invalid blueprint version"]}, {"content": ["Failed to validate blueprint: Invalid blueprint version"]},
) )

View File

@ -171,7 +171,7 @@ class Importer:
def default_context(self): def default_context(self):
"""Default context""" """Default context"""
return { return {
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid, "goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid(),
"goauthentik.io/rbac/models": rbac_models(), "goauthentik.io/rbac/models": rbac_models(),
} }
@ -429,7 +429,7 @@ class Importer:
orig_import = deepcopy(self._import) orig_import = deepcopy(self._import)
if self._import.version != 1: if self._import.version != 1:
self.logger.warning("Invalid blueprint version") self.logger.warning("Invalid blueprint version")
return False, [LogEvent("Invalid blueprint version", log_level="warning", logger=None)] return False, [{"event": "Invalid blueprint version"}]
with ( with (
transaction_rollback(), transaction_rollback(),
capture_logs() as logs, capture_logs() as logs,

View File

@ -30,10 +30,8 @@ from authentik.core.api.utils import (
PassiveSerializer, PassiveSerializer,
) )
from authentik.core.expression.evaluator import PropertyMappingEvaluator from authentik.core.expression.evaluator import PropertyMappingEvaluator
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group, PropertyMapping, User from authentik.core.models import Group, PropertyMapping, User
from authentik.events.utils import sanitize_item from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.api.exec import PolicyTestSerializer from authentik.policies.api.exec import PolicyTestSerializer
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
@ -164,15 +162,12 @@ class PropertyMappingViewSet(
response_data = {"successful": True, "result": ""} response_data = {"successful": True, "result": ""}
try: try:
result = mapping.evaluate(dry_run=True, **context) result = mapping.evaluate(**context)
response_data["result"] = dumps( response_data["result"] = dumps(
sanitize_item(result), indent=(4 if format_result else None) sanitize_item(result), indent=(4 if format_result else None)
) )
except PropertyMappingExpressionException as exc:
response_data["result"] = exception_to_string(exc.exc)
response_data["successful"] = False
except Exception as exc: except Exception as exc:
response_data["result"] = exception_to_string(exc) response_data["result"] = str(exc)
response_data["successful"] = False response_data["successful"] = False
response = PropertyMappingTestResultSerializer(response_data) response = PropertyMappingTestResultSerializer(response_data)
return Response(response.data) return Response(response.data)

View File

@ -14,7 +14,6 @@ from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.rbac.filters import ObjectFilter
class DeleteAction(Enum): class DeleteAction(Enum):
@ -54,7 +53,7 @@ class UsedByMixin:
@extend_schema( @extend_schema(
responses={200: UsedBySerializer(many=True)}, responses={200: UsedBySerializer(many=True)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def used_by(self, request: Request, *args, **kwargs) -> Response: def used_by(self, request: Request, *args, **kwargs) -> Response:
"""Get a list of all objects that use this object""" """Get a list of all objects that use this object"""
model: Model = self.get_object() model: Model = self.get_object()

View File

@ -678,13 +678,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
if not request.tenant.impersonation: if not request.tenant.impersonation:
LOGGER.debug("User attempted to impersonate", user=request.user) LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401) return Response(status=401)
user_to_be = self.get_object() if not request.user.has_perm("impersonate"):
# Check both object-level perms and global perms
if not request.user.has_perm(
"authentik_core.impersonate", user_to_be
) and not request.user.has_perm("authentik_core.impersonate"):
LOGGER.debug("User attempted to impersonate without permissions", user=request.user) LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return Response(status=401) return Response(status=401)
user_to_be = self.get_object()
if user_to_be.pk == self.request.user.pk: if user_to_be.pk == self.request.user.pk:
LOGGER.debug("User attempted to impersonate themselves", user=request.user) LOGGER.debug("User attempted to impersonate themselves", user=request.user)
return Response(status=401) return Response(status=401)

View File

@ -9,11 +9,10 @@ class Command(TenantCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("--type", type=str, required=True) parser.add_argument("--type", type=str, required=True)
parser.add_argument("--all", action="store_true", default=False) parser.add_argument("--all", action="store_true")
parser.add_argument("usernames", nargs="*", type=str) parser.add_argument("usernames", nargs="+", type=str)
def handle_per_tenant(self, **options): def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"]) new_type = UserTypes(options["type"])
qs = ( qs = (
User.objects.exclude_anonymous() User.objects.exclude_anonymous()
@ -23,9 +22,6 @@ class Command(TenantCommand):
if options["usernames"] and options["all"]: if options["usernames"] and options["all"]:
self.stderr.write("--all and usernames specified, only one can be specified") self.stderr.write("--all and usernames specified, only one can be specified")
return return
if not options["usernames"] and not options["all"]:
self.stderr.write("--all or usernames must be specified")
return
if options["usernames"] and not options["all"]: if options["usernames"] and not options["all"]:
qs = qs.filter(username__in=options["usernames"]) qs = qs.filter(username__in=options["usernames"])
updated = qs.update(type=new_type) updated = qs.update(type=new_type)

View File

@ -466,6 +466,8 @@ class ApplicationQuerySet(QuerySet):
def with_provider(self) -> "QuerySet[Application]": def with_provider(self) -> "QuerySet[Application]":
qs = self.select_related("provider") qs = self.select_related("provider")
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
if LOOKUP_SEP in subclass:
continue
qs = qs.select_related(f"provider__{subclass}") qs = qs.select_related(f"provider__{subclass}")
return qs return qs
@ -543,24 +545,15 @@ class Application(SerializerModel, PolicyBindingModel):
if not self.provider: if not self.provider:
return None return None
candidates = [] for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
base_class = Provider # We don't care about recursion, skip nested models
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class): if LOOKUP_SEP in subclass:
parent = self.provider
for level in subclass.split(LOOKUP_SEP):
try:
parent = getattr(parent, level)
except AttributeError:
break
if parent in candidates:
continue continue
idx = subclass.count(LOOKUP_SEP) try:
if type(parent) is not base_class: return getattr(self.provider, subclass)
idx += 1 except AttributeError:
candidates.insert(idx, parent) pass
if not candidates: return None
return None
return candidates[-1]
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)
@ -908,7 +901,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
except ControlFlowException as exc: except ControlFlowException as exc:
raise exc raise exc
except Exception as exc: except Exception as exc:
raise PropertyMappingExpressionException(exc, self) from exc raise PropertyMappingExpressionException(self, exc) from exc
def __str__(self): def __str__(self):
return f"Property Mapping {self.name}" return f"Property Mapping {self.name}"

View File

@ -9,12 +9,9 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import OAuth2Provider from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.proxy.models import ProxyProvider
from authentik.providers.saml.models import SAMLProvider
class TestApplicationsAPI(APITestCase): class TestApplicationsAPI(APITestCase):
@ -225,31 +222,3 @@ class TestApplicationsAPI(APITestCase):
], ],
}, },
) )
def test_get_provider(self):
"""Ensure that proxy providers (at the time of writing that is the only provider
that inherits from another proxy type (OAuth) instead of inheriting from the root
provider class) is correctly looked up and selected from the database"""
slug = generate_id()
provider = ProxyProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)
slug = generate_id()
provider = SAMLProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)

View File

@ -3,10 +3,10 @@
from json import loads from json import loads
from django.urls import reverse from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user, create_test_user from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.other_user = create_test_user() self.other_user = User.objects.create(username="to-impersonate")
self.user = create_test_admin_user() self.user = create_test_admin_user()
def test_impersonate_simple(self): def test_impersonate_simple(self):
@ -44,46 +44,6 @@ class TestImpersonation(APITestCase):
self.assertEqual(response_body["user"]["username"], self.user.username) self.assertEqual(response_body["user"]["username"], self.user.username)
self.assertNotIn("original", response_body) self.assertNotIn("original", response_body)
def test_impersonate_global(self):
"""Test impersonation with global permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user)
assign_perm("authentik_core.view_user", new_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_scoped(self):
"""Test impersonation with scoped permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user, self.other_user)
assign_perm("authentik_core.view_user", new_user, self.other_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_denied(self): def test_impersonate_denied(self):
"""test impersonation without permissions""" """test impersonation without permissions"""
self.client.force_login(self.other_user) self.client.force_login(self.other_user)

View File

@ -35,7 +35,6 @@ from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger() LOGGER = get_logger()
@ -266,7 +265,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
], ],
responses={200: CertificateDataSerializer(many=False)}, responses={200: CertificateDataSerializer(many=False)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def view_certificate(self, request: Request, pk: str) -> Response: def view_certificate(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs certificate and log access""" """Return certificate-key pairs certificate and log access"""
certificate: CertificateKeyPair = self.get_object() certificate: CertificateKeyPair = self.get_object()
@ -296,7 +295,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
], ],
responses={200: CertificateDataSerializer(many=False)}, responses={200: CertificateDataSerializer(many=False)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def view_private_key(self, request: Request, pk: str) -> Response: def view_private_key(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs private key and log access""" """Return certificate-key pairs private key and log access"""
certificate: CertificateKeyPair = self.get_object() certificate: CertificateKeyPair = self.get_object()

View File

@ -214,46 +214,6 @@ class TestCrypto(APITestCase):
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response) self.assertIn("Content-Disposition", response)
def test_certificate_download_denied(self):
"""Test certificate export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_private_key_download_denied(self):
"""Test private_key export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_used_by(self): def test_used_by(self):
"""Test used_by endpoint""" """Test used_by endpoint"""
self.client.force_login(create_test_admin_user()) self.client.force_login(create_test_admin_user())
@ -286,26 +246,6 @@ class TestCrypto(APITestCase):
], ],
) )
def test_used_by_denied(self):
"""Test used_by endpoint"""
self.client.logout()
keypair = create_test_cert()
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://localhost",
signing_key=keypair,
)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-used-by",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
def test_discovery(self): def test_discovery(self):
"""Test certificate discovery""" """Test certificate discovery"""
name = generate_id() name = generate_id()

View File

@ -1,11 +1,12 @@
"""Enterprise API Views""" """Enterprise API Views"""
from dataclasses import asdict
from datetime import timedelta from datetime import timedelta
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, inline_serializer from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, IntegerField from rest_framework.fields import CharField, IntegerField
@ -29,7 +30,7 @@ class EnterpriseRequiredMixin:
def validate(self, attrs: dict) -> dict: def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists""" """Check that a valid license exists"""
if not LicenseKey.cached_summary().status.is_valid: if not LicenseKey.cached_summary().has_license:
raise ValidationError(_("Enterprise is required to create/update this object.")) raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs) return super().validate(attrs)
@ -86,7 +87,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
}, },
) )
@action(detail=False, methods=["GET"]) @action(detail=False, methods=["GET"])
def install_id(self, request: Request) -> Response: def get_install_id(self, request: Request) -> Response:
"""Get install_id""" """Get install_id"""
return Response( return Response(
data={ data={
@ -99,22 +100,12 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
responses={ responses={
200: LicenseSummarySerializer(), 200: LicenseSummarySerializer(),
}, },
parameters=[
OpenApiParameter(
name="cached",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.BOOL,
default=True,
)
],
) )
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
def summary(self, request: Request) -> Response: def summary(self, request: Request) -> Response:
"""Get the total license status""" """Get the total license status"""
summary = LicenseKey.cached_summary() response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
if request.query_params.get("cached", "true").lower() == "false": response.is_valid(raise_exception=True)
summary = LicenseKey.get_total().summary()
response = LicenseSummarySerializer(instance=summary)
return Response(response.data) return Response(response.data)
@permission_required(None, ["authentik_enterprise.view_license"]) @permission_required(None, ["authentik_enterprise.view_license"])
@ -137,7 +128,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
forecast_for_months = 12 forecast_for_months = 12
response = LicenseForecastSerializer( response = LicenseForecastSerializer(
data={ data={
"internal_users": LicenseKey.get_internal_user_count(), "internal_users": LicenseKey.get_default_user_count(),
"external_users": LicenseKey.get_external_user_count(), "external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months), "forecasted_internal_users": (internal_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months), "forecasted_external_users": (external_in_last_month * forecast_for_months),

View File

@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
"""Actual enterprise check, cached""" """Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.cached_summary().status.is_valid return LicenseKey.cached_summary().valid

View File

@ -3,37 +3,24 @@
from base64 import b64decode from base64 import b64decode
from binascii import Error from binascii import Error
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from time import mktime from time import mktime
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import DaciteError, from_dict from dacite import from_dict
from django.core.cache import cache from django.core.cache import cache
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import ( from rest_framework.fields import BooleanField, DateTimeField, IntegerField
ChoiceField,
DateTimeField,
IntegerField,
ListField,
)
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes from authentik.core.models import User, UserTypes
from authentik.enterprise.models import ( from authentik.enterprise.models import License, LicenseUsage
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_EXPIRY_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.tenants.utils import get_unique_identifier from authentik.tenants.utils import get_unique_identifier
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
@ -55,9 +42,6 @@ def get_license_aud() -> str:
class LicenseFlags(Enum): class LicenseFlags(Enum):
"""License flags""" """License flags"""
TRIAL = "trial"
NON_PRODUCTION = "non_production"
@dataclass @dataclass
class LicenseSummary: class LicenseSummary:
@ -65,9 +49,12 @@ class LicenseSummary:
internal_users: int internal_users: int
external_users: int external_users: int
status: LicenseUsageStatus valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
latest_valid: datetime latest_valid: datetime
license_flags: list[LicenseFlags] has_license: bool
class LicenseSummarySerializer(PassiveSerializer): class LicenseSummarySerializer(PassiveSerializer):
@ -75,9 +62,12 @@ class LicenseSummarySerializer(PassiveSerializer):
internal_users = IntegerField(required=True) internal_users = IntegerField(required=True)
external_users = IntegerField(required=True) external_users = IntegerField(required=True)
status = ChoiceField(choices=LicenseUsageStatus.choices) valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField() latest_valid = DateTimeField()
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags))) has_license = BooleanField()
@dataclass @dataclass
@ -90,10 +80,10 @@ class LicenseKey:
name: str name: str
internal_users: int = 0 internal_users: int = 0
external_users: int = 0 external_users: int = 0
license_flags: list[LicenseFlags] = field(default_factory=list) flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod @staticmethod
def validate(jwt: str, check_expiry=True) -> "LicenseKey": def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT""" """Validate the license from a given JWT"""
try: try:
headers = get_unverified_header(jwt) headers = get_unverified_header(jwt)
@ -117,28 +107,26 @@ class LicenseKey:
our_cert.public_key(), our_cert.public_key(),
algorithms=["ES512"], algorithms=["ES512"],
audience=get_license_aud(), audience=get_license_aud(),
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
), ),
) )
except PyJWTError: except PyJWTError:
unverified = decode(jwt, options={"verify_signature": False})
if unverified["aud"] != get_license_aud():
raise ValidationError("Invalid Install ID in license") from None
raise ValidationError("Unable to verify license") from None raise ValidationError("Unable to verify license") from None
return body return body
@staticmethod @staticmethod
def get_total() -> "LicenseKey": def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses""" """Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in License.objects.all(): for lic in active_licenses:
total.internal_users += lic.internal_users total.internal_users += lic.internal_users
total.external_users += lic.external_users total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple())) exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0: if total.exp == 0:
total.exp = exp_ts total.exp = exp_ts
total.exp = max(total.exp, exp_ts) if exp_ts <= total.exp:
total.license_flags.extend(lic.status.license_flags) total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total return total
@staticmethod @staticmethod
@ -147,7 +135,7 @@ class LicenseKey:
return User.objects.all().exclude_anonymous().exclude(is_active=False) return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod @staticmethod
def get_internal_user_count(): def get_default_user_count():
"""Get current default user count""" """Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@ -156,73 +144,59 @@ class LicenseKey:
"""Get current external user count""" """Get current external user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count() return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()
def _last_valid_date(self): def is_valid(self) -> bool:
last_valid_date = ( """Check if the given license body covers all users
LicenseUsage.objects.order_by("-record_date")
.filter(status=LicenseUsageStatus.VALID)
.first()
)
if not last_valid_date:
return datetime.fromtimestamp(0, UTC)
return last_valid_date.record_date
def status(self) -> LicenseUsageStatus: Only checks the current count, no historical data is checked"""
"""Check if the given license body covers all users, and is valid.""" default_users = self.get_default_user_count()
last_valid = self._last_valid_date() if default_users > self.internal_users:
if self.exp == 0 and not License.objects.exists(): return False
return LicenseUsageStatus.UNLICENSED active_users = self.get_external_user_count()
_now = now() if active_users > self.external_users:
# Check limit-exceeded based status return False
internal_users = self.get_internal_user_count() return True
external_users = self.get_external_user_count()
if internal_users > self.internal_users or external_users > self.external_users:
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
return LicenseUsageStatus.READ_ONLY
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
# Check expiry based status
if datetime.fromtimestamp(self.exp, UTC) < _now:
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
weeks=THRESHOLD_READ_ONLY_WEEKS
):
return LicenseUsageStatus.READ_ONLY
return LicenseUsageStatus.EXPIRED
# Expiry warning
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
):
return LicenseUsageStatus.EXPIRY_SOON
return LicenseUsageStatus.VALID
def record_usage(self): def record_usage(self):
"""Capture the current validity status and metrics and save them""" """Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8) threshold = now() - timedelta(hours=8)
usage = ( if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first() LicenseUsage.objects.create(
) user_count=self.get_default_user_count(),
if not usage:
usage = LicenseUsage.objects.create(
internal_user_count=self.get_internal_user_count(),
external_user_count=self.get_external_user_count(), external_user_count=self.get_external_user_count(),
status=self.status(), within_limits=self.is_valid(),
) )
summary = asdict(self.summary()) summary = asdict(self.summary())
# Also cache the latest summary for the middleware # Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE) cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return usage return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
def summary(self) -> LicenseSummary: def summary(self) -> LicenseSummary:
"""Summary of license status""" """Summary of license status"""
status = self.status() has_license = License.objects.all().count() > 0
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(self.exp) latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary( return LicenseSummary(
show_admin_warning=show_admin_warning and has_license,
show_user_warning=show_user_warning and has_license,
read_only=read_only and has_license,
latest_valid=latest_valid, latest_valid=latest_valid,
internal_users=self.internal_users, internal_users=self.internal_users,
external_users=self.external_users, external_users=self.external_users,
status=status, valid=self.is_valid(),
license_flags=self.license_flags, has_license=has_license,
) )
@staticmethod @staticmethod
@ -231,8 +205,4 @@ class LicenseKey:
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE) summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary: if not summary:
return LicenseKey.get_total().summary() return LicenseKey.get_total().summary()
try: return from_dict(LicenseSummary, summary)
return from_dict(LicenseSummary, summary)
except DaciteError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()

View File

@ -8,7 +8,6 @@ from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsageStatus
from authentik.flows.views.executor import FlowExecutorView from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
@ -44,7 +43,7 @@ class EnterpriseMiddleware:
cached_status = LicenseKey.cached_summary() cached_status = LicenseKey.cached_summary()
if not cached_status: if not cached_status:
return True return True
if cached_status.status == LicenseUsageStatus.READ_ONLY: if cached_status.read_only:
return False return False
return True return True
@ -54,10 +53,10 @@ class EnterpriseMiddleware:
if request.method.lower() in ["get", "head", "options", "trace"]: if request.method.lower() in ["get", "head", "options", "trace"]:
return True return True
# Always allow requests to manage licenses # Always allow requests to manage licenses
if request.resolver_match._func_path == class_to_path(LicenseViewSet): if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
return True return True
# Flow executor is mounted as an API path but explicitly allowed # Flow executor is mounted as an API path but explicitly allowed
if request.resolver_match._func_path == class_to_path(FlowExecutorView): if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
return True return True
# Only apply these restrictions to the API # Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names: if "authentik_api" not in request.resolver_match.app_names:

View File

@ -1,68 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-08 14:15
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage")
db_alias = schema_editor.connection.alias
for usage in LicenseUsage.objects.using(db_alias).all():
usage.status = "valid" if usage.within_limits else "limit_exceeded_admin"
usage.save()
class Migration(migrations.Migration):
dependencies = [
("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"),
]
operations = [
migrations.AddField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
default=None,
null=True,
),
preserve_default=False,
),
migrations.RunPython(migrate_license_usage),
migrations.RemoveField(
model_name="licenseusage",
name="within_limits",
),
migrations.AlterField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
),
preserve_default=False,
),
migrations.RenameField(
model_name="licenseusage",
old_name="user_count",
new_name="internal_user_count",
),
]

View File

@ -17,17 +17,6 @@ if TYPE_CHECKING:
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
THRESHOLD_WARNING_ADMIN_WEEKS = 2
THRESHOLD_WARNING_USER_WEEKS = 4
THRESHOLD_WARNING_EXPIRY_WEEKS = 2
THRESHOLD_READ_ONLY_WEEKS = 6
class License(SerializerModel): class License(SerializerModel):
"""An authentik enterprise license""" """An authentik enterprise license"""
@ -50,7 +39,7 @@ class License(SerializerModel):
"""Get parsed license status""" """Get parsed license status"""
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key, check_expiry=False) return LicenseKey.validate(self.key)
class Meta: class Meta:
indexes = (HashIndex(fields=("key",)),) indexes = (HashIndex(fields=("key",)),)
@ -58,23 +47,9 @@ class License(SerializerModel):
verbose_name_plural = _("Licenses") verbose_name_plural = _("Licenses")
class LicenseUsageStatus(models.TextChoices): def usage_expiry():
"""License states an instance/tenant can be in""" """Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
UNLICENSED = "unlicensed"
VALID = "valid"
EXPIRED = "expired"
EXPIRY_SOON = "expiry_soon"
# User limit exceeded, 2 week threshold, show message in admin interface
LIMIT_EXCEEDED_ADMIN = "limit_exceeded_admin"
# User limit exceeded, 4 week threshold, show message in user interface
LIMIT_EXCEEDED_USER = "limit_exceeded_user"
READ_ONLY = "read_only"
@property
def is_valid(self) -> bool:
"""Quickly check if a license is valid"""
return self in [LicenseUsageStatus.VALID, LicenseUsageStatus.EXPIRY_SOON]
class LicenseUsage(ExpiringModel): class LicenseUsage(ExpiringModel):
@ -84,9 +59,9 @@ class LicenseUsage(ExpiringModel):
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
internal_user_count = models.BigIntegerField() user_count = models.BigIntegerField()
external_user_count = models.BigIntegerField() external_user_count = models.BigIntegerField()
status = models.TextField(choices=LicenseUsageStatus.choices) within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True) record_date = models.DateTimeField(auto_now_add=True)

View File

@ -13,7 +13,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
def check_license(self): def check_license(self):
"""Check license""" """Check license"""
if not LicenseKey.get_total().status().is_valid: if not LicenseKey.get_total().is_valid():
return PolicyResult(False, _("Enterprise required to access this feature.")) return PolicyResult(False, _("Enterprise required to access this feature."))
if self.request.user.type != UserTypes.INTERNAL: if self.request.user.type != UserTypes.INTERNAL:
return PolicyResult(False, _("Feature only accessible for internal users.")) return PolicyResult(False, _("Feature only accessible for internal users."))

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import ( from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
google_workspace_sync,
google_workspace_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -55,4 +52,3 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_single_task = google_workspace_sync sync_single_task = google_workspace_sync
sync_objects_task = google_workspace_sync_objects

View File

@ -181,7 +181,7 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-google-workspace-form" return "ak-property-mapping-google-workspace-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import ( from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
microsoft_entra_sync,
microsoft_entra_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -53,4 +50,3 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_single_task = microsoft_entra_sync sync_single_task = microsoft_entra_sync
sync_objects_task = microsoft_entra_sync_objects

View File

@ -170,7 +170,7 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-microsoft-entra-form" return "ak-property-mapping-microsoft-entra-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_rac", "0004_alter_connectiontoken_expires"),
]
operations = [
migrations.AlterModelOptions(
name="racpropertymapping",
options={
"verbose_name": "RAC Provider Property Mapping",
"verbose_name_plural": "RAC Provider Property Mappings",
},
),
]

View File

@ -125,7 +125,7 @@ class RACPropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-rac-form" return "ak-property-mapping-rac-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -136,8 +136,8 @@ class RACPropertyMapping(PropertyMapping):
return RACPropertyMappingSerializer return RACPropertyMappingSerializer
class Meta: class Meta:
verbose_name = _("RAC Provider Property Mapping") verbose_name = _("RAC Property Mapping")
verbose_name_plural = _("RAC Provider Property Mappings") verbose_name_plural = _("RAC Property Mappings")
class ConnectionToken(ExpiringModel): class ConnectionToken(ExpiringModel):

View File

@ -44,7 +44,7 @@ websocket_urlpatterns = [
api_urlpatterns = [ api_urlpatterns = [
("providers/rac", RACProviderViewSet), ("providers/rac", RACProviderViewSet),
("propertymappings/provider/rac", RACPropertyMappingViewSet), ("propertymappings/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet), ("rac/endpoints", EndpointViewSet),
("rac/connection_tokens", ConnectionTokenViewSet), ("rac/connection_tokens", ConnectionTokenViewSet),
] ]

View File

@ -3,7 +3,7 @@
from datetime import datetime from datetime import datetime
from django.core.cache import cache from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_save from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.timezone import get_current_timezone from django.utils.timezone import get_current_timezone
@ -27,9 +27,3 @@ def post_save_license(sender: type[License], instance: License, **_):
"""Trigger license usage calculation when license is saved""" """Trigger license usage calculation when license is saved"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
enterprise_update_usage.delay() enterprise_update_usage.delay()
@receiver(post_delete, sender=License)
def post_delete_license(sender: type[License], instance: License, **_):
"""Clear license cache when license is deleted"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)

View File

@ -9,26 +9,10 @@ from django.utils.timezone import now
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import ( from authentik.enterprise.models import License
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
# Valid license expiry _exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry, expires soon
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
# Invalid license expiry, recently expired
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
# Invalid license expiry, expired longer ago
expiry_expired_read_only = int(
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
)
class TestEnterpriseLicense(TestCase): class TestEnterpriseLicense(TestCase):
@ -39,7 +23,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
exp=expiry_valid, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, internal_users=100,
external_users=100, external_users=100,
@ -49,7 +33,7 @@ class TestEnterpriseLicense(TestCase):
def test_valid(self): def test_valid(self):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid) self.assertTrue(lic.status.is_valid())
self.assertEqual(lic.internal_users, 100) self.assertEqual(lic.internal_users, 100)
def test_invalid(self): def test_invalid(self):
@ -62,7 +46,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
exp=expiry_valid, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, internal_users=100,
external_users=100, external_users=100,
@ -72,186 +56,11 @@ class TestEnterpriseLicense(TestCase):
def test_valid_multiple(self): def test_valid_multiple(self):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid) self.assertTrue(lic.status.is_valid())
lic2 = License.objects.create(key=generate_id()) lic2 = License.objects.create(key=generate_id())
self.assertTrue(lic2.status.status().is_valid) self.assertTrue(lic2.status.is_valid())
total = LicenseKey.get_total() total = LicenseKey.get_total()
self.assertEqual(total.internal_users, 200) self.assertEqual(total.internal_users, 200)
self.assertEqual(total.external_users, 200) self.assertEqual(total.external_users, 200)
self.assertEqual(total.exp, expiry_valid) self.assertEqual(total.exp, _exp)
self.assertTrue(total.status().is_valid) self.assertTrue(total.is_valid())
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_user_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_admin_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired_read_only,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_expired(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_soon,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_soon(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON)

View File

@ -1,217 +0,0 @@
"""read only tests"""
from datetime import timedelta
from unittest.mock import MagicMock, patch
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.flows.models import (
FlowDesignation,
FlowStageBinding,
)
from authentik.flows.tests import FlowTestCase
from authentik.lib.generators import generate_id
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.password import BACKEND_INBUILT
from authentik.stages.password.models import PasswordStage
from authentik.stages.user_login.models import UserLoginStage
class TestReadOnly(FlowTestCase):
"""Test read_only"""
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_login(self):
"""Test flow, ensure login is still possible with read only mode"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
flow = create_test_flow(
FlowDesignation.AUTHENTICATION,
)
ident_stage = IdentificationStage.objects.create(
name=generate_id(),
user_fields=[UserFields.E_MAIL],
pretend_user_exists=False,
)
FlowStageBinding.objects.create(
target=flow,
stage=ident_stage,
order=0,
)
password_stage = PasswordStage.objects.create(
name=generate_id(), backends=[BACKEND_INBUILT]
)
FlowStageBinding.objects.create(
target=flow,
stage=password_stage,
order=1,
)
login_stage = UserLoginStage.objects.create(
name=generate_id(),
)
FlowStageBinding.objects.create(
target=flow,
stage=login_stage,
order=2,
)
user = create_test_user()
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(exec_url)
self.assertStageResponse(
response,
flow,
component="ak-stage-identification",
password_fields=False,
primary_action="Log in",
sources=[],
show_source_labels=False,
user_fields=[UserFields.E_MAIL],
)
response = self.client.post(exec_url, {"uid_field": user.email}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-password")
response = self.client.post(exec_url, {"password": user.username}, follow=True)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_licenses(self):
"""Test that managing licenses is still possible"""
license = License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Reading is always allowed
response = self.client.get(reverse("authentik_api:license-list"))
self.assertEqual(response.status_code, 200)
# Writing should also be allowed
response = self.client.patch(
reverse("authentik_api:license-detail", kwargs={"pk": license.pk})
)
self.assertEqual(response.status_code, 200)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_flows(self):
"""Test flow"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Read only is still allowed
response = self.client.get(reverse("authentik_api:flow-list"))
self.assertEqual(response.status_code, 200)
flow = create_test_flow()
# Writing is not
response = self.client.patch(
reverse("authentik_api:flow-detail", kwargs={"slug": flow.slug})
)
self.assertJSONEqual(
response.content,
{"detail": "Request denied due to expired/invalid license.", "code": "denied_license"},
)
self.assertEqual(response.status_code, 400)

View File

@ -69,5 +69,8 @@ class NotificationViewSet(
@action(detail=False, methods=["post"]) @action(detail=False, methods=["post"])
def mark_all_seen(self, request: Request) -> Response: def mark_all_seen(self, request: Request) -> Response:
"""Mark all the user's notifications as seen""" """Mark all the user's notifications as seen"""
Notification.objects.filter(user=request.user, seen=False).update(seen=True) notifications = Notification.objects.filter(user=request.user)
for notification in notifications:
notification.seen = True
Notification.objects.bulk_update(notifications, ["seen"])
return Response({}, status=204) return Response({}, status=204)

View File

@ -49,7 +49,6 @@ from authentik.policies.models import PolicyBindingModel
from authentik.root.middleware import ClientIPMiddleware from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger() LOGGER = get_logger()
DISCORD_FIELD_LIMIT = 25 DISCORD_FIELD_LIMIT = 25
@ -59,11 +58,7 @@ NOTIFICATION_SUMMARY_LENGTH = 75
def default_event_duration(): def default_event_duration():
"""Default duration an Event is saved. """Default duration an Event is saved.
This is used as a fallback when no brand is available""" This is used as a fallback when no brand is available"""
try: return now() + timedelta(days=365)
tenant = get_current_tenant()
return now() + timedelta_from_string(tenant.event_retention)
except Tenant.DoesNotExist:
return now() + timedelta(days=365)
def default_brand(): def default_brand():
@ -250,6 +245,12 @@ class Event(SerializerModel, ExpiringModel):
if QS_QUERY in self.context["http_request"]["args"]: if QS_QUERY in self.context["http_request"]["args"]:
wrapped = self.context["http_request"]["args"][QS_QUERY] wrapped = self.context["http_request"]["args"][QS_QUERY]
self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped)) self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
if hasattr(request, "tenant"):
tenant: Tenant = request.tenant
# Because self.created only gets set on save, we can't use it's value here
# hence we set self.created to now and then use it
self.created = now()
self.expires = self.created + timedelta_from_string(tenant.event_retention)
if hasattr(request, "brand"): if hasattr(request, "brand"):
brand: Brand = request.brand brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand)) self.brand = sanitize_dict(model_to_dict(brand))

View File

@ -6,7 +6,6 @@ from django.db.models import Model
from django.test import TestCase from django.test import TestCase
from authentik.core.models import default_token_key from authentik.core.models import default_token_key
from authentik.events.models import default_event_duration
from authentik.lib.utils.reflection import get_apps from authentik.lib.utils.reflection import get_apps
@ -21,7 +20,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable:
allowed = 0 allowed = 0
# Token-like objects need to lookup the current tenant to get the default token length # Token-like objects need to lookup the current tenant to get the default token length
for field in test_model._meta.fields: for field in test_model._meta.fields:
if field.default in [default_token_key, default_event_duration]: if field.default == default_token_key:
allowed += 1 allowed += 1
with self.assertNumQueries(allowed): with self.assertNumQueries(allowed):
str(test_model()) str(test_model())

View File

@ -2,8 +2,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from django.urls import reverse from django.test import TestCase
from rest_framework.test import APITestCase
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.models import ( from authentik.events.models import (
@ -11,7 +10,6 @@ from authentik.events.models import (
EventAction, EventAction,
Notification, Notification,
NotificationRule, NotificationRule,
NotificationSeverity,
NotificationTransport, NotificationTransport,
NotificationWebhookMapping, NotificationWebhookMapping,
TransportMode, TransportMode,
@ -22,7 +20,7 @@ from authentik.policies.exceptions import PolicyException
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
class TestEventsNotifications(APITestCase): class TestEventsNotifications(TestCase):
"""Test Event Notifications""" """Test Event Notifications"""
def setUp(self) -> None: def setUp(self) -> None:
@ -133,15 +131,3 @@ class TestEventsNotifications(APITestCase):
Notification.objects.all().delete() Notification.objects.all().delete()
Event.new(EventAction.CUSTOM_PREFIX).save() Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(Notification.objects.first().body, "foo") self.assertEqual(Notification.objects.first().body, "foo")
def test_api_mark_all_seen(self):
"""Test mark_all_seen"""
self.client.force_login(self.user)
Notification.objects.create(
severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
)
response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
self.assertEqual(response.status_code, 204)
self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())

View File

@ -37,7 +37,6 @@ from authentik.lib.utils.file import (
) )
from authentik.lib.views import bad_request_message from authentik.lib.views import bad_request_message
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger() LOGGER = get_logger()
@ -282,7 +281,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
400: OpenApiResponse(description="Flow not applicable"), 400: OpenApiResponse(description="Flow not applicable"),
}, },
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def execute(self, request: Request, slug: str): def execute(self, request: Request, slug: str):
"""Execute flow for current user""" """Execute flow for current user"""
# Because we pre-plan the flow here, and not in the planner, we need to manually clear # Because we pre-plan the flow here, and not in the planner, we need to manually clear

View File

@ -2,6 +2,7 @@
import re import re
import socket import socket
from collections.abc import Iterable
from ipaddress import ip_address, ip_network from ipaddress import ip_address, ip_network
from textwrap import indent from textwrap import indent
from types import CodeType from types import CodeType
@ -27,12 +28,6 @@ from authentik.stages.authenticator import devices_for_user
LOGGER = get_logger() LOGGER = get_logger()
ARG_SANITIZE = re.compile(r"[:.-]")
def sanitize_arg(arg_name: str) -> str:
return re.sub(ARG_SANITIZE, "_", arg_name)
class BaseEvaluator: class BaseEvaluator:
"""Validate and evaluate python-based expressions""" """Validate and evaluate python-based expressions"""
@ -182,9 +177,9 @@ class BaseEvaluator:
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None) proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
return proc.profiling_wrapper() return proc.profiling_wrapper()
def wrap_expression(self, expression: str) -> str: def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
"""Wrap expression in a function, call it, and save the result as `result`""" """Wrap expression in a function, call it, and save the result as `result`"""
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys()) handler_signature = ",".join(params)
full_expression = "" full_expression = ""
full_expression += f"def handler({handler_signature}):\n" full_expression += f"def handler({handler_signature}):\n"
full_expression += indent(expression, " ") full_expression += indent(expression, " ")
@ -193,8 +188,8 @@ class BaseEvaluator:
def compile(self, expression: str) -> CodeType: def compile(self, expression: str) -> CodeType:
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect.""" """Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
expression = self.wrap_expression(expression) param_keys = self._context.keys()
return compile(expression, self._filename, "exec") return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
def evaluate(self, expression_source: str) -> Any: def evaluate(self, expression_source: str) -> Any:
"""Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised. """Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
@ -210,7 +205,7 @@ class BaseEvaluator:
self.handle_error(exc, expression_source) self.handle_error(exc, expression_source)
raise exc raise exc
try: try:
_locals = {sanitize_arg(x): y for x, y in self._context.items()} _locals = self._context
# Yes this is an exec, yes it is potentially bad. Since we limit what variables are # Yes this is an exec, yes it is potentially bad. Since we limit what variables are
# available here, and these policies can only be edited by admins, this is a risk # available here, and these policies can only be edited by admins, this is a risk
# we're willing to take. # we're willing to take.

View File

@ -1,19 +1,16 @@
from celery import Task from collections.abc import Callable
from django.utils.text import slugify from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ChoiceField from rest_framework.fields import BooleanField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Group, User
from authentik.events.api.tasks import SystemTaskSerializer from authentik.events.api.tasks import SystemTaskSerializer
from authentik.events.logs import LogEvent, LogEventSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
from authentik.rbac.filters import ObjectFilter
class SyncStatusSerializer(PassiveSerializer): class SyncStatusSerializer(PassiveSerializer):
@ -23,29 +20,10 @@ class SyncStatusSerializer(PassiveSerializer):
tasks = SystemTaskSerializer(many=True, read_only=True) tasks = SystemTaskSerializer(many=True, read_only=True)
class SyncObjectSerializer(PassiveSerializer):
"""Sync object serializer"""
sync_object_model = ChoiceField(
choices=(
(class_to_path(User), "user"),
(class_to_path(Group), "group"),
)
)
sync_object_id = CharField()
class SyncObjectResultSerializer(PassiveSerializer):
"""Result of a single object sync"""
messages = LogEventSerializer(many=True, read_only=True)
class OutgoingSyncProviderStatusMixin: class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers""" """Common API Endpoints for Outgoing sync providers"""
sync_single_task: type[Task] = None sync_single_task: Callable = None
sync_objects_task: type[Task] = None
@extend_schema( @extend_schema(
responses={ responses={
@ -58,7 +36,7 @@ class OutgoingSyncProviderStatusMixin:
detail=True, detail=True,
pagination_class=None, pagination_class=None,
url_path="sync/status", url_path="sync/status",
filter_backends=[ObjectFilter], filter_backends=[],
) )
def sync_status(self, request: Request, pk: int) -> Response: def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status""" """Get provider's sync status"""
@ -77,30 +55,6 @@ class OutgoingSyncProviderStatusMixin:
} }
return Response(SyncStatusSerializer(status).data) return Response(SyncStatusSerializer(status).data)
@extend_schema(
request=SyncObjectSerializer,
responses={200: SyncObjectResultSerializer()},
)
@action(
methods=["POST"],
detail=True,
pagination_class=None,
url_path="sync/object",
filter_backends=[ObjectFilter],
)
def sync_object(self, request: Request, pk: int) -> Response:
"""Sync/Re-sync a single user/group object"""
provider: OutgoingSyncProvider = self.get_object()
params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True)
res: list[LogEvent] = self.sync_objects_task.delay(
params.validated_data["sync_object_model"],
page=1,
provider_pk=provider.pk,
pk=params.validated_data["sync_object_id"],
).get()
return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
class OutgoingSyncConnectionCreateMixin: class OutgoingSyncConnectionCreateMixin:
"""Mixin for connection objects that fetches remote data upon creation""" """Mixin for connection objects that fetches remote data upon creation"""

View File

@ -105,7 +105,7 @@ class SyncTasks:
return return
task.set_status(TaskStatus.SUCCESSFUL, *messages) task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects(self, object_type: str, page: int, provider_pk: int, **filter): def sync_objects(self, object_type: str, page: int, provider_pk: int):
_object_type = path_to_class(object_type) _object_type = path_to_class(object_type)
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
@ -120,7 +120,7 @@ class SyncTasks:
client = provider.client_for_model(_object_type) client = provider.client_for_model(_object_type)
except TransientSyncException: except TransientSyncException:
return messages return messages
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
if client.can_discover: if client.can_discover:
self.logger.debug("starting discover") self.logger.debug("starting discover")
client.discover() client.discover()

View File

@ -30,11 +30,6 @@ class TestHTTP(TestCase):
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2") request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2") self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
def test_forward_for_invalid(self):
"""Test invalid forward for"""
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="foobar")
self.assertEqual(ClientIPMiddleware.get_client_ip(request), ClientIPMiddleware.default_ip)
def test_fake_outpost(self): def test_fake_outpost(self):
"""Test faked IP which is overridden by an outpost""" """Test faked IP which is overridden by an outpost"""
token = Token.objects.create( token = Token.objects.create(
@ -58,17 +53,6 @@ class TestHTTP(TestCase):
}, },
) )
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1") self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
# Invalid, not a real IP
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
self.user.save()
request = self.factory.get(
"/",
**{
ClientIPMiddleware.outpost_remote_ip_header: "foobar",
ClientIPMiddleware.outpost_token_header: token.key,
},
)
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
# Valid # Valid
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
self.user.save() self.user.save()

View File

@ -21,14 +21,7 @@ class DebugSession(Session):
def send(self, req: PreparedRequest, *args, **kwargs): def send(self, req: PreparedRequest, *args, **kwargs):
request_id = str(uuid4()) request_id = str(uuid4())
LOGGER.debug( LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
"HTTP request sent",
uid=request_id,
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
)
resp = super().send(req, *args, **kwargs) resp = super().send(req, *args, **kwargs)
LOGGER.debug( LOGGER.debug(
"HTTP response received", "HTTP response received",

View File

@ -26,6 +26,7 @@ from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
from authentik.outposts.models import ( from authentik.outposts.models import (
Outpost, Outpost,
OutpostConfig, OutpostConfig,
OutpostState,
OutpostType, OutpostType,
default_outpost_config, default_outpost_config,
) )
@ -139,7 +140,7 @@ class OutpostHealthSerializer(PassiveSerializer):
def get_fips_enabled(self, obj: dict) -> bool | None: def get_fips_enabled(self, obj: dict) -> bool | None:
"""Get FIPS enabled""" """Get FIPS enabled"""
if not LicenseKey.get_total().status().is_valid: if not LicenseKey.get_total().is_valid():
return None return None
return obj["fips_enabled"] return obj["fips_enabled"]
@ -181,6 +182,7 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
outpost: Outpost = self.get_object() outpost: Outpost = self.get_object()
states = [] states = []
for state in outpost.state: for state in outpost.state:
state: OutpostState
states.append( states.append(
{ {
"uid": state.uid, "uid": state.uid,

View File

@ -26,7 +26,6 @@ from authentik.outposts.models import (
KubernetesServiceConnection, KubernetesServiceConnection,
OutpostServiceConnection, OutpostServiceConnection,
) )
from authentik.rbac.filters import ObjectFilter
class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer): class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer):
@ -76,7 +75,7 @@ class ServiceConnectionViewSet(
filterset_fields = ["name"] filterset_fields = ["name"]
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)}) @extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def state(self, request: Request, pk: str) -> Response: def state(self, request: Request, pk: str) -> Response:
"""Get the service connection's state""" """Get the service connection's state"""
connection = self.get_object() connection = self.get_object()

View File

@ -451,7 +451,7 @@ class OutpostState:
return False return False
if self.build_hash != get_build_hash(): if self.build_hash != get_build_hash():
return False return False
return parse(self.version) != OUR_VERSION return parse(self.version) < OUR_VERSION
@staticmethod @staticmethod
def for_outpost(outpost: Outpost) -> list["OutpostState"]: def for_outpost(outpost: Outpost) -> list["OutpostState"]:

View File

@ -214,7 +214,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
if not hasattr(instance, field_name): if not hasattr(instance, field_name):
continue continue
LOGGER.debug("triggering outpost update from field", field=field.name) LOGGER.debug("triggering outpost update from from field", field=field.name)
# Because the Outpost Model has an M2M to Provider, # Because the Outpost Model has an M2M to Provider,
# we have to iterate over the entire QS # we have to iterate over the entire QS
for reverse in getattr(instance, field_name).all(): for reverse in getattr(instance, field_name).all():

View File

@ -108,7 +108,7 @@ class EventMatcherPolicy(Policy):
result=result, result=result,
) )
matches.append(result) matches.append(result)
passing = all(x.passing for x in matches) passing = any(x.passing for x in matches)
messages = chain(*[x.messages for x in matches]) messages = chain(*[x.messages for x in matches])
result = PolicyResult(passing, *messages) result = PolicyResult(passing, *messages)
result.source_results = matches result.source_results = matches

View File

@ -77,24 +77,11 @@ class TestEventMatcherPolicy(TestCase):
request = PolicyRequest(get_anonymous_user()) request = PolicyRequest(get_anonymous_user())
request.context["event"] = event request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.5", app="foo" client_ip="1.2.3.5", app="bar"
) )
response = policy.passes(request) response = policy.passes(request)
self.assertFalse(response.passing) self.assertFalse(response.passing)
def test_multiple(self):
"""Test multiple"""
event = Event.new(EventAction.LOGIN)
event.app = "foo"
event.client_ip = "1.2.3.4"
request = PolicyRequest(get_anonymous_user())
request.context["event"] = event
policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(
client_ip="1.2.3.4", app="foo"
)
response = policy.passes(request)
self.assertTrue(response.passing)
def test_invalid(self): def test_invalid(self):
"""Test passing event""" """Test passing event"""
request = PolicyRequest(get_anonymous_user()) request = PolicyRequest(get_anonymous_user())

View File

@ -36,7 +36,7 @@ def update_score(request: HttpRequest, identifier: str, amount: int):
if not created: if not created:
reputation.score = F("score") + amount reputation.score = F("score") + amount
reputation.save() reputation.save()
LOGGER.info("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip) LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
@receiver(login_failed) @receiver(login_failed)

View File

@ -2,25 +2,15 @@
from django.db.models import QuerySet from django.db.models import QuerySet
from django.db.models.query import Q from django.db.models.query import Q
from django.shortcuts import get_object_or_404
from django_filters.filters import BooleanFilter from django_filters.filters import BooleanFilter
from django_filters.filterset import FilterSet from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes from rest_framework.fields import CharField, ListField, SerializerMethodField
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ListField, SerializerMethodField
from rest_framework.mixins import ListModelMixin from rest_framework.mixins import ListModelMixin
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.core.api.providers import ProviderSerializer from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.api.utils import ModelSerializer
from authentik.core.models import Application
from authentik.policies.api.exec import PolicyTestResultSerializer
from authentik.policies.engine import PolicyEngine
from authentik.policies.types import PolicyResult
from authentik.providers.ldap.models import LDAPProvider from authentik.providers.ldap.models import LDAPProvider
@ -33,6 +23,7 @@ class LDAPProviderSerializer(ProviderSerializer):
model = LDAPProvider model = LDAPProvider
fields = ProviderSerializer.Meta.fields + [ fields = ProviderSerializer.Meta.fields + [
"base_dn", "base_dn",
"search_group",
"certificate", "certificate",
"tls_server_name", "tls_server_name",
"uid_start_number", "uid_start_number",
@ -64,6 +55,8 @@ class LDAPProviderFilter(FilterSet):
"name": ["iexact"], "name": ["iexact"],
"authorization_flow__slug": ["iexact"], "authorization_flow__slug": ["iexact"],
"base_dn": ["iexact"], "base_dn": ["iexact"],
"search_group__group_uuid": ["iexact"],
"search_group__name": ["iexact"],
"certificate__kp_uuid": ["iexact"], "certificate__kp_uuid": ["iexact"],
"certificate__name": ["iexact"], "certificate__name": ["iexact"],
"tls_server_name": ["iexact"], "tls_server_name": ["iexact"],
@ -102,6 +95,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
"base_dn", "base_dn",
"bind_flow_slug", "bind_flow_slug",
"application_slug", "application_slug",
"search_group",
"certificate", "certificate",
"tls_server_name", "tls_server_name",
"uid_start_number", "uid_start_number",
@ -122,33 +116,3 @@ class LDAPOutpostConfigViewSet(ListModelMixin, GenericViewSet):
ordering = ["name"] ordering = ["name"]
search_fields = ["name"] search_fields = ["name"]
filterset_fields = ["name"] filterset_fields = ["name"]
class LDAPCheckAccessSerializer(PassiveSerializer):
has_search_permission = BooleanField(required=False)
access = PolicyTestResultSerializer()
@extend_schema(
request=None,
parameters=[OpenApiParameter("app_slug", OpenApiTypes.STR)],
responses={
200: LDAPCheckAccessSerializer(),
},
operation_id="outposts_ldap_access_check",
)
@action(detail=True)
def check_access(self, request: Request, pk) -> Response:
"""Check access to a single application by slug"""
provider = get_object_or_404(LDAPProvider, pk=pk)
application = get_object_or_404(Application, slug=request.query_params["app_slug"])
engine = PolicyEngine(application, request.user, request)
engine.use_cache = False
engine.build()
result = engine.result
access_response = PolicyResult(result.passing)
response = self.LDAPCheckAccessSerializer(
instance={
"has_search_permission": request.user.has_perm("search_full_directory", provider),
"access": access_response,
}
)
return Response(response.data)

View File

@ -1,66 +0,0 @@
# Generated by Django 5.0.7 on 2024-07-25 14:59
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db import migrations
def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from authentik.core.models import User
from django.apps import apps as real_apps
from django.contrib.auth.management import create_permissions
from guardian.shortcuts import UserObjectPermission
db_alias = schema_editor.connection.alias
# Permissions are only created _after_ migrations are run
# - https://github.com/django/django/blob/43cdfa8b20e567a801b7d0a09ec67ddd062d5ea4/django/contrib/auth/apps.py#L19
# - https://stackoverflow.com/a/72029063/1870445
create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
Permission = apps.get_model("auth", "Permission")
UserObjectPermission = apps.get_model("guardian", "UserObjectPermission")
ContentType = apps.get_model("contenttypes", "ContentType")
new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory")
ct = ContentType.objects.using(db_alias).get(
app_label="authentik_providers_ldap",
model="ldapprovider",
)
for provider in LDAPProvider.objects.using(db_alias).all():
if not provider.search_group:
continue
for user in provider.search_group.users.using(db_alias).all():
UserObjectPermission.objects.using(db_alias).create(
user=user,
permission=new_prem,
object_pk=provider.pk,
content_type=ct,
)
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
("guardian", "0002_generic_permissions_index"),
]
operations = [
migrations.AlterModelOptions(
name="ldapprovider",
options={
"permissions": [("search_full_directory", "Search full LDAP directory")],
"verbose_name": "LDAP Provider",
"verbose_name_plural": "LDAP Providers",
},
),
migrations.RunPython(migrate_search_group),
migrations.RemoveField(
model_name="ldapprovider",
name="search_group",
),
]

View File

@ -7,7 +7,7 @@ from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from authentik.core.models import BackchannelProvider from authentik.core.models import BackchannelProvider, Group
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.outposts.models import OutpostModel from authentik.outposts.models import OutpostModel
@ -27,6 +27,17 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
help_text=_("DN under which objects are accessible."), help_text=_("DN under which objects are accessible."),
) )
search_group = models.ForeignKey(
Group,
null=True,
default=None,
on_delete=models.SET_DEFAULT,
help_text=_(
"Users in this group can do search queries. "
"If not set, every user can execute search queries."
),
)
tls_server_name = models.TextField( tls_server_name = models.TextField(
default="", default="",
blank=True, blank=True,
@ -102,6 +113,3 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
class Meta: class Meta:
verbose_name = _("LDAP Provider") verbose_name = _("LDAP Provider")
verbose_name_plural = _("LDAP Providers") verbose_name_plural = _("LDAP Providers")
permissions = [
("search_full_directory", _("Search full LDAP directory")),
]

View File

@ -105,7 +105,7 @@ class ScopeMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-scope-form" return "ak-property-mapping-scope-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -29,6 +29,7 @@ class TesOAuth2Introspection(OAuthTestCase):
self.app = Application.objects.create( self.app = Application.objects.create(
name=generate_id(), slug=generate_id(), provider=self.provider name=generate_id(), slug=generate_id(), provider=self.provider
) )
self.app.save()
self.user = create_test_admin_user() self.user = create_test_admin_user()
self.auth = b64encode( self.auth = b64encode(
f"{self.provider.client_id}:{self.provider.client_secret}".encode() f"{self.provider.client_id}:{self.provider.client_secret}".encode()
@ -113,41 +114,6 @@ class TesOAuth2Introspection(OAuthTestCase):
}, },
) )
def test_introspect_invalid_provider(self):
"""Test introspection (mismatched provider and token)"""
provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris="",
signing_key=create_test_cert(),
)
auth = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
token: AccessToken = AccessToken.objects.create(
provider=self.provider,
user=self.user,
token=generate_id(),
auth_time=timezone.now(),
_scope="openid user profile",
_id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
)
res = self.client.post(
reverse("authentik_providers_oauth2:token-introspection"),
HTTP_AUTHORIZATION=f"Basic {auth}",
data={"token": token.token},
)
self.assertEqual(res.status_code, 200)
self.assertJSONEqual(
res.content.decode(),
{
"active": False,
},
)
def test_introspect_invalid_auth(self): def test_introspect_invalid_auth(self):
"""Test introspect (invalid auth)""" """Test introspect (invalid auth)"""
res = self.client.post( res = self.client.post(

View File

@ -62,7 +62,7 @@ urlpatterns = [
api_urlpatterns = [ api_urlpatterns = [
("providers/oauth2", OAuth2ProviderViewSet), ("providers/oauth2", OAuth2ProviderViewSet),
("propertymappings/provider/scope", ScopeMappingViewSet), ("propertymappings/scope", ScopeMappingViewSet),
("oauth2/authorization_codes", AuthorizationCodeViewSet), ("oauth2/authorization_codes", AuthorizationCodeViewSet),
("oauth2/refresh_tokens", RefreshTokenViewSet), ("oauth2/refresh_tokens", RefreshTokenViewSet),
("oauth2/access_tokens", AccessTokenViewSet), ("oauth2/access_tokens", AccessTokenViewSet),

View File

@ -46,10 +46,10 @@ class TokenIntrospectionParams:
if not provider: if not provider:
raise TokenIntrospectionError raise TokenIntrospectionError
access_token = AccessToken.objects.filter(token=raw_token, provider=provider).first() access_token = AccessToken.objects.filter(token=raw_token).first()
if access_token: if access_token:
return TokenIntrospectionParams(access_token, provider) return TokenIntrospectionParams(access_token, provider)
refresh_token = RefreshToken.objects.filter(token=raw_token, provider=provider).first() refresh_token = RefreshToken.objects.filter(token=raw_token).first()
if refresh_token: if refresh_token:
return TokenIntrospectionParams(refresh_token, provider) return TokenIntrospectionParams(refresh_token, provider)
LOGGER.debug("Token does not exist", token=raw_token) LOGGER.debug("Token does not exist", token=raw_token)

View File

@ -433,20 +433,20 @@ class TokenParams:
app = Application.objects.filter(provider=self.provider).first() app = Application.objects.filter(provider=self.provider).first()
if not app or not app.provider: if not app or not app.provider:
raise TokenError("invalid_grant") raise TokenError("invalid_grant")
with audit_ignore(): self.user, _ = User.objects.update_or_create(
self.user, _ = User.objects.update_or_create( # trim username to ensure the entire username is max 150 chars
# trim username to ensure the entire username is max 150 chars # (22 chars being the length of the "template")
# (22 chars being the length of the "template") username=f"ak-{self.provider.name[:150-22]}-client_credentials",
username=f"ak-{self.provider.name[:150-22]}-client_credentials", defaults={
defaults={ "attributes": {
"last_login": timezone.now(), USER_ATTRIBUTE_GENERATED: True,
"name": f"Autogenerated user from application {app.name} (client credentials)",
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
}, },
) "last_login": timezone.now(),
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True "name": f"Autogenerated user from application {app.name} (client credentials)",
self.user.save() "path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
"type": UserTypes.SERVICE_ACCOUNT,
},
)
self.__check_policy_access(app, request) self.__check_policy_access(app, request)
Event.new( Event.new(
@ -470,6 +470,9 @@ class TokenParams:
self.user, created = User.objects.update_or_create( self.user, created = User.objects.update_or_create(
username=f"{self.provider.name}-{token.get('sub')}", username=f"{self.provider.name}-{token.get('sub')}",
defaults={ defaults={
"attributes": {
USER_ATTRIBUTE_GENERATED: True,
},
"last_login": timezone.now(), "last_login": timezone.now(),
"name": ( "name": (
f"Autogenerated user from application {app.name} (client credentials JWT)" f"Autogenerated user from application {app.name} (client credentials JWT)"
@ -478,8 +481,6 @@ class TokenParams:
"type": UserTypes.SERVICE_ACCOUNT, "type": UserTypes.SERVICE_ACCOUNT,
}, },
) )
self.user.attributes[USER_ATTRIBUTE_GENERATED] = True
self.user.save()
exp = token.get("exp") exp = token.get("exp")
if created and exp: if created and exp:
self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp self.user.attributes[USER_ATTRIBUTE_EXPIRES] = exp

View File

@ -28,7 +28,7 @@ class ProxyDockerController(DockerController):
labels = super()._get_labels() labels = super()._get_labels()
labels["traefik.enable"] = "true" labels["traefik.enable"] = "true"
labels[f"traefik.http.routers.{traefik_name}-router.rule"] = ( labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
f"({' || '.join([f'Host({host})' for host in hosts])})" f"({' || '.join([f'Host(`{host}`)' for host in hosts])})"
f" && PathPrefix(`/outpost.goauthentik.io`)" f" && PathPrefix(`/outpost.goauthentik.io`)"
) )
labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true" labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"

View File

@ -154,7 +154,6 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
responses={ responses={
200: RadiusCheckAccessSerializer(), 200: RadiusCheckAccessSerializer(),
}, },
operation_id="outposts_radius_access_check",
) )
@action(detail=True) @action(detail=True)
def check_access(self, request: Request, pk) -> Response: def check_access(self, request: Request, pk) -> Response:

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_radius", "0003_radiusproviderpropertymapping"),
]
operations = [
migrations.AlterModelOptions(
name="radiusproviderpropertymapping",
options={
"verbose_name": "Radius Provider Property Mapping",
"verbose_name_plural": "Radius Provider Property Mappings",
},
),
]

View File

@ -70,7 +70,7 @@ class RadiusProviderPropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-radius-form" return "ak-property-mapping-radius-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -81,8 +81,8 @@ class RadiusProviderPropertyMapping(PropertyMapping):
return RadiusProviderPropertyMappingSerializer return RadiusProviderPropertyMappingSerializer
def __str__(self): def __str__(self):
return f"Radius Provider Property Mapping {self.name}" return f"Radius Property Mapping {self.name}"
class Meta: class Meta:
verbose_name = _("Radius Provider Property Mapping") verbose_name = _("Radius Property Mapping")
verbose_name_plural = _("Radius Provider Property Mappings") verbose_name_plural = _("Radius Property Mappings")

View File

@ -7,7 +7,7 @@ from authentik.providers.radius.api.providers import (
) )
api_urlpatterns = [ api_urlpatterns = [
("propertymappings/provider/radius", RadiusProviderPropertyMappingViewSet), ("propertymappings/radius", RadiusProviderPropertyMappingViewSet),
("outposts/radius", RadiusOutpostConfigViewSet, "radiusprovideroutpost"), ("outposts/radius", RadiusOutpostConfigViewSet, "radiusprovideroutpost"),
("providers/radius", RadiusProviderViewSet), ("providers/radius", RadiusProviderViewSet),
] ]

View File

@ -133,17 +133,6 @@ class SAMLProviderSerializer(ProviderSerializer):
except Provider.application.RelatedObjectDoesNotExist: except Provider.application.RelatedObjectDoesNotExist:
return "-" return "-"
def validate(self, attrs: dict):
if attrs.get("signing_kp"):
if not attrs.get("sign_assertion") and not attrs.get("sign_response"):
raise ValidationError(
_(
"With a signing keypair selected, at least one of 'Sign assertion' "
"and 'Sign Response' must be selected."
)
)
return super().validate(attrs)
class Meta: class Meta:
model = SAMLProvider model = SAMLProvider
fields = ProviderSerializer.Meta.fields + [ fields = ProviderSerializer.Meta.fields + [
@ -159,9 +148,6 @@ class SAMLProviderSerializer(ProviderSerializer):
"signature_algorithm", "signature_algorithm",
"signing_kp", "signing_kp",
"verification_kp", "verification_kp",
"encryption_kp",
"sign_assertion",
"sign_response",
"sp_binding", "sp_binding",
"default_relay_state", "default_relay_state",
"url_download_metadata", "url_download_metadata",

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_saml", "0014_alter_samlprovider_digest_algorithm_and_more"),
]
operations = [
migrations.AlterModelOptions(
name="samlpropertymapping",
options={
"verbose_name": "SAML Provider Property Mapping",
"verbose_name_plural": "SAML Provider Property Mappings",
},
),
]

View File

@ -1,39 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-15 14:52
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_crypto", "0004_alter_certificatekeypair_name"),
("authentik_providers_saml", "0015_alter_samlpropertymapping_options"),
]
operations = [
migrations.AddField(
model_name="samlprovider",
name="encryption_kp",
field=models.ForeignKey(
blank=True,
default=None,
help_text="When selected, incoming assertions are encrypted by the IdP using the public key of the encryption keypair. The assertion is decrypted by the SP using the the private key.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="+",
to="authentik_crypto.certificatekeypair",
verbose_name="Encryption Keypair",
),
),
migrations.AddField(
model_name="samlprovider",
name="sign_assertion",
field=models.BooleanField(default=True),
),
migrations.AddField(
model_name="samlprovider",
name="sign_response",
field=models.BooleanField(default=False),
),
]

View File

@ -144,28 +144,11 @@ class SAMLProvider(Provider):
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
verbose_name=_("Signing Keypair"), verbose_name=_("Signing Keypair"),
) )
encryption_kp = models.ForeignKey(
CertificateKeyPair,
default=None,
null=True,
blank=True,
help_text=_(
"When selected, incoming assertions are encrypted by the IdP using the public "
"key of the encryption keypair. The assertion is decrypted by the SP using the "
"the private key."
),
on_delete=models.SET_NULL,
verbose_name=_("Encryption Keypair"),
related_name="+",
)
default_relay_state = models.TextField( default_relay_state = models.TextField(
default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins") default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins")
) )
sign_assertion = models.BooleanField(default=True)
sign_response = models.BooleanField(default=False)
@property @property
def launch_url(self) -> str | None: def launch_url(self) -> str | None:
"""Use IDP-Initiated SAML flow as launch URL""" """Use IDP-Initiated SAML flow as launch URL"""
@ -208,7 +191,7 @@ class SAMLPropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-saml-form" return "ak-property-mapping-saml-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -221,8 +204,8 @@ class SAMLPropertyMapping(PropertyMapping):
return f"{self.name} ({name})" return f"{self.name} ({name})"
class Meta: class Meta:
verbose_name = _("SAML Provider Property Mapping") verbose_name = _("SAML Property Mapping")
verbose_name_plural = _("SAML Provider Property Mappings") verbose_name_plural = _("SAML Property Mappings")
class SAMLProviderImportModel(CreatableType, Provider): class SAMLProviderImportModel(CreatableType, Provider):

View File

@ -18,11 +18,7 @@ from authentik.providers.saml.processors.authn_request_parser import AuthNReques
from authentik.providers.saml.utils import get_random_id from authentik.providers.saml.utils import get_random_id
from authentik.providers.saml.utils.time import get_time_string from authentik.providers.saml.utils.time import get_time_string
from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME
from authentik.sources.saml.exceptions import ( from authentik.sources.saml.exceptions import InvalidSignature, UnsupportedNameIDFormat
InvalidEncryption,
InvalidSignature,
UnsupportedNameIDFormat,
)
from authentik.sources.saml.processors.constants import ( from authentik.sources.saml.processors.constants import (
DIGEST_ALGORITHM_TRANSLATION_MAP, DIGEST_ALGORITHM_TRANSLATION_MAP,
NS_MAP, NS_MAP,
@ -50,7 +46,6 @@ class AssertionProcessor:
_issue_instant: str _issue_instant: str
_assertion_id: str _assertion_id: str
_response_id: str
_valid_not_before: str _valid_not_before: str
_session_not_on_or_after: str _session_not_on_or_after: str
@ -63,7 +58,6 @@ class AssertionProcessor:
self._issue_instant = get_time_string() self._issue_instant = get_time_string()
self._assertion_id = get_random_id() self._assertion_id = get_random_id()
self._response_id = get_random_id()
self._valid_not_before = get_time_string( self._valid_not_before = get_time_string(
timedelta_from_string(self.provider.assertion_valid_not_before) timedelta_from_string(self.provider.assertion_valid_not_before)
@ -132,9 +126,7 @@ class AssertionProcessor:
"""Generate AuthnStatement with AuthnContext and ContextClassRef Elements.""" """Generate AuthnStatement with AuthnContext and ContextClassRef Elements."""
auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement") auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement")
auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before
auth_n_statement.attrib["SessionIndex"] = sha256( auth_n_statement.attrib["SessionIndex"] = self._assertion_id
self.http_request.session.session_key.encode("ascii")
).hexdigest()
auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after auth_n_statement.attrib["SessionNotOnOrAfter"] = self._session_not_on_or_after
auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext") auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext")
@ -264,17 +256,9 @@ class AssertionProcessor:
assertion, assertion,
xmlsec.constants.TransformExclC14N, xmlsec.constants.TransformExclC14N,
sign_algorithm_transform, sign_algorithm_transform,
ns=xmlsec.constants.DSigNs, ns="ds", # type: ignore
) )
assertion.append(signature) assertion.append(signature)
if self.provider.encryption_kp:
encryption = xmlsec.template.encrypted_data_create(
assertion,
xmlsec.constants.TransformAes128Cbc,
self._assertion_id,
ns=xmlsec.constants.DSigNs,
)
assertion.append(encryption)
assertion.append(self.get_assertion_subject()) assertion.append(self.get_assertion_subject())
assertion.append(self.get_assertion_conditions()) assertion.append(self.get_assertion_conditions())
@ -289,7 +273,7 @@ class AssertionProcessor:
response.attrib["Version"] = "2.0" response.attrib["Version"] = "2.0"
response.attrib["IssueInstant"] = self._issue_instant response.attrib["IssueInstant"] = self._issue_instant
response.attrib["Destination"] = self.provider.acs_url response.attrib["Destination"] = self.provider.acs_url
response.attrib["ID"] = self._response_id response.attrib["ID"] = get_random_id()
if self.auth_n_request.id: if self.auth_n_request.id:
response.attrib["InResponseTo"] = self.auth_n_request.id response.attrib["InResponseTo"] = self.auth_n_request.id
@ -302,86 +286,41 @@ class AssertionProcessor:
response.append(self.get_assertion()) response.append(self.get_assertion())
return response return response
def _sign(self, element: Element):
"""Sign an XML element based on the providers' configured signing settings"""
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
)
xmlsec.tree.add_ids(element, ["ID"])
signature_node = xmlsec.tree.find_node(element, xmlsec.constants.NodeSignature)
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + element.attrib["ID"],
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
key_info = xmlsec.template.ensure_key_info(signature_node)
xmlsec.template.add_x509_data(key_info)
ctx = xmlsec.SignatureContext()
key = xmlsec.Key.from_memory(
self.provider.signing_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
None,
)
key.load_cert_from_memory(
self.provider.signing_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
ctx.key = key
try:
ctx.sign(signature_node)
except xmlsec.Error as exc:
raise InvalidSignature() from exc
def _encrypt(self, element: Element, parent: Element):
"""Encrypt SAMLResponse EncryptedAssertion Element"""
manager = xmlsec.KeysManager()
key = xmlsec.Key.from_memory(
self.provider.encryption_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
)
key.load_cert_from_memory(
self.provider.encryption_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
manager.add_key(key)
encryption_context = xmlsec.EncryptionContext(manager)
encryption_context.key = xmlsec.Key.generate(
xmlsec.constants.KeyDataAes, 128, xmlsec.constants.KeyDataTypeSession
)
container = SubElement(parent, f"{{{NS_SAML_ASSERTION}}}EncryptedAssertion")
enc_data = xmlsec.template.encrypted_data_create(
container, xmlsec.Transform.AES128, type=xmlsec.EncryptionType.ELEMENT, ns="xenc"
)
xmlsec.template.encrypted_data_ensure_cipher_value(enc_data)
key_info = xmlsec.template.encrypted_data_ensure_key_info(enc_data, ns="ds")
enc_key = xmlsec.template.add_encrypted_key(key_info, xmlsec.Transform.RSA_OAEP)
xmlsec.template.encrypted_data_ensure_cipher_value(enc_key)
try:
enc_data = encryption_context.encrypt_xml(enc_data, element)
except xmlsec.Error as exc:
raise InvalidEncryption() from exc
parent.remove(enc_data)
container.append(enc_data)
def build_response(self) -> str: def build_response(self) -> str:
"""Build string XML Response and sign if signing is enabled.""" """Build string XML Response and sign if signing is enabled."""
root_response = self.get_response() root_response = self.get_response()
if self.provider.signing_kp: if self.provider.signing_kp:
if self.provider.sign_assertion: digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0] self.provider.digest_algorithm, xmlsec.constants.TransformSha1
self._sign(assertion) )
if self.provider.sign_response:
response = root_response.xpath("//samlp:Response", namespaces=NS_MAP)[0]
self._sign(response)
if self.provider.encryption_kp:
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0] assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
self._encrypt(assertion, root_response) xmlsec.tree.add_ids(assertion, ["ID"])
signature_node = xmlsec.tree.find_node(assertion, xmlsec.constants.NodeSignature)
ref = xmlsec.template.add_reference(
signature_node,
digest_algorithm_transform,
uri="#" + self._assertion_id,
)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
key_info = xmlsec.template.ensure_key_info(signature_node)
xmlsec.template.add_x509_data(key_info)
ctx = xmlsec.SignatureContext()
key = xmlsec.Key.from_memory(
self.provider.signing_kp.key_data,
xmlsec.constants.KeyDataFormatPem,
None,
)
key.load_cert_from_memory(
self.provider.signing_kp.certificate_data,
xmlsec.constants.KeyDataFormatCertPem,
)
ctx.key = key
try:
ctx.sign(signature_node)
except xmlsec.Error as exc:
raise InvalidSignature() from exc
return etree.tostring(root_response).decode("utf-8") # nosec return etree.tostring(root_response).decode("utf-8") # nosec

View File

@ -126,7 +126,7 @@ class MetadataProcessor:
entity_descriptor, entity_descriptor,
xmlsec.constants.TransformExclC14N, xmlsec.constants.TransformExclC14N,
sign_algorithm_transform, sign_algorithm_transform,
ns=xmlsec.constants.DSigNs, ns="ds", # type: ignore
) )
entity_descriptor.append(signature) entity_descriptor.append(signature)

View File

@ -8,7 +8,7 @@ from rest_framework.test import APITestCase
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import FlowDesignation from authentik.flows.models import FlowDesignation
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture from authentik.lib.tests.utils import load_fixture
@ -29,52 +29,12 @@ class TestSAMLProviderAPI(APITestCase):
name=generate_id(), name=generate_id(),
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
) )
response = self.client.get(
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
)
self.assertEqual(200, response.status_code)
Application.objects.create(name=generate_id(), provider=provider, slug=generate_id()) Application.objects.create(name=generate_id(), provider=provider, slug=generate_id())
response = self.client.get( response = self.client.get(
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}), reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
def test_create_validate_signing_kp(self):
"""Test create"""
cert = create_test_cert()
response = self.client.post(
reverse("authentik_api:samlprovider-list"),
data={
"name": generate_id(),
"authorization_flow": create_test_flow().pk,
"acs_url": "http://localhost",
"signing_kp": cert.pk,
},
)
self.assertEqual(400, response.status_code)
self.assertJSONEqual(
response.content,
{
"non_field_errors": [
(
"With a signing keypair selected, at least one "
"of 'Sign assertion' and 'Sign Response' must be selected."
)
]
},
)
response = self.client.post(
reverse("authentik_api:samlprovider-list"),
data={
"name": generate_id(),
"authorization_flow": create_test_flow().pk,
"acs_url": "http://localhost",
"signing_kp": cert.pk,
"sign_assertion": True,
},
)
self.assertEqual(201, response.status_code)
def test_metadata(self): def test_metadata(self):
"""Test metadata export (normal)""" """Test metadata export (normal)"""
self.client.logout() self.client.logout()

View File

@ -78,12 +78,12 @@ class TestAuthNRequest(TestCase):
@apply_blueprint("system/providers-saml.yaml") @apply_blueprint("system/providers-saml.yaml")
def setUp(self): def setUp(self):
self.cert = create_test_cert() cert = create_test_cert()
self.provider: SAMLProvider = SAMLProvider.objects.create( self.provider: SAMLProvider = SAMLProvider.objects.create(
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
acs_url="http://testserver/source/saml/provider/acs/", acs_url="http://testserver/source/saml/provider/acs/",
signing_kp=self.cert, signing_kp=cert,
verification_kp=self.cert, verification_kp=cert,
) )
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all()) self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
self.provider.save() self.provider.save()
@ -91,8 +91,8 @@ class TestAuthNRequest(TestCase):
slug="provider", slug="provider",
issuer="authentik", issuer="authentik",
pre_authentication_flow=create_test_flow(), pre_authentication_flow=create_test_flow(),
signing_kp=self.cert, signing_kp=cert,
verification_kp=self.cert, verification_kp=cert,
) )
def test_signed_valid(self): def test_signed_valid(self):
@ -112,34 +112,7 @@ class TestAuthNRequest(TestCase):
self.assertEqual(parsed_request.id, request_proc.request_id) self.assertEqual(parsed_request.id, request_proc.request_id)
self.assertEqual(parsed_request.relay_state, "test_state") self.assertEqual(parsed_request.relay_state, "test_state")
def test_request_encrypt(self): def test_request_full_signed(self):
"""Test full SAML Request/Response flow, fully encrypted"""
self.provider.encryption_kp = self.cert
self.provider.save()
self.source.encryption_kp = self.cert
self.source.save()
http_request = get_request("/")
# First create an AuthNRequest
request_proc = RequestProcessor(self.source, http_request, "test_state")
request = request_proc.build_auth_n()
# To get an assertion we need a parsed request (parsed by provider)
parsed_request = AuthNRequestParser(self.provider).parse(
b64encode(request.encode()).decode(), "test_state"
)
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse()
def test_request_signed(self):
"""Test full SAML Request/Response flow, fully signed""" """Test full SAML Request/Response flow, fully signed"""
http_request = get_request("/") http_request = get_request("/")
@ -162,36 +135,6 @@ class TestAuthNRequest(TestCase):
response_parser = ResponseProcessor(self.source, http_request) response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse() response_parser.parse()
def test_request_signed_both(self):
"""Test full SAML Request/Response flow, fully signed"""
self.provider.sign_assertion = True
self.provider.sign_response = True
self.provider.save()
http_request = get_request("/")
# First create an AuthNRequest
request_proc = RequestProcessor(self.source, http_request, "test_state")
request = request_proc.build_auth_n()
# To get an assertion we need a parsed request (parsed by provider)
parsed_request = AuthNRequestParser(self.provider).parse(
b64encode(request.encode()).decode(), "test_state"
)
# Now create a response and convert it to string (provider)
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
response = response_proc.build_response()
# Ensure both response and assertion ID are in the response twice (once as ID attribute,
# once as ds:Reference URI)
self.assertEqual(response.count(response_proc._assertion_id), 2)
self.assertEqual(response.count(response_proc._response_id), 2)
# Now parse the response (source)
http_request.POST = QueryDict(mutable=True)
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
response_parser = ResponseProcessor(self.source, http_request)
response_parser.parse()
def test_request_id_invalid(self): def test_request_id_invalid(self):
"""Test generated AuthNRequest with invalid request ID""" """Test generated AuthNRequest with invalid request ID"""
http_request = get_request("/") http_request = get_request("/")

View File

@ -54,11 +54,7 @@ class TestServiceProviderMetadataParser(TestCase):
request = self.factory.get("/") request = self.factory.get("/")
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor()) metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
schema = etree.XMLSchema( schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
etree.parse(
source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()
) # nosec
)
self.assertTrue(schema.validate(metadata)) self.assertTrue(schema.validate(metadata))
def test_schema_want_authn_requests_signed(self): def test_schema_want_authn_requests_signed(self):

View File

@ -47,9 +47,7 @@ class TestSchema(TestCase):
metadata = lxml_from_string(request) metadata = lxml_from_string(request)
schema = etree.XMLSchema( schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
self.assertTrue(schema.validate(metadata)) self.assertTrue(schema.validate(metadata))
def test_response_schema(self): def test_response_schema(self):
@ -70,7 +68,5 @@ class TestSchema(TestCase):
metadata = lxml_from_string(response) metadata = lxml_from_string(response)
schema = etree.XMLSchema( schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
)
self.assertTrue(schema.validate(metadata)) self.assertTrue(schema.validate(metadata))

View File

@ -44,6 +44,6 @@ urlpatterns = [
] ]
api_urlpatterns = [ api_urlpatterns = [
("propertymappings/provider/saml", SAMLPropertyMappingViewSet), ("propertymappings/saml", SAMLPropertyMappingViewSet),
("providers/saml", SAMLProviderViewSet), ("providers/saml", SAMLProviderViewSet),
] ]

View File

@ -6,7 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects from authentik.providers.scim.tasks import scim_sync
class SCIMProviderSerializer(ProviderSerializer): class SCIMProviderSerializer(ProviderSerializer):
@ -42,4 +42,3 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
search_fields = ["name", "url"] search_fields = ["name", "url"]
ordering = ["name", "url"] ordering = ["name", "url"]
sync_single_task = scim_sync sync_single_task = scim_sync
sync_objects_task = scim_sync_objects

View File

@ -1,11 +1,8 @@
"""Group client""" """Group client"""
from itertools import batched
from django.db import transaction
from pydantic import ValidationError from pydantic import ValidationError
from pydanticscim.group import GroupMember from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp from pydanticscim.responses import PatchOp, PatchOperation
from authentik.core.models import Group from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager from authentik.lib.sync.mapper import PropertyMappingManager
@ -20,7 +17,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import ( from authentik.providers.scim.clients.exceptions import (
SCIMRequestException, SCIMRequestException,
) )
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import ( from authentik.providers.scim.models import (
SCIMMapping, SCIMMapping,
@ -59,22 +56,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
if not scim_group.externalId: if not scim_group.externalId:
scim_group.externalId = str(obj.pk) scim_group.externalId = str(obj.pk)
if not self._config.patch.supported: users = list(obj.users.order_by("id").values_list("id", flat=True))
users = list(obj.users.order_by("id").values_list("id", flat=True)) connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
connections = SCIMProviderUser.objects.filter( members = []
provider=self.provider, user__pk__in=users for user in connections:
) members.append(
members = [] GroupMember(
for user in connections: value=user.scim_id,
members.append(
GroupMember(
value=user.scim_id,
)
) )
if members: )
scim_group.members = members if members:
else: scim_group.members = members
del scim_group.members
return scim_group return scim_group
def delete(self, obj: Group): def delete(self, obj: Group):
@ -101,53 +93,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
scim_id = response.get("id") scim_id = response.get("id")
if not scim_id or scim_id == "": if not scim_id or scim_id == "":
raise StopSync("SCIM Response with missing or invalid `id`") raise StopSync("SCIM Response with missing or invalid `id`")
connection = SCIMProviderGroup.objects.create( return SCIMProviderGroup.objects.create(
provider=self.provider, group=group, scim_id=scim_id provider=self.provider, group=group, scim_id=scim_id
) )
users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(connection, users)
return connection
def update(self, group: Group, connection: SCIMProviderGroup): def update(self, group: Group, connection: SCIMProviderGroup):
"""Update existing group""" """Update existing group"""
scim_group = self.to_schema(group, connection) scim_group = self.to_schema(group, connection)
scim_group.id = connection.scim_id scim_group.id = connection.scim_id
try: try:
if self._config.patch.supported: return self._request(
return self._update_patch(group, scim_group, connection)
return self._update_put(group, scim_group, connection)
except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
def _update_patch(
self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup
):
"""Update a group via PATCH request"""
# Patch group's attributes instead of replacing it and re-adding users if we can
self._request(
"PATCH",
f"/Groups/{connection.scim_id}",
json=PatchRequest(
Operations=[
PatchOperation(
op=PatchOp.replace,
path=None,
value=scim_group.model_dump(mode="json", exclude_unset=True),
)
]
).model_dump(
mode="json",
exclude_unset=True,
exclude_none=True,
),
)
return self.patch_compare_users(group)
def _update_put(self, group: Group, scim_group: SCIMGroupSchema, connection: SCIMProviderGroup):
"""Update a group via PUT request"""
try:
self._request(
"PUT", "PUT",
f"/Groups/{connection.scim_id}", f"/Groups/{connection.scim_id}",
json=scim_group.model_dump( json=scim_group.model_dump(
@ -155,25 +110,31 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
exclude_unset=True, exclude_unset=True,
), ),
) )
return self.patch_compare_users(group) except NotFoundSyncException:
# Resource missing is handled by self.write, which will re-create the group
raise
except (SCIMRequestException, ObjectExistsSyncException): except (SCIMRequestException, ObjectExistsSyncException):
# Some providers don't support PUT on groups, so this is mainly a fix for the initial # Some providers don't support PUT on groups, so this is mainly a fix for the initial
# sync, send patch add requests for all the users the group currently has # sync, send patch add requests for all the users the group currently has
return self._update_patch(group, scim_group, connection) users = list(group.users.order_by("id").values_list("id", flat=True))
self._patch_add_users(group, users)
# Also update the group name
return self._patch(
scim_group.id,
PatchOperation(
op=PatchOp.replace,
path="displayName",
value=scim_group.displayName,
),
)
def update_group(self, group: Group, action: Direction, users_set: set[int]): def update_group(self, group: Group, action: Direction, users_set: set[int]):
"""Update a group, either using PUT to replace it or PATCH if supported""" """Update a group, either using PUT to replace it or PATCH if supported"""
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
if self._config.patch.supported: if self._config.patch.supported:
if action == Direction.add: if action == Direction.add:
return self._patch_add_users(scim_group, users_set) return self._patch_add_users(group, users_set)
if action == Direction.remove: if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set) return self._patch_remove_users(group, users_set)
try: try:
return self.write(group) return self.write(group)
except SCIMRequestException as exc: except SCIMRequestException as exc:
@ -181,98 +142,35 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
# Assume that provider does not support PUT and also doesn't support # Assume that provider does not support PUT and also doesn't support
# ServiceProviderConfig, so try PATCH as a fallback # ServiceProviderConfig, so try PATCH as a fallback
if action == Direction.add: if action == Direction.add:
return self._patch_add_users(scim_group, users_set) return self._patch_add_users(group, users_set)
if action == Direction.remove: if action == Direction.remove:
return self._patch_remove_users(scim_group, users_set) return self._patch_remove_users(group, users_set)
raise exc raise exc
def _patch_chunked( def _patch(
self, self,
group_id: str, group_id: str,
*ops: PatchOperation, *ops: PatchOperation,
): ):
"""Helper function that chunks patch requests based on the maxOperations attribute. req = PatchRequest(Operations=ops)
This is not strictly according to specs but there's nothing in the schema that allows the self._request(
us to know what the maximum patch operations per request should be.""" "PATCH",
chunk_size = self._config.bulk.maxOperations f"/Groups/{group_id}",
if chunk_size < 1: json=req.model_dump(
chunk_size = len(ops) mode="json",
if len(ops) < 1: ),
return )
for chunk in batched(ops, chunk_size):
req = PatchRequest(Operations=list(chunk))
self._request(
"PATCH",
f"/Groups/{group_id}",
json=req.model_dump(
mode="json",
),
)
@transaction.atomic def _patch_add_users(self, group: Group, users_set: set[int]):
def patch_compare_users(self, group: Group): """Add users in users_set to group"""
"""Compare users with a SCIM group and add/remove any differences""" if len(users_set) < 1:
# Get scim group first return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first() scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group: if not scim_group:
self.logger.warning( self.logger.warning(
"could not sync group membership, group does not exist", group=group "could not sync group membership, group does not exist", group=group
) )
return return
# Get a list of all users in the authentik group
raw_users_should = list(group.users.order_by("id").values_list("id", flat=True))
# Lookup the SCIM IDs of the users
users_should: list[str] = list(
SCIMProviderUser.objects.filter(
user__pk__in=raw_users_should, provider=self.provider
).values_list("scim_id", flat=True)
)
if len(raw_users_should) != len(users_should):
self.logger.warning(
"User count mismatch, not all users in the group are synced to SCIM yet.",
group=group,
)
# Get current group status
current_group = SCIMGroupSchema.model_validate(
self._request("GET", f"/Groups/{scim_group.scim_id}")
)
users_to_add = []
users_to_remove = []
# Check users currently in group and if they shouldn't be in the group and remove them
for user in current_group.members or []:
if user.value not in users_should:
users_to_remove.append(user.value)
# Check users that should be in the group and add them
for user in users_should:
if len([x for x in current_group.members if x.value == user]) < 1:
users_to_add.append(user)
# Only send request if we need to make changes
if len(users_to_add) < 1 and len(users_to_remove) < 1:
return
return self._patch_chunked(
scim_group.scim_id,
*[
PatchOperation(
op=PatchOp.add,
path="members",
value=[{"value": x}],
)
for x in users_to_add
],
*[
PatchOperation(
op=PatchOp.remove,
path="members",
value=[{"value": x}],
)
for x in users_to_remove
],
)
def _patch_add_users(self, scim_group: SCIMProviderGroup, users_set: set[int]):
"""Add users in users_set to group"""
if len(users_set) < 1:
return
user_ids = list( user_ids = list(
SCIMProviderUser.objects.filter( SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider user__pk__in=users_set, provider=self.provider
@ -280,22 +178,25 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
) )
if len(user_ids) < 1: if len(user_ids) < 1:
return return
self._patch_chunked( self._patch(
scim_group.scim_id, scim_group.scim_id,
*[ PatchOperation(
PatchOperation( op=PatchOp.add,
op=PatchOp.add, path="members",
path="members", value=[{"value": x} for x in user_ids],
value=[{"value": x}], ),
)
for x in user_ids
],
) )
def _patch_remove_users(self, scim_group: SCIMProviderGroup, users_set: set[int]): def _patch_remove_users(self, group: Group, users_set: set[int]):
"""Remove users in users_set from group""" """Remove users in users_set from group"""
if len(users_set) < 1: if len(users_set) < 1:
return return
scim_group = SCIMProviderGroup.objects.filter(provider=self.provider, group=group).first()
if not scim_group:
self.logger.warning(
"could not sync group membership, group does not exist", group=group
)
return
user_ids = list( user_ids = list(
SCIMProviderUser.objects.filter( SCIMProviderUser.objects.filter(
user__pk__in=users_set, provider=self.provider user__pk__in=users_set, provider=self.provider
@ -303,14 +204,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
) )
if len(user_ids) < 1: if len(user_ids) < 1:
return return
self._patch_chunked( self._patch(
scim_group.scim_id, scim_group.scim_id,
*[ PatchOperation(
PatchOperation( op=PatchOp.remove,
op=PatchOp.remove, path="members",
path="members", value=[{"value": x} for x in user_ids],
value=[{"value": x}], ),
)
for x in user_ids
],
) )

View File

@ -1,12 +1,9 @@
"""Custom SCIM schemas""" """Custom SCIM schemas"""
from pydantic import Field
from pydanticscim.group import Group as BaseGroup from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
from pydanticscim.responses import PatchRequest as BasePatchRequest from pydanticscim.responses import PatchRequest as BasePatchRequest
from pydanticscim.responses import SCIMError as BaseSCIMError from pydanticscim.responses import SCIMError as BaseSCIMError
from pydanticscim.service_provider import Bulk as BaseBulk from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
from pydanticscim.service_provider import ( from pydanticscim.service_provider import (
ServiceProviderConfiguration as BaseServiceProviderConfiguration, ServiceProviderConfiguration as BaseServiceProviderConfiguration,
) )
@ -32,16 +29,10 @@ class Group(BaseGroup):
meta: dict | None = None meta: dict | None = None
class Bulk(BaseBulk):
maxOperations: int = Field()
class ServiceProviderConfiguration(BaseServiceProviderConfiguration): class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""ServiceProviderConfig with fallback""" """ServiceProviderConfig with fallback"""
_is_fallback: bool | None = False _is_fallback: bool | None = False
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
@property @property
def is_fallback(self) -> bool: def is_fallback(self) -> bool:
@ -54,7 +45,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
"""Get default configuration, which doesn't support any optional features as fallback""" """Get default configuration, which doesn't support any optional features as fallback"""
return ServiceProviderConfiguration( return ServiceProviderConfiguration(
patch=Patch(supported=False), patch=Patch(supported=False),
bulk=Bulk(supported=False, maxOperations=0), bulk=Bulk(supported=False),
filter=Filter(supported=False), filter=Filter(supported=False),
changePassword=ChangePassword(supported=False), changePassword=ChangePassword(supported=False),
sort=Sort(supported=False), sort=Sort(supported=False),
@ -69,12 +60,6 @@ class PatchRequest(BasePatchRequest):
schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",) schemas: tuple[str] = ("urn:ietf:params:scim:api:messages:2.0:PatchOp",)
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
path: str | None
class SCIMError(BaseSCIMError): class SCIMError(BaseSCIMError):
"""SCIM error with optional status code""" """SCIM error with optional status code"""

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_scim", "0008_rename_scimgroup_scimprovidergroup_and_more"),
]
operations = [
migrations.AlterModelOptions(
name="scimmapping",
options={
"verbose_name": "SCIM Provider Mapping",
"verbose_name_plural": "SCIM Provider Mappings",
},
),
]

View File

@ -133,7 +133,7 @@ class SCIMMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-scim-form" return "ak-property-mapping-scim-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -142,8 +142,8 @@ class SCIMMapping(PropertyMapping):
return SCIMMappingSerializer return SCIMMappingSerializer
def __str__(self): def __str__(self):
return f"SCIM Provider Mapping {self.name}" return f"SCIM Mapping {self.name}"
class Meta: class Meta:
verbose_name = _("SCIM Provider Mapping") verbose_name = _("SCIM Mapping")
verbose_name_plural = _("SCIM Provider Mappings") verbose_name_plural = _("SCIM Mappings")

View File

@ -252,118 +252,3 @@ class SCIMMembershipTests(TestCase):
], ],
}, },
) )
def test_member_add_save(self):
"""Test member add + save"""
config = ServiceProviderConfiguration.default()
config.patch.supported = True
user_scim_id = generate_id()
group_scim_id = generate_id()
uid = generate_id()
group = Group.objects.create(
name=uid,
)
user = User.objects.create(username=generate_id())
# Test initial sync of group creation
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.post(
"https://localhost/Users",
json={
"id": user_scim_id,
},
)
mocker.post(
"https://localhost/Groups",
json={
"id": group_scim_id,
},
)
self.configure()
sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual(
mocker.request_history[3].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [],
"active": True,
"externalId": user.uid,
"name": {"familyName": " ", "formatted": " ", "givenName": ""},
"displayName": "",
"userName": user.username,
},
)
self.assertJSONEqual(
mocker.request_history[5].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
)
with Mocker() as mocker:
mocker.get(
"https://localhost/ServiceProviderConfig",
json=config.model_dump(),
)
mocker.get(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
mocker.patch(
f"https://localhost/Groups/{group_scim_id}",
json={},
)
group.users.add(user)
group.save()
self.assertEqual(mocker.call_count, 5)
self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "PATCH")
self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "PATCH")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertJSONEqual(
mocker.request_history[1].body,
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [
{
"op": "add",
"path": "members",
"value": [{"value": user_scim_id}],
}
],
},
)
self.assertJSONEqual(
mocker.request_history[3].body,
{
"Operations": [
{
"op": "replace",
"value": {
"id": group_scim_id,
"displayName": group.name,
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
},
}
]
},
)

View File

@ -13,5 +13,5 @@ api_urlpatterns = [
("providers/scim", SCIMProviderViewSet), ("providers/scim", SCIMProviderViewSet),
("providers/scim_users", SCIMProviderUserViewSet), ("providers/scim_users", SCIMProviderUserViewSet),
("providers/scim_groups", SCIMProviderGroupViewSet), ("providers/scim_groups", SCIMProviderGroupViewSet),
("propertymappings/provider/scim", SCIMMappingViewSet), ("propertymappings/scim", SCIMMappingViewSet),
] ]

View File

@ -2,7 +2,7 @@
from uuid import uuid4 from uuid import uuid4
from django.contrib.auth.management import _get_all_permissions from django.contrib.auth.models import Permission
from django.db import models from django.db import models
from django.db.transaction import atomic from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -10,26 +10,28 @@ from guardian.shortcuts import assign_perm
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.reflection import get_apps
def get_permission_choices(): def get_permissions():
all_perms = [] return (
for app in get_apps(): Permission.objects.all()
for model in app.get_models(): .select_related("content_type")
for perm, _desc in _get_all_permissions(model._meta): .filter(
all_perms.append((model, perm)) content_type__app_label__startswith="authentik",
return sorted( )
[
(
f"{model._meta.app_label}.{perm}",
f"{model._meta.app_label}.{perm}",
)
for model, perm in all_perms
]
) )
def get_permission_choices() -> list[tuple[str, str]]:
return [
(
f"{x.content_type.app_label}.{x.codename}",
f"{x.content_type.app_label}.{x.codename}",
)
for x in get_permissions()
]
class Role(SerializerModel): class Role(SerializerModel):
"""RBAC role, which can have different permissions (both global and per-object) attached """RBAC role, which can have different permissions (both global and per-object) attached
to it.""" to it."""

View File

@ -87,11 +87,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
def _get_startup_tasks_default_tenant() -> list[Callable]: def _get_startup_tasks_default_tenant() -> list[Callable]:
"""Get all tasks to be run on startup for the default tenant""" """Get all tasks to be run on startup for the default tenant"""
from authentik.outposts.tasks import outpost_connection_discovery return []
return [
outpost_connection_discovery,
]
def _get_startup_tasks_all_tenants() -> list[Callable]: def _get_startup_tasks_all_tenants() -> list[Callable]:

View File

@ -2,7 +2,6 @@
from collections.abc import Callable from collections.abc import Callable
from hashlib import sha512 from hashlib import sha512
from ipaddress import ip_address
from time import perf_counter, time from time import perf_counter, time
from typing import Any from typing import Any
@ -175,7 +174,6 @@ class ClientIPMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response self.get_response = get_response
self.logger = get_logger().bind()
def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str: def _get_client_ip_from_meta(self, meta: dict[str, Any]) -> str:
"""Attempt to get the client's IP by checking common HTTP Headers. """Attempt to get the client's IP by checking common HTTP Headers.
@ -187,16 +185,11 @@ class ClientIPMiddleware:
"HTTP_X_FORWARDED_FOR", "HTTP_X_FORWARDED_FOR",
"REMOTE_ADDR", "REMOTE_ADDR",
) )
try: for _header in headers:
for _header in headers: if _header in meta:
if _header in meta: ips: list[str] = meta.get(_header).split(",")
ips: list[str] = meta.get(_header).split(",") return ips[0].strip()
# Ensure the IP parses as a valid IP return self.default_ip
return str(ip_address(ips[0].strip()))
return self.default_ip
except ValueError as exc:
self.logger.debug("Invalid remote IP", exc=exc)
return self.default_ip
# FIXME: this should probably not be in `root` but rather in a middleware in `outposts` # FIXME: this should probably not be in `root` but rather in a middleware in `outposts`
# but for now it's fine # but for now it's fine
@ -233,11 +226,7 @@ class ClientIPMiddleware:
Scope.get_isolation_scope().set_user(user) Scope.get_isolation_scope().set_user(user)
# Set the outpost service account on the request # Set the outpost service account on the request
setattr(request, self.request_attr_outpost_user, user) setattr(request, self.request_attr_outpost_user, user)
try: return delegated_ip
return str(ip_address(delegated_ip))
except ValueError as exc:
self.logger.debug("Invalid remote IP from Outpost", exc=exc)
return None
def _get_client_ip(self, request: HttpRequest | None) -> str: def _get_client_ip(self, request: HttpRequest | None) -> str:
"""Attempt to get the client's IP by checking common HTTP Headers. """Attempt to get the client's IP by checking common HTTP Headers.

View File

@ -9,7 +9,6 @@ import orjson
from celery.schedules import crontab from celery.schedules import crontab
from django.conf import ImproperlyConfigured from django.conf import ImproperlyConfigured
from sentry_sdk import set_tag from sentry_sdk import set_tag
from xmlsec import enable_debug_trace
from authentik import __version__ from authentik import __version__
from authentik.lib.config import CONFIG, redis_url from authentik.lib.config import CONFIG, redis_url
@ -521,7 +520,6 @@ if DEBUG:
"rest_framework.renderers.BrowsableAPIRenderer" "rest_framework.renderers.BrowsableAPIRenderer"
) )
SHARED_APPS.insert(SHARED_APPS.index("django.contrib.staticfiles"), "daphne") SHARED_APPS.insert(SHARED_APPS.index("django.contrib.staticfiles"), "daphne")
enable_debug_trace(True)
TENANT_APPS.append("authentik.core") TENANT_APPS.append("authentik.core")

View File

@ -1,7 +1,6 @@
"""authentik storage backends""" """authentik storage backends"""
import os import os
from urllib.parse import parse_qsl, urlsplit
from django.conf import settings from django.conf import settings
from django.core.exceptions import SuspiciousOperation from django.core.exceptions import SuspiciousOperation
@ -111,34 +110,3 @@ class S3Storage(BaseS3Storage):
if self.querystring_auth: if self.querystring_auth:
return url return url
return self._strip_signing_parameters(url) return self._strip_signing_parameters(url)
def _strip_signing_parameters(self, url):
# Boto3 does not currently support generating URLs that are unsigned. Instead
# we take the signed URLs and strip any querystring params related to signing
# and expiration.
# Note that this may end up with URLs that are still invalid, especially if
# params are passed in that only work with signed URLs, e.g. response header
# params.
# The code attempts to strip all query parameters that match names of known
# parameters from v2 and v4 signatures, regardless of the actual signature
# version used.
split_url = urlsplit(url)
qs = parse_qsl(split_url.query, keep_blank_values=True)
blacklist = {
"x-amz-algorithm",
"x-amz-credential",
"x-amz-date",
"x-amz-expires",
"x-amz-signedheaders",
"x-amz-signature",
"x-amz-security-token",
"awsaccesskeyid",
"expires",
"signature",
}
filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist)
# Note: Parameters that did not have a value in the original query string will
# have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar=
joined_qs = ("=".join(keyval) for keyval in filtered_qs)
split_url = split_url._replace(query="&".join(joined_qs))
return split_url.geturl()

View File

@ -3,7 +3,6 @@
from typing import Any from typing import Any
from django.core.cache import cache from django.core.cache import cache
from django.utils.translation import gettext_lazy as _
from drf_spectacular.utils import extend_schema, inline_serializer from drf_spectacular.utils import extend_schema, inline_serializer
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
@ -40,8 +39,9 @@ class LDAPSourceSerializer(SourceSerializer):
"""Get cached source connectivity""" """Get cached source connectivity"""
return cache.get(CACHE_KEY_STATUS + source.slug, None) return cache.get(CACHE_KEY_STATUS + source.slug, None)
def validate_sync_users_password(self, sync_users_password: bool) -> bool: def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Check that only a single source has password_sync on""" """Check that only a single source has password_sync on"""
sync_users_password = attrs.get("sync_users_password", True)
if sync_users_password: if sync_users_password:
sources = LDAPSource.objects.filter(sync_users_password=True) sources = LDAPSource.objects.filter(sync_users_password=True)
if self.instance: if self.instance:
@ -49,31 +49,11 @@ class LDAPSourceSerializer(SourceSerializer):
if sources.exists(): if sources.exists():
raise ValidationError( raise ValidationError(
{ {
"sync_users_password": _( "sync_users_password": (
"Only a single LDAP Source with password synchronization is allowed" "Only a single LDAP Source with password synchronization is allowed"
) )
} }
) )
return sync_users_password
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Validate property mappings with sync_ flags"""
types = ["user", "group"]
for type in types:
toggle_value = attrs.get(f"sync_{type}s", False)
mappings_field = f"{type}_property_mappings"
mappings_value = attrs.get(mappings_field, [])
if toggle_value and len(mappings_value) == 0:
raise ValidationError(
{
mappings_field: _(
(
"When 'Sync {type}s' is enabled, '{type}s property "
"mappings' cannot be empty."
).format(type=type)
)
}
)
return super().validate(attrs) return super().validate(attrs)
class Meta: class Meta:
@ -186,12 +166,11 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
for sync_class in SYNC_CLASSES: for sync_class in SYNC_CLASSES:
class_name = sync_class.name() class_name = sync_class.name()
all_objects.setdefault(class_name, []) all_objects.setdefault(class_name, [])
for page in sync_class(source).get_objects(size_limit=10): for obj in sync_class(source).get_objects(size_limit=10):
for obj in page: obj: dict
obj: dict obj.pop("raw_attributes", None)
obj.pop("raw_attributes", None) obj.pop("raw_dn", None)
obj.pop("raw_dn", None) all_objects[class_name].append(obj)
all_objects[class_name].append(obj)
return Response(data=all_objects) return Response(data=all_objects)

View File

@ -290,7 +290,7 @@ class LDAPSourcePropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-source-ldap-form" return "ak-property-mapping-ldap-source-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -26,16 +26,17 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
"""Ensure that source is synced on save (if enabled)""" """Ensure that source is synced on save (if enabled)"""
if not instance.enabled: if not instance.enabled:
return return
ldap_connectivity_check.delay(instance.pk)
# Don't sync sources when they don't have any property mappings. This will only happen if: # Don't sync sources when they don't have any property mappings. This will only happen if:
# - the user forgets to set them or # - the user forgets to set them or
# - the source is newly created, this is the first save event # - the source is newly created, this is the first save event
# and the mappings are created with an m2m event # and the mappings are created with an m2m event
if instance.sync_users and not instance.user_property_mappings.exists(): if (
return not instance.user_property_mappings.exists()
if instance.sync_groups and not instance.group_property_mappings.exists(): or not instance.group_property_mappings.exists()
):
return return
ldap_sync_single.delay(instance.pk) ldap_sync_single.delay(instance.pk)
ldap_connectivity_check.delay(instance.pk)
@receiver(password_validate) @receiver(password_validate)

View File

@ -38,11 +38,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,
search_scope=SUBTREE, search_scope=SUBTREE,
attributes=[ attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
ALL_ATTRIBUTES,
ALL_OPERATIONAL_ATTRIBUTES,
self._source.object_uniqueness_field,
],
**kwargs, **kwargs,
) )
@ -57,9 +53,9 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
continue continue
attributes = group.get("attributes", {}) attributes = group.get("attributes", {})
group_dn = flatten(flatten(group.get("entryDN", group.get("dn")))) group_dn = flatten(flatten(group.get("entryDN", group.get("dn"))))
if not attributes.get(self._source.object_uniqueness_field): if self._source.object_uniqueness_field not in attributes:
self.message( self.message(
f"Uniqueness field not found/not set in attributes: '{group_dn}'", f"Cannot find uniqueness field in attributes: '{group_dn}'",
attributes=attributes.keys(), attributes=attributes.keys(),
dn=group_dn, dn=group_dn,
) )

Some files were not shown because too many files have changed in this diff Show More