Compare commits
17 Commits
analytics
...
web/packag
Author | SHA1 | Date | |
---|---|---|---|
1dc4fbbb2b | |||
3332de267d | |||
ab366d0ec2 | |||
ac162582aa | |||
7d82e029d5 | |||
9b40ecb023 | |||
0cc0fdaae3 | |||
b55b168718 | |||
c46dc8f290 | |||
e48da3520c | |||
1ec4652c60 | |||
e375646705 | |||
b84652d9d3 | |||
74b8da28ca | |||
9084c7c6b4 | |||
7a0b227b46 | |||
cc9128fd46 |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2024.8.2
|
||||
current_version = 2024.6.3
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||
|
@ -29,9 +29,9 @@ outputs:
|
||||
imageTags:
|
||||
description: "Docker image tags"
|
||||
value: ${{ steps.ev.outputs.imageTags }}
|
||||
attestImageNames:
|
||||
description: "Docker image names used for attestation"
|
||||
value: ${{ steps.ev.outputs.attestImageNames }}
|
||||
imageNames:
|
||||
description: "Docker image names"
|
||||
value: ${{ steps.ev.outputs.imageNames }}
|
||||
imageMainTag:
|
||||
description: "Docker image main tag"
|
||||
value: ${{ steps.ev.outputs.imageMainTag }}
|
||||
|
@ -51,24 +51,15 @@ else:
|
||||
]
|
||||
|
||||
image_main_tag = image_tags[0].split(":")[-1]
|
||||
|
||||
|
||||
def get_attest_image_names(image_with_tags: list[str]):
|
||||
"""Attestation only for GHCR"""
|
||||
image_tags = []
|
||||
for image_name in set(name.split(":")[0] for name in image_with_tags):
|
||||
if not image_name.startswith("ghcr.io"):
|
||||
continue
|
||||
image_tags.append(image_name)
|
||||
return ",".join(set(image_tags))
|
||||
|
||||
image_tags_rendered = ",".join(image_tags)
|
||||
image_names_rendered = ",".join(set(name.split(":")[0] for name in image_tags))
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
|
||||
print(f"shouldBuild={should_build}", file=_output)
|
||||
print(f"sha={sha}", file=_output)
|
||||
print(f"version={version}", file=_output)
|
||||
print(f"prerelease={prerelease}", file=_output)
|
||||
print(f"imageTags={','.join(image_tags)}", file=_output)
|
||||
print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output)
|
||||
print(f"imageTags={image_tags_rendered}", file=_output)
|
||||
print(f"imageNames={image_names_rendered}", file=_output)
|
||||
print(f"imageMainTag={image_main_tag}", file=_output)
|
||||
print(f"imageMainName={image_tags[0]}", file=_output)
|
||||
|
10
.github/dependabot.yml
vendored
10
.github/dependabot.yml
vendored
@ -44,11 +44,9 @@ updates:
|
||||
- "babel-*"
|
||||
eslint:
|
||||
patterns:
|
||||
- "@eslint/*"
|
||||
- "@typescript-eslint/*"
|
||||
- "eslint-*"
|
||||
- "eslint"
|
||||
- "typescript-eslint"
|
||||
- "eslint-*"
|
||||
storybook:
|
||||
patterns:
|
||||
- "@storybook/*"
|
||||
@ -56,16 +54,10 @@ updates:
|
||||
esbuild:
|
||||
patterns:
|
||||
- "@esbuild/*"
|
||||
- "esbuild*"
|
||||
rollup:
|
||||
patterns:
|
||||
- "@rollup/*"
|
||||
- "rollup-*"
|
||||
- "rollup*"
|
||||
swc:
|
||||
patterns:
|
||||
- "@swc/*"
|
||||
- "swc-*"
|
||||
wdio:
|
||||
patterns:
|
||||
- "@wdio/*"
|
||||
|
2
.github/workflows/api-ts-publish.yml
vendored
2
.github/workflows/api-ts-publish.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
run: |
|
||||
export VERSION=`node -e 'console.log(require("../gen-ts-api/package.json").version)'`
|
||||
npm i @goauthentik/api@$VERSION
|
||||
- uses: peter-evans/create-pull-request@v7
|
||||
- uses: peter-evans/create-pull-request@v6
|
||||
id: cpr
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
20
.github/workflows/ci-main.yml
vendored
20
.github/workflows/ci-main.yml
vendored
@ -120,12 +120,6 @@ jobs:
|
||||
with:
|
||||
flags: unit
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
- if: ${{ !cancelled() }}
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
flags: unit
|
||||
file: unittest.xml
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
test-integration:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
@ -144,12 +138,6 @@ jobs:
|
||||
with:
|
||||
flags: integration
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
- if: ${{ !cancelled() }}
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
flags: integration
|
||||
file: unittest.xml
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
test-e2e:
|
||||
name: test-e2e (${{ matrix.job.name }})
|
||||
runs-on: ubuntu-latest
|
||||
@ -202,12 +190,6 @@ jobs:
|
||||
with:
|
||||
flags: e2e
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
- if: ${{ !cancelled() }}
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
flags: e2e
|
||||
file: unittest.xml
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
ci-core-mark:
|
||||
needs:
|
||||
- lint
|
||||
@ -279,7 +261,7 @@ jobs:
|
||||
id: attest
|
||||
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
pr-comment:
|
||||
|
4
.github/workflows/ci-outpost.yml
vendored
4
.github/workflows/ci-outpost.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: latest
|
||||
version: v1.54.2
|
||||
args: --timeout 5000s --verbose
|
||||
skip-cache: true
|
||||
test-unittest:
|
||||
@ -115,7 +115,7 @@ jobs:
|
||||
id: attest
|
||||
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
build-binary:
|
||||
|
3
.github/workflows/ci-web.yml
vendored
3
.github/workflows/ci-web.yml
vendored
@ -45,6 +45,7 @@ jobs:
|
||||
- working-directory: ${{ matrix.project }}/
|
||||
run: |
|
||||
npm ci
|
||||
${{ matrix.extra_setup }}
|
||||
- name: Generate API
|
||||
run: make gen-client-ts
|
||||
- name: Lint
|
||||
@ -91,4 +92,4 @@ jobs:
|
||||
run: make gen-client-ts
|
||||
- name: test
|
||||
working-directory: web/
|
||||
run: npm run test || exit 0
|
||||
run: npm run test
|
||||
|
@ -24,7 +24,7 @@ jobs:
|
||||
- name: Setup authentik env
|
||||
uses: ./.github/actions/setup
|
||||
- run: poetry run ak update_webauthn_mds
|
||||
- uses: peter-evans/create-pull-request@v7
|
||||
- uses: peter-evans/create-pull-request@v6
|
||||
id: cpr
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
|
2
.github/workflows/image-compress.yml
vendored
2
.github/workflows/image-compress.yml
vendored
@ -42,7 +42,7 @@ jobs:
|
||||
with:
|
||||
githubToken: ${{ steps.generate_token.outputs.token }}
|
||||
compressOnly: ${{ github.event_name != 'pull_request' }}
|
||||
- uses: peter-evans/create-pull-request@v7
|
||||
- uses: peter-evans/create-pull-request@v6
|
||||
if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}"
|
||||
id: cpr
|
||||
with:
|
||||
|
8
.github/workflows/release-publish.yml
vendored
8
.github/workflows/release-publish.yml
vendored
@ -51,14 +51,12 @@ jobs:
|
||||
secrets: |
|
||||
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
|
||||
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
|
||||
build-args: |
|
||||
VERSION=${{ github.ref }}
|
||||
tags: ${{ steps.ev.outputs.imageTags }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- uses: actions/attest-build-provenance@v1
|
||||
id: attest
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
build-outpost:
|
||||
@ -113,8 +111,6 @@ jobs:
|
||||
id: push
|
||||
with:
|
||||
push: true
|
||||
build-args: |
|
||||
VERSION=${{ github.ref }}
|
||||
tags: ${{ steps.ev.outputs.imageTags }}
|
||||
file: ${{ matrix.type }}.Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
@ -122,7 +118,7 @@ jobs:
|
||||
- uses: actions/attest-build-provenance@v1
|
||||
id: attest
|
||||
with:
|
||||
subject-name: ${{ steps.ev.outputs.attestImageNames }}
|
||||
subject-name: ${{ steps.ev.outputs.imageNames }}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
build-outpost-binary:
|
||||
|
@ -32,7 +32,7 @@ jobs:
|
||||
poetry run ak compilemessages
|
||||
make web-check-compile
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ steps.generate_token.outputs.token }}
|
||||
branch: extract-compile-backend-translation
|
||||
|
27
Dockerfile
27
Dockerfile
@ -1,7 +1,7 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build website
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS website-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as website-builder
|
||||
|
||||
ENV NODE_ENV=production
|
||||
|
||||
@ -20,7 +20,7 @@ COPY ./SECURITY.md /work/
|
||||
RUN npm run build-bundled
|
||||
|
||||
# Stage 2: Build webui
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS web-builder
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 as web-builder
|
||||
|
||||
ARG GIT_BUILD_HASH
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
@ -43,7 +43,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
|
||||
RUN npm run build
|
||||
|
||||
# Stage 3: Build go proxy
|
||||
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.23-fips-bookworm AS go-builder
|
||||
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
|
||||
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
@ -80,7 +80,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
go build -o /go/authentik ./cmd/server
|
||||
|
||||
# Stage 4: MaxMind GeoIP
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 AS geoip
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
|
||||
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||
ENV GEOIPUPDATE_VERBOSE="1"
|
||||
@ -94,10 +94,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
|
||||
# Stage 5: Python dependencies
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS python-deps
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps
|
||||
|
||||
WORKDIR /ak-root/poetry
|
||||
|
||||
@ -124,17 +121,17 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
|
||||
pip install --force-reinstall /wheels/*"
|
||||
|
||||
# Stage 6: Run
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS final-image
|
||||
FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image
|
||||
|
||||
ARG VERSION
|
||||
ARG GIT_BUILD_HASH
|
||||
ARG VERSION
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
|
||||
LABEL org.opencontainers.image.url=https://goauthentik.io
|
||||
LABEL org.opencontainers.image.description="goauthentik.io Main server image, see https://goauthentik.io for more info."
|
||||
LABEL org.opencontainers.image.source=https://github.com/goauthentik/authentik
|
||||
LABEL org.opencontainers.image.version=${VERSION}
|
||||
LABEL org.opencontainers.image.revision=${GIT_BUILD_HASH}
|
||||
LABEL org.opencontainers.image.url https://goauthentik.io
|
||||
LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info.
|
||||
LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
|
||||
LABEL org.opencontainers.image.version ${VERSION}
|
||||
LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
|
||||
|
||||
WORKDIR /
|
||||
|
||||
|
5
Makefile
5
Makefile
@ -43,7 +43,7 @@ help: ## Show this help
|
||||
sort
|
||||
@echo ""
|
||||
|
||||
go-test:
|
||||
test-go:
|
||||
go test -timeout 0 -v -race -cover ./...
|
||||
|
||||
test-docker: ## Run all tests in a docker-compose
|
||||
@ -210,9 +210,6 @@ web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting is
|
||||
web-install: ## Install the necessary libraries to build the Authentik UI
|
||||
cd web && npm ci
|
||||
|
||||
web-test: ## Run tests for the Authentik UI
|
||||
cd web && npm run test
|
||||
|
||||
web-watch: ## Build and watch the Authentik UI for changes, updating automatically
|
||||
rm -rf web/dist/
|
||||
mkdir web/dist/
|
||||
|
@ -15,9 +15,7 @@
|
||||
|
||||
## What is authentik?
|
||||
|
||||
authentik is an open-source Identity Provider that emphasizes flexibility and versatility, with support for a wide set of protocols.
|
||||
|
||||
Our [enterprise offer](https://goauthentik.io/pricing) can also be used as a self-hosted replacement for large-scale deployments of Okta/Auth0, Entra ID, Ping Identity, or other legacy IdPs for employees and B2B2C use.
|
||||
authentik is an open-source Identity Provider that emphasizes flexibility and versatility. It can be seamlessly integrated into existing environments to support new protocols. authentik is also a great solution for implementing sign-up, recovery, and other similar features in your application, saving you the hassle of dealing with them.
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -20,8 +20,8 @@ Even if the issue is not a CVE, we still greatly appreciate your help in hardeni
|
||||
|
||||
| Version | Supported |
|
||||
| -------- | --------- |
|
||||
| 2024.4.x | ✅ |
|
||||
| 2024.6.x | ✅ |
|
||||
| 2024.8.x | ✅ |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from os import environ
|
||||
|
||||
__version__ = "2024.8.2"
|
||||
__version__ = "2024.6.3"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
@ -1,20 +0,0 @@
|
||||
"""authentik admin analytics"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
def get_analytics_description() -> dict[str, str]:
|
||||
return {
|
||||
"worker_count": _("Number of running workers"),
|
||||
}
|
||||
|
||||
|
||||
def get_analytics_data() -> dict[str, Any]:
|
||||
worker_count = len(CELERY_APP.control.ping(timeout=0.5))
|
||||
return {
|
||||
"worker_count": worker_count,
|
||||
}
|
@ -73,7 +73,7 @@ class SystemInfoSerializer(PassiveSerializer):
|
||||
"authentik_version": get_full_version(),
|
||||
"environment": get_env(),
|
||||
"openssl_fips_enabled": (
|
||||
backend._fips_enabled if LicenseKey.get_total().status().is_valid else None
|
||||
backend._fips_enabled if LicenseKey.get_total().is_valid() else None
|
||||
),
|
||||
"openssl_version": OPENSSL_VERSION,
|
||||
"platform": platform.platform(),
|
||||
|
@ -12,7 +12,6 @@ from rest_framework.views import APIView
|
||||
from authentik import __version__, get_build_hash
|
||||
from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.outposts.models import Outpost
|
||||
|
||||
|
||||
class VersionSerializer(PassiveSerializer):
|
||||
@ -23,7 +22,6 @@ class VersionSerializer(PassiveSerializer):
|
||||
version_latest_valid = SerializerMethodField()
|
||||
build_hash = SerializerMethodField()
|
||||
outdated = SerializerMethodField()
|
||||
outpost_outdated = SerializerMethodField()
|
||||
|
||||
def get_build_hash(self, _) -> str:
|
||||
"""Get build hash, if version is not latest or released"""
|
||||
@ -49,15 +47,6 @@ class VersionSerializer(PassiveSerializer):
|
||||
"""Check if we're running the latest version"""
|
||||
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
|
||||
|
||||
def get_outpost_outdated(self, _) -> bool:
|
||||
"""Check if any outpost is outdated/has a version mismatch"""
|
||||
any_outdated = False
|
||||
for outpost in Outpost.objects.all():
|
||||
for state in outpost.state:
|
||||
if state.version_outdated:
|
||||
any_outdated = True
|
||||
return any_outdated
|
||||
|
||||
|
||||
class VersionView(APIView):
|
||||
"""Get running and latest version."""
|
||||
|
@ -1,8 +1,10 @@
|
||||
"""authentik admin tasks"""
|
||||
|
||||
import re
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.core.validators import URLValidator
|
||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from packaging.version import parse
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
@ -19,6 +21,8 @@ LOGGER = get_logger()
|
||||
VERSION_NULL = "0.0.0"
|
||||
VERSION_CACHE_KEY = "authentik_latest_version"
|
||||
VERSION_CACHE_TIMEOUT = 8 * 60 * 60 # 8 hours
|
||||
# Chop of the first ^ because we want to search the entire string
|
||||
URL_FINDER = URLValidator.regex.pattern[1:]
|
||||
LOCAL_VERSION = parse(__version__)
|
||||
|
||||
|
||||
@ -74,16 +78,10 @@ def update_latest_version(self: SystemTask):
|
||||
context__new_version=upstream_version,
|
||||
).exists():
|
||||
return
|
||||
Event.new(
|
||||
EventAction.UPDATE_AVAILABLE,
|
||||
message=_(
|
||||
"New version {version} available!".format(
|
||||
version=upstream_version,
|
||||
)
|
||||
),
|
||||
new_version=upstream_version,
|
||||
changelog=data.get("stable", {}).get("changelog_url"),
|
||||
).save()
|
||||
event_dict = {"new_version": upstream_version}
|
||||
if match := re.search(URL_FINDER, data.get("stable", {}).get("changelog", "")):
|
||||
event_dict["message"] = f"Changelog: {match.group()}"
|
||||
Event.new(EventAction.UPDATE_AVAILABLE, **event_dict).save()
|
||||
except (RequestException, IndexError) as exc:
|
||||
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
|
||||
self.set_error(exc)
|
||||
|
@ -17,7 +17,6 @@ RESPONSE_VALID = {
|
||||
"stable": {
|
||||
"version": "99999999.9999999",
|
||||
"changelog": "See https://goauthentik.io/test",
|
||||
"changelog_url": "https://goauthentik.io/test",
|
||||
"reason": "bugfix",
|
||||
},
|
||||
}
|
||||
@ -36,7 +35,7 @@ class TestAdminTasks(TestCase):
|
||||
Event.objects.filter(
|
||||
action=EventAction.UPDATE_AVAILABLE,
|
||||
context__new_version="99999999.9999999",
|
||||
context__message="New version 99999999.9999999 available!",
|
||||
context__message="Changelog: https://goauthentik.io/test",
|
||||
).exists()
|
||||
)
|
||||
# test that a consecutive check doesn't create a duplicate event
|
||||
@ -46,7 +45,7 @@ class TestAdminTasks(TestCase):
|
||||
Event.objects.filter(
|
||||
action=EventAction.UPDATE_AVAILABLE,
|
||||
context__new_version="99999999.9999999",
|
||||
context__message="New version 99999999.9999999 available!",
|
||||
context__message="Changelog: https://goauthentik.io/test",
|
||||
)
|
||||
),
|
||||
1,
|
||||
|
@ -1,54 +0,0 @@
|
||||
"""authentik analytics api"""
|
||||
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework.fields import CharField, DictField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ViewSet
|
||||
|
||||
from authentik.analytics.utils import get_analytics_data, get_analytics_description
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.rbac.permissions import HasPermission
|
||||
|
||||
|
||||
class AnalyticsDescriptionSerializer(PassiveSerializer):
|
||||
label = CharField()
|
||||
desc = CharField()
|
||||
|
||||
|
||||
class AnalyticsDescriptionViewSet(ViewSet):
|
||||
"""Read-only view of analytics descriptions"""
|
||||
|
||||
permission_classes = [HasPermission("authentik_rbac.view_system_settings")]
|
||||
|
||||
@extend_schema(responses={200: AnalyticsDescriptionSerializer})
|
||||
def list(self, request: Request) -> Response:
|
||||
"""Read-only view of analytics descriptions"""
|
||||
data = []
|
||||
for label, desc in get_analytics_description().items():
|
||||
data.append({"label": label, "desc": desc})
|
||||
return Response(AnalyticsDescriptionSerializer(data, many=True).data)
|
||||
|
||||
|
||||
class AnalyticsDataViewSet(ViewSet):
|
||||
"""Read-only view of analytics descriptions"""
|
||||
|
||||
permission_classes = [HasPermission("authentik_rbac.edit_system_settings")]
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: inline_serializer(
|
||||
name="AnalyticsData",
|
||||
fields={
|
||||
"data": DictField(),
|
||||
},
|
||||
)
|
||||
}
|
||||
)
|
||||
def list(self, request: Request) -> Response:
|
||||
"""Read-only view of analytics descriptions"""
|
||||
return Response(
|
||||
{
|
||||
"data": get_analytics_data(force=True),
|
||||
}
|
||||
)
|
@ -1,12 +0,0 @@
|
||||
"""authentik analytics app config"""
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
|
||||
|
||||
class AuthentikAdminConfig(ManagedAppConfig):
|
||||
"""authentik analytics app config"""
|
||||
|
||||
name = "authentik.analytics"
|
||||
label = "authentik_analytics"
|
||||
verbose_name = "authentik Analytics"
|
||||
default = True
|
@ -1,19 +0,0 @@
|
||||
"""authentik analytics mixins"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class AnalyticsMixin:
|
||||
@classmethod
|
||||
def get_analytics_description(cls) -> dict[str, str]:
|
||||
object_name = _(cls._meta.verbose_name)
|
||||
count_desc = _("Number of {object_name} objects".format_map({"object_name": object_name}))
|
||||
return {
|
||||
"count": count_desc,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_analytics_data(cls) -> dict[str, Any]:
|
||||
return {"count": cls.objects.all().count()}
|
@ -1,17 +0,0 @@
|
||||
"""authentik admin settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"analytics_send": {
|
||||
"task": "authentik.analytics.tasks.send_analytics",
|
||||
"schedule": crontab(
|
||||
minute=fqdn_rand("analytics_send"),
|
||||
hour=fqdn_rand("analytics_send", stop=24),
|
||||
day_of_week=fqdn_rand("analytics_send", 7),
|
||||
),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
"""authentik admin tasks"""
|
||||
|
||||
import orjson
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.analytics.utils import get_analytics_data
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def send_analytics(self: SystemTask):
|
||||
"""Send analytics"""
|
||||
for tenant in Tenant.objects.filter(ready=True):
|
||||
data = get_analytics_data(current_tenant=tenant)
|
||||
if not tenant.analytics_enabled or not data:
|
||||
self.set_status(TaskStatus.WARNING, "Analytics disabled. Nothing was sent.")
|
||||
return
|
||||
try:
|
||||
response = get_http_session().post(
|
||||
"https://customers.goauthentik.io/api/analytics/post/", json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.set_status(
|
||||
TaskStatus.SUCCESSFUL,
|
||||
"Successfully sent analytics",
|
||||
orjson.dumps(
|
||||
data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS | orjson.OPT_UTC_Z
|
||||
).decode(),
|
||||
)
|
||||
Event.new(
|
||||
EventAction.ANALYTICS_SENT,
|
||||
message=_("Analytics sent"),
|
||||
analytics_data=data,
|
||||
).save()
|
||||
except (RequestException, IndexError) as exc:
|
||||
self.set_error(exc)
|
@ -1,76 +0,0 @@
|
||||
"""authentik analytics tests"""
|
||||
|
||||
from json import loads
|
||||
from requests_mock import Mocker
|
||||
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
from authentik import __version__
|
||||
from authentik.analytics.tasks import send_analytics
|
||||
from authentik.analytics.utils import get_analytics_apps_data, get_analytics_apps_description, get_analytics_data, get_analytics_description, get_analytics_models_data, get_analytics_models_description
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
|
||||
class TestAnalytics(TestCase):
|
||||
"""test analytics api"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.user = User.objects.create(username=generate_id())
|
||||
self.group = Group.objects.create(name=generate_id(), is_superuser=True)
|
||||
self.group.users.add(self.user)
|
||||
self.client.force_login(self.user)
|
||||
self.tenant = get_current_tenant()
|
||||
|
||||
def test_description_api(self):
|
||||
"""Test Version API"""
|
||||
response = self.client.get(reverse("authentik_api:analytics-description-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
loads(response.content)
|
||||
|
||||
def test_data_api(self):
|
||||
"""Test Version API"""
|
||||
response = self.client.get(reverse("authentik_api:analytics-data-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["data"]["version"], __version__)
|
||||
|
||||
def test_sending_enabled(self):
|
||||
"""Test analytics sending"""
|
||||
self.tenant.analytics_enabled = True
|
||||
self.tenant.save()
|
||||
with Mocker() as mocker:
|
||||
mocker.post("https://customers.goauthentik.io/api/analytics/post/", status_code=200)
|
||||
send_analytics.delay().get()
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.ANALYTICS_SENT
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_sending_disabled(self):
|
||||
"""Test analytics sending"""
|
||||
self.tenant.analytics_enabled = False
|
||||
self.tenant.save()
|
||||
send_analytics.delay().get()
|
||||
self.assertFalse(
|
||||
Event.objects.filter(
|
||||
action=EventAction.ANALYTICS_SENT
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_description_data_match_apps(self):
|
||||
"""Test description and data keys match"""
|
||||
description = get_analytics_apps_description()
|
||||
data = get_analytics_apps_data()
|
||||
self.assertEqual(data.keys(), description.keys())
|
||||
|
||||
def test_description_data_match_models(self):
|
||||
"""Test description and data keys match"""
|
||||
description = get_analytics_models_description()
|
||||
data = get_analytics_models_data()
|
||||
self.assertEqual(data.keys(), description.keys())
|
@ -1,8 +0,0 @@
|
||||
"""API URLs"""
|
||||
|
||||
from authentik.analytics.api import AnalyticsDataViewSet, AnalyticsDescriptionViewSet
|
||||
|
||||
api_urlpatterns = [
|
||||
("analytics/description", AnalyticsDescriptionViewSet, "analytics-description"),
|
||||
("analytics/data", AnalyticsDataViewSet, "analytics-data"),
|
||||
]
|
@ -1,112 +0,0 @@
|
||||
"""authentik analytics utils"""
|
||||
|
||||
from hashlib import sha256
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
from structlog import get_logger
|
||||
|
||||
from authentik import get_full_version
|
||||
from authentik.analytics.models import AnalyticsMixin
|
||||
from authentik.lib.utils.reflection import get_apps
|
||||
from authentik.root.install_id import get_install_id
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def get_analytics_apps() -> dict:
|
||||
modules = {}
|
||||
for _authentik_app in get_apps():
|
||||
try:
|
||||
module = import_module(f"{_authentik_app.name}.analytics")
|
||||
except ModuleNotFoundError:
|
||||
continue
|
||||
except ImportError as exc:
|
||||
LOGGER.warning(
|
||||
"Could not import app's analytics", app_name=_authentik_app.name, exc=exc
|
||||
)
|
||||
continue
|
||||
if not hasattr(module, "get_analytics_description") or not hasattr(
|
||||
module, "get_analytics_data"
|
||||
):
|
||||
LOGGER.debug(
|
||||
"App does not define API URLs",
|
||||
app_name=_authentik_app.name,
|
||||
)
|
||||
continue
|
||||
modules[_authentik_app.label] = module
|
||||
return modules
|
||||
|
||||
|
||||
def get_analytics_apps_description() -> dict[str, str]:
|
||||
result = {}
|
||||
for app_label, module in get_analytics_apps().items():
|
||||
for k, v in module.get_analytics_description().items():
|
||||
result[f"{app_label}/app/{k}"] = v
|
||||
return result
|
||||
|
||||
|
||||
def get_analytics_apps_data() -> dict[str, Any]:
|
||||
result = {}
|
||||
for app_label, module in get_analytics_apps().items():
|
||||
for k, v in module.get_analytics_data().items():
|
||||
result[f"{app_label}/app/{k}"] = v
|
||||
return result
|
||||
|
||||
|
||||
def get_analytics_models() -> list[AnalyticsMixin]:
|
||||
def get_subclasses(cls):
|
||||
for subclass in cls.__subclasses__():
|
||||
if subclass.__subclasses__():
|
||||
yield from get_subclasses(subclass)
|
||||
elif not subclass._meta.abstract:
|
||||
yield subclass
|
||||
|
||||
return list(get_subclasses(AnalyticsMixin))
|
||||
|
||||
|
||||
def get_analytics_models_description() -> dict[str, str]:
|
||||
result = {}
|
||||
for model in get_analytics_models():
|
||||
for k, v in model.get_analytics_description().items():
|
||||
result[f"{model._meta.app_label}/models/{model._meta.object_name}/{k}"] = v
|
||||
return result
|
||||
|
||||
|
||||
def get_analytics_models_data() -> dict[str, Any]:
|
||||
result = {}
|
||||
for model in get_analytics_models():
|
||||
for k, v in model.get_analytics_data().items():
|
||||
result[f"{model._meta.app_label}/models/{model._meta.object_name}/{k}"] = v
|
||||
return result
|
||||
|
||||
|
||||
def get_analytics_description() -> dict[str, str]:
|
||||
return {
|
||||
**get_analytics_apps_description(),
|
||||
**get_analytics_models_description(),
|
||||
}
|
||||
|
||||
|
||||
def get_analytics_data(current_tenant: Tenant | None = None, force: bool = False) -> dict[str, Any]:
|
||||
current_tenant = current_tenant or get_current_tenant()
|
||||
if not current_tenant.analytics_enabled and not force:
|
||||
return {}
|
||||
data = {
|
||||
**get_analytics_apps_data(),
|
||||
**get_analytics_models_data(),
|
||||
}
|
||||
to_remove = []
|
||||
for key in data.keys():
|
||||
if key not in current_tenant.analytics_sources:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del data[key]
|
||||
return {
|
||||
**data,
|
||||
"install_id_hash": sha256(get_install_id().encode()).hexdigest(),
|
||||
"tenant_hash": sha256(current_tenant.tenant_uuid.bytes).hexdigest(),
|
||||
"version": get_full_version(),
|
||||
}
|
@ -171,7 +171,7 @@ class Importer:
|
||||
def default_context(self):
|
||||
"""Default context"""
|
||||
return {
|
||||
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid,
|
||||
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid(),
|
||||
"goauthentik.io/rbac/models": rbac_models(),
|
||||
}
|
||||
|
||||
|
@ -30,10 +30,8 @@ from authentik.core.api.utils import (
|
||||
PassiveSerializer,
|
||||
)
|
||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import Group, PropertyMapping, User
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.api.exec import PolicyTestSerializer
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
@ -164,15 +162,12 @@ class PropertyMappingViewSet(
|
||||
|
||||
response_data = {"successful": True, "result": ""}
|
||||
try:
|
||||
result = mapping.evaluate(dry_run=True, **context)
|
||||
result = mapping.evaluate(**context)
|
||||
response_data["result"] = dumps(
|
||||
sanitize_item(result), indent=(4 if format_result else None)
|
||||
)
|
||||
except PropertyMappingExpressionException as exc:
|
||||
response_data["result"] = exception_to_string(exc.exc)
|
||||
response_data["successful"] = False
|
||||
except Exception as exc:
|
||||
response_data["result"] = exception_to_string(exc)
|
||||
response_data["result"] = str(exc)
|
||||
response_data["successful"] = False
|
||||
response = PropertyMappingTestResultSerializer(response_data)
|
||||
return Response(response.data)
|
||||
|
@ -14,7 +14,6 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
|
||||
class DeleteAction(Enum):
|
||||
@ -54,7 +53,7 @@ class UsedByMixin:
|
||||
@extend_schema(
|
||||
responses={200: UsedBySerializer(many=True)},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def used_by(self, request: Request, *args, **kwargs) -> Response:
|
||||
"""Get a list of all objects that use this object"""
|
||||
model: Model = self.get_object()
|
||||
|
@ -678,10 +678,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
if not request.tenant.impersonation:
|
||||
LOGGER.debug("User attempted to impersonate", user=request.user)
|
||||
return Response(status=401)
|
||||
user_to_be = self.get_object()
|
||||
if not request.user.has_perm("impersonate", user_to_be):
|
||||
if not request.user.has_perm("impersonate"):
|
||||
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
|
||||
return Response(status=401)
|
||||
user_to_be = self.get_object()
|
||||
if user_to_be.pk == self.request.user.pk:
|
||||
LOGGER.debug("User attempted to impersonate themselves", user=request.user)
|
||||
return Response(status=401)
|
||||
|
@ -9,11 +9,10 @@ class Command(TenantCommand):
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("--type", type=str, required=True)
|
||||
parser.add_argument("--all", action="store_true", default=False)
|
||||
parser.add_argument("usernames", nargs="*", type=str)
|
||||
parser.add_argument("--all", action="store_true")
|
||||
parser.add_argument("usernames", nargs="+", type=str)
|
||||
|
||||
def handle_per_tenant(self, **options):
|
||||
print(options)
|
||||
new_type = UserTypes(options["type"])
|
||||
qs = (
|
||||
User.objects.exclude_anonymous()
|
||||
@ -23,9 +22,6 @@ class Command(TenantCommand):
|
||||
if options["usernames"] and options["all"]:
|
||||
self.stderr.write("--all and usernames specified, only one can be specified")
|
||||
return
|
||||
if not options["usernames"] and not options["all"]:
|
||||
self.stderr.write("--all or usernames must be specified")
|
||||
return
|
||||
if options["usernames"] and not options["all"]:
|
||||
qs = qs.filter(username__in=options["usernames"])
|
||||
updated = qs.update(type=new_type)
|
||||
|
@ -23,7 +23,6 @@ from model_utils.managers import InheritanceManager
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.analytics.models import AnalyticsMixin
|
||||
from authentik.blueprints.models import ManagedModel
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.types import UILoginButton, UserSettingSerializer
|
||||
@ -169,7 +168,7 @@ class GroupQuerySet(CTEQuerySet):
|
||||
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
|
||||
|
||||
|
||||
class Group(SerializerModel, AttributesMixin, AnalyticsMixin):
|
||||
class Group(SerializerModel, AttributesMixin):
|
||||
"""Group model which supports a basic hierarchy and has attributes"""
|
||||
|
||||
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
@ -259,7 +258,7 @@ class UserManager(DjangoUserManager):
|
||||
return self.get_queryset().exclude_anonymous()
|
||||
|
||||
|
||||
class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser, AnalyticsMixin):
|
||||
class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
|
||||
"""authentik User model, based on django's contrib auth user model."""
|
||||
|
||||
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
||||
@ -377,7 +376,7 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser, An
|
||||
return get_avatar(self)
|
||||
|
||||
|
||||
class Provider(SerializerModel, AnalyticsMixin):
|
||||
class Provider(SerializerModel):
|
||||
"""Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application"""
|
||||
|
||||
name = models.TextField(unique=True)
|
||||
@ -467,11 +466,13 @@ class ApplicationQuerySet(QuerySet):
|
||||
def with_provider(self) -> "QuerySet[Application]":
|
||||
qs = self.select_related("provider")
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
if LOOKUP_SEP in subclass:
|
||||
continue
|
||||
qs = qs.select_related(f"provider__{subclass}")
|
||||
return qs
|
||||
|
||||
|
||||
class Application(SerializerModel, PolicyBindingModel, AnalyticsMixin):
|
||||
class Application(SerializerModel, PolicyBindingModel):
|
||||
"""Every Application which uses authentik for authentication/identification/authorization
|
||||
needs an Application record. Other authentication types can subclass this Model to
|
||||
add custom fields and other properties"""
|
||||
@ -544,24 +545,15 @@ class Application(SerializerModel, PolicyBindingModel, AnalyticsMixin):
|
||||
if not self.provider:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
base_class = Provider
|
||||
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class):
|
||||
parent = self.provider
|
||||
for level in subclass.split(LOOKUP_SEP):
|
||||
try:
|
||||
parent = getattr(parent, level)
|
||||
except AttributeError:
|
||||
break
|
||||
if parent in candidates:
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
# We don't care about recursion, skip nested models
|
||||
if LOOKUP_SEP in subclass:
|
||||
continue
|
||||
idx = subclass.count(LOOKUP_SEP)
|
||||
if type(parent) is not base_class:
|
||||
idx += 1
|
||||
candidates.insert(idx, parent)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[-1]
|
||||
try:
|
||||
return getattr(self.provider, subclass)
|
||||
except AttributeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name)
|
||||
@ -604,7 +596,7 @@ class SourceGroupMatchingModes(models.TextChoices):
|
||||
)
|
||||
|
||||
|
||||
class Source(ManagedModel, SerializerModel, PolicyBindingModel, AnalyticsMixin):
|
||||
class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
|
||||
|
||||
name = models.TextField(help_text=_("Source's display Name."))
|
||||
@ -736,7 +728,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel, AnalyticsMixin):
|
||||
]
|
||||
|
||||
|
||||
class UserSourceConnection(SerializerModel, CreatedUpdatedModel, AnalyticsMixin):
|
||||
class UserSourceConnection(SerializerModel, CreatedUpdatedModel):
|
||||
"""Connection between User and Source."""
|
||||
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
@ -756,7 +748,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel, AnalyticsMixin)
|
||||
unique_together = (("user", "source"),)
|
||||
|
||||
|
||||
class GroupSourceConnection(SerializerModel, CreatedUpdatedModel, AnalyticsMixin):
|
||||
class GroupSourceConnection(SerializerModel, CreatedUpdatedModel):
|
||||
"""Connection between Group and Source."""
|
||||
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
@ -880,7 +872,7 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
|
||||
).save()
|
||||
|
||||
|
||||
class PropertyMapping(SerializerModel, ManagedModel, AnalyticsMixin):
|
||||
class PropertyMapping(SerializerModel, ManagedModel):
|
||||
"""User-defined key -> x mapping which can be used by providers to expose extra data."""
|
||||
|
||||
pm_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
@ -909,7 +901,7 @@ class PropertyMapping(SerializerModel, ManagedModel, AnalyticsMixin):
|
||||
except ControlFlowException as exc:
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
raise PropertyMappingExpressionException(exc, self) from exc
|
||||
raise PropertyMappingExpressionException(self, exc) from exc
|
||||
|
||||
def __str__(self):
|
||||
return f"Property Mapping {self.name}"
|
||||
|
@ -9,12 +9,9 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
from authentik.providers.saml.models import SAMLProvider
|
||||
|
||||
|
||||
class TestApplicationsAPI(APITestCase):
|
||||
@ -225,31 +222,3 @@ class TestApplicationsAPI(APITestCase):
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_get_provider(self):
|
||||
"""Ensure that proxy providers (at the time of writing that is the only provider
|
||||
that inherits from another proxy type (OAuth) instead of inheriting from the root
|
||||
provider class) is correctly looked up and selected from the database"""
|
||||
slug = generate_id()
|
||||
provider = ProxyProvider.objects.create(name=generate_id())
|
||||
Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=slug,
|
||||
provider=provider,
|
||||
)
|
||||
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
|
||||
self.assertEqual(
|
||||
Application.objects.with_provider().get(slug=slug).get_provider(), provider
|
||||
)
|
||||
|
||||
slug = generate_id()
|
||||
provider = SAMLProvider.objects.create(name=generate_id())
|
||||
Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=slug,
|
||||
provider=provider,
|
||||
)
|
||||
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
|
||||
self.assertEqual(
|
||||
Application.objects.with_provider().get(slug=slug).get_provider(), provider
|
||||
)
|
||||
|
@ -3,10 +3,10 @@
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
|
||||
from authentik.core.models import User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.other_user = create_test_user()
|
||||
self.other_user = User.objects.create(username="to-impersonate")
|
||||
self.user = create_test_admin_user()
|
||||
|
||||
def test_impersonate_simple(self):
|
||||
@ -44,26 +44,6 @@ class TestImpersonation(APITestCase):
|
||||
self.assertEqual(response_body["user"]["username"], self.user.username)
|
||||
self.assertNotIn("original", response_body)
|
||||
|
||||
def test_impersonate_scoped(self):
|
||||
"""Test impersonation with scoped permissions"""
|
||||
new_user = create_test_user()
|
||||
assign_perm("authentik_core.impersonate", new_user, self.other_user)
|
||||
assign_perm("authentik_core.view_user", new_user, self.other_user)
|
||||
self.client.force_login(new_user)
|
||||
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:user-impersonate",
|
||||
kwargs={"pk": self.other_user.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
|
||||
response = self.client.get(reverse("authentik_api:user-me"))
|
||||
response_body = loads(response.content.decode())
|
||||
self.assertEqual(response_body["user"]["username"], self.other_user.username)
|
||||
self.assertEqual(response_body["original"]["username"], new_user.username)
|
||||
|
||||
def test_impersonate_denied(self):
|
||||
"""test impersonation without permissions"""
|
||||
self.client.force_login(self.other_user)
|
||||
|
@ -35,7 +35,6 @@ from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.rbac.decorators import permission_required
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -266,7 +265,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
],
|
||||
responses={200: CertificateDataSerializer(many=False)},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def view_certificate(self, request: Request, pk: str) -> Response:
|
||||
"""Return certificate-key pairs certificate and log access"""
|
||||
certificate: CertificateKeyPair = self.get_object()
|
||||
@ -296,7 +295,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
],
|
||||
responses={200: CertificateDataSerializer(many=False)},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def view_private_key(self, request: Request, pk: str) -> Response:
|
||||
"""Return certificate-key pairs private key and log access"""
|
||||
certificate: CertificateKeyPair = self.get_object()
|
||||
|
@ -214,46 +214,6 @@ class TestCrypto(APITestCase):
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertIn("Content-Disposition", response)
|
||||
|
||||
def test_certificate_download_denied(self):
|
||||
"""Test certificate export (download)"""
|
||||
self.client.logout()
|
||||
keypair = create_test_cert()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-certificate",
|
||||
kwargs={"pk": keypair.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-certificate",
|
||||
kwargs={"pk": keypair.pk},
|
||||
),
|
||||
data={"download": True},
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
|
||||
def test_private_key_download_denied(self):
|
||||
"""Test private_key export (download)"""
|
||||
self.client.logout()
|
||||
keypair = create_test_cert()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-private-key",
|
||||
kwargs={"pk": keypair.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-view-private-key",
|
||||
kwargs={"pk": keypair.pk},
|
||||
),
|
||||
data={"download": True},
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
|
||||
def test_used_by(self):
|
||||
"""Test used_by endpoint"""
|
||||
self.client.force_login(create_test_admin_user())
|
||||
@ -286,26 +246,6 @@ class TestCrypto(APITestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_used_by_denied(self):
|
||||
"""Test used_by endpoint"""
|
||||
self.client.logout()
|
||||
keypair = create_test_cert()
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
client_secret=generate_key(),
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris="http://localhost",
|
||||
signing_key=keypair,
|
||||
)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:certificatekeypair-used-by",
|
||||
kwargs={"pk": keypair.pk},
|
||||
)
|
||||
)
|
||||
self.assertEqual(403, response.status_code)
|
||||
|
||||
def test_discovery(self):
|
||||
"""Test certificate discovery"""
|
||||
name = generate_id()
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Enterprise API Views"""
|
||||
|
||||
from dataclasses import asdict
|
||||
from datetime import timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, inline_serializer
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, IntegerField
|
||||
@ -29,7 +30,7 @@ class EnterpriseRequiredMixin:
|
||||
|
||||
def validate(self, attrs: dict) -> dict:
|
||||
"""Check that a valid license exists"""
|
||||
if not LicenseKey.cached_summary().status.is_valid:
|
||||
if not LicenseKey.cached_summary().has_license:
|
||||
raise ValidationError(_("Enterprise is required to create/update this object."))
|
||||
return super().validate(attrs)
|
||||
|
||||
@ -86,7 +87,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
|
||||
},
|
||||
)
|
||||
@action(detail=False, methods=["GET"])
|
||||
def install_id(self, request: Request) -> Response:
|
||||
def get_install_id(self, request: Request) -> Response:
|
||||
"""Get install_id"""
|
||||
return Response(
|
||||
data={
|
||||
@ -99,22 +100,12 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
|
||||
responses={
|
||||
200: LicenseSummarySerializer(),
|
||||
},
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="cached",
|
||||
location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.BOOL,
|
||||
default=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
|
||||
def summary(self, request: Request) -> Response:
|
||||
"""Get the total license status"""
|
||||
summary = LicenseKey.cached_summary()
|
||||
if request.query_params.get("cached", "true").lower() == "false":
|
||||
summary = LicenseKey.get_total().summary()
|
||||
response = LicenseSummarySerializer(instance=summary)
|
||||
response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
|
||||
response.is_valid(raise_exception=True)
|
||||
return Response(response.data)
|
||||
|
||||
@permission_required(None, ["authentik_enterprise.view_license"])
|
||||
@ -137,7 +128,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
|
||||
forecast_for_months = 12
|
||||
response = LicenseForecastSerializer(
|
||||
data={
|
||||
"internal_users": LicenseKey.get_internal_user_count(),
|
||||
"internal_users": LicenseKey.get_default_user_count(),
|
||||
"external_users": LicenseKey.get_external_user_count(),
|
||||
"forecasted_internal_users": (internal_in_last_month * forecast_for_months),
|
||||
"forecasted_external_users": (external_in_last_month * forecast_for_months),
|
||||
|
@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
|
||||
"""Actual enterprise check, cached"""
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
return LicenseKey.cached_summary().status.is_valid
|
||||
return LicenseKey.cached_summary().valid
|
||||
|
@ -3,37 +3,24 @@
|
||||
from base64 import b64decode
|
||||
from binascii import Error
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from time import mktime
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
|
||||
from dacite import DaciteError, from_dict
|
||||
from dacite import from_dict
|
||||
from django.core.cache import cache
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils.timezone import now
|
||||
from jwt import PyJWTError, decode, get_unverified_header
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import (
|
||||
ChoiceField,
|
||||
DateTimeField,
|
||||
IntegerField,
|
||||
ListField,
|
||||
)
|
||||
from rest_framework.fields import BooleanField, DateTimeField, IntegerField
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.core.models import User, UserTypes
|
||||
from authentik.enterprise.models import (
|
||||
THRESHOLD_READ_ONLY_WEEKS,
|
||||
THRESHOLD_WARNING_ADMIN_WEEKS,
|
||||
THRESHOLD_WARNING_EXPIRY_WEEKS,
|
||||
THRESHOLD_WARNING_USER_WEEKS,
|
||||
License,
|
||||
LicenseUsage,
|
||||
LicenseUsageStatus,
|
||||
)
|
||||
from authentik.enterprise.models import License, LicenseUsage
|
||||
from authentik.tenants.utils import get_unique_identifier
|
||||
|
||||
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
|
||||
@ -55,9 +42,6 @@ def get_license_aud() -> str:
|
||||
class LicenseFlags(Enum):
|
||||
"""License flags"""
|
||||
|
||||
TRIAL = "trial"
|
||||
NON_PRODUCTION = "non_production"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LicenseSummary:
|
||||
@ -65,9 +49,12 @@ class LicenseSummary:
|
||||
|
||||
internal_users: int
|
||||
external_users: int
|
||||
status: LicenseUsageStatus
|
||||
valid: bool
|
||||
show_admin_warning: bool
|
||||
show_user_warning: bool
|
||||
read_only: bool
|
||||
latest_valid: datetime
|
||||
license_flags: list[LicenseFlags]
|
||||
has_license: bool
|
||||
|
||||
|
||||
class LicenseSummarySerializer(PassiveSerializer):
|
||||
@ -75,9 +62,12 @@ class LicenseSummarySerializer(PassiveSerializer):
|
||||
|
||||
internal_users = IntegerField(required=True)
|
||||
external_users = IntegerField(required=True)
|
||||
status = ChoiceField(choices=LicenseUsageStatus.choices)
|
||||
valid = BooleanField()
|
||||
show_admin_warning = BooleanField()
|
||||
show_user_warning = BooleanField()
|
||||
read_only = BooleanField()
|
||||
latest_valid = DateTimeField()
|
||||
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags)))
|
||||
has_license = BooleanField()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -90,10 +80,10 @@ class LicenseKey:
|
||||
name: str
|
||||
internal_users: int = 0
|
||||
external_users: int = 0
|
||||
license_flags: list[LicenseFlags] = field(default_factory=list)
|
||||
flags: list[LicenseFlags] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
|
||||
def validate(jwt: str) -> "LicenseKey":
|
||||
"""Validate the license from a given JWT"""
|
||||
try:
|
||||
headers = get_unverified_header(jwt)
|
||||
@ -117,28 +107,26 @@ class LicenseKey:
|
||||
our_cert.public_key(),
|
||||
algorithms=["ES512"],
|
||||
audience=get_license_aud(),
|
||||
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
|
||||
),
|
||||
)
|
||||
except PyJWTError:
|
||||
unverified = decode(jwt, options={"verify_signature": False})
|
||||
if unverified["aud"] != get_license_aud():
|
||||
raise ValidationError("Invalid Install ID in license") from None
|
||||
raise ValidationError("Unable to verify license") from None
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def get_total() -> "LicenseKey":
|
||||
"""Get a summarized version of all (not expired) licenses"""
|
||||
active_licenses = License.objects.filter(expiry__gte=now())
|
||||
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
|
||||
for lic in License.objects.all():
|
||||
for lic in active_licenses:
|
||||
total.internal_users += lic.internal_users
|
||||
total.external_users += lic.external_users
|
||||
exp_ts = int(mktime(lic.expiry.timetuple()))
|
||||
if total.exp == 0:
|
||||
total.exp = exp_ts
|
||||
total.exp = max(total.exp, exp_ts)
|
||||
total.license_flags.extend(lic.status.license_flags)
|
||||
if exp_ts <= total.exp:
|
||||
total.exp = exp_ts
|
||||
total.flags.extend(lic.status.flags)
|
||||
return total
|
||||
|
||||
@staticmethod
|
||||
@ -147,7 +135,7 @@ class LicenseKey:
|
||||
return User.objects.all().exclude_anonymous().exclude(is_active=False)
|
||||
|
||||
@staticmethod
|
||||
def get_internal_user_count():
|
||||
def get_default_user_count():
|
||||
"""Get current default user count"""
|
||||
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
|
||||
|
||||
@ -156,73 +144,59 @@ class LicenseKey:
|
||||
"""Get current external user count"""
|
||||
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()
|
||||
|
||||
def _last_valid_date(self):
|
||||
last_valid_date = (
|
||||
LicenseUsage.objects.order_by("-record_date")
|
||||
.filter(status=LicenseUsageStatus.VALID)
|
||||
.first()
|
||||
)
|
||||
if not last_valid_date:
|
||||
return datetime.fromtimestamp(0, UTC)
|
||||
return last_valid_date.record_date
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the given license body covers all users
|
||||
|
||||
def status(self) -> LicenseUsageStatus:
|
||||
"""Check if the given license body covers all users, and is valid."""
|
||||
last_valid = self._last_valid_date()
|
||||
if self.exp == 0 and not License.objects.exists():
|
||||
return LicenseUsageStatus.UNLICENSED
|
||||
_now = now()
|
||||
# Check limit-exceeded based status
|
||||
internal_users = self.get_internal_user_count()
|
||||
external_users = self.get_external_user_count()
|
||||
if internal_users > self.internal_users or external_users > self.external_users:
|
||||
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
|
||||
return LicenseUsageStatus.READ_ONLY
|
||||
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
|
||||
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
|
||||
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
|
||||
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
|
||||
# Check expiry based status
|
||||
if datetime.fromtimestamp(self.exp, UTC) < _now:
|
||||
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
|
||||
weeks=THRESHOLD_READ_ONLY_WEEKS
|
||||
):
|
||||
return LicenseUsageStatus.READ_ONLY
|
||||
return LicenseUsageStatus.EXPIRED
|
||||
# Expiry warning
|
||||
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
|
||||
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
|
||||
):
|
||||
return LicenseUsageStatus.EXPIRY_SOON
|
||||
return LicenseUsageStatus.VALID
|
||||
Only checks the current count, no historical data is checked"""
|
||||
default_users = self.get_default_user_count()
|
||||
if default_users > self.internal_users:
|
||||
return False
|
||||
active_users = self.get_external_user_count()
|
||||
if active_users > self.external_users:
|
||||
return False
|
||||
return True
|
||||
|
||||
def record_usage(self):
|
||||
"""Capture the current validity status and metrics and save them"""
|
||||
threshold = now() - timedelta(hours=8)
|
||||
usage = (
|
||||
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first()
|
||||
)
|
||||
if not usage:
|
||||
usage = LicenseUsage.objects.create(
|
||||
internal_user_count=self.get_internal_user_count(),
|
||||
if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
|
||||
LicenseUsage.objects.create(
|
||||
user_count=self.get_default_user_count(),
|
||||
external_user_count=self.get_external_user_count(),
|
||||
status=self.status(),
|
||||
within_limits=self.is_valid(),
|
||||
)
|
||||
summary = asdict(self.summary())
|
||||
# Also cache the latest summary for the middleware
|
||||
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
|
||||
return usage
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def last_valid_date() -> datetime:
|
||||
"""Get the last date the license was valid"""
|
||||
usage: LicenseUsage = (
|
||||
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
|
||||
)
|
||||
if not usage:
|
||||
return now()
|
||||
return usage.record_date
|
||||
|
||||
def summary(self) -> LicenseSummary:
|
||||
"""Summary of license status"""
|
||||
status = self.status()
|
||||
has_license = License.objects.all().count() > 0
|
||||
last_valid = LicenseKey.last_valid_date()
|
||||
show_admin_warning = last_valid < now() - timedelta(weeks=2)
|
||||
show_user_warning = last_valid < now() - timedelta(weeks=4)
|
||||
read_only = last_valid < now() - timedelta(weeks=6)
|
||||
latest_valid = datetime.fromtimestamp(self.exp)
|
||||
return LicenseSummary(
|
||||
show_admin_warning=show_admin_warning and has_license,
|
||||
show_user_warning=show_user_warning and has_license,
|
||||
read_only=read_only and has_license,
|
||||
latest_valid=latest_valid,
|
||||
internal_users=self.internal_users,
|
||||
external_users=self.external_users,
|
||||
status=status,
|
||||
license_flags=self.license_flags,
|
||||
valid=self.is_valid(),
|
||||
has_license=has_license,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -231,8 +205,4 @@ class LicenseKey:
|
||||
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
|
||||
if not summary:
|
||||
return LicenseKey.get_total().summary()
|
||||
try:
|
||||
return from_dict(LicenseSummary, summary)
|
||||
except DaciteError:
|
||||
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
|
||||
return LicenseKey.get_total().summary()
|
||||
return from_dict(LicenseSummary, summary)
|
||||
|
@ -8,7 +8,6 @@ from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.enterprise.api import LicenseViewSet
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.enterprise.models import LicenseUsageStatus
|
||||
from authentik.flows.views.executor import FlowExecutorView
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
@ -44,7 +43,7 @@ class EnterpriseMiddleware:
|
||||
cached_status = LicenseKey.cached_summary()
|
||||
if not cached_status:
|
||||
return True
|
||||
if cached_status.status == LicenseUsageStatus.READ_ONLY:
|
||||
if cached_status.read_only:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -54,10 +53,10 @@ class EnterpriseMiddleware:
|
||||
if request.method.lower() in ["get", "head", "options", "trace"]:
|
||||
return True
|
||||
# Always allow requests to manage licenses
|
||||
if request.resolver_match._func_path == class_to_path(LicenseViewSet):
|
||||
if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
|
||||
return True
|
||||
# Flow executor is mounted as an API path but explicitly allowed
|
||||
if request.resolver_match._func_path == class_to_path(FlowExecutorView):
|
||||
if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
|
||||
return True
|
||||
# Only apply these restrictions to the API
|
||||
if "authentik_api" not in request.resolver_match.app_names:
|
||||
|
@ -1,68 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-08 14:15
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.apps.registry import Apps
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage")
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
for usage in LicenseUsage.objects.using(db_alias).all():
|
||||
usage.status = "valid" if usage.within_limits else "limit_exceeded_admin"
|
||||
usage.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="licenseusage",
|
||||
name="status",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("unlicensed", "Unlicensed"),
|
||||
("valid", "Valid"),
|
||||
("expired", "Expired"),
|
||||
("expiry_soon", "Expiry Soon"),
|
||||
("limit_exceeded_admin", "Limit Exceeded Admin"),
|
||||
("limit_exceeded_user", "Limit Exceeded User"),
|
||||
("read_only", "Read Only"),
|
||||
],
|
||||
default=None,
|
||||
null=True,
|
||||
),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.RunPython(migrate_license_usage),
|
||||
migrations.RemoveField(
|
||||
model_name="licenseusage",
|
||||
name="within_limits",
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="licenseusage",
|
||||
name="status",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("unlicensed", "Unlicensed"),
|
||||
("valid", "Valid"),
|
||||
("expired", "Expired"),
|
||||
("expiry_soon", "Expiry Soon"),
|
||||
("limit_exceeded_admin", "Limit Exceeded Admin"),
|
||||
("limit_exceeded_user", "Limit Exceeded User"),
|
||||
("read_only", "Read Only"),
|
||||
],
|
||||
),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.RenameField(
|
||||
model_name="licenseusage",
|
||||
old_name="user_count",
|
||||
new_name="internal_user_count",
|
||||
),
|
||||
]
|
@ -17,17 +17,6 @@ if TYPE_CHECKING:
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
|
||||
def usage_expiry():
|
||||
"""Keep license usage records for 3 months"""
|
||||
return now() + timedelta(days=30 * 3)
|
||||
|
||||
|
||||
THRESHOLD_WARNING_ADMIN_WEEKS = 2
|
||||
THRESHOLD_WARNING_USER_WEEKS = 4
|
||||
THRESHOLD_WARNING_EXPIRY_WEEKS = 2
|
||||
THRESHOLD_READ_ONLY_WEEKS = 6
|
||||
|
||||
|
||||
class License(SerializerModel):
|
||||
"""An authentik enterprise license"""
|
||||
|
||||
@ -50,7 +39,7 @@ class License(SerializerModel):
|
||||
"""Get parsed license status"""
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
return LicenseKey.validate(self.key, check_expiry=False)
|
||||
return LicenseKey.validate(self.key)
|
||||
|
||||
class Meta:
|
||||
indexes = (HashIndex(fields=("key",)),)
|
||||
@ -58,23 +47,9 @@ class License(SerializerModel):
|
||||
verbose_name_plural = _("Licenses")
|
||||
|
||||
|
||||
class LicenseUsageStatus(models.TextChoices):
|
||||
"""License states an instance/tenant can be in"""
|
||||
|
||||
UNLICENSED = "unlicensed"
|
||||
VALID = "valid"
|
||||
EXPIRED = "expired"
|
||||
EXPIRY_SOON = "expiry_soon"
|
||||
# User limit exceeded, 2 week threshold, show message in admin interface
|
||||
LIMIT_EXCEEDED_ADMIN = "limit_exceeded_admin"
|
||||
# User limit exceeded, 4 week threshold, show message in user interface
|
||||
LIMIT_EXCEEDED_USER = "limit_exceeded_user"
|
||||
READ_ONLY = "read_only"
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Quickly check if a license is valid"""
|
||||
return self in [LicenseUsageStatus.VALID, LicenseUsageStatus.EXPIRY_SOON]
|
||||
def usage_expiry():
|
||||
"""Keep license usage records for 3 months"""
|
||||
return now() + timedelta(days=30 * 3)
|
||||
|
||||
|
||||
class LicenseUsage(ExpiringModel):
|
||||
@ -84,9 +59,9 @@ class LicenseUsage(ExpiringModel):
|
||||
|
||||
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
|
||||
internal_user_count = models.BigIntegerField()
|
||||
user_count = models.BigIntegerField()
|
||||
external_user_count = models.BigIntegerField()
|
||||
status = models.TextField(choices=LicenseUsageStatus.choices)
|
||||
within_limits = models.BooleanField()
|
||||
|
||||
record_date = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
|
@ -13,7 +13,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
|
||||
|
||||
def check_license(self):
|
||||
"""Check license"""
|
||||
if not LicenseKey.get_total().status().is_valid:
|
||||
if not LicenseKey.get_total().is_valid():
|
||||
return PolicyResult(False, _("Enterprise required to access this feature."))
|
||||
if self.request.user.type != UserTypes.INTERNAL:
|
||||
return PolicyResult(False, _("Feature only accessible for internal users."))
|
||||
|
@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import (
|
||||
google_workspace_sync,
|
||||
google_workspace_sync_objects,
|
||||
)
|
||||
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
|
||||
|
||||
@ -55,4 +52,3 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_single_task = google_workspace_sync
|
||||
sync_objects_task = google_workspace_sync_objects
|
||||
|
@ -181,7 +181,7 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-google-workspace-form"
|
||||
return "ak-property-mapping-google-workspace-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import (
|
||||
microsoft_entra_sync,
|
||||
microsoft_entra_sync_objects,
|
||||
)
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
|
||||
|
||||
@ -53,4 +50,3 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_single_task = microsoft_entra_sync
|
||||
sync_objects_task = microsoft_entra_sync_objects
|
||||
|
@ -170,7 +170,7 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-microsoft-entra-form"
|
||||
return "ak-property-mapping-microsoft-entra-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-12 12:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_rac", "0004_alter_connectiontoken_expires"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="racpropertymapping",
|
||||
options={
|
||||
"verbose_name": "RAC Provider Property Mapping",
|
||||
"verbose_name_plural": "RAC Provider Property Mappings",
|
||||
},
|
||||
),
|
||||
]
|
@ -125,7 +125,7 @@ class RACPropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-rac-form"
|
||||
return "ak-property-mapping-rac-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -136,8 +136,8 @@ class RACPropertyMapping(PropertyMapping):
|
||||
return RACPropertyMappingSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Provider Property Mapping")
|
||||
verbose_name_plural = _("RAC Provider Property Mappings")
|
||||
verbose_name = _("RAC Property Mapping")
|
||||
verbose_name_plural = _("RAC Property Mappings")
|
||||
|
||||
|
||||
class ConnectionToken(ExpiringModel):
|
||||
|
@ -44,7 +44,7 @@ websocket_urlpatterns = [
|
||||
|
||||
api_urlpatterns = [
|
||||
("providers/rac", RACProviderViewSet),
|
||||
("propertymappings/provider/rac", RACPropertyMappingViewSet),
|
||||
("propertymappings/rac", RACPropertyMappingViewSet),
|
||||
("rac/endpoints", EndpointViewSet),
|
||||
("rac/connection_tokens", ConnectionTokenViewSet),
|
||||
]
|
||||
|
@ -3,7 +3,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models.signals import post_delete, post_save, pre_save
|
||||
from django.db.models.signals import post_save, pre_save
|
||||
from django.dispatch import receiver
|
||||
from django.utils.timezone import get_current_timezone
|
||||
|
||||
@ -27,9 +27,3 @@ def post_save_license(sender: type[License], instance: License, **_):
|
||||
"""Trigger license usage calculation when license is saved"""
|
||||
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
|
||||
enterprise_update_usage.delay()
|
||||
|
||||
|
||||
@receiver(post_delete, sender=License)
|
||||
def post_delete_license(sender: type[License], instance: License, **_):
|
||||
"""Clear license cache when license is deleted"""
|
||||
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
|
||||
|
@ -9,26 +9,10 @@ from django.utils.timezone import now
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.enterprise.models import (
|
||||
THRESHOLD_READ_ONLY_WEEKS,
|
||||
THRESHOLD_WARNING_ADMIN_WEEKS,
|
||||
THRESHOLD_WARNING_USER_WEEKS,
|
||||
License,
|
||||
LicenseUsage,
|
||||
LicenseUsageStatus,
|
||||
)
|
||||
from authentik.enterprise.models import License
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
# Valid license expiry
|
||||
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
|
||||
# Valid license expiry, expires soon
|
||||
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
|
||||
# Invalid license expiry, recently expired
|
||||
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
|
||||
# Invalid license expiry, expired longer ago
|
||||
expiry_expired_read_only = int(
|
||||
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
|
||||
)
|
||||
_exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
|
||||
|
||||
|
||||
class TestEnterpriseLicense(TestCase):
|
||||
@ -39,7 +23,7 @@ class TestEnterpriseLicense(TestCase):
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
exp=_exp,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
@ -49,7 +33,7 @@ class TestEnterpriseLicense(TestCase):
|
||||
def test_valid(self):
|
||||
"""Check license verification"""
|
||||
lic = License.objects.create(key=generate_id())
|
||||
self.assertTrue(lic.status.status().is_valid)
|
||||
self.assertTrue(lic.status.is_valid())
|
||||
self.assertEqual(lic.internal_users, 100)
|
||||
|
||||
def test_invalid(self):
|
||||
@ -62,7 +46,7 @@ class TestEnterpriseLicense(TestCase):
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
exp=_exp,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
@ -72,186 +56,11 @@ class TestEnterpriseLicense(TestCase):
|
||||
def test_valid_multiple(self):
|
||||
"""Check license verification"""
|
||||
lic = License.objects.create(key=generate_id())
|
||||
self.assertTrue(lic.status.status().is_valid)
|
||||
self.assertTrue(lic.status.is_valid())
|
||||
lic2 = License.objects.create(key=generate_id())
|
||||
self.assertTrue(lic2.status.status().is_valid)
|
||||
self.assertTrue(lic2.status.is_valid())
|
||||
total = LicenseKey.get_total()
|
||||
self.assertEqual(total.internal_users, 200)
|
||||
self.assertEqual(total.external_users, 200)
|
||||
self.assertEqual(total.exp, expiry_valid)
|
||||
self.assertTrue(total.status().is_valid)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_external_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_limit_exceeded_read_only(self):
|
||||
"""Check license verification"""
|
||||
License.objects.create(key=generate_id())
|
||||
usage = LicenseUsage.objects.create(
|
||||
internal_user_count=100,
|
||||
external_user_count=100,
|
||||
status=LicenseUsageStatus.VALID,
|
||||
)
|
||||
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
|
||||
usage.save(update_fields=["record_date"])
|
||||
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_external_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_limit_exceeded_user_warning(self):
|
||||
"""Check license verification"""
|
||||
License.objects.create(key=generate_id())
|
||||
usage = LicenseUsage.objects.create(
|
||||
internal_user_count=100,
|
||||
external_user_count=100,
|
||||
status=LicenseUsageStatus.VALID,
|
||||
)
|
||||
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS + 1)
|
||||
usage.save(update_fields=["record_date"])
|
||||
self.assertEqual(
|
||||
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_external_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_limit_exceeded_admin_warning(self):
|
||||
"""Check license verification"""
|
||||
License.objects.create(key=generate_id())
|
||||
usage = LicenseUsage.objects.create(
|
||||
internal_user_count=100,
|
||||
external_user_count=100,
|
||||
status=LicenseUsageStatus.VALID,
|
||||
)
|
||||
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS + 1)
|
||||
usage.save(update_fields=["record_date"])
|
||||
self.assertEqual(
|
||||
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_expired_read_only,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_expiry_read_only(self):
|
||||
"""Check license verification"""
|
||||
License.objects.create(key=generate_id())
|
||||
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_expired,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_expiry_expired(self):
|
||||
"""Check license verification"""
|
||||
License.objects.create(key=generate_id())
|
||||
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_soon,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_expiry_soon(self):
|
||||
"""Check license verification"""
|
||||
License.objects.create(key=generate_id())
|
||||
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON)
|
||||
self.assertEqual(total.exp, _exp)
|
||||
self.assertTrue(total.is_valid())
|
||||
|
@ -1,217 +0,0 @@
|
||||
"""read only tests"""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.enterprise.models import (
|
||||
THRESHOLD_READ_ONLY_WEEKS,
|
||||
License,
|
||||
LicenseUsage,
|
||||
LicenseUsageStatus,
|
||||
)
|
||||
from authentik.enterprise.tests.test_license import expiry_valid
|
||||
from authentik.flows.models import (
|
||||
FlowDesignation,
|
||||
FlowStageBinding,
|
||||
)
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.stages.identification.models import IdentificationStage, UserFields
|
||||
from authentik.stages.password import BACKEND_INBUILT
|
||||
from authentik.stages.password.models import PasswordStage
|
||||
from authentik.stages.user_login.models import UserLoginStage
|
||||
|
||||
|
||||
class TestReadOnly(FlowTestCase):
|
||||
"""Test read_only"""
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_external_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_login(self):
|
||||
"""Test flow, ensure login is still possible with read only mode"""
|
||||
License.objects.create(key=generate_id())
|
||||
usage = LicenseUsage.objects.create(
|
||||
internal_user_count=100,
|
||||
external_user_count=100,
|
||||
status=LicenseUsageStatus.VALID,
|
||||
)
|
||||
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
|
||||
usage.save(update_fields=["record_date"])
|
||||
|
||||
flow = create_test_flow(
|
||||
FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
|
||||
ident_stage = IdentificationStage.objects.create(
|
||||
name=generate_id(),
|
||||
user_fields=[UserFields.E_MAIL],
|
||||
pretend_user_exists=False,
|
||||
)
|
||||
FlowStageBinding.objects.create(
|
||||
target=flow,
|
||||
stage=ident_stage,
|
||||
order=0,
|
||||
)
|
||||
password_stage = PasswordStage.objects.create(
|
||||
name=generate_id(), backends=[BACKEND_INBUILT]
|
||||
)
|
||||
FlowStageBinding.objects.create(
|
||||
target=flow,
|
||||
stage=password_stage,
|
||||
order=1,
|
||||
)
|
||||
login_stage = UserLoginStage.objects.create(
|
||||
name=generate_id(),
|
||||
)
|
||||
FlowStageBinding.objects.create(
|
||||
target=flow,
|
||||
stage=login_stage,
|
||||
order=2,
|
||||
)
|
||||
|
||||
user = create_test_user()
|
||||
|
||||
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||
response = self.client.get(exec_url)
|
||||
self.assertStageResponse(
|
||||
response,
|
||||
flow,
|
||||
component="ak-stage-identification",
|
||||
password_fields=False,
|
||||
primary_action="Log in",
|
||||
sources=[],
|
||||
show_source_labels=False,
|
||||
user_fields=[UserFields.E_MAIL],
|
||||
)
|
||||
response = self.client.post(exec_url, {"uid_field": user.email}, follow=True)
|
||||
self.assertStageResponse(response, flow, component="ak-stage-password")
|
||||
response = self.client.post(exec_url, {"password": user.username}, follow=True)
|
||||
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_external_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_manage_licenses(self):
|
||||
"""Test that managing licenses is still possible"""
|
||||
license = License.objects.create(key=generate_id())
|
||||
usage = LicenseUsage.objects.create(
|
||||
internal_user_count=100,
|
||||
external_user_count=100,
|
||||
status=LicenseUsageStatus.VALID,
|
||||
)
|
||||
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
|
||||
usage.save(update_fields=["record_date"])
|
||||
|
||||
admin = create_test_admin_user()
|
||||
self.client.force_login(admin)
|
||||
|
||||
# Reading is always allowed
|
||||
response = self.client.get(reverse("authentik_api:license-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Writing should also be allowed
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:license-detail", kwargs={"pk": license.pk})
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=expiry_valid,
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.get_external_user_count",
|
||||
MagicMock(return_value=1000),
|
||||
)
|
||||
@patch(
|
||||
"authentik.enterprise.license.LicenseKey.record_usage",
|
||||
MagicMock(),
|
||||
)
|
||||
def test_manage_flows(self):
|
||||
"""Test flow"""
|
||||
License.objects.create(key=generate_id())
|
||||
usage = LicenseUsage.objects.create(
|
||||
internal_user_count=100,
|
||||
external_user_count=100,
|
||||
status=LicenseUsageStatus.VALID,
|
||||
)
|
||||
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
|
||||
usage.save(update_fields=["record_date"])
|
||||
|
||||
admin = create_test_admin_user()
|
||||
self.client.force_login(admin)
|
||||
|
||||
# Read only is still allowed
|
||||
response = self.client.get(reverse("authentik_api:flow-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
flow = create_test_flow()
|
||||
# Writing is not
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:flow-detail", kwargs={"slug": flow.slug})
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{"detail": "Request denied due to expired/invalid license.", "code": "denied_license"},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
@ -69,5 +69,8 @@ class NotificationViewSet(
|
||||
@action(detail=False, methods=["post"])
|
||||
def mark_all_seen(self, request: Request) -> Response:
|
||||
"""Mark all the user's notifications as seen"""
|
||||
Notification.objects.filter(user=request.user, seen=False).update(seen=True)
|
||||
notifications = Notification.objects.filter(user=request.user)
|
||||
for notification in notifications:
|
||||
notification.seen = True
|
||||
Notification.objects.bulk_update(notifications, ["seen"])
|
||||
return Response({}, status=204)
|
||||
|
@ -1,49 +0,0 @@
|
||||
# Generated by Django 5.0.9 on 2024-09-25 11:06
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_events", "0007_event_authentik_e_action_9a9dd9_idx_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="event",
|
||||
name="action",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("login", "Login"),
|
||||
("login_failed", "Login Failed"),
|
||||
("logout", "Logout"),
|
||||
("user_write", "User Write"),
|
||||
("suspicious_request", "Suspicious Request"),
|
||||
("password_set", "Password Set"),
|
||||
("secret_view", "Secret View"),
|
||||
("secret_rotate", "Secret Rotate"),
|
||||
("invitation_used", "Invite Used"),
|
||||
("authorize_application", "Authorize Application"),
|
||||
("source_linked", "Source Linked"),
|
||||
("impersonation_started", "Impersonation Started"),
|
||||
("impersonation_ended", "Impersonation Ended"),
|
||||
("flow_execution", "Flow Execution"),
|
||||
("policy_execution", "Policy Execution"),
|
||||
("policy_exception", "Policy Exception"),
|
||||
("property_mapping_exception", "Property Mapping Exception"),
|
||||
("system_task_execution", "System Task Execution"),
|
||||
("system_task_exception", "System Task Exception"),
|
||||
("system_exception", "System Exception"),
|
||||
("configuration_error", "Configuration Error"),
|
||||
("model_created", "Model Created"),
|
||||
("model_updated", "Model Updated"),
|
||||
("model_deleted", "Model Deleted"),
|
||||
("email_sent", "Email Sent"),
|
||||
("analytics_sent", "Analytics Sent"),
|
||||
("update_available", "Update Available"),
|
||||
("custom_", "Custom Prefix"),
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
@ -49,7 +49,6 @@ from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
DISCORD_FIELD_LIMIT = 25
|
||||
@ -59,11 +58,7 @@ NOTIFICATION_SUMMARY_LENGTH = 75
|
||||
def default_event_duration():
|
||||
"""Default duration an Event is saved.
|
||||
This is used as a fallback when no brand is available"""
|
||||
try:
|
||||
tenant = get_current_tenant()
|
||||
return now() + timedelta_from_string(tenant.event_retention)
|
||||
except Tenant.DoesNotExist:
|
||||
return now() + timedelta(days=365)
|
||||
return now() + timedelta(days=365)
|
||||
|
||||
|
||||
def default_brand():
|
||||
@ -119,7 +114,6 @@ class EventAction(models.TextChoices):
|
||||
MODEL_DELETED = "model_deleted"
|
||||
EMAIL_SENT = "email_sent"
|
||||
|
||||
ANALYTICS_SENT = "analytics_sent"
|
||||
UPDATE_AVAILABLE = "update_available"
|
||||
|
||||
CUSTOM_PREFIX = "custom_"
|
||||
@ -251,6 +245,12 @@ class Event(SerializerModel, ExpiringModel):
|
||||
if QS_QUERY in self.context["http_request"]["args"]:
|
||||
wrapped = self.context["http_request"]["args"][QS_QUERY]
|
||||
self.context["http_request"]["args"] = cleanse_dict(QueryDict(wrapped))
|
||||
if hasattr(request, "tenant"):
|
||||
tenant: Tenant = request.tenant
|
||||
# Because self.created only gets set on save, we can't use it's value here
|
||||
# hence we set self.created to now and then use it
|
||||
self.created = now()
|
||||
self.expires = self.created + timedelta_from_string(tenant.event_retention)
|
||||
if hasattr(request, "brand"):
|
||||
brand: Brand = request.brand
|
||||
self.brand = sanitize_dict(model_to_dict(brand))
|
||||
|
@ -13,7 +13,7 @@ from authentik.events.apps import SYSTEM_TASK_STATUS
|
||||
from authentik.events.models import Event, EventAction, SystemTask
|
||||
from authentik.events.tasks import event_notification_handler, gdpr_cleanup
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan
|
||||
from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
from authentik.stages.invitation.models import Invitation
|
||||
@ -38,9 +38,6 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
|
||||
# Save the login method used
|
||||
kwargs[PLAN_CONTEXT_METHOD] = flow_plan.context[PLAN_CONTEXT_METHOD]
|
||||
kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||
if PLAN_CONTEXT_OUTPOST in flow_plan.context:
|
||||
# Save outpost context
|
||||
kwargs[PLAN_CONTEXT_OUTPOST] = flow_plan.context[PLAN_CONTEXT_OUTPOST]
|
||||
event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user)
|
||||
request.session[SESSION_LOGIN_EVENT] = event
|
||||
|
||||
|
@ -6,7 +6,6 @@ from django.db.models import Model
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.core.models import default_token_key
|
||||
from authentik.events.models import default_event_duration
|
||||
from authentik.lib.utils.reflection import get_apps
|
||||
|
||||
|
||||
@ -21,7 +20,7 @@ def model_tester_factory(test_model: type[Model]) -> Callable:
|
||||
allowed = 0
|
||||
# Token-like objects need to lookup the current tenant to get the default token length
|
||||
for field in test_model._meta.fields:
|
||||
if field.default in [default_token_key, default_event_duration]:
|
||||
if field.default == default_token_key:
|
||||
allowed += 1
|
||||
with self.assertNumQueries(allowed):
|
||||
str(test_model())
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.models import (
|
||||
@ -11,7 +10,6 @@ from authentik.events.models import (
|
||||
EventAction,
|
||||
Notification,
|
||||
NotificationRule,
|
||||
NotificationSeverity,
|
||||
NotificationTransport,
|
||||
NotificationWebhookMapping,
|
||||
TransportMode,
|
||||
@ -22,7 +20,7 @@ from authentik.policies.exceptions import PolicyException
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
||||
|
||||
class TestEventsNotifications(APITestCase):
|
||||
class TestEventsNotifications(TestCase):
|
||||
"""Test Event Notifications"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
@ -133,15 +131,3 @@ class TestEventsNotifications(APITestCase):
|
||||
Notification.objects.all().delete()
|
||||
Event.new(EventAction.CUSTOM_PREFIX).save()
|
||||
self.assertEqual(Notification.objects.first().body, "foo")
|
||||
|
||||
def test_api_mark_all_seen(self):
|
||||
"""Test mark_all_seen"""
|
||||
self.client.force_login(self.user)
|
||||
|
||||
Notification.objects.create(
|
||||
severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
|
||||
)
|
||||
|
||||
response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
|
||||
self.assertEqual(response.status_code, 204)
|
||||
self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())
|
||||
|
@ -37,7 +37,6 @@ from authentik.lib.utils.file import (
|
||||
)
|
||||
from authentik.lib.views import bad_request_message
|
||||
from authentik.rbac.decorators import permission_required
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -282,7 +281,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
|
||||
400: OpenApiResponse(description="Flow not applicable"),
|
||||
},
|
||||
)
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def execute(self, request: Request, slug: str):
|
||||
"""Execute flow for current user"""
|
||||
# Because we pre-plan the flow here, and not in the planner, we need to manually clear
|
||||
|
@ -23,7 +23,6 @@ from authentik.flows.models import (
|
||||
in_memory_stage,
|
||||
)
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.outposts.models import Outpost
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
|
||||
@ -33,7 +32,6 @@ PLAN_CONTEXT_SSO = "is_sso"
|
||||
PLAN_CONTEXT_REDIRECT = "redirect"
|
||||
PLAN_CONTEXT_APPLICATION = "application"
|
||||
PLAN_CONTEXT_SOURCE = "source"
|
||||
PLAN_CONTEXT_OUTPOST = "outpost"
|
||||
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
||||
# was restored.
|
||||
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
||||
@ -145,23 +143,10 @@ class FlowPlanner:
|
||||
and not request.user.is_superuser
|
||||
):
|
||||
raise FlowNonApplicableException()
|
||||
outpost_user = ClientIPMiddleware.get_outpost_user(request)
|
||||
if self.flow.authentication == FlowAuthenticationRequirement.REQUIRE_OUTPOST:
|
||||
outpost_user = ClientIPMiddleware.get_outpost_user(request)
|
||||
if not outpost_user:
|
||||
raise FlowNonApplicableException()
|
||||
if outpost_user:
|
||||
outpost = Outpost.objects.filter(
|
||||
# TODO: Since Outpost and user are not directly connected, we have to look up a user
|
||||
# like this. This should ideally by in authentik/outposts/models.py
|
||||
pk=outpost_user.username.replace("ak-outpost-", "")
|
||||
).first()
|
||||
if outpost:
|
||||
return {
|
||||
PLAN_CONTEXT_OUTPOST: {
|
||||
"instance": outpost,
|
||||
}
|
||||
}
|
||||
return {}
|
||||
|
||||
def plan(self, request: HttpRequest, default_context: dict[str, Any] | None = None) -> FlowPlan:
|
||||
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
|
||||
@ -174,12 +159,11 @@ class FlowPlanner:
|
||||
self._logger.debug(
|
||||
"f(plan): starting planning process",
|
||||
)
|
||||
context = default_context or {}
|
||||
# Bit of a workaround here, if there is a pending user set in the default context
|
||||
# we use that user for our cache key
|
||||
# to make sure they don't get the generic response
|
||||
if context and PLAN_CONTEXT_PENDING_USER in context:
|
||||
user = context[PLAN_CONTEXT_PENDING_USER]
|
||||
if default_context and PLAN_CONTEXT_PENDING_USER in default_context:
|
||||
user = default_context[PLAN_CONTEXT_PENDING_USER]
|
||||
else:
|
||||
user = request.user
|
||||
# We only need to check the flow authentication if it's planned without a user
|
||||
@ -187,13 +171,14 @@ class FlowPlanner:
|
||||
# or if a flow is restarted due to `invalid_response_action` being set to
|
||||
# `restart_with_context`, which can only happen if the user was already authorized
|
||||
# to use the flow
|
||||
context.update(self._check_authentication(request))
|
||||
self._check_authentication(request)
|
||||
# First off, check the flow's direct policy bindings
|
||||
# to make sure the user even has access to the flow
|
||||
engine = PolicyEngine(self.flow, user, request)
|
||||
engine.use_cache = self.use_cache
|
||||
span.set_data("context", cleanse_dict(context))
|
||||
engine.request.context.update(context)
|
||||
if default_context:
|
||||
span.set_data("default_context", cleanse_dict(default_context))
|
||||
engine.request.context.update(default_context)
|
||||
engine.build()
|
||||
result = engine.result
|
||||
if not result.passing:
|
||||
@ -210,12 +195,12 @@ class FlowPlanner:
|
||||
key=cached_plan_key,
|
||||
)
|
||||
# Reset the context as this isn't factored into caching
|
||||
cached_plan.context = context
|
||||
cached_plan.context = default_context or {}
|
||||
return cached_plan
|
||||
self._logger.debug(
|
||||
"f(plan): building plan",
|
||||
)
|
||||
plan = self._build_plan(user, request, context)
|
||||
plan = self._build_plan(user, request, default_context)
|
||||
if self.use_cache:
|
||||
cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT)
|
||||
if not plan.bindings and not self.allow_empty_flows:
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import re
|
||||
import socket
|
||||
from collections.abc import Iterable
|
||||
from ipaddress import ip_address, ip_network
|
||||
from textwrap import indent
|
||||
from types import CodeType
|
||||
@ -27,12 +28,6 @@ from authentik.stages.authenticator import devices_for_user
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
ARG_SANITIZE = re.compile(r"[:.-]")
|
||||
|
||||
|
||||
def sanitize_arg(arg_name: str) -> str:
|
||||
return re.sub(ARG_SANITIZE, "_", arg_name)
|
||||
|
||||
|
||||
class BaseEvaluator:
|
||||
"""Validate and evaluate python-based expressions"""
|
||||
@ -182,9 +177,9 @@ class BaseEvaluator:
|
||||
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
|
||||
return proc.profiling_wrapper()
|
||||
|
||||
def wrap_expression(self, expression: str) -> str:
|
||||
def wrap_expression(self, expression: str, params: Iterable[str]) -> str:
|
||||
"""Wrap expression in a function, call it, and save the result as `result`"""
|
||||
handler_signature = ",".join(sanitize_arg(x) for x in self._context.keys())
|
||||
handler_signature = ",".join(params)
|
||||
full_expression = ""
|
||||
full_expression += f"def handler({handler_signature}):\n"
|
||||
full_expression += indent(expression, " ")
|
||||
@ -193,8 +188,8 @@ class BaseEvaluator:
|
||||
|
||||
def compile(self, expression: str) -> CodeType:
|
||||
"""Parse expression. Raises SyntaxError or ValueError if the syntax is incorrect."""
|
||||
expression = self.wrap_expression(expression)
|
||||
return compile(expression, self._filename, "exec")
|
||||
param_keys = self._context.keys()
|
||||
return compile(self.wrap_expression(expression, param_keys), self._filename, "exec")
|
||||
|
||||
def evaluate(self, expression_source: str) -> Any:
|
||||
"""Parse and evaluate expression. If the syntax is incorrect, a SyntaxError is raised.
|
||||
@ -210,7 +205,7 @@ class BaseEvaluator:
|
||||
self.handle_error(exc, expression_source)
|
||||
raise exc
|
||||
try:
|
||||
_locals = {sanitize_arg(x): y for x, y in self._context.items()}
|
||||
_locals = self._context
|
||||
# Yes this is an exec, yes it is potentially bad. Since we limit what variables are
|
||||
# available here, and these policies can only be edited by admins, this is a risk
|
||||
# we're willing to take.
|
||||
|
@ -1,19 +1,16 @@
|
||||
from celery import Task
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.utils.text import slugify
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, ChoiceField
|
||||
from rest_framework.fields import BooleanField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.events.logs import LogEvent, LogEventSerializer
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
|
||||
class SyncStatusSerializer(PassiveSerializer):
|
||||
@ -23,29 +20,10 @@ class SyncStatusSerializer(PassiveSerializer):
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class SyncObjectSerializer(PassiveSerializer):
|
||||
"""Sync object serializer"""
|
||||
|
||||
sync_object_model = ChoiceField(
|
||||
choices=(
|
||||
(class_to_path(User), "user"),
|
||||
(class_to_path(Group), "group"),
|
||||
)
|
||||
)
|
||||
sync_object_id = CharField()
|
||||
|
||||
|
||||
class SyncObjectResultSerializer(PassiveSerializer):
|
||||
"""Result of a single object sync"""
|
||||
|
||||
messages = LogEventSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class OutgoingSyncProviderStatusMixin:
|
||||
"""Common API Endpoints for Outgoing sync providers"""
|
||||
|
||||
sync_single_task: type[Task] = None
|
||||
sync_objects_task: type[Task] = None
|
||||
sync_single_task: Callable = None
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
@ -58,7 +36,7 @@ class OutgoingSyncProviderStatusMixin:
|
||||
detail=True,
|
||||
pagination_class=None,
|
||||
url_path="sync/status",
|
||||
filter_backends=[ObjectFilter],
|
||||
filter_backends=[],
|
||||
)
|
||||
def sync_status(self, request: Request, pk: int) -> Response:
|
||||
"""Get provider's sync status"""
|
||||
@ -77,30 +55,6 @@ class OutgoingSyncProviderStatusMixin:
|
||||
}
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
|
||||
@extend_schema(
|
||||
request=SyncObjectSerializer,
|
||||
responses={200: SyncObjectResultSerializer()},
|
||||
)
|
||||
@action(
|
||||
methods=["POST"],
|
||||
detail=True,
|
||||
pagination_class=None,
|
||||
url_path="sync/object",
|
||||
filter_backends=[ObjectFilter],
|
||||
)
|
||||
def sync_object(self, request: Request, pk: int) -> Response:
|
||||
"""Sync/Re-sync a single user/group object"""
|
||||
provider: OutgoingSyncProvider = self.get_object()
|
||||
params = SyncObjectSerializer(data=request.data)
|
||||
params.is_valid(raise_exception=True)
|
||||
res: list[LogEvent] = self.sync_objects_task.delay(
|
||||
params.validated_data["sync_object_model"],
|
||||
page=1,
|
||||
provider_pk=provider.pk,
|
||||
pk=params.validated_data["sync_object_id"],
|
||||
).get()
|
||||
return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
|
||||
|
||||
|
||||
class OutgoingSyncConnectionCreateMixin:
|
||||
"""Mixin for connection objects that fetches remote data upon creation"""
|
||||
|
@ -105,7 +105,7 @@ class SyncTasks:
|
||||
return
|
||||
task.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
def sync_objects(self, object_type: str, page: int, provider_pk: int, **filter):
|
||||
def sync_objects(self, object_type: str, page: int, provider_pk: int):
|
||||
_object_type = path_to_class(object_type)
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
@ -120,7 +120,7 @@ class SyncTasks:
|
||||
client = provider.client_for_model(_object_type)
|
||||
except TransientSyncException:
|
||||
return messages
|
||||
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
|
||||
paginator = Paginator(provider.get_object_qs(_object_type), PAGE_SIZE)
|
||||
if client.can_discover:
|
||||
self.logger.debug("starting discover")
|
||||
client.discover()
|
||||
|
@ -26,6 +26,7 @@ from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
|
||||
from authentik.outposts.models import (
|
||||
Outpost,
|
||||
OutpostConfig,
|
||||
OutpostState,
|
||||
OutpostType,
|
||||
default_outpost_config,
|
||||
)
|
||||
@ -139,7 +140,7 @@ class OutpostHealthSerializer(PassiveSerializer):
|
||||
|
||||
def get_fips_enabled(self, obj: dict) -> bool | None:
|
||||
"""Get FIPS enabled"""
|
||||
if not LicenseKey.get_total().status().is_valid:
|
||||
if not LicenseKey.get_total().is_valid():
|
||||
return None
|
||||
return obj["fips_enabled"]
|
||||
|
||||
@ -181,6 +182,7 @@ class OutpostViewSet(UsedByMixin, ModelViewSet):
|
||||
outpost: Outpost = self.get_object()
|
||||
states = []
|
||||
for state in outpost.state:
|
||||
state: OutpostState
|
||||
states.append(
|
||||
{
|
||||
"uid": state.uid,
|
||||
|
@ -26,7 +26,6 @@ from authentik.outposts.models import (
|
||||
KubernetesServiceConnection,
|
||||
OutpostServiceConnection,
|
||||
)
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
|
||||
class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer):
|
||||
@ -76,7 +75,7 @@ class ServiceConnectionViewSet(
|
||||
filterset_fields = ["name"]
|
||||
|
||||
@extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
|
||||
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
|
||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||
def state(self, request: Request, pk: str) -> Response:
|
||||
"""Get the service connection's state"""
|
||||
connection = self.get_object()
|
||||
|
@ -451,7 +451,7 @@ class OutpostState:
|
||||
return False
|
||||
if self.build_hash != get_build_hash():
|
||||
return False
|
||||
return parse(self.version) != OUR_VERSION
|
||||
return parse(self.version) < OUR_VERSION
|
||||
|
||||
@staticmethod
|
||||
def for_outpost(outpost: Outpost) -> list["OutpostState"]:
|
||||
|
@ -214,7 +214,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
|
||||
if not hasattr(instance, field_name):
|
||||
continue
|
||||
|
||||
LOGGER.debug("triggering outpost update from field", field=field.name)
|
||||
LOGGER.debug("triggering outpost update from from field", field=field.name)
|
||||
# Because the Outpost Model has an M2M to Provider,
|
||||
# we have to iterate over the entire QS
|
||||
for reverse in getattr(instance, field_name).all():
|
||||
|
@ -1,52 +0,0 @@
|
||||
# Generated by Django 5.0.9 on 2024-09-25 11:06
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_policies_event_matcher", "0023_alter_eventmatcherpolicy_action_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="eventmatcherpolicy",
|
||||
name="action",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("login", "Login"),
|
||||
("login_failed", "Login Failed"),
|
||||
("logout", "Logout"),
|
||||
("user_write", "User Write"),
|
||||
("suspicious_request", "Suspicious Request"),
|
||||
("password_set", "Password Set"),
|
||||
("secret_view", "Secret View"),
|
||||
("secret_rotate", "Secret Rotate"),
|
||||
("invitation_used", "Invite Used"),
|
||||
("authorize_application", "Authorize Application"),
|
||||
("source_linked", "Source Linked"),
|
||||
("impersonation_started", "Impersonation Started"),
|
||||
("impersonation_ended", "Impersonation Ended"),
|
||||
("flow_execution", "Flow Execution"),
|
||||
("policy_execution", "Policy Execution"),
|
||||
("policy_exception", "Policy Exception"),
|
||||
("property_mapping_exception", "Property Mapping Exception"),
|
||||
("system_task_execution", "System Task Execution"),
|
||||
("system_task_exception", "System Task Exception"),
|
||||
("system_exception", "System Exception"),
|
||||
("configuration_error", "Configuration Error"),
|
||||
("model_created", "Model Created"),
|
||||
("model_updated", "Model Updated"),
|
||||
("model_deleted", "Model Deleted"),
|
||||
("email_sent", "Email Sent"),
|
||||
("analytics_sent", "Analytics Sent"),
|
||||
("update_available", "Update Available"),
|
||||
("custom_", "Custom Prefix"),
|
||||
],
|
||||
default=None,
|
||||
help_text="Match created events with this action type. When left empty, all action types will be matched.",
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
]
|
@ -36,7 +36,7 @@ def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
if not created:
|
||||
reputation.score = F("score") + amount
|
||||
reputation.save()
|
||||
LOGGER.info("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
|
||||
LOGGER.debug("Updated score", amount=amount, for_user=identifier, for_ip=remote_ip)
|
||||
|
||||
|
||||
@receiver(login_failed)
|
||||
|
@ -2,25 +2,15 @@
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.query import Q
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django_filters.filters import BooleanFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, ListField, SerializerMethodField
|
||||
from rest_framework.fields import CharField, ListField, SerializerMethodField
|
||||
from rest_framework.mixins import ListModelMixin
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import Application
|
||||
from authentik.policies.api.exec import PolicyTestResultSerializer
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.types import PolicyResult
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.providers.ldap.models import LDAPProvider
|
||||
|
||||
|
||||
@ -33,6 +23,7 @@ class LDAPProviderSerializer(ProviderSerializer):
|
||||
model = LDAPProvider
|
||||
fields = ProviderSerializer.Meta.fields + [
|
||||
"base_dn",
|
||||
"search_group",
|
||||
"certificate",
|
||||
"tls_server_name",
|
||||
"uid_start_number",
|
||||
@ -64,6 +55,8 @@ class LDAPProviderFilter(FilterSet):
|
||||
"name": ["iexact"],
|
||||
"authorization_flow__slug": ["iexact"],
|
||||
"base_dn": ["iexact"],
|
||||
"search_group__group_uuid": ["iexact"],
|
||||
"search_group__name": ["iexact"],
|
||||
"certificate__kp_uuid": ["iexact"],
|
||||
"certificate__name": ["iexact"],
|
||||
"tls_server_name": ["iexact"],
|
||||
@ -102,6 +95,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
|
||||
"base_dn",
|
||||
"bind_flow_slug",
|
||||
"application_slug",
|
||||
"search_group",
|
||||
"certificate",
|
||||
"tls_server_name",
|
||||
"uid_start_number",
|
||||
@ -122,33 +116,3 @@ class LDAPOutpostConfigViewSet(ListModelMixin, GenericViewSet):
|
||||
ordering = ["name"]
|
||||
search_fields = ["name"]
|
||||
filterset_fields = ["name"]
|
||||
|
||||
class LDAPCheckAccessSerializer(PassiveSerializer):
|
||||
has_search_permission = BooleanField(required=False)
|
||||
access = PolicyTestResultSerializer()
|
||||
|
||||
@extend_schema(
|
||||
request=None,
|
||||
parameters=[OpenApiParameter("app_slug", OpenApiTypes.STR)],
|
||||
responses={
|
||||
200: LDAPCheckAccessSerializer(),
|
||||
},
|
||||
operation_id="outposts_ldap_access_check",
|
||||
)
|
||||
@action(detail=True)
|
||||
def check_access(self, request: Request, pk) -> Response:
|
||||
"""Check access to a single application by slug"""
|
||||
provider = get_object_or_404(LDAPProvider, pk=pk)
|
||||
application = get_object_or_404(Application, slug=request.query_params["app_slug"])
|
||||
engine = PolicyEngine(application, request.user, request)
|
||||
engine.use_cache = False
|
||||
engine.build()
|
||||
result = engine.result
|
||||
access_response = PolicyResult(result.passing)
|
||||
response = self.LDAPCheckAccessSerializer(
|
||||
instance={
|
||||
"has_search_permission": request.user.has_perm("search_full_directory", provider),
|
||||
"access": access_response,
|
||||
}
|
||||
)
|
||||
return Response(response.data)
|
||||
|
@ -1,66 +0,0 @@
|
||||
# Generated by Django 5.0.7 on 2024-07-25 14:59
|
||||
from django.apps.registry import Apps
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def migrate_search_group(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
from authentik.core.models import User
|
||||
from django.apps import apps as real_apps
|
||||
from django.contrib.auth.management import create_permissions
|
||||
from guardian.shortcuts import UserObjectPermission
|
||||
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
# Permissions are only created _after_ migrations are run
|
||||
# - https://github.com/django/django/blob/43cdfa8b20e567a801b7d0a09ec67ddd062d5ea4/django/contrib/auth/apps.py#L19
|
||||
# - https://stackoverflow.com/a/72029063/1870445
|
||||
create_permissions(real_apps.get_app_config("authentik_providers_ldap"), using=db_alias)
|
||||
|
||||
LDAPProvider = apps.get_model("authentik_providers_ldap", "ldapprovider")
|
||||
Permission = apps.get_model("auth", "Permission")
|
||||
UserObjectPermission = apps.get_model("guardian", "UserObjectPermission")
|
||||
ContentType = apps.get_model("contenttypes", "ContentType")
|
||||
|
||||
new_prem = Permission.objects.using(db_alias).get(codename="search_full_directory")
|
||||
ct = ContentType.objects.using(db_alias).get(
|
||||
app_label="authentik_providers_ldap",
|
||||
model="ldapprovider",
|
||||
)
|
||||
|
||||
for provider in LDAPProvider.objects.using(db_alias).all():
|
||||
if not provider.search_group:
|
||||
continue
|
||||
for user in provider.search_group.users.using(db_alias).all():
|
||||
UserObjectPermission.objects.using(db_alias).create(
|
||||
user=user,
|
||||
permission=new_prem,
|
||||
object_pk=provider.pk,
|
||||
content_type=ct,
|
||||
)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_ldap", "0003_ldapprovider_mfa_support_and_more"),
|
||||
("guardian", "0002_generic_permissions_index"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="ldapprovider",
|
||||
options={
|
||||
"permissions": [("search_full_directory", "Search full LDAP directory")],
|
||||
"verbose_name": "LDAP Provider",
|
||||
"verbose_name_plural": "LDAP Providers",
|
||||
},
|
||||
),
|
||||
migrations.RunPython(migrate_search_group),
|
||||
migrations.RemoveField(
|
||||
model_name="ldapprovider",
|
||||
name="search_group",
|
||||
),
|
||||
]
|
@ -7,7 +7,7 @@ from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import BackchannelProvider
|
||||
from authentik.core.models import BackchannelProvider, Group
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.outposts.models import OutpostModel
|
||||
|
||||
@ -27,6 +27,17 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
||||
help_text=_("DN under which objects are accessible."),
|
||||
)
|
||||
|
||||
search_group = models.ForeignKey(
|
||||
Group,
|
||||
null=True,
|
||||
default=None,
|
||||
on_delete=models.SET_DEFAULT,
|
||||
help_text=_(
|
||||
"Users in this group can do search queries. "
|
||||
"If not set, every user can execute search queries."
|
||||
),
|
||||
)
|
||||
|
||||
tls_server_name = models.TextField(
|
||||
default="",
|
||||
blank=True,
|
||||
@ -102,6 +113,3 @@ class LDAPProvider(OutpostModel, BackchannelProvider):
|
||||
class Meta:
|
||||
verbose_name = _("LDAP Provider")
|
||||
verbose_name_plural = _("LDAP Providers")
|
||||
permissions = [
|
||||
("search_full_directory", _("Search full LDAP directory")),
|
||||
]
|
||||
|
@ -105,7 +105,7 @@ class ScopeMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-scope-form"
|
||||
return "ak-property-mapping-scope-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
|
@ -62,7 +62,7 @@ urlpatterns = [
|
||||
|
||||
api_urlpatterns = [
|
||||
("providers/oauth2", OAuth2ProviderViewSet),
|
||||
("propertymappings/provider/scope", ScopeMappingViewSet),
|
||||
("propertymappings/scope", ScopeMappingViewSet),
|
||||
("oauth2/authorization_codes", AuthorizationCodeViewSet),
|
||||
("oauth2/refresh_tokens", RefreshTokenViewSet),
|
||||
("oauth2/access_tokens", AccessTokenViewSet),
|
||||
|
@ -433,21 +433,20 @@ class TokenParams:
|
||||
app = Application.objects.filter(provider=self.provider).first()
|
||||
if not app or not app.provider:
|
||||
raise TokenError("invalid_grant")
|
||||
with audit_ignore():
|
||||
self.user, _ = User.objects.update_or_create(
|
||||
# trim username to ensure the entire username is max 150 chars
|
||||
# (22 chars being the length of the "template")
|
||||
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
|
||||
defaults={
|
||||
"attributes": {
|
||||
USER_ATTRIBUTE_GENERATED: True,
|
||||
},
|
||||
"last_login": timezone.now(),
|
||||
"name": f"Autogenerated user from application {app.name} (client credentials)",
|
||||
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
|
||||
"type": UserTypes.SERVICE_ACCOUNT,
|
||||
self.user, _ = User.objects.update_or_create(
|
||||
# trim username to ensure the entire username is max 150 chars
|
||||
# (22 chars being the length of the "template")
|
||||
username=f"ak-{self.provider.name[:150-22]}-client_credentials",
|
||||
defaults={
|
||||
"attributes": {
|
||||
USER_ATTRIBUTE_GENERATED: True,
|
||||
},
|
||||
)
|
||||
"last_login": timezone.now(),
|
||||
"name": f"Autogenerated user from application {app.name} (client credentials)",
|
||||
"path": f"{USER_PATH_SYSTEM_PREFIX}/apps/{app.slug}",
|
||||
"type": UserTypes.SERVICE_ACCOUNT,
|
||||
},
|
||||
)
|
||||
self.__check_policy_access(app, request)
|
||||
|
||||
Event.new(
|
||||
|
@ -28,7 +28,7 @@ class ProxyDockerController(DockerController):
|
||||
labels = super()._get_labels()
|
||||
labels["traefik.enable"] = "true"
|
||||
labels[f"traefik.http.routers.{traefik_name}-router.rule"] = (
|
||||
f"({' || '.join([f'Host({host})' for host in hosts])})"
|
||||
f"({' || '.join([f'Host(`{host}`)' for host in hosts])})"
|
||||
f" && PathPrefix(`/outpost.goauthentik.io`)"
|
||||
)
|
||||
labels[f"traefik.http.routers.{traefik_name}-router.tls"] = "true"
|
||||
|
@ -154,7 +154,6 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
|
||||
responses={
|
||||
200: RadiusCheckAccessSerializer(),
|
||||
},
|
||||
operation_id="outposts_radius_access_check",
|
||||
)
|
||||
@action(detail=True)
|
||||
def check_access(self, request: Request, pk) -> Response:
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-12 12:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_radius", "0003_radiusproviderpropertymapping"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="radiusproviderpropertymapping",
|
||||
options={
|
||||
"verbose_name": "Radius Provider Property Mapping",
|
||||
"verbose_name_plural": "Radius Provider Property Mappings",
|
||||
},
|
||||
),
|
||||
]
|
@ -70,7 +70,7 @@ class RadiusProviderPropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-radius-form"
|
||||
return "ak-property-mapping-radius-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -81,8 +81,8 @@ class RadiusProviderPropertyMapping(PropertyMapping):
|
||||
return RadiusProviderPropertyMappingSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"Radius Provider Property Mapping {self.name}"
|
||||
return f"Radius Property Mapping {self.name}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Radius Provider Property Mapping")
|
||||
verbose_name_plural = _("Radius Provider Property Mappings")
|
||||
verbose_name = _("Radius Property Mapping")
|
||||
verbose_name_plural = _("Radius Property Mappings")
|
||||
|
@ -7,7 +7,7 @@ from authentik.providers.radius.api.providers import (
|
||||
)
|
||||
|
||||
api_urlpatterns = [
|
||||
("propertymappings/provider/radius", RadiusProviderPropertyMappingViewSet),
|
||||
("propertymappings/radius", RadiusProviderPropertyMappingViewSet),
|
||||
("outposts/radius", RadiusOutpostConfigViewSet, "radiusprovideroutpost"),
|
||||
("providers/radius", RadiusProviderViewSet),
|
||||
]
|
||||
|
@ -133,17 +133,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return "-"
|
||||
|
||||
def validate(self, attrs: dict):
|
||||
if attrs.get("signing_kp"):
|
||||
if not attrs.get("sign_assertion") and not attrs.get("sign_response"):
|
||||
raise ValidationError(
|
||||
_(
|
||||
"With a signing keypair selected, at least one of 'Sign assertion' "
|
||||
"and 'Sign Response' must be selected."
|
||||
)
|
||||
)
|
||||
return super().validate(attrs)
|
||||
|
||||
class Meta:
|
||||
model = SAMLProvider
|
||||
fields = ProviderSerializer.Meta.fields + [
|
||||
@ -159,9 +148,6 @@ class SAMLProviderSerializer(ProviderSerializer):
|
||||
"signature_algorithm",
|
||||
"signing_kp",
|
||||
"verification_kp",
|
||||
"encryption_kp",
|
||||
"sign_assertion",
|
||||
"sign_response",
|
||||
"sp_binding",
|
||||
"default_relay_state",
|
||||
"url_download_metadata",
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-12 12:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_providers_saml", "0014_alter_samlprovider_digest_algorithm_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="samlpropertymapping",
|
||||
options={
|
||||
"verbose_name": "SAML Provider Property Mapping",
|
||||
"verbose_name_plural": "SAML Provider Property Mappings",
|
||||
},
|
||||
),
|
||||
]
|
@ -1,39 +0,0 @@
|
||||
# Generated by Django 5.0.8 on 2024-08-15 14:52
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_crypto", "0004_alter_certificatekeypair_name"),
|
||||
("authentik_providers_saml", "0015_alter_samlpropertymapping_options"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="samlprovider",
|
||||
name="encryption_kp",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
help_text="When selected, incoming assertions are encrypted by the IdP using the public key of the encryption keypair. The assertion is decrypted by the SP using the the private key.",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="+",
|
||||
to="authentik_crypto.certificatekeypair",
|
||||
verbose_name="Encryption Keypair",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="samlprovider",
|
||||
name="sign_assertion",
|
||||
field=models.BooleanField(default=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="samlprovider",
|
||||
name="sign_response",
|
||||
field=models.BooleanField(default=False),
|
||||
),
|
||||
]
|
@ -144,28 +144,11 @@ class SAMLProvider(Provider):
|
||||
on_delete=models.SET_NULL,
|
||||
verbose_name=_("Signing Keypair"),
|
||||
)
|
||||
encryption_kp = models.ForeignKey(
|
||||
CertificateKeyPair,
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text=_(
|
||||
"When selected, incoming assertions are encrypted by the IdP using the public "
|
||||
"key of the encryption keypair. The assertion is decrypted by the SP using the "
|
||||
"the private key."
|
||||
),
|
||||
on_delete=models.SET_NULL,
|
||||
verbose_name=_("Encryption Keypair"),
|
||||
related_name="+",
|
||||
)
|
||||
|
||||
default_relay_state = models.TextField(
|
||||
default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins")
|
||||
)
|
||||
|
||||
sign_assertion = models.BooleanField(default=True)
|
||||
sign_response = models.BooleanField(default=False)
|
||||
|
||||
@property
|
||||
def launch_url(self) -> str | None:
|
||||
"""Use IDP-Initiated SAML flow as launch URL"""
|
||||
@ -208,7 +191,7 @@ class SAMLPropertyMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-provider-saml-form"
|
||||
return "ak-property-mapping-saml-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
@ -221,8 +204,8 @@ class SAMLPropertyMapping(PropertyMapping):
|
||||
return f"{self.name} ({name})"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("SAML Provider Property Mapping")
|
||||
verbose_name_plural = _("SAML Provider Property Mappings")
|
||||
verbose_name = _("SAML Property Mapping")
|
||||
verbose_name_plural = _("SAML Property Mappings")
|
||||
|
||||
|
||||
class SAMLProviderImportModel(CreatableType, Provider):
|
||||
|
@ -18,11 +18,7 @@ from authentik.providers.saml.processors.authn_request_parser import AuthNReques
|
||||
from authentik.providers.saml.utils import get_random_id
|
||||
from authentik.providers.saml.utils.time import get_time_string
|
||||
from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME
|
||||
from authentik.sources.saml.exceptions import (
|
||||
InvalidEncryption,
|
||||
InvalidSignature,
|
||||
UnsupportedNameIDFormat,
|
||||
)
|
||||
from authentik.sources.saml.exceptions import InvalidSignature, UnsupportedNameIDFormat
|
||||
from authentik.sources.saml.processors.constants import (
|
||||
DIGEST_ALGORITHM_TRANSLATION_MAP,
|
||||
NS_MAP,
|
||||
@ -260,17 +256,9 @@ class AssertionProcessor:
|
||||
assertion,
|
||||
xmlsec.constants.TransformExclC14N,
|
||||
sign_algorithm_transform,
|
||||
ns=xmlsec.constants.DSigNs,
|
||||
ns="ds", # type: ignore
|
||||
)
|
||||
assertion.append(signature)
|
||||
if self.provider.encryption_kp:
|
||||
encryption = xmlsec.template.encrypted_data_create(
|
||||
assertion,
|
||||
xmlsec.constants.TransformAes128Cbc,
|
||||
self._assertion_id,
|
||||
ns=xmlsec.constants.DSigNs,
|
||||
)
|
||||
assertion.append(encryption)
|
||||
|
||||
assertion.append(self.get_assertion_subject())
|
||||
assertion.append(self.get_assertion_conditions())
|
||||
@ -298,86 +286,41 @@ class AssertionProcessor:
|
||||
response.append(self.get_assertion())
|
||||
return response
|
||||
|
||||
def _sign(self, element: Element):
|
||||
"""Sign an XML element based on the providers' configured signing settings"""
|
||||
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
|
||||
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
|
||||
)
|
||||
xmlsec.tree.add_ids(element, ["ID"])
|
||||
signature_node = xmlsec.tree.find_node(element, xmlsec.constants.NodeSignature)
|
||||
ref = xmlsec.template.add_reference(
|
||||
signature_node,
|
||||
digest_algorithm_transform,
|
||||
uri="#" + self._assertion_id,
|
||||
)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
|
||||
key_info = xmlsec.template.ensure_key_info(signature_node)
|
||||
xmlsec.template.add_x509_data(key_info)
|
||||
|
||||
ctx = xmlsec.SignatureContext()
|
||||
|
||||
key = xmlsec.Key.from_memory(
|
||||
self.provider.signing_kp.key_data,
|
||||
xmlsec.constants.KeyDataFormatPem,
|
||||
None,
|
||||
)
|
||||
key.load_cert_from_memory(
|
||||
self.provider.signing_kp.certificate_data,
|
||||
xmlsec.constants.KeyDataFormatCertPem,
|
||||
)
|
||||
ctx.key = key
|
||||
try:
|
||||
ctx.sign(signature_node)
|
||||
except xmlsec.Error as exc:
|
||||
raise InvalidSignature() from exc
|
||||
|
||||
def _encrypt(self, element: Element, parent: Element):
|
||||
"""Encrypt SAMLResponse EncryptedAssertion Element"""
|
||||
manager = xmlsec.KeysManager()
|
||||
key = xmlsec.Key.from_memory(
|
||||
self.provider.encryption_kp.key_data,
|
||||
xmlsec.constants.KeyDataFormatPem,
|
||||
)
|
||||
key.load_cert_from_memory(
|
||||
self.provider.encryption_kp.certificate_data,
|
||||
xmlsec.constants.KeyDataFormatCertPem,
|
||||
)
|
||||
|
||||
manager.add_key(key)
|
||||
encryption_context = xmlsec.EncryptionContext(manager)
|
||||
encryption_context.key = xmlsec.Key.generate(
|
||||
xmlsec.constants.KeyDataAes, 128, xmlsec.constants.KeyDataTypeSession
|
||||
)
|
||||
|
||||
container = SubElement(parent, f"{{{NS_SAML_ASSERTION}}}EncryptedAssertion")
|
||||
enc_data = xmlsec.template.encrypted_data_create(
|
||||
container, xmlsec.Transform.AES128, type=xmlsec.EncryptionType.ELEMENT, ns="xenc"
|
||||
)
|
||||
xmlsec.template.encrypted_data_ensure_cipher_value(enc_data)
|
||||
key_info = xmlsec.template.encrypted_data_ensure_key_info(enc_data, ns="ds")
|
||||
enc_key = xmlsec.template.add_encrypted_key(key_info, xmlsec.Transform.RSA_OAEP)
|
||||
xmlsec.template.encrypted_data_ensure_cipher_value(enc_key)
|
||||
|
||||
try:
|
||||
enc_data = encryption_context.encrypt_xml(enc_data, element)
|
||||
except xmlsec.Error as exc:
|
||||
raise InvalidEncryption() from exc
|
||||
|
||||
parent.remove(enc_data)
|
||||
container.append(enc_data)
|
||||
|
||||
def build_response(self) -> str:
|
||||
"""Build string XML Response and sign if signing is enabled."""
|
||||
root_response = self.get_response()
|
||||
if self.provider.signing_kp:
|
||||
if self.provider.sign_assertion:
|
||||
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
|
||||
self._sign(assertion)
|
||||
if self.provider.sign_response:
|
||||
response = root_response.xpath("//samlp:Response", namespaces=NS_MAP)[0]
|
||||
self._sign(response)
|
||||
if self.provider.encryption_kp:
|
||||
digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get(
|
||||
self.provider.digest_algorithm, xmlsec.constants.TransformSha1
|
||||
)
|
||||
assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0]
|
||||
self._encrypt(assertion, root_response)
|
||||
xmlsec.tree.add_ids(assertion, ["ID"])
|
||||
signature_node = xmlsec.tree.find_node(assertion, xmlsec.constants.NodeSignature)
|
||||
ref = xmlsec.template.add_reference(
|
||||
signature_node,
|
||||
digest_algorithm_transform,
|
||||
uri="#" + self._assertion_id,
|
||||
)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped)
|
||||
xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N)
|
||||
key_info = xmlsec.template.ensure_key_info(signature_node)
|
||||
xmlsec.template.add_x509_data(key_info)
|
||||
|
||||
ctx = xmlsec.SignatureContext()
|
||||
|
||||
key = xmlsec.Key.from_memory(
|
||||
self.provider.signing_kp.key_data,
|
||||
xmlsec.constants.KeyDataFormatPem,
|
||||
None,
|
||||
)
|
||||
key.load_cert_from_memory(
|
||||
self.provider.signing_kp.certificate_data,
|
||||
xmlsec.constants.KeyDataFormatCertPem,
|
||||
)
|
||||
ctx.key = key
|
||||
try:
|
||||
ctx.sign(signature_node)
|
||||
except xmlsec.Error as exc:
|
||||
raise InvalidSignature() from exc
|
||||
|
||||
return etree.tostring(root_response).decode("utf-8") # nosec
|
||||
|
@ -126,7 +126,7 @@ class MetadataProcessor:
|
||||
entity_descriptor,
|
||||
xmlsec.constants.TransformExclC14N,
|
||||
sign_algorithm_transform,
|
||||
ns=xmlsec.constants.DSigNs,
|
||||
ns="ds", # type: ignore
|
||||
)
|
||||
entity_descriptor.append(signature)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.flows.models import FlowDesignation
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.tests.utils import load_fixture
|
||||
@ -29,52 +29,12 @@ class TestSAMLProviderAPI(APITestCase):
|
||||
name=generate_id(),
|
||||
authorization_flow=create_test_flow(),
|
||||
)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
Application.objects.create(name=generate_id(), provider=provider, slug=generate_id())
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:samlprovider-detail", kwargs={"pk": provider.pk}),
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
def test_create_validate_signing_kp(self):
|
||||
"""Test create"""
|
||||
cert = create_test_cert()
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:samlprovider-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"authorization_flow": create_test_flow().pk,
|
||||
"acs_url": "http://localhost",
|
||||
"signing_kp": cert.pk,
|
||||
},
|
||||
)
|
||||
self.assertEqual(400, response.status_code)
|
||||
self.assertJSONEqual(
|
||||
response.content,
|
||||
{
|
||||
"non_field_errors": [
|
||||
(
|
||||
"With a signing keypair selected, at least one "
|
||||
"of 'Sign assertion' and 'Sign Response' must be selected."
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
response = self.client.post(
|
||||
reverse("authentik_api:samlprovider-list"),
|
||||
data={
|
||||
"name": generate_id(),
|
||||
"authorization_flow": create_test_flow().pk,
|
||||
"acs_url": "http://localhost",
|
||||
"signing_kp": cert.pk,
|
||||
"sign_assertion": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(201, response.status_code)
|
||||
|
||||
def test_metadata(self):
|
||||
"""Test metadata export (normal)"""
|
||||
self.client.logout()
|
||||
|
@ -78,12 +78,12 @@ class TestAuthNRequest(TestCase):
|
||||
|
||||
@apply_blueprint("system/providers-saml.yaml")
|
||||
def setUp(self):
|
||||
self.cert = create_test_cert()
|
||||
cert = create_test_cert()
|
||||
self.provider: SAMLProvider = SAMLProvider.objects.create(
|
||||
authorization_flow=create_test_flow(),
|
||||
acs_url="http://testserver/source/saml/provider/acs/",
|
||||
signing_kp=self.cert,
|
||||
verification_kp=self.cert,
|
||||
signing_kp=cert,
|
||||
verification_kp=cert,
|
||||
)
|
||||
self.provider.property_mappings.set(SAMLPropertyMapping.objects.all())
|
||||
self.provider.save()
|
||||
@ -91,8 +91,8 @@ class TestAuthNRequest(TestCase):
|
||||
slug="provider",
|
||||
issuer="authentik",
|
||||
pre_authentication_flow=create_test_flow(),
|
||||
signing_kp=self.cert,
|
||||
verification_kp=self.cert,
|
||||
signing_kp=cert,
|
||||
verification_kp=cert,
|
||||
)
|
||||
|
||||
def test_signed_valid(self):
|
||||
@ -112,34 +112,7 @@ class TestAuthNRequest(TestCase):
|
||||
self.assertEqual(parsed_request.id, request_proc.request_id)
|
||||
self.assertEqual(parsed_request.relay_state, "test_state")
|
||||
|
||||
def test_request_encrypt(self):
|
||||
"""Test full SAML Request/Response flow, fully encrypted"""
|
||||
self.provider.encryption_kp = self.cert
|
||||
self.provider.save()
|
||||
self.source.encryption_kp = self.cert
|
||||
self.source.save()
|
||||
http_request = get_request("/")
|
||||
|
||||
# First create an AuthNRequest
|
||||
request_proc = RequestProcessor(self.source, http_request, "test_state")
|
||||
request = request_proc.build_auth_n()
|
||||
|
||||
# To get an assertion we need a parsed request (parsed by provider)
|
||||
parsed_request = AuthNRequestParser(self.provider).parse(
|
||||
b64encode(request.encode()).decode(), "test_state"
|
||||
)
|
||||
# Now create a response and convert it to string (provider)
|
||||
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
|
||||
response = response_proc.build_response()
|
||||
|
||||
# Now parse the response (source)
|
||||
http_request.POST = QueryDict(mutable=True)
|
||||
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
|
||||
|
||||
response_parser = ResponseProcessor(self.source, http_request)
|
||||
response_parser.parse()
|
||||
|
||||
def test_request_signed(self):
|
||||
def test_request_full_signed(self):
|
||||
"""Test full SAML Request/Response flow, fully signed"""
|
||||
http_request = get_request("/")
|
||||
|
||||
@ -162,32 +135,6 @@ class TestAuthNRequest(TestCase):
|
||||
response_parser = ResponseProcessor(self.source, http_request)
|
||||
response_parser.parse()
|
||||
|
||||
def test_request_signed_both(self):
|
||||
"""Test full SAML Request/Response flow, fully signed"""
|
||||
self.provider.sign_assertion = True
|
||||
self.provider.sign_response = True
|
||||
self.provider.save()
|
||||
http_request = get_request("/")
|
||||
|
||||
# First create an AuthNRequest
|
||||
request_proc = RequestProcessor(self.source, http_request, "test_state")
|
||||
request = request_proc.build_auth_n()
|
||||
|
||||
# To get an assertion we need a parsed request (parsed by provider)
|
||||
parsed_request = AuthNRequestParser(self.provider).parse(
|
||||
b64encode(request.encode()).decode(), "test_state"
|
||||
)
|
||||
# Now create a response and convert it to string (provider)
|
||||
response_proc = AssertionProcessor(self.provider, http_request, parsed_request)
|
||||
response = response_proc.build_response()
|
||||
|
||||
# Now parse the response (source)
|
||||
http_request.POST = QueryDict(mutable=True)
|
||||
http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()
|
||||
|
||||
response_parser = ResponseProcessor(self.source, http_request)
|
||||
response_parser.parse()
|
||||
|
||||
def test_request_id_invalid(self):
|
||||
"""Test generated AuthNRequest with invalid request ID"""
|
||||
http_request = get_request("/")
|
||||
|
@ -54,11 +54,7 @@ class TestServiceProviderMetadataParser(TestCase):
|
||||
request = self.factory.get("/")
|
||||
metadata = lxml_from_string(MetadataProcessor(provider, request).build_entity_descriptor())
|
||||
|
||||
schema = etree.XMLSchema(
|
||||
etree.parse(
|
||||
source="schemas/saml-schema-metadata-2.0.xsd", parser=etree.XMLParser()
|
||||
) # nosec
|
||||
)
|
||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-metadata-2.0.xsd")) # nosec
|
||||
self.assertTrue(schema.validate(metadata))
|
||||
|
||||
def test_schema_want_authn_requests_signed(self):
|
||||
|
@ -47,9 +47,7 @@ class TestSchema(TestCase):
|
||||
|
||||
metadata = lxml_from_string(request)
|
||||
|
||||
schema = etree.XMLSchema(
|
||||
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
|
||||
)
|
||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
|
||||
self.assertTrue(schema.validate(metadata))
|
||||
|
||||
def test_response_schema(self):
|
||||
@ -70,7 +68,5 @@ class TestSchema(TestCase):
|
||||
|
||||
metadata = lxml_from_string(response)
|
||||
|
||||
schema = etree.XMLSchema(
|
||||
etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec
|
||||
)
|
||||
schema = etree.XMLSchema(etree.parse("schemas/saml-schema-protocol-2.0.xsd")) # nosec
|
||||
self.assertTrue(schema.validate(metadata))
|
||||
|
@ -44,6 +44,6 @@ urlpatterns = [
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
("propertymappings/provider/saml", SAMLPropertyMappingViewSet),
|
||||
("propertymappings/saml", SAMLPropertyMappingViewSet),
|
||||
("providers/saml", SAMLProviderViewSet),
|
||||
]
|
||||
|
@ -6,7 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
|
||||
from authentik.providers.scim.models import SCIMProvider
|
||||
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects
|
||||
from authentik.providers.scim.tasks import scim_sync
|
||||
|
||||
|
||||
class SCIMProviderSerializer(ProviderSerializer):
|
||||
@ -42,4 +42,3 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
|
||||
search_fields = ["name", "url"]
|
||||
ordering = ["name", "url"]
|
||||
sync_single_task = scim_sync
|
||||
sync_objects_task = scim_sync_objects
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Group client"""
|
||||
|
||||
from itertools import batched
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydanticscim.group import GroupMember
|
||||
from pydanticscim.responses import PatchOp, PatchOperation
|
||||
@ -58,22 +56,17 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
if not scim_group.externalId:
|
||||
scim_group.externalId = str(obj.pk)
|
||||
|
||||
if not self._config.patch.supported:
|
||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||
connections = SCIMProviderUser.objects.filter(
|
||||
provider=self.provider, user__pk__in=users
|
||||
)
|
||||
members = []
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.scim_id,
|
||||
)
|
||||
users = list(obj.users.order_by("id").values_list("id", flat=True))
|
||||
connections = SCIMProviderUser.objects.filter(provider=self.provider, user__pk__in=users)
|
||||
members = []
|
||||
for user in connections:
|
||||
members.append(
|
||||
GroupMember(
|
||||
value=user.scim_id,
|
||||
)
|
||||
if members:
|
||||
scim_group.members = members
|
||||
else:
|
||||
del scim_group.members
|
||||
)
|
||||
if members:
|
||||
scim_group.members = members
|
||||
return scim_group
|
||||
|
||||
def delete(self, obj: Group):
|
||||
@ -100,19 +93,16 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
scim_id = response.get("id")
|
||||
if not scim_id or scim_id == "":
|
||||
raise StopSync("SCIM Response with missing or invalid `id`")
|
||||
connection = SCIMProviderGroup.objects.create(
|
||||
return SCIMProviderGroup.objects.create(
|
||||
provider=self.provider, group=group, scim_id=scim_id
|
||||
)
|
||||
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||
self._patch_add_users(group, users)
|
||||
return connection
|
||||
|
||||
def update(self, group: Group, connection: SCIMProviderGroup):
|
||||
"""Update existing group"""
|
||||
scim_group = self.to_schema(group, connection)
|
||||
scim_group.id = connection.scim_id
|
||||
try:
|
||||
self._request(
|
||||
return self._request(
|
||||
"PUT",
|
||||
f"/Groups/{connection.scim_id}",
|
||||
json=scim_group.model_dump(
|
||||
@ -120,8 +110,6 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
users = list(group.users.order_by("id").values_list("id", flat=True))
|
||||
return self._patch_add_users(group, users)
|
||||
except NotFoundSyncException:
|
||||
# Resource missing is handled by self.write, which will re-create the group
|
||||
raise
|
||||
@ -164,18 +152,14 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
group_id: str,
|
||||
*ops: PatchOperation,
|
||||
):
|
||||
chunk_size = self._config.bulk.maxOperations
|
||||
if chunk_size < 1:
|
||||
chunk_size = len(ops)
|
||||
for chunk in batched(ops, chunk_size):
|
||||
req = PatchRequest(Operations=list(chunk))
|
||||
self._request(
|
||||
"PATCH",
|
||||
f"/Groups/{group_id}",
|
||||
json=req.model_dump(
|
||||
mode="json",
|
||||
),
|
||||
)
|
||||
req = PatchRequest(Operations=ops)
|
||||
self._request(
|
||||
"PATCH",
|
||||
f"/Groups/{group_id}",
|
||||
json=req.model_dump(
|
||||
mode="json",
|
||||
),
|
||||
)
|
||||
|
||||
def _patch_add_users(self, group: Group, users_set: set[int]):
|
||||
"""Add users in users_set to group"""
|
||||
@ -196,14 +180,11 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
*[
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
value=[{"value": x}],
|
||||
)
|
||||
for x in user_ids
|
||||
],
|
||||
PatchOperation(
|
||||
op=PatchOp.add,
|
||||
path="members",
|
||||
value=[{"value": x} for x in user_ids],
|
||||
),
|
||||
)
|
||||
|
||||
def _patch_remove_users(self, group: Group, users_set: set[int]):
|
||||
@ -225,12 +206,9 @@ class SCIMGroupClient(SCIMClient[Group, SCIMProviderGroup, SCIMGroupSchema]):
|
||||
return
|
||||
self._patch(
|
||||
scim_group.scim_id,
|
||||
*[
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
||||
value=[{"value": x}],
|
||||
)
|
||||
for x in user_ids
|
||||
],
|
||||
PatchOperation(
|
||||
op=PatchOp.remove,
|
||||
path="members",
|
||||
value=[{"value": x} for x in user_ids],
|
||||
),
|
||||
)
|
||||
|
@ -1,11 +1,9 @@
|
||||
"""Custom SCIM schemas"""
|
||||
|
||||
from pydantic import Field
|
||||
from pydanticscim.group import Group as BaseGroup
|
||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
||||
from pydanticscim.responses import SCIMError as BaseSCIMError
|
||||
from pydanticscim.service_provider import Bulk as BaseBulk
|
||||
from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import Bulk, ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import (
|
||||
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
||||
)
|
||||
@ -31,16 +29,10 @@ class Group(BaseGroup):
|
||||
meta: dict | None = None
|
||||
|
||||
|
||||
class Bulk(BaseBulk):
|
||||
|
||||
maxOperations: int = Field()
|
||||
|
||||
|
||||
class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||
"""ServiceProviderConfig with fallback"""
|
||||
|
||||
_is_fallback: bool | None = False
|
||||
bulk: Bulk = Field(..., description="A complex type that specifies bulk configuration options.")
|
||||
|
||||
@property
|
||||
def is_fallback(self) -> bool:
|
||||
@ -53,7 +45,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||
"""Get default configuration, which doesn't support any optional features as fallback"""
|
||||
return ServiceProviderConfiguration(
|
||||
patch=Patch(supported=False),
|
||||
bulk=Bulk(supported=False, maxOperations=0),
|
||||
bulk=Bulk(supported=False),
|
||||
filter=Filter(supported=False),
|
||||
changePassword=ChangePassword(supported=False),
|
||||
sort=Sort(supported=False),
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user