Compare commits
280 Commits
import-org
...
celery-2-d
Author | SHA1 | Date | |
---|---|---|---|
80eb56b016 | |||
3a1d0fbd35 | |||
c85471575a | |||
5d00dc7e9e | |||
6982e7d1c9 | |||
c7fe987c5a | |||
e48739c8a0 | |||
b2ee585c43 | |||
97e8ea8e76 | |||
1f1e0c9db1 | |||
ca47a803fe | |||
c606eb53b0 | |||
c2dc38e804 | |||
9aad7e29a4 | |||
62357133b0 | |||
99d2d91257 | |||
69d9363fce | |||
315f9073cc | |||
c94fa13826 | |||
6b3fbb0abf | |||
06191d6cfc | |||
c608dc5110 | |||
bb05d6063d | |||
c5b4a630c9 | |||
4a73c65710 | |||
51e645c26a | |||
de3607b9a6 | |||
fb5804b914 | |||
287f41fed8 | |||
172b595a9f | |||
d61995d0e2 | |||
cfc7f6b993 | |||
056460dac0 | |||
bebbbe9b90 | |||
188d3c69c1 | |||
877f312145 | |||
f471a98bc7 | |||
b30b736a71 | |||
d3aa43ced0 | |||
92c837f7b5 | |||
7fcd65e318 | |||
e874cfc21d | |||
ec7bdf74aa | |||
9c01c7d890 | |||
bdb0564d4c | |||
e87bc94b95 | |||
09688f7a55 | |||
31bb995490 | |||
5d813438f0 | |||
a3865abaa9 | |||
7100d3c674 | |||
d92764789f | |||
c702fa1f95 | |||
786bada7d0 | |||
c0c2d2ad3c | |||
dc287989db | |||
03204f6943 | |||
fcd369e466 | |||
cb79407bc1 | |||
04a88daf34 | |||
c6a49da5c3 | |||
bfeeecf3fa | |||
690766d377 | |||
4990abdf4a | |||
76616cf1c5 | |||
48c0b5449e | |||
8956606564 | |||
0908d0b559 | |||
7e5c90fcdc | |||
4e210b8299 | |||
bf65ca4c70 | |||
670a88659e | |||
0ebbaeea6f | |||
49a9911271 | |||
f3c0ca1a59 | |||
64e7bff16c | |||
f264989f9e | |||
8f4353181e | |||
813b7aa8ba | |||
3bb3e0d1ef | |||
101e5adeba | |||
ae228c91e3 | |||
411c52491e | |||
85869806a2 | |||
d132db475e | |||
13b5aa604b | |||
97a5acdff5 | |||
ea38f2d120 | |||
94867aaebf | |||
0e67c1d818 | |||
2a460201bb | |||
f99cb3e9fb | |||
e4bd05f444 | |||
80c4eb9bef | |||
96b4d5aee4 | |||
6321537c8d | |||
43975ec231 | |||
9b13922fc2 | |||
031456629b | |||
2433ed1c9b | |||
9868d54320 | |||
747a3ed6e9 | |||
527e849ce2 | |||
cfcd54ca19 | |||
faed9cd66e | |||
897d0dbcbd | |||
a12e991798 | |||
e5b86c3578 | |||
07ff433134 | |||
21b3e0c8cb | |||
cbdec236dd | |||
2509ccde1c | |||
7e7b33dba7 | |||
13e1e44626 | |||
e634f23fc8 | |||
8554a8e0c5 | |||
b80abffafc | |||
204f21699e | |||
0fd478fa3e | |||
7d7e47e972 | |||
92a33a408f | |||
d18a54e9e6 | |||
e6614a0705 | |||
4c491cf221 | |||
17434c84cf | |||
234fb2a0c6 | |||
00612f921d | |||
8b67015190 | |||
5a5176e21f | |||
8980282a02 | |||
2ca9edb1bc | |||
61d970cda4 | |||
16fd9cab67 | |||
8c7818a252 | |||
374779102a | |||
0ac854458a | |||
1cfaddf49d | |||
5ae69f5987 | |||
c62e1d5457 | |||
5b8681b1af | |||
e0dcade9ad | |||
1a6ab7f24b | |||
769844314c | |||
e211604860 | |||
7ed711e8f0 | |||
196b276345 | |||
3c62c80ff1 | |||
a031f1107a | |||
8f399bba3f | |||
e354e877ea | |||
f254b8cf8c | |||
814b06322a | |||
217063ef7b | |||
c2f7883a5c | |||
bd64c34787 | |||
7518d4391f | |||
e67bd79c66 | |||
2fc6da53c1 | |||
250a98cf98 | |||
f2926fa1eb | |||
5e2af4a740 | |||
41f2ca42cc | |||
7ef547b357 | |||
1a9c529e92 | |||
75d19bfe76 | |||
7f8f7376e0 | |||
7c49de9cba | |||
00ac9b6367 | |||
0e786f7040 | |||
03d363ba84 | |||
3f33519ec0 | |||
cae03beb6d | |||
e4c1e5aed0 | |||
5acdd67cba | |||
40dbac7a65 | |||
1b4ed02959 | |||
a95e730cdb | |||
d8c13159e1 | |||
5f951ca3ef | |||
338da72622 | |||
90debcdd70 | |||
3766ca86e8 | |||
59c8472628 | |||
293616e6b0 | |||
f7305d58b1 | |||
ba94f63705 | |||
06b2e0d14b | |||
80a5f44491 | |||
aca0bde46d | |||
e671811ad2 | |||
3140325493 | |||
6c0b879b30 | |||
0e0fb37dd7 | |||
d2cacdc640 | |||
e65fabf040 | |||
107b96e65c | |||
5d7ba51872 | |||
3037701a14 | |||
66f8377c79 | |||
86f81d92aa | |||
369437f2a1 | |||
4b8b80f1d4 | |||
f839aef33a | |||
eb87e30076 | |||
4302f91028 | |||
b0af20b0d5 | |||
9b556cf4c4 | |||
7118219544 | |||
475600ea87 | |||
2139e0be05 | |||
a43a0f77fb | |||
8a073e8c60 | |||
35640fcdfa | |||
c62f73400a | |||
c92cbd7e22 | |||
c3b0d09e04 | |||
c5a40fced3 | |||
9cc6ebabc1 | |||
e89659fe71 | |||
3c1512028d | |||
c7f80686de | |||
8a8386cfcb | |||
e60165ee45 | |||
bc6085adc7 | |||
d413e2875c | |||
144986f48e | |||
c9f1e34beb | |||
39f769b150 | |||
78180e376f | |||
d51150102c | |||
d5da16ad26 | |||
2b12e32fcf | |||
a9b9661155 | |||
f5f0cef275 | |||
b756965511 | |||
db900c4a42 | |||
72cb62085b | |||
5a42815850 | |||
b9083a906a | |||
04be734c49 | |||
1ed6cf7517 | |||
d6c4f97158 | |||
781704fa38 | |||
28f4d7d566 | |||
991778b2be | |||
9465dafd7d | |||
75c13a8801 | |||
8ae0f145f5 | |||
4d0e0e3afe | |||
7aeb874ded | |||
ffc695f7b8 | |||
93cb621af3 | |||
3a34680196 | |||
2335a3130a | |||
0bc4b69f52 | |||
43c5c1276d | |||
a3ebfd9bbd | |||
af5b894e62 | |||
c982066235 | |||
1f6c1522b6 | |||
bae83ba78e | |||
0d0aeab4ee | |||
7fe91339ad | |||
44dea1d208 | |||
6d3be40022 | |||
07773f92a0 | |||
dbc4a2b730 | |||
df15a78aac | |||
61b517edfa | |||
082b342f65 | |||
9a536ee4b9 | |||
677f04cab2 | |||
3ddc35cddc | |||
ae211226ef | |||
6662611347 | |||
c4b988c632 | |||
2b1ee8cd5c | |||
e8cfc2b91e | |||
de54404ab7 | |||
f8c3b64274 |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2025.6.2
|
||||
current_version = 2025.6.3
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||
|
@ -38,6 +38,8 @@ jobs:
|
||||
# Needed for attestation
|
||||
id-token: write
|
||||
attestations: write
|
||||
# Needed for checkout
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: docker/setup-qemu-action@v3.6.0
|
||||
|
1
.github/workflows/ci-main-daily.yml
vendored
1
.github/workflows/ci-main-daily.yml
vendored
@ -9,6 +9,7 @@ on:
|
||||
|
||||
jobs:
|
||||
test-container:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
4
.github/workflows/ci-main.yml
vendored
4
.github/workflows/ci-main.yml
vendored
@ -247,11 +247,13 @@ jobs:
|
||||
# Needed for attestation
|
||||
id-token: write
|
||||
attestations: write
|
||||
# Needed for checkout
|
||||
contents: read
|
||||
needs: ci-core-mark
|
||||
uses: ./.github/workflows/_reusable-docker-build.yaml
|
||||
secrets: inherit
|
||||
with:
|
||||
image_name: ghcr.io/goauthentik/dev-server
|
||||
image_name: ${{ github.repository == 'goauthentik/authentik-internal' && 'ghcr.io/goauthentik/internal-server' || 'ghcr.io/goauthentik/dev-server' }}
|
||||
release: false
|
||||
pr-comment:
|
||||
needs:
|
||||
|
1
.github/workflows/ci-outpost.yml
vendored
1
.github/workflows/ci-outpost.yml
vendored
@ -59,6 +59,7 @@ jobs:
|
||||
with:
|
||||
jobs: ${{ toJSON(needs) }}
|
||||
build-container:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
timeout-minutes: 120
|
||||
needs:
|
||||
- ci-outpost-mark
|
||||
|
2
.github/workflows/ci-website.yml
vendored
2
.github/workflows/ci-website.yml
vendored
@ -63,6 +63,7 @@ jobs:
|
||||
working-directory: website/
|
||||
run: npm run ${{ matrix.job }}
|
||||
build-container:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
# Needed to upload container images to ghcr.io
|
||||
@ -122,3 +123,4 @@ jobs:
|
||||
- uses: re-actors/alls-green@release/v1
|
||||
with:
|
||||
jobs: ${{ toJSON(needs) }}
|
||||
allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }}
|
||||
|
21
.github/workflows/repo-mirror-cleanup.yml
vendored
Normal file
21
.github/workflows/repo-mirror-cleanup.yml
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
name: "authentik-repo-mirror-cleanup"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
to_internal:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- if: ${{ env.MIRROR_KEY != '' }}
|
||||
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb
|
||||
with:
|
||||
target_repo_url: git@github.com:goauthentik/authentik-internal.git
|
||||
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }}
|
||||
args: --tags --force --prune
|
||||
env:
|
||||
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}
|
9
.github/workflows/repo-mirror.yml
vendored
9
.github/workflows/repo-mirror.yml
vendored
@ -11,11 +11,10 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- if: ${{ env.MIRROR_KEY != '' }}
|
||||
uses: pixta-dev/repository-mirroring-action@v1
|
||||
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb
|
||||
with:
|
||||
target_repo_url:
|
||||
git@github.com:goauthentik/authentik-internal.git
|
||||
ssh_private_key:
|
||||
${{ secrets.GH_MIRROR_KEY }}
|
||||
target_repo_url: git@github.com:goauthentik/authentik-internal.git
|
||||
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }}
|
||||
args: --tags --force
|
||||
env:
|
||||
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}
|
||||
|
@ -16,6 +16,7 @@ env:
|
||||
|
||||
jobs:
|
||||
compile:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- id: generate_token
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -100,9 +100,6 @@ ipython_config.py
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
@ -166,8 +163,6 @@ dmypy.json
|
||||
|
||||
# pyenv
|
||||
|
||||
# celery beat schedule file
|
||||
|
||||
# SageMath parsed files
|
||||
|
||||
# Environments
|
||||
|
@ -75,7 +75,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
|
||||
# Stage 4: Download uv
|
||||
FROM ghcr.io/astral-sh/uv:0.7.15 AS uv
|
||||
FROM ghcr.io/astral-sh/uv:0.7.17 AS uv
|
||||
# Stage 5: Base python image
|
||||
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
|
||||
|
||||
@ -122,6 +122,7 @@ ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec"
|
||||
|
||||
RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
|
||||
--mount=type=bind,target=uv.lock,src=uv.lock \
|
||||
--mount=type=bind,target=packages,src=packages \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
@ -167,6 +168,7 @@ COPY ./blueprints /blueprints
|
||||
COPY ./lifecycle/ /lifecycle
|
||||
COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf
|
||||
COPY --from=go-builder /go/authentik /bin/authentik
|
||||
COPY ./packages/ /ak-root/packages
|
||||
COPY --from=python-deps /ak-root/.venv /ak-root/.venv
|
||||
COPY --from=node-builder /work/web/dist/ /web/dist/
|
||||
COPY --from=node-builder /work/web/authentik/ /web/authentik/
|
||||
|
2
Makefile
2
Makefile
@ -6,7 +6,7 @@ PWD = $(shell pwd)
|
||||
UID = $(shell id -u)
|
||||
GID = $(shell id -g)
|
||||
NPM_VERSION = $(shell python -m scripts.generate_semver)
|
||||
PY_SOURCES = authentik tests scripts lifecycle .github
|
||||
PY_SOURCES = authentik packages tests scripts lifecycle .github
|
||||
DOCKER_IMAGE ?= "authentik:test"
|
||||
|
||||
GEN_API_TS = gen-ts-api
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from os import environ
|
||||
|
||||
__version__ = "2025.6.2"
|
||||
__version__ = "2025.6.3"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ class VersionSerializer(PassiveSerializer):
|
||||
return __version__
|
||||
version_in_cache = cache.get(VERSION_CACHE_KEY)
|
||||
if not version_in_cache: # pragma: no cover
|
||||
update_latest_version.delay()
|
||||
update_latest_version.send()
|
||||
return __version__
|
||||
return version_in_cache
|
||||
|
||||
|
@ -1,57 +0,0 @@
|
||||
"""authentik administration overview"""
|
||||
|
||||
from socket import gethostname
|
||||
|
||||
from django.conf import settings
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from packaging.version import parse
|
||||
from rest_framework.fields import BooleanField, CharField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from authentik import get_full_version
|
||||
from authentik.rbac.permissions import HasPermission
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
class WorkerView(APIView):
|
||||
"""Get currently connected worker count."""
|
||||
|
||||
permission_classes = [HasPermission("authentik_rbac.view_system_info")]
|
||||
|
||||
@extend_schema(
|
||||
responses=inline_serializer(
|
||||
"Worker",
|
||||
fields={
|
||||
"worker_id": CharField(),
|
||||
"version": CharField(),
|
||||
"version_matching": BooleanField(),
|
||||
},
|
||||
many=True,
|
||||
)
|
||||
)
|
||||
def get(self, request: Request) -> Response:
|
||||
"""Get currently connected worker count."""
|
||||
raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5)
|
||||
our_version = parse(get_full_version())
|
||||
response = []
|
||||
for worker in raw:
|
||||
key = list(worker.keys())[0]
|
||||
version = worker[key].get("version")
|
||||
version_matching = False
|
||||
if version:
|
||||
version_matching = parse(version) == our_version
|
||||
response.append(
|
||||
{"worker_id": key, "version": version, "version_matching": version_matching}
|
||||
)
|
||||
# In debug we run with `task_always_eager`, so tasks are ran on the main process
|
||||
if settings.DEBUG: # pragma: no cover
|
||||
response.append(
|
||||
{
|
||||
"worker_id": f"authentik-debug@{gethostname()}",
|
||||
"version": get_full_version(),
|
||||
"version_matching": True,
|
||||
}
|
||||
)
|
||||
return Response(response)
|
@ -3,6 +3,9 @@
|
||||
from prometheus_client import Info
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
PROM_INFO = Info("authentik_version", "Currently running authentik version")
|
||||
|
||||
@ -30,3 +33,15 @@ class AuthentikAdminConfig(ManagedAppConfig):
|
||||
notification_version = notification.event.context["new_version"]
|
||||
if LOCAL_VERSION >= parse(notification_version):
|
||||
notification.delete()
|
||||
|
||||
@property
|
||||
def global_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.admin.tasks import update_latest_version
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=update_latest_version,
|
||||
crontab=f"{fqdn_rand('admin_latest_version')} * * * *",
|
||||
paused=CONFIG.get_bool("disable_update_check"),
|
||||
),
|
||||
]
|
||||
|
@ -1,15 +0,0 @@
|
||||
"""authentik admin settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
from django_tenants.utils import get_public_schema_name
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"admin_latest_version": {
|
||||
"task": "authentik.admin.tasks.update_latest_version",
|
||||
"schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"),
|
||||
"tenant_schemas": [get_public_schema_name()],
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
}
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
"""admin signals"""
|
||||
|
||||
from django.dispatch import receiver
|
||||
from packaging.version import parse
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from authentik import get_full_version
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
|
||||
GAUGE_WORKERS = Gauge(
|
||||
"authentik_admin_workers",
|
||||
"Currently connected workers, their versions and if they are the same version as authentik",
|
||||
["version", "version_matched"],
|
||||
)
|
||||
|
||||
|
||||
_version = parse(get_full_version())
|
||||
|
||||
|
||||
@receiver(monitoring_set)
|
||||
def monitoring_set_workers(sender, **kwargs):
|
||||
"""Set worker gauge"""
|
||||
raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5)
|
||||
worker_version_count = {}
|
||||
for worker in raw:
|
||||
key = list(worker.keys())[0]
|
||||
version = worker[key].get("version")
|
||||
version_matching = False
|
||||
if version:
|
||||
version_matching = parse(version) == _version
|
||||
worker_version_count.setdefault(version, {"count": 0, "matching": version_matching})
|
||||
worker_version_count[version]["count"] += 1
|
||||
for version, stats in worker_version_count.items():
|
||||
GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"])
|
@ -2,6 +2,8 @@
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq import actor
|
||||
from packaging.version import parse
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
@ -9,10 +11,9 @@ from structlog.stdlib import get_logger
|
||||
from authentik import __version__, get_build_hash
|
||||
from authentik.admin.apps import PROM_INFO
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
VERSION_NULL = "0.0.0"
|
||||
@ -32,13 +33,12 @@ def _set_prom_info():
|
||||
)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def update_latest_version(self: SystemTask):
|
||||
"""Update latest version info"""
|
||||
@actor(description=_("Update latest version info."))
|
||||
def update_latest_version():
|
||||
self: Task = CurrentTask.get_task()
|
||||
if CONFIG.get_bool("disable_update_check"):
|
||||
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
|
||||
self.set_status(TaskStatus.WARNING, "Version check disabled.")
|
||||
self.info("Version check disabled.")
|
||||
return
|
||||
try:
|
||||
response = get_http_session().get(
|
||||
@ -48,7 +48,7 @@ def update_latest_version(self: SystemTask):
|
||||
data = response.json()
|
||||
upstream_version = data.get("stable", {}).get("version")
|
||||
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version")
|
||||
self.info("Successfully updated latest Version")
|
||||
_set_prom_info()
|
||||
# Check if upstream version is newer than what we're running,
|
||||
# and if no event exists yet, create one.
|
||||
@ -71,7 +71,7 @@ def update_latest_version(self: SystemTask):
|
||||
).save()
|
||||
except (RequestException, IndexError) as exc:
|
||||
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
|
||||
self.set_error(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
_set_prom_info()
|
||||
|
@ -29,13 +29,6 @@ class TestAdminAPI(TestCase):
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["version_current"], __version__)
|
||||
|
||||
def test_workers(self):
|
||||
"""Test Workers API"""
|
||||
response = self.client.get(reverse("authentik_api:admin_workers"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(len(body), 0)
|
||||
|
||||
def test_apps(self):
|
||||
"""Test apps API"""
|
||||
response = self.client.get(reverse("authentik_api:apps-list"))
|
||||
|
@ -30,7 +30,7 @@ class TestAdminTasks(TestCase):
|
||||
"""Test Update checker with valid response"""
|
||||
with Mocker() as mocker, CONFIG.patch("disable_update_check", False):
|
||||
mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID)
|
||||
update_latest_version.delay().get()
|
||||
update_latest_version.send()
|
||||
self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999")
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
@ -40,7 +40,7 @@ class TestAdminTasks(TestCase):
|
||||
).exists()
|
||||
)
|
||||
# test that a consecutive check doesn't create a duplicate event
|
||||
update_latest_version.delay().get()
|
||||
update_latest_version.send()
|
||||
self.assertEqual(
|
||||
len(
|
||||
Event.objects.filter(
|
||||
@ -56,7 +56,7 @@ class TestAdminTasks(TestCase):
|
||||
"""Test Update checker with invalid response"""
|
||||
with Mocker() as mocker:
|
||||
mocker.get("https://version.goauthentik.io/version.json", status_code=400)
|
||||
update_latest_version.delay().get()
|
||||
update_latest_version.send()
|
||||
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
|
||||
self.assertFalse(
|
||||
Event.objects.filter(
|
||||
@ -67,14 +67,15 @@ class TestAdminTasks(TestCase):
|
||||
def test_version_disabled(self):
|
||||
"""Test Update checker while its disabled"""
|
||||
with CONFIG.patch("disable_update_check", True):
|
||||
update_latest_version.delay().get()
|
||||
update_latest_version.send()
|
||||
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
|
||||
|
||||
def test_clear_update_notifications(self):
|
||||
"""Test clear of previous notification"""
|
||||
admin_config = apps.get_app_config("authentik_admin")
|
||||
Event.objects.create(
|
||||
action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"}
|
||||
action=EventAction.UPDATE_AVAILABLE,
|
||||
context={"new_version": "99999999.9999999.9999999"},
|
||||
)
|
||||
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"})
|
||||
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={})
|
||||
|
@ -6,13 +6,11 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet
|
||||
from authentik.admin.api.system import SystemView
|
||||
from authentik.admin.api.version import VersionView
|
||||
from authentik.admin.api.version_history import VersionHistoryViewSet
|
||||
from authentik.admin.api.workers import WorkerView
|
||||
|
||||
api_urlpatterns = [
|
||||
("admin/apps", AppsViewSet, "apps"),
|
||||
("admin/models", ModelViewSet, "models"),
|
||||
path("admin/version/", VersionView.as_view(), name="admin_version"),
|
||||
("admin/version/history", VersionHistoryViewSet, "version_history"),
|
||||
path("admin/workers/", WorkerView.as_view(), name="admin_workers"),
|
||||
path("admin/system/", SystemView.as_view(), name="admin_system"),
|
||||
]
|
||||
|
@ -39,7 +39,7 @@ class BlueprintInstanceSerializer(ModelSerializer):
|
||||
"""Ensure the path (if set) specified is retrievable"""
|
||||
if path == "" or path.startswith(OCI_PREFIX):
|
||||
return path
|
||||
files: list[dict] = blueprints_find_dict.delay().get()
|
||||
files: list[dict] = blueprints_find_dict.send().get_result(block=True)
|
||||
if path not in [file["path"] for file in files]:
|
||||
raise ValidationError(_("Blueprint file does not exist"))
|
||||
return path
|
||||
@ -115,7 +115,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
|
||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||
def available(self, request: Request) -> Response:
|
||||
"""Get blueprints"""
|
||||
files: list[dict] = blueprints_find_dict.delay().get()
|
||||
files: list[dict] = blueprints_find_dict.send().get_result(block=True)
|
||||
return Response(files)
|
||||
|
||||
@permission_required("authentik_blueprints.view_blueprintinstance")
|
||||
@ -129,5 +129,5 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
|
||||
def apply(self, request: Request, *args, **kwargs) -> Response:
|
||||
"""Apply a blueprint"""
|
||||
blueprint = self.get_object()
|
||||
apply_blueprint.delay(str(blueprint.pk)).get()
|
||||
apply_blueprint.send_with_options(args=(blueprint.pk,), rel_obj=blueprint)
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
@ -6,9 +6,12 @@ from inspect import ismethod
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from dramatiq.broker import get_broker
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.root.signals import startup
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
|
||||
class ManagedAppConfig(AppConfig):
|
||||
@ -34,7 +37,7 @@ class ManagedAppConfig(AppConfig):
|
||||
|
||||
def import_related(self):
|
||||
"""Automatically import related modules which rely on just being imported
|
||||
to register themselves (mainly django signals and celery tasks)"""
|
||||
to register themselves (mainly django signals and tasks)"""
|
||||
|
||||
def import_relative(rel_module: str):
|
||||
try:
|
||||
@ -80,6 +83,16 @@ class ManagedAppConfig(AppConfig):
|
||||
func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY
|
||||
return func
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
"""Get a list of schedule specs that must exist in each tenant"""
|
||||
return []
|
||||
|
||||
@property
|
||||
def global_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
"""Get a list of schedule specs that must exist in the default tenant"""
|
||||
return []
|
||||
|
||||
def _reconcile_tenant(self) -> None:
|
||||
"""reconcile ourselves for tenanted methods"""
|
||||
from authentik.tenants.models import Tenant
|
||||
@ -100,8 +113,12 @@ class ManagedAppConfig(AppConfig):
|
||||
"""
|
||||
from django_tenants.utils import get_public_schema_name, schema_context
|
||||
|
||||
with schema_context(get_public_schema_name()):
|
||||
self._reconcile(self.RECONCILE_GLOBAL_CATEGORY)
|
||||
try:
|
||||
with schema_context(get_public_schema_name()):
|
||||
self._reconcile(self.RECONCILE_GLOBAL_CATEGORY)
|
||||
except (DatabaseError, ProgrammingError, InternalError) as exc:
|
||||
self.logger.debug("Failed to access database to run reconcile", exc=exc)
|
||||
return
|
||||
|
||||
|
||||
class AuthentikBlueprintsConfig(ManagedAppConfig):
|
||||
@ -112,19 +129,29 @@ class AuthentikBlueprintsConfig(ManagedAppConfig):
|
||||
verbose_name = "authentik Blueprints"
|
||||
default = True
|
||||
|
||||
@ManagedAppConfig.reconcile_global
|
||||
def load_blueprints_v1_tasks(self):
|
||||
"""Load v1 tasks"""
|
||||
self.import_module("authentik.blueprints.v1.tasks")
|
||||
|
||||
@ManagedAppConfig.reconcile_tenant
|
||||
def blueprints_discovery(self):
|
||||
"""Run blueprint discovery"""
|
||||
from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
|
||||
|
||||
blueprints_discovery.delay()
|
||||
clear_failed_blueprints.delay()
|
||||
|
||||
def import_models(self):
|
||||
super().import_models()
|
||||
self.import_module("authentik.blueprints.v1.meta.apply_blueprint")
|
||||
|
||||
@ManagedAppConfig.reconcile_global
|
||||
def tasks_middlewares(self):
|
||||
from authentik.blueprints.v1.tasks import BlueprintWatcherMiddleware
|
||||
|
||||
get_broker().add_middleware(BlueprintWatcherMiddleware())
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=blueprints_discovery,
|
||||
crontab=f"{fqdn_rand('blueprints_v1_discover')} * * * *",
|
||||
send_on_startup=True,
|
||||
),
|
||||
ScheduleSpec(
|
||||
actor=clear_failed_blueprints,
|
||||
crontab=f"{fqdn_rand('blueprints_v1_cleanup')} * * * *",
|
||||
send_on_startup=True,
|
||||
),
|
||||
]
|
||||
|
@ -3,6 +3,7 @@
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.contenttypes.fields import GenericRelation
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@ -71,6 +72,13 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
||||
enabled = models.BooleanField(default=True)
|
||||
managed_models = ArrayField(models.TextField(), default=list)
|
||||
|
||||
# Manual link to tasks instead of using TasksModel because of loop imports
|
||||
tasks = GenericRelation(
|
||||
"authentik_tasks.Task",
|
||||
content_type_field="rel_obj_content_type",
|
||||
object_id_field="rel_obj_id",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Blueprint Instance")
|
||||
verbose_name_plural = _("Blueprint Instances")
|
||||
|
@ -1,18 +0,0 @@
|
||||
"""blueprint Settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"blueprints_v1_discover": {
|
||||
"task": "authentik.blueprints.v1.tasks.blueprints_discovery",
|
||||
"schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
"blueprints_v1_cleanup": {
|
||||
"task": "authentik.blueprints.v1.tasks.clear_failed_blueprints",
|
||||
"schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
2
authentik/blueprints/tasks.py
Normal file
2
authentik/blueprints/tasks.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Import all v1 tasks for auto task discovery
|
||||
from authentik.blueprints.v1.tasks import * # noqa: F403
|
@ -5,7 +5,6 @@ from collections.abc import Callable
|
||||
from django.apps import apps
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.blueprints.v1.importer import is_model_allowed
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.providers.oauth2.models import RefreshToken
|
||||
|
||||
@ -22,10 +21,13 @@ def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable:
|
||||
return
|
||||
model_class = test_model()
|
||||
self.assertTrue(isinstance(model_class, SerializerModel))
|
||||
# Models that have subclasses don't have to have a serializer
|
||||
if len(test_model.__subclasses__()) > 0:
|
||||
return
|
||||
self.assertIsNotNone(model_class.serializer)
|
||||
if model_class.serializer.Meta().model == RefreshToken:
|
||||
return
|
||||
self.assertEqual(model_class.serializer.Meta().model, test_model)
|
||||
self.assertTrue(issubclass(test_model, model_class.serializer.Meta().model))
|
||||
|
||||
return tester
|
||||
|
||||
@ -34,6 +36,6 @@ for app in apps.get_app_configs():
|
||||
if not app.label.startswith("authentik"):
|
||||
continue
|
||||
for model in app.get_models():
|
||||
if not is_model_allowed(model):
|
||||
if not issubclass(model, SerializerModel):
|
||||
continue
|
||||
setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model))
|
||||
|
@ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
||||
file.seek(0)
|
||||
file_hash = sha512(file.read().encode()).hexdigest()
|
||||
file.flush()
|
||||
blueprints_discovery()
|
||||
blueprints_discovery.send()
|
||||
instance = BlueprintInstance.objects.filter(name=blueprint_id).first()
|
||||
self.assertEqual(instance.last_applied_hash, file_hash)
|
||||
self.assertEqual(
|
||||
@ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
||||
)
|
||||
)
|
||||
file.flush()
|
||||
blueprints_discovery()
|
||||
blueprints_discovery.send()
|
||||
blueprint = BlueprintInstance.objects.filter(name="foo").first()
|
||||
self.assertEqual(
|
||||
blueprint.last_applied_hash,
|
||||
@ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
|
||||
)
|
||||
)
|
||||
file.flush()
|
||||
blueprints_discovery()
|
||||
blueprints_discovery.send()
|
||||
blueprint.refresh_from_db()
|
||||
self.assertEqual(
|
||||
blueprint.last_applied_hash,
|
||||
|
@ -57,7 +57,6 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
|
||||
EndpointDeviceConnection,
|
||||
)
|
||||
from authentik.events.logs import LogEvent, capture_logs
|
||||
from authentik.events.models import SystemTask
|
||||
from authentik.events.utils import cleanse_dict
|
||||
from authentik.flows.models import FlowToken, Stage
|
||||
from authentik.lib.models import SerializerModel
|
||||
@ -77,6 +76,7 @@ from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser
|
||||
from authentik.rbac.models import Role
|
||||
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
|
||||
from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
# Context set when the serializer is created in a blueprint context
|
||||
@ -118,7 +118,7 @@ def excluded_models() -> list[type[Model]]:
|
||||
SCIMProviderGroup,
|
||||
SCIMProviderUser,
|
||||
Tenant,
|
||||
SystemTask,
|
||||
Task,
|
||||
ConnectionToken,
|
||||
AuthorizationCode,
|
||||
AccessToken,
|
||||
|
@ -44,7 +44,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
|
||||
return MetaResult()
|
||||
LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance)
|
||||
|
||||
apply_blueprint(str(self.blueprint_instance.pk))
|
||||
apply_blueprint(self.blueprint_instance.pk)
|
||||
return MetaResult()
|
||||
|
||||
|
||||
|
@ -4,12 +4,17 @@ from dataclasses import asdict, dataclass, field
|
||||
from hashlib import sha512
|
||||
from pathlib import Path
|
||||
from sys import platform
|
||||
from uuid import UUID
|
||||
|
||||
from dacite.core import from_dict
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from django.utils.text import slugify
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound
|
||||
from dramatiq.actor import actor
|
||||
from dramatiq.middleware import Middleware
|
||||
from structlog.stdlib import get_logger
|
||||
from watchdog.events import (
|
||||
FileCreatedEvent,
|
||||
@ -31,15 +36,13 @@ from authentik.blueprints.v1.importer import Importer
|
||||
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
|
||||
from authentik.blueprints.v1.oci import OCI_PREFIX
|
||||
from authentik.events.logs import capture_logs
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.events.utils import sanitize_dict
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
_file_watcher_started = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -53,22 +56,21 @@ class BlueprintFile:
|
||||
meta: BlueprintMetadata | None = field(default=None)
|
||||
|
||||
|
||||
def start_blueprint_watcher():
|
||||
"""Start blueprint watcher, if it's not running already."""
|
||||
# This function might be called twice since it's called on celery startup
|
||||
class BlueprintWatcherMiddleware(Middleware):
|
||||
def start_blueprint_watcher(self):
|
||||
"""Start blueprint watcher"""
|
||||
observer = Observer()
|
||||
kwargs = {}
|
||||
if platform.startswith("linux"):
|
||||
kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent)
|
||||
observer.schedule(
|
||||
BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs
|
||||
)
|
||||
observer.start()
|
||||
|
||||
global _file_watcher_started # noqa: PLW0603
|
||||
if _file_watcher_started:
|
||||
return
|
||||
observer = Observer()
|
||||
kwargs = {}
|
||||
if platform.startswith("linux"):
|
||||
kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent)
|
||||
observer.schedule(
|
||||
BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs
|
||||
)
|
||||
observer.start()
|
||||
_file_watcher_started = True
|
||||
def after_worker_boot(self, broker, worker):
|
||||
if not settings.TEST:
|
||||
self.start_blueprint_watcher()
|
||||
|
||||
|
||||
class BlueprintEventHandler(FileSystemEventHandler):
|
||||
@ -92,7 +94,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
|
||||
LOGGER.debug("new blueprint file created, starting discovery")
|
||||
for tenant in Tenant.objects.filter(ready=True):
|
||||
with tenant:
|
||||
blueprints_discovery.delay()
|
||||
Schedule.dispatch_by_actor(blueprints_discovery)
|
||||
|
||||
def on_modified(self, event: FileSystemEvent):
|
||||
"""Process file modification"""
|
||||
@ -103,14 +105,14 @@ class BlueprintEventHandler(FileSystemEventHandler):
|
||||
with tenant:
|
||||
for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True):
|
||||
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
|
||||
apply_blueprint.delay(instance.pk.hex)
|
||||
apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
@actor(
|
||||
description=_("Find blueprints as `blueprints_find` does, but return a safe dict."),
|
||||
throws=(DatabaseError, ProgrammingError, InternalError),
|
||||
)
|
||||
def blueprints_find_dict():
|
||||
"""Find blueprints as `blueprints_find` does, but return a safe dict"""
|
||||
blueprints = []
|
||||
for blueprint in blueprints_find():
|
||||
blueprints.append(sanitize_dict(asdict(blueprint)))
|
||||
@ -146,21 +148,19 @@ def blueprints_find() -> list[BlueprintFile]:
|
||||
return blueprints
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True
|
||||
@actor(
|
||||
description=_("Find blueprints and check if they need to be created in the database."),
|
||||
throws=(DatabaseError, ProgrammingError, InternalError),
|
||||
)
|
||||
@prefill_task
|
||||
def blueprints_discovery(self: SystemTask, path: str | None = None):
|
||||
"""Find blueprints and check if they need to be created in the database"""
|
||||
def blueprints_discovery(path: str | None = None):
|
||||
self: Task = CurrentTask.get_task()
|
||||
count = 0
|
||||
for blueprint in blueprints_find():
|
||||
if path and blueprint.path != path:
|
||||
continue
|
||||
check_blueprint_v1_file(blueprint)
|
||||
count += 1
|
||||
self.set_status(
|
||||
TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count))
|
||||
)
|
||||
self.info(f"Successfully imported {count} files.")
|
||||
|
||||
|
||||
def check_blueprint_v1_file(blueprint: BlueprintFile):
|
||||
@ -187,22 +187,26 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
|
||||
)
|
||||
if instance.last_applied_hash != blueprint.hash:
|
||||
LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path)
|
||||
apply_blueprint.delay(str(instance.pk))
|
||||
apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=SystemTask,
|
||||
)
|
||||
def apply_blueprint(self: SystemTask, instance_pk: str):
|
||||
"""Apply single blueprint"""
|
||||
self.save_on_success = False
|
||||
@actor(description=_("Apply single blueprint."))
|
||||
def apply_blueprint(instance_pk: UUID):
|
||||
try:
|
||||
self: Task = CurrentTask.get_task()
|
||||
except CurrentTaskNotFound:
|
||||
self = Task()
|
||||
self.set_uid(str(instance_pk))
|
||||
instance: BlueprintInstance | None = None
|
||||
try:
|
||||
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
|
||||
if not instance or not instance.enabled:
|
||||
if not instance:
|
||||
self.warning(f"Could not find blueprint {instance_pk}, skipping")
|
||||
return
|
||||
self.set_uid(slugify(instance.name))
|
||||
if not instance.enabled:
|
||||
self.info(f"Blueprint {instance.name} is disabled, skipping")
|
||||
return
|
||||
blueprint_content = instance.retrieve()
|
||||
file_hash = sha512(blueprint_content.encode()).hexdigest()
|
||||
importer = Importer.from_string(blueprint_content, instance.context)
|
||||
@ -212,19 +216,18 @@ def apply_blueprint(self: SystemTask, instance_pk: str):
|
||||
if not valid:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
instance.save()
|
||||
self.set_status(TaskStatus.ERROR, *logs)
|
||||
self.logs(logs)
|
||||
return
|
||||
with capture_logs() as logs:
|
||||
applied = importer.apply()
|
||||
if not applied:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
instance.save()
|
||||
self.set_status(TaskStatus.ERROR, *logs)
|
||||
self.logs(logs)
|
||||
return
|
||||
instance.status = BlueprintInstanceStatus.SUCCESSFUL
|
||||
instance.last_applied_hash = file_hash
|
||||
instance.last_applied = now()
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
except (
|
||||
OSError,
|
||||
DatabaseError,
|
||||
@ -235,15 +238,14 @@ def apply_blueprint(self: SystemTask, instance_pk: str):
|
||||
) as exc:
|
||||
if instance:
|
||||
instance.status = BlueprintInstanceStatus.ERROR
|
||||
self.set_error(exc)
|
||||
self.error(exc)
|
||||
finally:
|
||||
if instance:
|
||||
instance.save()
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
@actor(description=_("Remove blueprints which couldn't be fetched."))
|
||||
def clear_failed_blueprints():
|
||||
"""Remove blueprints which couldn't be fetched"""
|
||||
# Exclude OCI blueprints as those might be temporarily unavailable
|
||||
for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX):
|
||||
try:
|
||||
|
@ -9,6 +9,7 @@ class AuthentikBrandsConfig(ManagedAppConfig):
|
||||
name = "authentik.brands"
|
||||
label = "authentik_brands"
|
||||
verbose_name = "authentik Brands"
|
||||
default = True
|
||||
mountpoints = {
|
||||
"authentik.brands.urls_root": "",
|
||||
}
|
||||
|
@ -1,8 +1,7 @@
|
||||
"""authentik core app config"""
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
|
||||
class AuthentikCoreConfig(ManagedAppConfig):
|
||||
@ -14,14 +13,6 @@ class AuthentikCoreConfig(ManagedAppConfig):
|
||||
mountpoint = ""
|
||||
default = True
|
||||
|
||||
@ManagedAppConfig.reconcile_global
|
||||
def debug_worker_hook(self):
|
||||
"""Dispatch startup tasks inline when debugging"""
|
||||
if settings.DEBUG:
|
||||
from authentik.root.celery import worker_ready_hook
|
||||
|
||||
worker_ready_hook()
|
||||
|
||||
@ManagedAppConfig.reconcile_tenant
|
||||
def source_inbuilt(self):
|
||||
"""Reconcile inbuilt source"""
|
||||
@ -34,3 +25,18 @@ class AuthentikCoreConfig(ManagedAppConfig):
|
||||
},
|
||||
managed=Source.MANAGED_INBUILT,
|
||||
)
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.core.tasks import clean_expired_models, clean_temporary_users
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=clean_expired_models,
|
||||
crontab="2-59/5 * * * *",
|
||||
),
|
||||
ScheduleSpec(
|
||||
actor=clean_temporary_users,
|
||||
crontab="9-59/5 * * * *",
|
||||
),
|
||||
]
|
||||
|
@ -1,21 +0,0 @@
|
||||
"""Run bootstrap tasks"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django_tenants.utils import get_public_schema_name
|
||||
|
||||
from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Run bootstrap tasks to ensure certain objects are created"""
|
||||
|
||||
def handle(self, **options):
|
||||
for task in _get_startup_tasks_default_tenant():
|
||||
with Tenant.objects.get(schema_name=get_public_schema_name()):
|
||||
task()
|
||||
|
||||
for task in _get_startup_tasks_all_tenants():
|
||||
for tenant in Tenant.objects.filter(ready=True):
|
||||
with tenant:
|
||||
task()
|
@ -1,47 +0,0 @@
|
||||
"""Run worker"""
|
||||
|
||||
from sys import exit as sysexit
|
||||
from tempfile import tempdir
|
||||
|
||||
from celery.apps.worker import Worker
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import close_old_connections
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.debug import start_debug_server
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Run worker"""
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--beat",
|
||||
action="store_false",
|
||||
help="When set, this worker will _not_ run Beat (scheduled) tasks",
|
||||
)
|
||||
|
||||
def handle(self, **options):
|
||||
LOGGER.debug("Celery options", **options)
|
||||
close_old_connections()
|
||||
start_debug_server()
|
||||
worker: Worker = CELERY_APP.Worker(
|
||||
no_color=False,
|
||||
quiet=True,
|
||||
optimization="fair",
|
||||
autoscale=(CONFIG.get_int("worker.concurrency"), 1),
|
||||
task_events=True,
|
||||
beat=options.get("beat", True),
|
||||
schedule_filename=f"{tempdir}/celerybeat-schedule",
|
||||
queues=["authentik", "authentik_scheduled", "authentik_events"],
|
||||
)
|
||||
for task in CELERY_APP.tasks:
|
||||
LOGGER.debug("Registered task", task=task)
|
||||
|
||||
worker.start()
|
||||
sysexit(worker.exitcode)
|
@ -1082,6 +1082,12 @@ class AuthenticatedSession(SerializerModel):
|
||||
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.core.api.authenticated_sessions import AuthenticatedSessionSerializer
|
||||
|
||||
return AuthenticatedSessionSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Authenticated Session")
|
||||
verbose_name_plural = _("Authenticated Sessions")
|
||||
|
@ -3,6 +3,9 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import (
|
||||
@ -11,17 +14,14 @@ from authentik.core.models import (
|
||||
ExpiringModel,
|
||||
User,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def clean_expired_models(self: SystemTask):
|
||||
"""Remove expired objects"""
|
||||
messages = []
|
||||
@actor(description=_("Remove expired objects."))
|
||||
def clean_expired_models():
|
||||
self: Task = CurrentTask.get_task()
|
||||
for cls in ExpiringModel.__subclasses__():
|
||||
cls: ExpiringModel
|
||||
objects = (
|
||||
@ -31,16 +31,13 @@ def clean_expired_models(self: SystemTask):
|
||||
for obj in objects:
|
||||
obj.expire_action()
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def clean_temporary_users(self: SystemTask):
|
||||
"""Remove temporary users created by SAML Sources"""
|
||||
@actor(description=_("Remove temporary users created by SAML Sources."))
|
||||
def clean_temporary_users():
|
||||
self: Task = CurrentTask.get_task()
|
||||
_now = datetime.now()
|
||||
messages = []
|
||||
deleted_users = 0
|
||||
for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):
|
||||
if not user.attributes.get(USER_ATTRIBUTE_EXPIRES):
|
||||
@ -52,5 +49,4 @@ def clean_temporary_users(self: SystemTask):
|
||||
LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta)
|
||||
user.delete()
|
||||
deleted_users += 1
|
||||
messages.append(f"Successfully deleted {deleted_users} users.")
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
self.info(f"Successfully deleted {deleted_users} users.")
|
||||
|
@ -36,7 +36,7 @@ class TestTasks(APITestCase):
|
||||
expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API
|
||||
)
|
||||
key = token.key
|
||||
clean_expired_models.delay().get()
|
||||
clean_expired_models.send()
|
||||
token.refresh_from_db()
|
||||
self.assertNotEqual(key, token.key)
|
||||
|
||||
@ -50,5 +50,5 @@ class TestTasks(APITestCase):
|
||||
USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()),
|
||||
},
|
||||
)
|
||||
clean_temporary_users.delay().get()
|
||||
clean_temporary_users.send()
|
||||
self.assertFalse(User.objects.filter(username=username))
|
||||
|
@ -4,6 +4,8 @@ from datetime import UTC, datetime
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
MANAGED_KEY = "goauthentik.io/crypto/jwt-managed"
|
||||
|
||||
@ -67,3 +69,14 @@ class AuthentikCryptoConfig(ManagedAppConfig):
|
||||
"key_data": builder.private_key,
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.crypto.tasks import certificate_discovery
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=certificate_discovery,
|
||||
crontab=f"{fqdn_rand('crypto_certificate_discovery')} * * * *",
|
||||
),
|
||||
]
|
||||
|
@ -1,13 +0,0 @@
|
||||
"""Crypto task Settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"crypto_certificate_discovery": {
|
||||
"task": "authentik.crypto.tasks.certificate_discovery",
|
||||
"schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
@ -7,13 +7,13 @@ from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
from cryptography.x509.base import load_pem_x509_certificate
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -36,10 +36,9 @@ def ensure_certificate_valid(body: str):
|
||||
return body
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def certificate_discovery(self: SystemTask):
|
||||
"""Discover, import and update certificates from the filesystem"""
|
||||
@actor(description=_("Discover, import and update certificates from the filesystem."))
|
||||
def certificate_discovery():
|
||||
self: Task = CurrentTask.get_task()
|
||||
certs = {}
|
||||
private_keys = {}
|
||||
discovered = 0
|
||||
@ -84,6 +83,4 @@ def certificate_discovery(self: SystemTask):
|
||||
dirty = True
|
||||
if dirty:
|
||||
cert.save()
|
||||
self.set_status(
|
||||
TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=discovered))
|
||||
)
|
||||
self.info(f"Successfully imported {discovered} files.")
|
||||
|
@ -338,7 +338,7 @@ class TestCrypto(APITestCase):
|
||||
with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key:
|
||||
_key.write(builder.private_key)
|
||||
with CONFIG.patch("cert_discovery_dir", temp_dir):
|
||||
certificate_discovery()
|
||||
certificate_discovery.send()
|
||||
keypair: CertificateKeyPair = CertificateKeyPair.objects.filter(
|
||||
managed=MANAGED_DISCOVERED % "foo"
|
||||
).first()
|
||||
|
@ -3,6 +3,8 @@
|
||||
from django.conf import settings
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
|
||||
class EnterpriseConfig(ManagedAppConfig):
|
||||
@ -26,3 +28,14 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
return LicenseKey.cached_summary().status.is_valid
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.enterprise.tasks import enterprise_update_usage
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=enterprise_update_usage,
|
||||
crontab=f"{fqdn_rand('enterprise_update_usage')} */2 * * *",
|
||||
),
|
||||
]
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""authentik Unique Password policy app config"""
|
||||
|
||||
from authentik.enterprise.apps import EnterpriseConfig
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
|
||||
class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig):
|
||||
@ -8,3 +10,21 @@ class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig):
|
||||
label = "authentik_policies_unique_password"
|
||||
verbose_name = "authentik Enterprise.Policies.Unique Password"
|
||||
default = True
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.enterprise.policies.unique_password.tasks import (
|
||||
check_and_purge_password_history,
|
||||
trim_password_histories,
|
||||
)
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=trim_password_histories,
|
||||
crontab=f"{fqdn_rand('policies_unique_password_trim')} */12 * * *",
|
||||
),
|
||||
ScheduleSpec(
|
||||
actor=check_and_purge_password_history,
|
||||
crontab=f"{fqdn_rand('policies_unique_password_purge')} */24 * * *",
|
||||
),
|
||||
]
|
||||
|
@ -1,20 +0,0 @@
|
||||
"""Unique Password Policy settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"policies_unique_password_trim_history": {
|
||||
"task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories",
|
||||
"schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
"policies_unique_password_check_purge": {
|
||||
"task": (
|
||||
"authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history"
|
||||
),
|
||||
"schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
@ -1,35 +1,37 @@
|
||||
from django.db.models.aggregates import Count
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from structlog import get_logger
|
||||
|
||||
from authentik.enterprise.policies.unique_password.models import (
|
||||
UniquePasswordPolicy,
|
||||
UserPasswordHistory,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def check_and_purge_password_history(self: SystemTask):
|
||||
"""Check if any UniquePasswordPolicy exists, and if not, purge the password history table.
|
||||
This is run on a schedule instead of being triggered by policy binding deletion.
|
||||
"""
|
||||
@actor(
|
||||
description=_(
|
||||
"Check if any UniquePasswordPolicy exists, and if not, purge the password history table."
|
||||
)
|
||||
)
|
||||
def check_and_purge_password_history():
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
if not UniquePasswordPolicy.objects.exists():
|
||||
UserPasswordHistory.objects.all().delete()
|
||||
LOGGER.debug("Purged UserPasswordHistory table as no policies are in use")
|
||||
self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory")
|
||||
self.info("Successfully purged UserPasswordHistory")
|
||||
return
|
||||
|
||||
self.set_status(
|
||||
TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists"
|
||||
)
|
||||
self.info("Not purging password histories, a unique password policy exists")
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def trim_password_histories(self: SystemTask):
|
||||
@actor(description=_("Remove user password history that are too old."))
|
||||
def trim_password_histories():
|
||||
"""Removes rows from UserPasswordHistory older than
|
||||
the `n` most recent entries.
|
||||
|
||||
@ -37,6 +39,8 @@ def trim_password_histories(self: SystemTask):
|
||||
UniquePasswordPolicy policies.
|
||||
"""
|
||||
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
# No policy, we'll let the cleanup above do its thing
|
||||
if not UniquePasswordPolicy.objects.exists():
|
||||
return
|
||||
@ -63,4 +67,4 @@ def trim_password_histories(self: SystemTask):
|
||||
|
||||
num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete()
|
||||
LOGGER.debug("Deleted stale password history records", count=num_deleted)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records")
|
||||
self.info(f"Delete {num_deleted} stale password history records")
|
||||
|
@ -76,7 +76,7 @@ class TestCheckAndPurgePasswordHistory(TestCase):
|
||||
self.assertTrue(UserPasswordHistory.objects.exists())
|
||||
|
||||
# Run the task - should purge since no policy is in use
|
||||
check_and_purge_password_history()
|
||||
check_and_purge_password_history.send()
|
||||
|
||||
# Verify the table is empty
|
||||
self.assertFalse(UserPasswordHistory.objects.exists())
|
||||
@ -99,7 +99,7 @@ class TestCheckAndPurgePasswordHistory(TestCase):
|
||||
self.assertTrue(UserPasswordHistory.objects.exists())
|
||||
|
||||
# Run the task - should NOT purge since a policy is in use
|
||||
check_and_purge_password_history()
|
||||
check_and_purge_password_history.send()
|
||||
|
||||
# Verify the entries still exist
|
||||
self.assertTrue(UserPasswordHistory.objects.exists())
|
||||
@ -142,7 +142,7 @@ class TestTrimPasswordHistory(TestCase):
|
||||
enabled=True,
|
||||
order=0,
|
||||
)
|
||||
trim_password_histories.delay()
|
||||
trim_password_histories.send()
|
||||
user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user)
|
||||
self.assertEqual(len(user_pwd_history_qs), 1)
|
||||
|
||||
@ -159,7 +159,7 @@ class TestTrimPasswordHistory(TestCase):
|
||||
enabled=False,
|
||||
order=0,
|
||||
)
|
||||
trim_password_histories.delay()
|
||||
trim_password_histories.send()
|
||||
self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())
|
||||
|
||||
def test_trim_password_history_fewer_records_than_maximum_is_no_op(self):
|
||||
@ -174,5 +174,5 @@ class TestTrimPasswordHistory(TestCase):
|
||||
enabled=True,
|
||||
order=0,
|
||||
)
|
||||
trim_password_histories.delay()
|
||||
trim_password_histories.send()
|
||||
self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())
|
||||
|
@ -55,5 +55,5 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
|
||||
]
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_single_task = google_workspace_sync
|
||||
sync_task = google_workspace_sync
|
||||
sync_objects_task = google_workspace_sync_objects
|
||||
|
@ -7,6 +7,7 @@ from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import Actor
|
||||
from google.oauth2.service_account import Credentials
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
@ -110,6 +111,12 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
help_text=_("Property mappings used for group creation/updating."),
|
||||
)
|
||||
|
||||
@property
|
||||
def sync_actor(self) -> Actor:
|
||||
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
|
||||
|
||||
return google_workspace_sync
|
||||
|
||||
def client_for_model(
|
||||
self,
|
||||
model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup],
|
||||
|
@ -1,13 +0,0 @@
|
||||
"""Google workspace provider task Settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"providers_google_workspace_sync": {
|
||||
"task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all",
|
||||
"schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
@ -2,15 +2,13 @@
|
||||
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.enterprise.providers.google_workspace.tasks import (
|
||||
google_workspace_sync,
|
||||
google_workspace_sync_direct,
|
||||
google_workspace_sync_m2m,
|
||||
google_workspace_sync_direct_dispatch,
|
||||
google_workspace_sync_m2m_dispatch,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.signals import register_signals
|
||||
|
||||
register_signals(
|
||||
GoogleWorkspaceProvider,
|
||||
task_sync_single=google_workspace_sync,
|
||||
task_sync_direct=google_workspace_sync_direct,
|
||||
task_sync_m2m=google_workspace_sync_m2m,
|
||||
task_sync_direct_dispatch=google_workspace_sync_direct_dispatch,
|
||||
task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch,
|
||||
)
|
||||
|
@ -1,37 +1,48 @@
|
||||
"""Google Provider tasks"""
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
|
||||
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
|
||||
from authentik.lib.sync.outgoing.tasks import SyncTasks
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
sync_tasks = SyncTasks(GoogleWorkspaceProvider)
|
||||
|
||||
|
||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
|
||||
@actor(description=_("Sync Google Workspace provider objects."))
|
||||
def google_workspace_sync_objects(*args, **kwargs):
|
||||
return sync_tasks.sync_objects(*args, **kwargs)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
|
||||
)
|
||||
def google_workspace_sync(self, provider_pk: int, *args, **kwargs):
|
||||
@actor(description=_("Full sync for Google Workspace provider."))
|
||||
def google_workspace_sync(provider_pk: int, *args, **kwargs):
|
||||
"""Run full sync for Google Workspace provider"""
|
||||
return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects)
|
||||
return sync_tasks.sync(provider_pk, google_workspace_sync_objects)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def google_workspace_sync_all():
|
||||
return sync_tasks.sync_all(google_workspace_sync)
|
||||
|
||||
|
||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
|
||||
@actor(description=_("Sync a direct object (user, group) for Google Workspace provider."))
|
||||
def google_workspace_sync_direct(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct(*args, **kwargs)
|
||||
|
||||
|
||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
|
||||
@actor(
|
||||
description=_(
|
||||
"Dispatch syncs for a direct object (user, group) for Google Workspace providers."
|
||||
)
|
||||
)
|
||||
def google_workspace_sync_direct_dispatch(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs)
|
||||
|
||||
|
||||
@actor(description=_("Sync a related object (memberships) for Google Workspace provider."))
|
||||
def google_workspace_sync_m2m(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_m2m(*args, **kwargs)
|
||||
|
||||
|
||||
@actor(
|
||||
description=_(
|
||||
"Dispatch syncs for a related object (memberships) for Google Workspace providers."
|
||||
)
|
||||
)
|
||||
def google_workspace_sync_m2m_dispatch(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs)
|
||||
|
@ -324,7 +324,7 @@ class GoogleWorkspaceGroupTests(TestCase):
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
google_workspace_sync.delay(self.provider.pk).get()
|
||||
google_workspace_sync.send(self.provider.pk).get_result()
|
||||
self.assertTrue(
|
||||
GoogleWorkspaceProviderGroup.objects.filter(
|
||||
group=different_group, provider=self.provider
|
||||
|
@ -302,7 +302,7 @@ class GoogleWorkspaceUserTests(TestCase):
|
||||
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
|
||||
MagicMock(return_value={"developerKey": self.api_key, "http": http}),
|
||||
):
|
||||
google_workspace_sync.delay(self.provider.pk).get()
|
||||
google_workspace_sync.send(self.provider.pk).get_result()
|
||||
self.assertTrue(
|
||||
GoogleWorkspaceProviderUser.objects.filter(
|
||||
user=different_user, provider=self.provider
|
||||
|
@ -53,5 +53,5 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
|
||||
]
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
sync_single_task = microsoft_entra_sync
|
||||
sync_task = microsoft_entra_sync
|
||||
sync_objects_task = microsoft_entra_sync_objects
|
||||
|
@ -8,6 +8,7 @@ from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import Actor
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import (
|
||||
@ -99,6 +100,12 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
|
||||
help_text=_("Property mappings used for group creation/updating."),
|
||||
)
|
||||
|
||||
@property
|
||||
def sync_actor(self) -> Actor:
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
|
||||
|
||||
return microsoft_entra_sync
|
||||
|
||||
def client_for_model(
|
||||
self,
|
||||
model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup],
|
||||
|
@ -1,13 +0,0 @@
|
||||
"""Microsoft Entra provider task Settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"providers_microsoft_entra_sync": {
|
||||
"task": "authentik.enterprise.providers.microsoft_entra.tasks.microsoft_entra_sync_all",
|
||||
"schedule": crontab(minute=fqdn_rand("microsoft_entra_sync_all"), hour="*/4"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
@ -2,15 +2,13 @@
|
||||
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||
from authentik.enterprise.providers.microsoft_entra.tasks import (
|
||||
microsoft_entra_sync,
|
||||
microsoft_entra_sync_direct,
|
||||
microsoft_entra_sync_m2m,
|
||||
microsoft_entra_sync_direct_dispatch,
|
||||
microsoft_entra_sync_m2m_dispatch,
|
||||
)
|
||||
from authentik.lib.sync.outgoing.signals import register_signals
|
||||
|
||||
register_signals(
|
||||
MicrosoftEntraProvider,
|
||||
task_sync_single=microsoft_entra_sync,
|
||||
task_sync_direct=microsoft_entra_sync_direct,
|
||||
task_sync_m2m=microsoft_entra_sync_m2m,
|
||||
task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch,
|
||||
task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch,
|
||||
)
|
||||
|
@ -1,37 +1,46 @@
|
||||
"""Microsoft Entra Provider tasks"""
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
|
||||
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
|
||||
from authentik.lib.sync.outgoing.tasks import SyncTasks
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
sync_tasks = SyncTasks(MicrosoftEntraProvider)
|
||||
|
||||
|
||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
|
||||
@actor(description=_("Sync Microsoft Entra provider objects."))
|
||||
def microsoft_entra_sync_objects(*args, **kwargs):
|
||||
return sync_tasks.sync_objects(*args, **kwargs)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
|
||||
)
|
||||
def microsoft_entra_sync(self, provider_pk: int, *args, **kwargs):
|
||||
@actor(description=_("Full sync for Microsoft Entra provider."))
|
||||
def microsoft_entra_sync(provider_pk: int, *args, **kwargs):
|
||||
"""Run full sync for Microsoft Entra provider"""
|
||||
return sync_tasks.sync_single(self, provider_pk, microsoft_entra_sync_objects)
|
||||
return sync_tasks.sync(provider_pk, microsoft_entra_sync_objects)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def microsoft_entra_sync_all():
|
||||
return sync_tasks.sync_all(microsoft_entra_sync)
|
||||
|
||||
|
||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
|
||||
@actor(description=_("Sync a direct object (user, group) for Microsoft Entra provider."))
|
||||
def microsoft_entra_sync_direct(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct(*args, **kwargs)
|
||||
|
||||
|
||||
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
|
||||
@actor(
|
||||
description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.")
|
||||
)
|
||||
def microsoft_entra_sync_direct_dispatch(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs)
|
||||
|
||||
|
||||
@actor(description=_("Sync a related object (memberships) for Microsoft Entra provider."))
|
||||
def microsoft_entra_sync_m2m(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_m2m(*args, **kwargs)
|
||||
|
||||
|
||||
@actor(
|
||||
description=_(
|
||||
"Dispatch syncs for a related object (memberships) for Microsoft Entra providers."
|
||||
)
|
||||
)
|
||||
def microsoft_entra_sync_m2m_dispatch(*args, **kwargs):
|
||||
return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs)
|
||||
|
@ -252,9 +252,13 @@ class MicrosoftEntraGroupTests(TestCase):
|
||||
member_add.assert_called_once()
|
||||
self.assertEqual(
|
||||
member_add.call_args[0][0].odata_id,
|
||||
f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter(
|
||||
f"https://graph.microsoft.com/v1.0/directoryObjects/{
|
||||
MicrosoftEntraProviderUser.objects.filter(
|
||||
provider=self.provider,
|
||||
).first().microsoft_id}",
|
||||
)
|
||||
.first()
|
||||
.microsoft_id
|
||||
}",
|
||||
)
|
||||
|
||||
def test_group_create_member_remove(self):
|
||||
@ -311,9 +315,13 @@ class MicrosoftEntraGroupTests(TestCase):
|
||||
member_add.assert_called_once()
|
||||
self.assertEqual(
|
||||
member_add.call_args[0][0].odata_id,
|
||||
f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter(
|
||||
f"https://graph.microsoft.com/v1.0/directoryObjects/{
|
||||
MicrosoftEntraProviderUser.objects.filter(
|
||||
provider=self.provider,
|
||||
).first().microsoft_id}",
|
||||
)
|
||||
.first()
|
||||
.microsoft_id
|
||||
}",
|
||||
)
|
||||
member_remove.assert_called_once()
|
||||
|
||||
@ -413,7 +421,7 @@ class MicrosoftEntraGroupTests(TestCase):
|
||||
),
|
||||
) as group_list,
|
||||
):
|
||||
microsoft_entra_sync.delay(self.provider.pk).get()
|
||||
microsoft_entra_sync.send(self.provider.pk).get_result()
|
||||
self.assertTrue(
|
||||
MicrosoftEntraProviderGroup.objects.filter(
|
||||
group=different_group, provider=self.provider
|
||||
|
@ -397,7 +397,7 @@ class MicrosoftEntraUserTests(APITestCase):
|
||||
AsyncMock(return_value=GroupCollectionResponse(value=[])),
|
||||
),
|
||||
):
|
||||
microsoft_entra_sync.delay(self.provider.pk).get()
|
||||
microsoft_entra_sync.send(self.provider.pk).get_result()
|
||||
self.assertTrue(
|
||||
MicrosoftEntraProviderUser.objects.filter(
|
||||
user=different_user, provider=self.provider
|
||||
|
@ -17,6 +17,7 @@ from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.models import CreatedUpdatedModel
|
||||
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
|
||||
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider
|
||||
from authentik.tasks.models import TasksModel
|
||||
|
||||
|
||||
class EventTypes(models.TextChoices):
|
||||
@ -42,7 +43,7 @@ class SSFEventStatus(models.TextChoices):
|
||||
SENT = "sent"
|
||||
|
||||
|
||||
class SSFProvider(BackchannelProvider):
|
||||
class SSFProvider(TasksModel, BackchannelProvider):
|
||||
"""Shared Signals Framework provider to allow applications to
|
||||
receive user events from authentik."""
|
||||
|
||||
|
@ -18,7 +18,7 @@ from authentik.enterprise.providers.ssf.models import (
|
||||
EventTypes,
|
||||
SSFProvider,
|
||||
)
|
||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_event
|
||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_events
|
||||
from authentik.events.middleware import audit_ignore
|
||||
from authentik.stages.authenticator.models import Device
|
||||
from authentik.stages.authenticator_duo.models import DuoDevice
|
||||
@ -66,7 +66,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
|
||||
|
||||
As this signal is also triggered with a regular logout, we can't be sure
|
||||
if the session has been deleted by an admin or by the user themselves."""
|
||||
send_ssf_event(
|
||||
send_ssf_events(
|
||||
EventTypes.CAEP_SESSION_REVOKED,
|
||||
{
|
||||
"initiating_entity": "user",
|
||||
@ -88,7 +88,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
|
||||
@receiver(password_changed)
|
||||
def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_):
|
||||
"""Credential change trigger (password changed)"""
|
||||
send_ssf_event(
|
||||
send_ssf_events(
|
||||
EventTypes.CAEP_CREDENTIAL_CHANGE,
|
||||
{
|
||||
"credential_type": "password",
|
||||
@ -126,7 +126,7 @@ def ssf_device_post_save(sender: type[Model], instance: Device, created: bool, *
|
||||
}
|
||||
if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID:
|
||||
data["fido2_aaguid"] = instance.aaguid
|
||||
send_ssf_event(
|
||||
send_ssf_events(
|
||||
EventTypes.CAEP_CREDENTIAL_CHANGE,
|
||||
data,
|
||||
sub_id={
|
||||
@ -153,7 +153,7 @@ def ssf_device_post_delete(sender: type[Model], instance: Device, **_):
|
||||
}
|
||||
if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID:
|
||||
data["fido2_aaguid"] = instance.aaguid
|
||||
send_ssf_event(
|
||||
send_ssf_events(
|
||||
EventTypes.CAEP_CREDENTIAL_CHANGE,
|
||||
data,
|
||||
sub_id={
|
||||
|
@ -1,7 +1,11 @@
|
||||
from celery import group
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from requests.exceptions import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@ -13,19 +17,16 @@ from authentik.enterprise.providers.ssf.models import (
|
||||
Stream,
|
||||
StreamEvent,
|
||||
)
|
||||
from authentik.events.logs import LogEvent
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
session = get_http_session()
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def send_ssf_event(
|
||||
def send_ssf_events(
|
||||
event_type: EventTypes,
|
||||
data: dict,
|
||||
stream_filter: dict | None = None,
|
||||
@ -33,7 +34,7 @@ def send_ssf_event(
|
||||
**extra_data,
|
||||
):
|
||||
"""Wrapper to send an SSF event to multiple streams"""
|
||||
payload = []
|
||||
events_data = {}
|
||||
if not stream_filter:
|
||||
stream_filter = {}
|
||||
stream_filter["events_requested__contains"] = [event_type]
|
||||
@ -41,16 +42,22 @@ def send_ssf_event(
|
||||
extra_data.setdefault("txn", request.request_id)
|
||||
for stream in Stream.objects.filter(**stream_filter):
|
||||
event_data = stream.prepare_event_payload(event_type, data, **extra_data)
|
||||
payload.append((str(stream.uuid), event_data))
|
||||
return _send_ssf_event.delay(payload)
|
||||
events_data[stream.uuid] = event_data
|
||||
ssf_events_dispatch.send(events_data)
|
||||
|
||||
|
||||
def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
|
||||
@actor(description=_("Dispatch SSF events."))
|
||||
def ssf_events_dispatch(events_data: dict[str, dict[str, Any]]):
|
||||
for stream_uuid, event_data in events_data.items():
|
||||
stream = Stream.objects.filter(pk=stream_uuid).first()
|
||||
if not stream:
|
||||
continue
|
||||
send_ssf_event.send_with_options(args=(stream_uuid, event_data), rel_obj=stream.provider)
|
||||
|
||||
|
||||
def _check_app_access(stream: Stream, event_data: dict) -> bool:
|
||||
"""Check if event is related to user and if so, check
|
||||
if the user has access to the application"""
|
||||
stream = Stream.objects.filter(pk=stream_uuid).first()
|
||||
if not stream:
|
||||
return False
|
||||
# `event_data` is a dict version of a StreamEvent
|
||||
sub_id = event_data.get("payload", {}).get("sub_id", {})
|
||||
email = sub_id.get("user", {}).get("email", None)
|
||||
@ -65,42 +72,22 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
|
||||
return engine.passing
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def _send_ssf_event(event_data: list[tuple[str, dict]]):
|
||||
tasks = []
|
||||
for stream, data in event_data:
|
||||
if not _check_app_access(stream, data):
|
||||
continue
|
||||
event = StreamEvent.objects.create(**data)
|
||||
tasks.extend(send_single_ssf_event(stream, str(event.uuid)))
|
||||
main_task = group(*tasks)
|
||||
main_task()
|
||||
@actor(description=_("Send an SSF event."))
|
||||
def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
|
||||
def send_single_ssf_event(stream_id: str, evt_id: str):
|
||||
stream = Stream.objects.filter(pk=stream_id).first()
|
||||
stream = Stream.objects.filter(pk=stream_uuid).first()
|
||||
if not stream:
|
||||
return
|
||||
event = StreamEvent.objects.filter(pk=evt_id).first()
|
||||
if not event:
|
||||
if not _check_app_access(stream, event_data):
|
||||
return
|
||||
event = StreamEvent.objects.create(**event_data)
|
||||
self.set_uid(event.pk)
|
||||
if event.status == SSFEventStatus.SENT:
|
||||
return
|
||||
if stream.delivery_method == DeliveryMethods.RISC_PUSH:
|
||||
return [ssf_push_event.si(str(event.pk))]
|
||||
return []
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def ssf_push_event(self: SystemTask, event_id: str):
|
||||
self.save_on_success = False
|
||||
event = StreamEvent.objects.filter(pk=event_id).first()
|
||||
if not event:
|
||||
return
|
||||
self.set_uid(event_id)
|
||||
if event.status == SSFEventStatus.SENT:
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
if stream.delivery_method != DeliveryMethods.RISC_PUSH:
|
||||
return
|
||||
|
||||
try:
|
||||
response = session.post(
|
||||
event.stream.endpoint_url,
|
||||
@ -110,26 +97,17 @@ def ssf_push_event(self: SystemTask, event_id: str):
|
||||
response.raise_for_status()
|
||||
event.status = SSFEventStatus.SENT
|
||||
event.save()
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
return
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Failed to send SSF event", exc=exc)
|
||||
self.set_status(TaskStatus.ERROR)
|
||||
attrs = {}
|
||||
if exc.response:
|
||||
attrs["response"] = {
|
||||
"content": exc.response.text,
|
||||
"status": exc.response.status_code,
|
||||
}
|
||||
self.set_error(
|
||||
exc,
|
||||
LogEvent(
|
||||
_("Failed to send request"),
|
||||
log_level="warning",
|
||||
logger=self.__name__,
|
||||
attributes=attrs,
|
||||
),
|
||||
)
|
||||
self.warning(exc)
|
||||
self.warning("Failed to send request", **attrs)
|
||||
# Re-up the expiry of the stream event
|
||||
event.expires = now() + timedelta_from_string(event.stream.provider.event_retention)
|
||||
event.status = SSFEventStatus.PENDING_FAILED
|
||||
|
@ -13,7 +13,7 @@ from authentik.enterprise.providers.ssf.models import (
|
||||
SSFProvider,
|
||||
Stream,
|
||||
)
|
||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_event
|
||||
from authentik.enterprise.providers.ssf.tasks import send_ssf_events
|
||||
from authentik.enterprise.providers.ssf.views.base import SSFView
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -109,7 +109,7 @@ class StreamView(SSFView):
|
||||
"User does not have permission to create stream for this provider."
|
||||
)
|
||||
instance: Stream = stream.save(provider=self.provider)
|
||||
send_ssf_event(
|
||||
send_ssf_events(
|
||||
EventTypes.SET_VERIFICATION,
|
||||
{
|
||||
"state": None,
|
||||
|
@ -6,7 +6,7 @@ from djangoql.ast import Name
|
||||
from djangoql.exceptions import DjangoQLError
|
||||
from djangoql.queryset import apply_search
|
||||
from djangoql.schema import DjangoQLSchema
|
||||
from rest_framework.filters import SearchFilter
|
||||
from rest_framework.filters import BaseFilterBackend, SearchFilter
|
||||
from rest_framework.request import Request
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@ -39,19 +39,21 @@ class BaseSchema(DjangoQLSchema):
|
||||
return super().resolve_name(name)
|
||||
|
||||
|
||||
class QLSearch(SearchFilter):
|
||||
class QLSearch(BaseFilterBackend):
|
||||
"""rest_framework search filter which uses DjangoQL"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._fallback = SearchFilter()
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return apps.get_app_config("authentik_enterprise").enabled()
|
||||
|
||||
def get_search_terms(self, request) -> str:
|
||||
"""
|
||||
Search terms are set by a ?search=... query parameter,
|
||||
and may be comma and/or whitespace delimited.
|
||||
"""
|
||||
params = request.query_params.get(self.search_param, "")
|
||||
def get_search_terms(self, request: Request) -> str:
|
||||
"""Search terms are set by a ?search=... query parameter,
|
||||
and may be comma and/or whitespace delimited."""
|
||||
params = request.query_params.get("search", "")
|
||||
params = params.replace("\x00", "") # strip null characters
|
||||
return params
|
||||
|
||||
@ -70,9 +72,9 @@ class QLSearch(SearchFilter):
|
||||
search_query = self.get_search_terms(request)
|
||||
schema = self.get_schema(request, view)
|
||||
if len(search_query) == 0 or not self.enabled:
|
||||
return super().filter_queryset(request, queryset, view)
|
||||
return self._fallback.filter_queryset(request, queryset, view)
|
||||
try:
|
||||
return apply_search(queryset, search_query, schema=schema)
|
||||
except DjangoQLError as exc:
|
||||
LOGGER.debug("Failed to parse search expression", exc=exc)
|
||||
return super().filter_queryset(request, queryset, view)
|
||||
return self._fallback.filter_queryset(request, queryset, view)
|
||||
|
@ -57,7 +57,7 @@ class QLTest(APITestCase):
|
||||
)
|
||||
self.assertEqual(res.status_code, 200)
|
||||
content = loads(res.content)
|
||||
self.assertGreaterEqual(content["pagination"]["count"], 1)
|
||||
self.assertEqual(content["pagination"]["count"], 1)
|
||||
self.assertEqual(content["results"][0]["username"], self.user.username)
|
||||
|
||||
def test_search_json(self):
|
||||
|
@ -1,17 +1,5 @@
|
||||
"""Enterprise additional settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"enterprise_update_usage": {
|
||||
"task": "authentik.enterprise.tasks.enterprise_update_usage",
|
||||
"schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
}
|
||||
}
|
||||
|
||||
TENANT_APPS = [
|
||||
"authentik.enterprise.audit",
|
||||
"authentik.enterprise.policies.unique_password",
|
||||
|
@ -10,6 +10,7 @@ from django.utils.timezone import get_current_timezone
|
||||
from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE
|
||||
from authentik.enterprise.models import License
|
||||
from authentik.enterprise.tasks import enterprise_update_usage
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
|
||||
|
||||
@receiver(pre_save, sender=License)
|
||||
@ -26,7 +27,7 @@ def pre_save_license(sender: type[License], instance: License, **_):
|
||||
def post_save_license(sender: type[License], instance: License, **_):
|
||||
"""Trigger license usage calculation when license is saved"""
|
||||
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
|
||||
enterprise_update_usage.delay()
|
||||
Schedule.dispatch_by_actor(enterprise_update_usage)
|
||||
|
||||
|
||||
@receiver(post_delete, sender=License)
|
||||
|
@ -1,14 +1,11 @@
|
||||
"""Enterprise tasks"""
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def enterprise_update_usage(self: SystemTask):
|
||||
"""Update enterprise license status"""
|
||||
@actor(description=_("Update enterprise license status."))
|
||||
def enterprise_update_usage():
|
||||
LicenseKey.get_total().record_usage()
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
|
@ -1,104 +0,0 @@
|
||||
"""Tasks API"""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from django.contrib import messages
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import (
|
||||
CharField,
|
||||
ChoiceField,
|
||||
DateTimeField,
|
||||
FloatField,
|
||||
SerializerMethodField,
|
||||
)
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.utils import ModelSerializer
|
||||
from authentik.events.logs import LogEventSerializer
|
||||
from authentik.events.models import SystemTask, TaskStatus
|
||||
from authentik.rbac.decorators import permission_required
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class SystemTaskSerializer(ModelSerializer):
|
||||
"""Serialize TaskInfo and TaskResult"""
|
||||
|
||||
name = CharField()
|
||||
full_name = SerializerMethodField()
|
||||
uid = CharField(required=False)
|
||||
description = CharField()
|
||||
start_timestamp = DateTimeField(read_only=True)
|
||||
finish_timestamp = DateTimeField(read_only=True)
|
||||
duration = FloatField(read_only=True)
|
||||
|
||||
status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus])
|
||||
messages = LogEventSerializer(many=True)
|
||||
|
||||
def get_full_name(self, instance: SystemTask) -> str:
|
||||
"""Get full name with UID"""
|
||||
if instance.uid:
|
||||
return f"{instance.name}:{instance.uid}"
|
||||
return instance.name
|
||||
|
||||
class Meta:
|
||||
model = SystemTask
|
||||
fields = [
|
||||
"uuid",
|
||||
"name",
|
||||
"full_name",
|
||||
"uid",
|
||||
"description",
|
||||
"start_timestamp",
|
||||
"finish_timestamp",
|
||||
"duration",
|
||||
"status",
|
||||
"messages",
|
||||
"expires",
|
||||
"expiring",
|
||||
]
|
||||
|
||||
|
||||
class SystemTaskViewSet(ReadOnlyModelViewSet):
|
||||
"""Read-only view set that returns all background tasks"""
|
||||
|
||||
queryset = SystemTask.objects.all()
|
||||
serializer_class = SystemTaskSerializer
|
||||
filterset_fields = ["name", "uid", "status"]
|
||||
ordering = ["name", "uid", "status"]
|
||||
search_fields = ["name", "description", "uid", "status"]
|
||||
|
||||
@permission_required(None, ["authentik_events.run_task"])
|
||||
@extend_schema(
|
||||
request=OpenApiTypes.NONE,
|
||||
responses={
|
||||
204: OpenApiResponse(description="Task retried successfully"),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
500: OpenApiResponse(description="Failed to retry task"),
|
||||
},
|
||||
)
|
||||
@action(detail=True, methods=["POST"], permission_classes=[])
|
||||
def run(self, request: Request, pk=None) -> Response:
|
||||
"""Run task"""
|
||||
task: SystemTask = self.get_object()
|
||||
try:
|
||||
task_module = import_module(task.task_call_module)
|
||||
task_func = getattr(task_module, task.task_call_func)
|
||||
LOGGER.info("Running task", task=task_func)
|
||||
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
|
||||
messages.success(
|
||||
self.request,
|
||||
_("Successfully started task {name}.".format_map({"name": task.name})),
|
||||
)
|
||||
return Response(status=204)
|
||||
except (ImportError, AttributeError) as exc: # pragma: no cover
|
||||
LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc)
|
||||
# if we get an import error, the module path has probably changed
|
||||
task.delete()
|
||||
return Response(status=500)
|
@ -1,12 +1,11 @@
|
||||
"""authentik events app"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
from prometheus_client import Gauge, Histogram
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.config import CONFIG, ENV_PREFIX
|
||||
from authentik.lib.utils.reflection import path_to_class
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
# TODO: Deprecated metric - remove in 2024.2 or later
|
||||
GAUGE_TASKS = Gauge(
|
||||
@ -35,6 +34,17 @@ class AuthentikEventsConfig(ManagedAppConfig):
|
||||
verbose_name = "authentik Events"
|
||||
default = True
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.events.tasks import notification_cleanup
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=notification_cleanup,
|
||||
crontab=f"{fqdn_rand('notification_cleanup')} */8 * * *",
|
||||
),
|
||||
]
|
||||
|
||||
@ManagedAppConfig.reconcile_global
|
||||
def check_deprecations(self):
|
||||
"""Check for config deprecations"""
|
||||
@ -56,41 +66,3 @@ class AuthentikEventsConfig(ManagedAppConfig):
|
||||
replacement_env=replace_env,
|
||||
message=msg,
|
||||
).save()
|
||||
|
||||
@ManagedAppConfig.reconcile_tenant
|
||||
def prefill_tasks(self):
|
||||
"""Prefill tasks"""
|
||||
from authentik.events.models import SystemTask
|
||||
from authentik.events.system_tasks import _prefill_tasks
|
||||
|
||||
for task in _prefill_tasks:
|
||||
if SystemTask.objects.filter(name=task.name).exists():
|
||||
continue
|
||||
task.save()
|
||||
self.logger.debug("prefilled task", task_name=task.name)
|
||||
|
||||
@ManagedAppConfig.reconcile_tenant
|
||||
def run_scheduled_tasks(self):
|
||||
"""Run schedule tasks which are behind schedule (only applies
|
||||
to tasks of which we keep metrics)"""
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask as CelerySystemTask
|
||||
|
||||
for task in CELERY_APP.conf["beat_schedule"].values():
|
||||
schedule = task["schedule"]
|
||||
if not isinstance(schedule, crontab):
|
||||
continue
|
||||
task_class: CelerySystemTask = path_to_class(task["task"])
|
||||
if not isinstance(task_class, CelerySystemTask):
|
||||
continue
|
||||
db_task = task_class.db()
|
||||
if not db_task:
|
||||
continue
|
||||
due, _ = schedule.is_due(db_task.finish_timestamp)
|
||||
if due or db_task.status == TaskStatus.UNKNOWN:
|
||||
self.logger.debug("Running past-due scheduled task", task=task["task"])
|
||||
task_class.apply_async(
|
||||
args=task.get("args", None),
|
||||
kwargs=task.get("kwargs", None),
|
||||
**task.get("options", {}),
|
||||
)
|
||||
|
22
authentik/events/migrations/0011_alter_systemtask_options.py
Normal file
22
authentik/events/migrations/0011_alter_systemtask_options.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Generated by Django 5.1.11 on 2025-06-24 15:36
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="systemtask",
|
||||
options={
|
||||
"default_permissions": (),
|
||||
"permissions": (),
|
||||
"verbose_name": "System Task",
|
||||
"verbose_name_plural": "System Tasks",
|
||||
},
|
||||
),
|
||||
]
|
@ -5,12 +5,11 @@ from datetime import timedelta
|
||||
from difflib import get_close_matches
|
||||
from functools import lru_cache
|
||||
from inspect import currentframe
|
||||
from smtplib import SMTPException
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from django.apps import apps
|
||||
from django.db import connection, models
|
||||
from django.db import models
|
||||
from django.http import HttpRequest
|
||||
from django.http.request import QueryDict
|
||||
from django.utils.timezone import now
|
||||
@ -27,7 +26,6 @@ from authentik.core.middleware import (
|
||||
SESSION_KEY_IMPERSONATE_USER,
|
||||
)
|
||||
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
|
||||
from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME
|
||||
from authentik.events.context_processors.base import get_context_processors
|
||||
from authentik.events.utils import (
|
||||
cleanse_dict,
|
||||
@ -43,6 +41,7 @@ from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
from authentik.root.middleware import ClientIPMiddleware
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
from authentik.tasks.models import TasksModel
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
@ -267,7 +266,8 @@ class Event(SerializerModel, ExpiringModel):
|
||||
models.Index(fields=["created"]),
|
||||
models.Index(fields=["client_ip"]),
|
||||
models.Index(
|
||||
models.F("context__authorized_application"), name="authentik_e_ctx_app__idx"
|
||||
models.F("context__authorized_application"),
|
||||
name="authentik_e_ctx_app__idx",
|
||||
),
|
||||
]
|
||||
|
||||
@ -281,7 +281,7 @@ class TransportMode(models.TextChoices):
|
||||
EMAIL = "email", _("Email")
|
||||
|
||||
|
||||
class NotificationTransport(SerializerModel):
|
||||
class NotificationTransport(TasksModel, SerializerModel):
|
||||
"""Action which is executed when a Rule matches"""
|
||||
|
||||
uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
@ -446,6 +446,8 @@ class NotificationTransport(SerializerModel):
|
||||
|
||||
def send_email(self, notification: "Notification") -> list[str]:
|
||||
"""Send notification via global email configuration"""
|
||||
from authentik.stages.email.tasks import send_mail
|
||||
|
||||
if notification.user.email.strip() == "":
|
||||
LOGGER.info(
|
||||
"Discarding notification as user has no email address",
|
||||
@ -487,17 +489,14 @@ class NotificationTransport(SerializerModel):
|
||||
template_name="email/event_notification.html",
|
||||
template_context=context,
|
||||
)
|
||||
# Email is sent directly here, as the call to send() should have been from a task.
|
||||
try:
|
||||
from authentik.stages.email.tasks import send_mail
|
||||
|
||||
return send_mail(mail.__dict__)
|
||||
except (SMTPException, ConnectionError, OSError) as exc:
|
||||
raise NotificationTransportError(exc) from exc
|
||||
send_mail.send_with_options(args=(mail.__dict__,), rel_obj=self)
|
||||
return []
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.events.api.notification_transports import NotificationTransportSerializer
|
||||
from authentik.events.api.notification_transports import (
|
||||
NotificationTransportSerializer,
|
||||
)
|
||||
|
||||
return NotificationTransportSerializer
|
||||
|
||||
@ -547,7 +546,7 @@ class Notification(SerializerModel):
|
||||
verbose_name_plural = _("Notifications")
|
||||
|
||||
|
||||
class NotificationRule(SerializerModel, PolicyBindingModel):
|
||||
class NotificationRule(TasksModel, SerializerModel, PolicyBindingModel):
|
||||
"""Decide when to create a Notification based on policies attached to this object."""
|
||||
|
||||
name = models.TextField(unique=True)
|
||||
@ -611,7 +610,9 @@ class NotificationWebhookMapping(PropertyMapping):
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[type[Serializer]]:
|
||||
from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer
|
||||
from authentik.events.api.notification_mappings import (
|
||||
NotificationWebhookMappingSerializer,
|
||||
)
|
||||
|
||||
return NotificationWebhookMappingSerializer
|
||||
|
||||
@ -624,7 +625,7 @@ class NotificationWebhookMapping(PropertyMapping):
|
||||
|
||||
|
||||
class TaskStatus(models.TextChoices):
|
||||
"""Possible states of tasks"""
|
||||
"""DEPRECATED do not use"""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
SUCCESSFUL = "successful"
|
||||
@ -632,8 +633,8 @@ class TaskStatus(models.TextChoices):
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SystemTask(SerializerModel, ExpiringModel):
|
||||
"""Info about a system task running in the background along with details to restart the task"""
|
||||
class SystemTask(ExpiringModel):
|
||||
"""DEPRECATED do not use"""
|
||||
|
||||
uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
name = models.TextField()
|
||||
@ -653,41 +654,13 @@ class SystemTask(SerializerModel, ExpiringModel):
|
||||
task_call_args = models.JSONField(default=list)
|
||||
task_call_kwargs = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
|
||||
return SystemTaskSerializer
|
||||
|
||||
def update_metrics(self):
|
||||
"""Update prometheus metrics"""
|
||||
# TODO: Deprecated metric - remove in 2024.2 or later
|
||||
GAUGE_TASKS.labels(
|
||||
tenant=connection.schema_name,
|
||||
task_name=self.name,
|
||||
task_uid=self.uid or "",
|
||||
status=self.status.lower(),
|
||||
).set(self.duration)
|
||||
SYSTEM_TASK_TIME.labels(
|
||||
tenant=connection.schema_name,
|
||||
task_name=self.name,
|
||||
task_uid=self.uid or "",
|
||||
).observe(self.duration)
|
||||
SYSTEM_TASK_STATUS.labels(
|
||||
tenant=connection.schema_name,
|
||||
task_name=self.name,
|
||||
task_uid=self.uid or "",
|
||||
status=self.status.lower(),
|
||||
).inc()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"System Task {self.name}"
|
||||
|
||||
class Meta:
|
||||
unique_together = (("name", "uid"),)
|
||||
# Remove "add", "change" and "delete" permissions as those are not used
|
||||
default_permissions = ["view"]
|
||||
permissions = [("run_task", _("Run task"))]
|
||||
default_permissions = ()
|
||||
permissions = ()
|
||||
verbose_name = _("System Task")
|
||||
verbose_name_plural = _("System Tasks")
|
||||
indexes = ExpiringModel.Meta.indexes
|
||||
|
@ -1,13 +0,0 @@
|
||||
"""Event Settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"events_notification_cleanup": {
|
||||
"task": "authentik.events.tasks.notification_cleanup",
|
||||
"schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
@ -12,13 +12,10 @@ from rest_framework.request import Request
|
||||
|
||||
from authentik.core.models import AuthenticatedSession, User
|
||||
from authentik.core.signals import login_failed, password_changed
|
||||
from authentik.events.apps import SYSTEM_TASK_STATUS
|
||||
from authentik.events.models import Event, EventAction, SystemTask
|
||||
from authentik.events.tasks import event_notification_handler, gdpr_cleanup
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
from authentik.stages.invitation.models import Invitation
|
||||
from authentik.stages.invitation.signals import invitation_used
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
@ -114,19 +111,15 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest
|
||||
@receiver(post_save, sender=Event)
|
||||
def event_post_save_notification(sender, instance: Event, **_):
|
||||
"""Start task to check if any policies trigger an notification on this event"""
|
||||
event_notification_handler.delay(instance.event_uuid.hex)
|
||||
from authentik.events.tasks import event_trigger_dispatch
|
||||
|
||||
event_trigger_dispatch.send(instance.event_uuid)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=User)
|
||||
def event_user_pre_delete_cleanup(sender, instance: User, **_):
|
||||
"""If gdpr_compliance is enabled, remove all the user's events"""
|
||||
from authentik.events.tasks import gdpr_cleanup
|
||||
|
||||
if get_current_tenant().gdpr_compliance:
|
||||
gdpr_cleanup.delay(instance.pk)
|
||||
|
||||
|
||||
@receiver(monitoring_set)
|
||||
def monitoring_system_task(sender, **_):
|
||||
"""Update metrics when task is saved"""
|
||||
SYSTEM_TASK_STATUS.clear()
|
||||
for task in SystemTask.objects.all():
|
||||
task.update_metrics()
|
||||
gdpr_cleanup.send(instance.pk)
|
||||
|
@ -1,156 +0,0 @@
|
||||
"""Monitored tasks"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
from tenant_schemas_celery.task import TenantTask
|
||||
|
||||
from authentik.events.logs import LogEvent
|
||||
from authentik.events.models import Event, EventAction, TaskStatus
|
||||
from authentik.events.models import SystemTask as DBSystemTask
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
|
||||
class SystemTask(TenantTask):
|
||||
"""Task which can save its state to the cache"""
|
||||
|
||||
logger: BoundLogger
|
||||
|
||||
# For tasks that should only be listed if they failed, set this to False
|
||||
save_on_success: bool
|
||||
|
||||
_status: TaskStatus
|
||||
_messages: list[LogEvent]
|
||||
|
||||
_uid: str | None
|
||||
# Precise start time from perf_counter
|
||||
_start_precise: float | None = None
|
||||
_start: datetime | None = None
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._status = TaskStatus.SUCCESSFUL
|
||||
self.save_on_success = True
|
||||
self._uid = None
|
||||
self._status = None
|
||||
self._messages = []
|
||||
self.result_timeout_hours = 6
|
||||
|
||||
def set_uid(self, uid: str):
|
||||
"""Set UID, so in the case of an unexpected error its saved correctly"""
|
||||
self._uid = uid
|
||||
|
||||
def set_status(self, status: TaskStatus, *messages: LogEvent):
|
||||
"""Set result for current run, will overwrite previous result."""
|
||||
self._status = status
|
||||
self._messages = list(messages)
|
||||
for idx, msg in enumerate(self._messages):
|
||||
if not isinstance(msg, LogEvent):
|
||||
self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info")
|
||||
|
||||
def set_error(self, exception: Exception, *messages: LogEvent):
|
||||
"""Set result to error and save exception"""
|
||||
self._status = TaskStatus.ERROR
|
||||
self._messages = list(messages)
|
||||
self._messages.extend(
|
||||
[LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error")]
|
||||
)
|
||||
|
||||
def before_start(self, task_id, args, kwargs):
|
||||
self._start_precise = perf_counter()
|
||||
self._start = now()
|
||||
self.logger = get_logger().bind(task_id=task_id)
|
||||
return super().before_start(task_id, args, kwargs)
|
||||
|
||||
def db(self) -> DBSystemTask | None:
|
||||
"""Get DB object for latest task"""
|
||||
return DBSystemTask.objects.filter(
|
||||
name=self.__name__,
|
||||
uid=self._uid,
|
||||
).first()
|
||||
|
||||
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
|
||||
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
|
||||
if not self._status:
|
||||
return
|
||||
if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success:
|
||||
DBSystemTask.objects.filter(
|
||||
name=self.__name__,
|
||||
uid=self._uid,
|
||||
).delete()
|
||||
return
|
||||
DBSystemTask.objects.update_or_create(
|
||||
name=self.__name__,
|
||||
uid=self._uid,
|
||||
defaults={
|
||||
"description": self.__doc__,
|
||||
"start_timestamp": self._start or now(),
|
||||
"finish_timestamp": now(),
|
||||
"duration": max(perf_counter() - self._start_precise, 0),
|
||||
"task_call_module": self.__module__,
|
||||
"task_call_func": self.__name__,
|
||||
"task_call_args": sanitize_item(args),
|
||||
"task_call_kwargs": sanitize_item(kwargs),
|
||||
"status": self._status,
|
||||
"messages": sanitize_item(self._messages),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours),
|
||||
"expiring": True,
|
||||
},
|
||||
)
|
||||
|
||||
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
||||
super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
|
||||
if not self._status:
|
||||
self.set_error(exc)
|
||||
DBSystemTask.objects.update_or_create(
|
||||
name=self.__name__,
|
||||
uid=self._uid,
|
||||
defaults={
|
||||
"description": self.__doc__,
|
||||
"start_timestamp": self._start or now(),
|
||||
"finish_timestamp": now(),
|
||||
"duration": max(perf_counter() - self._start_precise, 0),
|
||||
"task_call_module": self.__module__,
|
||||
"task_call_func": self.__name__,
|
||||
"task_call_args": sanitize_item(args),
|
||||
"task_call_kwargs": sanitize_item(kwargs),
|
||||
"status": self._status,
|
||||
"messages": sanitize_item(self._messages),
|
||||
"expires": now() + timedelta(hours=self.result_timeout_hours + 3),
|
||||
"expiring": True,
|
||||
},
|
||||
)
|
||||
Event.new(
|
||||
EventAction.SYSTEM_TASK_EXCEPTION,
|
||||
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
|
||||
).save()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def prefill_task(func):
|
||||
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
|
||||
_prefill_tasks.append(
|
||||
DBSystemTask(
|
||||
name=func.__name__,
|
||||
description=func.__doc__,
|
||||
start_timestamp=now(),
|
||||
finish_timestamp=now(),
|
||||
status=TaskStatus.UNKNOWN,
|
||||
messages=sanitize_item([_("Task has not been run yet.")]),
|
||||
task_call_module=func.__module__,
|
||||
task_call_func=func.__name__,
|
||||
expiring=False,
|
||||
duration=0,
|
||||
)
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
_prefill_tasks = []
|
@ -1,41 +1,49 @@
|
||||
"""Event notification tasks"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from django.db.models.query_utils import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import actor
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import (
|
||||
Event,
|
||||
Notification,
|
||||
NotificationRule,
|
||||
NotificationTransport,
|
||||
NotificationTransportError,
|
||||
TaskStatus,
|
||||
)
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBinding, PolicyEngineMode
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def event_notification_handler(event_uuid: str):
|
||||
"""Start task for each trigger definition"""
|
||||
@actor(description=_("Dispatch new event notifications."))
|
||||
def event_trigger_dispatch(event_uuid: UUID):
|
||||
for trigger in NotificationRule.objects.all():
|
||||
event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events")
|
||||
event_trigger_handler.send_with_options(args=(event_uuid, trigger.name), rel_obj=trigger)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||
@actor(
|
||||
description=_(
|
||||
"Check if policies attached to NotificationRule match event "
|
||||
"and dispatch notification tasks."
|
||||
)
|
||||
)
|
||||
def event_trigger_handler(event_uuid: UUID, trigger_name: str):
|
||||
"""Check if policies attached to NotificationRule match event"""
|
||||
self: Task = CurrentTask.get_task()
|
||||
|
||||
event: Event = Event.objects.filter(event_uuid=event_uuid).first()
|
||||
if not event:
|
||||
LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid)
|
||||
self.warning("event doesn't exist yet or anymore", event_uuid=event_uuid)
|
||||
return
|
||||
|
||||
trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first()
|
||||
if not trigger:
|
||||
return
|
||||
@ -70,57 +78,46 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||
|
||||
LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
|
||||
# Create the notification objects
|
||||
count = 0
|
||||
for transport in trigger.transports.all():
|
||||
for user in trigger.destination_users(event):
|
||||
LOGGER.debug("created notification")
|
||||
notification_transport.apply_async(
|
||||
args=[
|
||||
notification_transport.send_with_options(
|
||||
args=(
|
||||
transport.pk,
|
||||
str(event.pk),
|
||||
event.pk,
|
||||
user.pk,
|
||||
str(trigger.pk),
|
||||
],
|
||||
queue="authentik_events",
|
||||
trigger.pk,
|
||||
),
|
||||
rel_obj=transport,
|
||||
)
|
||||
count += 1
|
||||
if transport.send_once:
|
||||
break
|
||||
self.info(f"Created {count} notification tasks")
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
autoretry_for=(NotificationTransportError,),
|
||||
retry_backoff=True,
|
||||
base=SystemTask,
|
||||
)
|
||||
def notification_transport(
|
||||
self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
|
||||
):
|
||||
@actor(description=_("Send notification."))
|
||||
def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str):
|
||||
"""Send notification over specified transport"""
|
||||
self.save_on_success = False
|
||||
try:
|
||||
event = Event.objects.filter(pk=event_pk).first()
|
||||
if not event:
|
||||
return
|
||||
user = User.objects.filter(pk=user_pk).first()
|
||||
if not user:
|
||||
return
|
||||
trigger = NotificationRule.objects.filter(pk=trigger_pk).first()
|
||||
if not trigger:
|
||||
return
|
||||
notification = Notification(
|
||||
severity=trigger.severity, body=event.summary, event=event, user=user
|
||||
)
|
||||
transport = NotificationTransport.objects.filter(pk=transport_pk).first()
|
||||
if not transport:
|
||||
return
|
||||
transport.send(notification)
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
except (NotificationTransportError, PropertyMappingExpressionException) as exc:
|
||||
self.set_error(exc)
|
||||
raise exc
|
||||
event = Event.objects.filter(pk=event_pk).first()
|
||||
if not event:
|
||||
return
|
||||
user = User.objects.filter(pk=user_pk).first()
|
||||
if not user:
|
||||
return
|
||||
trigger = NotificationRule.objects.filter(pk=trigger_pk).first()
|
||||
if not trigger:
|
||||
return
|
||||
notification = Notification(
|
||||
severity=trigger.severity, body=event.summary, event=event, user=user
|
||||
)
|
||||
transport: NotificationTransport = NotificationTransport.objects.filter(pk=transport_pk).first()
|
||||
if not transport:
|
||||
return
|
||||
transport.send(notification)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
@actor(description=_("Cleanup events for GDPR compliance."))
|
||||
def gdpr_cleanup(user_pk: int):
|
||||
"""cleanup events from gdpr_compliance"""
|
||||
events = Event.objects.filter(user__pk=user_pk)
|
||||
@ -128,12 +125,12 @@ def gdpr_cleanup(user_pk: int):
|
||||
events.delete()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def notification_cleanup(self: SystemTask):
|
||||
@actor(description=_("Cleanup seen notifications and notifications whose event expired."))
|
||||
def notification_cleanup():
|
||||
"""Cleanup seen notifications and notifications whose event expired."""
|
||||
self: Task = CurrentTask.get_task()
|
||||
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
|
||||
amount = notifications.count()
|
||||
notifications.delete()
|
||||
LOGGER.debug("Expired notifications", amount=amount)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications")
|
||||
self.info(f"Expired {amount} Notifications")
|
||||
|
@ -1,103 +0,0 @@
|
||||
"""Test Monitored tasks"""
|
||||
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tasks import clean_expired_models
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.events.models import SystemTask as DBSystemTask
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
||||
class TestSystemTasks(APITestCase):
|
||||
"""Test Monitored tasks"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_failed_successful_remove_state(self):
|
||||
"""Test that a task with `save_on_success` set to `False` that failed saves
|
||||
a state, and upon successful completion will delete the state"""
|
||||
should_fail = True
|
||||
uid = generate_id()
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=SystemTask,
|
||||
)
|
||||
def test_task(self: SystemTask):
|
||||
self.save_on_success = False
|
||||
self.set_uid(uid)
|
||||
self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL)
|
||||
|
||||
# First test successful run
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||
|
||||
# Then test failed
|
||||
should_fail = True
|
||||
test_task.delay().get()
|
||||
task = DBSystemTask.objects.filter(name="test_task", uid=uid).first()
|
||||
self.assertEqual(task.status, TaskStatus.ERROR)
|
||||
|
||||
# Then after that, the state should be removed
|
||||
should_fail = False
|
||||
test_task.delay().get()
|
||||
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
|
||||
|
||||
def test_tasks(self):
|
||||
"""Test Task API"""
|
||||
clean_expired_models.delay().get()
|
||||
response = self.client.get(reverse("authentik_api:systemtask-list"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
|
||||
|
||||
def test_tasks_single(self):
|
||||
"""Test Task API (read single)"""
|
||||
clean_expired_models.delay().get()
|
||||
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_api:systemtask-detail",
|
||||
kwargs={"pk": str(task.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
|
||||
self.assertEqual(body["name"], "clean_expired_models")
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_tasks_run(self):
|
||||
"""Test Task API (run)"""
|
||||
clean_expired_models.delay().get()
|
||||
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:systemtask-run",
|
||||
kwargs={"pk": str(task.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 204)
|
||||
|
||||
def test_tasks_run_404(self):
|
||||
"""Test Task API (run, 404)"""
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_api:systemtask-run",
|
||||
kwargs={"pk": "qwerqewrqrqewrqewr"},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
@ -5,13 +5,11 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin
|
||||
from authentik.events.api.notification_rules import NotificationRuleViewSet
|
||||
from authentik.events.api.notification_transports import NotificationTransportViewSet
|
||||
from authentik.events.api.notifications import NotificationViewSet
|
||||
from authentik.events.api.tasks import SystemTaskViewSet
|
||||
|
||||
api_urlpatterns = [
|
||||
("events/events", EventViewSet),
|
||||
("events/notifications", NotificationViewSet),
|
||||
("events/transports", NotificationTransportViewSet),
|
||||
("events/rules", NotificationRuleViewSet),
|
||||
("events/system_tasks", SystemTaskViewSet),
|
||||
("propertymappings/notification", NotificationWebhookMappingViewSet),
|
||||
]
|
||||
|
@ -41,6 +41,7 @@ REDIS_ENV_KEYS = [
|
||||
# Old key -> new key
|
||||
DEPRECATIONS = {
|
||||
"geoip": "events.context_processors.geoip",
|
||||
"worker.concurrency": "worker.processes",
|
||||
"redis.broker_url": "broker.url",
|
||||
"redis.broker_transport_options": "broker.transport_options",
|
||||
"redis.cache_timeout": "cache.timeout",
|
||||
|
@ -21,6 +21,10 @@ def start_debug_server(**kwargs) -> bool:
|
||||
|
||||
listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901")
|
||||
host, _, port = listen.rpartition(":")
|
||||
debugpy.listen((host, int(port)), **kwargs) # nosec
|
||||
try:
|
||||
debugpy.listen((host, int(port)), **kwargs) # nosec
|
||||
except RuntimeError:
|
||||
LOGGER.warning("Could not start debug server. Continuing without")
|
||||
return False
|
||||
LOGGER.debug("Starting debug server", host=host, port=port)
|
||||
return True
|
||||
|
@ -157,7 +157,14 @@ web:
|
||||
path: /
|
||||
|
||||
worker:
|
||||
concurrency: 2
|
||||
processes: 2
|
||||
threads: 1
|
||||
consumer_listen_timeout: "seconds=30"
|
||||
task_max_retries: 20
|
||||
task_default_time_limit: "minutes=10"
|
||||
task_purge_interval: "days=1"
|
||||
task_expiration: "days=30"
|
||||
scheduler_interval: "seconds=60"
|
||||
|
||||
storage:
|
||||
media:
|
||||
|
@ -88,7 +88,6 @@ def get_logger_config():
|
||||
"authentik": global_level,
|
||||
"django": "WARNING",
|
||||
"django.request": "ERROR",
|
||||
"celery": "WARNING",
|
||||
"selenium": "WARNING",
|
||||
"docker": "WARNING",
|
||||
"urllib3": "WARNING",
|
||||
|
@ -3,8 +3,6 @@
|
||||
from asyncio.exceptions import CancelledError
|
||||
from typing import Any
|
||||
|
||||
from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError
|
||||
from celery.exceptions import CeleryError
|
||||
from channels_redis.core import ChannelFull
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
|
||||
@ -22,7 +20,6 @@ from sentry_sdk import HttpTransport, get_current_scope
|
||||
from sentry_sdk import init as sentry_sdk_init
|
||||
from sentry_sdk.api import set_tag
|
||||
from sentry_sdk.integrations.argv import ArgvIntegration
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.django import DjangoIntegration
|
||||
from sentry_sdk.integrations.redis import RedisIntegration
|
||||
from sentry_sdk.integrations.socket import SocketIntegration
|
||||
@ -71,10 +68,6 @@ ignored_classes = (
|
||||
LocalProtocolError,
|
||||
# rest_framework error
|
||||
APIException,
|
||||
# celery errors
|
||||
WorkerLostError,
|
||||
CeleryError,
|
||||
SoftTimeLimitExceeded,
|
||||
# custom baseclass
|
||||
SentryIgnoredException,
|
||||
# ldap errors
|
||||
@ -115,7 +108,6 @@ def sentry_init(**sentry_init_kwargs):
|
||||
ArgvIntegration(),
|
||||
StdlibIntegration(),
|
||||
DjangoIntegration(transaction_style="function_name", cache_spans=True),
|
||||
CeleryIntegration(),
|
||||
RedisIntegration(),
|
||||
ThreadingIntegration(propagate_hub=True),
|
||||
SocketIntegration(),
|
||||
@ -160,14 +152,11 @@ def before_send(event: dict, hint: dict) -> dict | None:
|
||||
return None
|
||||
if "logger" in event:
|
||||
if event["logger"] in [
|
||||
"kombu",
|
||||
"asyncio",
|
||||
"multiprocessing",
|
||||
"django_redis",
|
||||
"django.security.DisallowedHost",
|
||||
"django_redis.cache",
|
||||
"celery.backends.redis",
|
||||
"celery.worker",
|
||||
"paramiko.transport",
|
||||
]:
|
||||
return None
|
||||
|
12
authentik/lib/sync/api.py
Normal file
12
authentik/lib/sync/api.py
Normal file
@ -0,0 +1,12 @@
|
||||
from rest_framework.fields import BooleanField, ChoiceField, DateTimeField
|
||||
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.tasks.models import TaskStatus
|
||||
|
||||
|
||||
class SyncStatusSerializer(PassiveSerializer):
|
||||
"""Provider/source sync status"""
|
||||
|
||||
is_running = BooleanField()
|
||||
last_successful_sync = DateTimeField(required=False)
|
||||
last_sync_status = ChoiceField(required=False, choices=TaskStatus.choices)
|
@ -1,7 +1,7 @@
|
||||
"""Sync constants"""
|
||||
|
||||
PAGE_SIZE = 100
|
||||
PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
|
||||
PAGE_TIMEOUT_MS = 60 * 60 * 0.5 * 1000 # Half an hour
|
||||
HTTP_CONFLICT = 409
|
||||
HTTP_NO_CONTENT = 204
|
||||
HTTP_SERVICE_UNAVAILABLE = 503
|
||||
|
@ -1,7 +1,5 @@
|
||||
from celery import Task
|
||||
from django.utils.text import slugify
|
||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from dramatiq.actor import Actor
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, ChoiceField
|
||||
from rest_framework.request import Request
|
||||
@ -9,18 +7,12 @@ from rest_framework.response import Response
|
||||
|
||||
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.api.tasks import SystemTaskSerializer
|
||||
from authentik.events.logs import LogEvent, LogEventSerializer
|
||||
from authentik.events.logs import LogEventSerializer
|
||||
from authentik.lib.sync.api import SyncStatusSerializer
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
|
||||
class SyncStatusSerializer(PassiveSerializer):
|
||||
"""Provider sync status"""
|
||||
|
||||
is_running = BooleanField(read_only=True)
|
||||
tasks = SystemTaskSerializer(many=True, read_only=True)
|
||||
from authentik.tasks.models import Task, TaskStatus
|
||||
|
||||
|
||||
class SyncObjectSerializer(PassiveSerializer):
|
||||
@ -45,15 +37,10 @@ class SyncObjectResultSerializer(PassiveSerializer):
|
||||
class OutgoingSyncProviderStatusMixin:
|
||||
"""Common API Endpoints for Outgoing sync providers"""
|
||||
|
||||
sync_single_task: type[Task] = None
|
||||
sync_objects_task: type[Task] = None
|
||||
sync_task: Actor
|
||||
sync_objects_task: Actor
|
||||
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: SyncStatusSerializer(),
|
||||
404: OpenApiResponse(description="Task not found"),
|
||||
}
|
||||
)
|
||||
@extend_schema(responses={200: SyncStatusSerializer()})
|
||||
@action(
|
||||
methods=["GET"],
|
||||
detail=True,
|
||||
@ -64,18 +51,39 @@ class OutgoingSyncProviderStatusMixin:
|
||||
def sync_status(self, request: Request, pk: int) -> Response:
|
||||
"""Get provider's sync status"""
|
||||
provider: OutgoingSyncProvider = self.get_object()
|
||||
tasks = list(
|
||||
get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
|
||||
name=self.sync_single_task.__name__,
|
||||
uid=slugify(provider.name),
|
||||
)
|
||||
)
|
||||
|
||||
status = {}
|
||||
|
||||
with provider.sync_lock as lock_acquired:
|
||||
status = {
|
||||
"tasks": tasks,
|
||||
# If we could not acquire the lock, it means a task is using it, and thus is running
|
||||
"is_running": not lock_acquired,
|
||||
}
|
||||
# If we could not acquire the lock, it means a task is using it, and thus is running
|
||||
status["is_running"] = not lock_acquired
|
||||
|
||||
sync_schedule = None
|
||||
for schedule in provider.schedules.all():
|
||||
if schedule.actor_name == self.sync_task.actor_name:
|
||||
sync_schedule = schedule
|
||||
|
||||
if not sync_schedule:
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
|
||||
last_task: Task = (
|
||||
sync_schedule.tasks.exclude(
|
||||
aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED)
|
||||
)
|
||||
.order_by("-mtime")
|
||||
.first()
|
||||
)
|
||||
last_successful_task: Task = (
|
||||
sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO))
|
||||
.order_by("-mtime")
|
||||
.first()
|
||||
)
|
||||
|
||||
if last_task:
|
||||
status["last_sync_status"] = last_task.aggregated_status
|
||||
if last_successful_task:
|
||||
status["last_successful_sync"] = last_successful_task.mtime
|
||||
|
||||
return Response(SyncStatusSerializer(status).data)
|
||||
|
||||
@extend_schema(
|
||||
@ -94,14 +102,20 @@ class OutgoingSyncProviderStatusMixin:
|
||||
provider: OutgoingSyncProvider = self.get_object()
|
||||
params = SyncObjectSerializer(data=request.data)
|
||||
params.is_valid(raise_exception=True)
|
||||
res: list[LogEvent] = self.sync_objects_task.delay(
|
||||
params.validated_data["sync_object_model"],
|
||||
page=1,
|
||||
provider_pk=provider.pk,
|
||||
pk=params.validated_data["sync_object_id"],
|
||||
override_dry_run=params.validated_data["override_dry_run"],
|
||||
).get()
|
||||
return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
|
||||
msg = self.sync_objects_task.send_with_options(
|
||||
kwargs={
|
||||
"object_type": params.validated_data["sync_object_model"],
|
||||
"page": 1,
|
||||
"provider_pk": provider.pk,
|
||||
"override_dry_run": params.validated_data["override_dry_run"],
|
||||
"pk": params.validated_data["sync_object_id"],
|
||||
},
|
||||
rel_obj=provider,
|
||||
)
|
||||
msg.get_result(block=True)
|
||||
task: Task = msg.options["task"]
|
||||
task.refresh_from_db()
|
||||
return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data)
|
||||
|
||||
|
||||
class OutgoingSyncConnectionCreateMixin:
|
||||
|
@ -1,12 +1,18 @@
|
||||
from typing import Any, Self
|
||||
|
||||
import pglock
|
||||
from django.core.paginator import Paginator
|
||||
from django.db import connection, models
|
||||
from django.db.models import Model, QuerySet, TextChoices
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import Actor
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS
|
||||
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
from authentik.tasks.schedules.models import ScheduledModel
|
||||
|
||||
|
||||
class OutgoingSyncDeleteAction(TextChoices):
|
||||
@ -18,7 +24,7 @@ class OutgoingSyncDeleteAction(TextChoices):
|
||||
SUSPEND = "suspend"
|
||||
|
||||
|
||||
class OutgoingSyncProvider(Model):
|
||||
class OutgoingSyncProvider(ScheduledModel, Model):
|
||||
"""Base abstract models for providers implementing outgoing sync"""
|
||||
|
||||
dry_run = models.BooleanField(
|
||||
@ -39,6 +45,19 @@ class OutgoingSyncProvider(Model):
|
||||
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
|
||||
return Paginator(self.get_object_qs(type), PAGE_SIZE)
|
||||
|
||||
def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int:
|
||||
num_pages: int = self.get_paginator(type).num_pages
|
||||
return int(num_pages * PAGE_TIMEOUT_MS * 1.5)
|
||||
|
||||
def get_sync_time_limit_ms(self) -> int:
|
||||
return int(
|
||||
(self.get_object_sync_time_limit_ms(User) + self.get_object_sync_time_limit_ms(Group))
|
||||
* 1.5
|
||||
)
|
||||
|
||||
@property
|
||||
def sync_lock(self) -> pglock.advisory:
|
||||
"""Postgres lock for syncing to prevent multiple parallel syncs happening"""
|
||||
@ -47,3 +66,22 @@ class OutgoingSyncProvider(Model):
|
||||
timeout=0,
|
||||
side_effect=pglock.Return,
|
||||
)
|
||||
|
||||
@property
|
||||
def sync_actor(self) -> Actor:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def schedule_specs(self) -> list[ScheduleSpec]:
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=self.sync_actor,
|
||||
uid=self.pk,
|
||||
args=(self.pk,),
|
||||
options={
|
||||
"time_limit": self.get_sync_time_limit_ms(),
|
||||
},
|
||||
send_on_save=True,
|
||||
crontab=f"{fqdn_rand(self.pk)} */4 * * *",
|
||||
),
|
||||
]
|
||||
|
@ -1,12 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models import Model
|
||||
from django.db.models.query import Q
|
||||
from django.db.models.signals import m2m_changed, post_save, pre_delete
|
||||
from dramatiq.actor import Actor
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
@ -14,45 +10,30 @@ from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
def register_signals(
|
||||
provider_type: type[OutgoingSyncProvider],
|
||||
task_sync_single: Callable[[int], None],
|
||||
task_sync_direct: Callable[[int], None],
|
||||
task_sync_m2m: Callable[[int], None],
|
||||
task_sync_direct_dispatch: Actor[[str, str | int, str], None],
|
||||
task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None],
|
||||
):
|
||||
"""Register sync signals"""
|
||||
uid = class_to_path(provider_type)
|
||||
|
||||
def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_):
|
||||
"""Trigger sync when Provider is saved"""
|
||||
users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE)
|
||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
||||
time_limit = soft_time_limit * 1.5
|
||||
task_sync_single.apply_async(
|
||||
(instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
||||
)
|
||||
|
||||
post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
|
||||
"""Post save handler"""
|
||||
if not provider_type.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
).exists():
|
||||
return
|
||||
task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)
|
||||
task_sync_direct_dispatch.send(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
Direction.add.value,
|
||||
)
|
||||
|
||||
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
|
||||
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
|
||||
"""Pre-delete handler"""
|
||||
if not provider_type.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
).exists():
|
||||
return
|
||||
task_sync_direct.delay(
|
||||
class_to_path(instance.__class__), instance.pk, Direction.remove.value
|
||||
).get(propagate=False)
|
||||
task_sync_direct_dispatch.send(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
Direction.remove.value,
|
||||
)
|
||||
|
||||
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
|
||||
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
|
||||
@ -63,16 +44,6 @@ def register_signals(
|
||||
"""Sync group membership"""
|
||||
if action not in ["post_add", "post_remove"]:
|
||||
return
|
||||
if not provider_type.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
).exists():
|
||||
return
|
||||
# reverse: instance is a Group, pk_set is a list of user pks
|
||||
# non-reverse: instance is a User, pk_set is a list of groups
|
||||
if reverse:
|
||||
task_sync_m2m.delay(str(instance.pk), action, list(pk_set))
|
||||
else:
|
||||
for group_pk in pk_set:
|
||||
task_sync_m2m.delay(group_pk, action, [instance.pk])
|
||||
task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse)
|
||||
|
||||
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)
|
||||
|
@ -1,23 +1,17 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
|
||||
from celery import group
|
||||
from celery.exceptions import Retry
|
||||
from celery.result import allow_join_result
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.db.models.query import Q
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.composition import group
|
||||
from dramatiq.errors import Retry
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.events.logs import LogEvent
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
|
||||
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS
|
||||
from authentik.lib.sync.outgoing.base import Direction
|
||||
from authentik.lib.sync.outgoing.exceptions import (
|
||||
BadRequestSyncException,
|
||||
@ -27,11 +21,12 @@ from authentik.lib.sync.outgoing.exceptions import (
|
||||
)
|
||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
class SyncTasks:
|
||||
"""Container for all sync 'tasks' (this class doesn't actually contain celery
|
||||
tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
|
||||
"""Container for all sync 'tasks' (this class doesn't actually contain
|
||||
tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)"""
|
||||
|
||||
logger: BoundLogger
|
||||
|
||||
@ -39,107 +34,104 @@ class SyncTasks:
|
||||
super().__init__()
|
||||
self._provider_model = provider_model
|
||||
|
||||
def sync_all(self, single_sync: Callable[[int], None]):
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
self.trigger_single_task(provider, single_sync)
|
||||
|
||||
def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
|
||||
"""Wrapper single sync task that correctly sets time limits based
|
||||
on the amount of objects that will be synced"""
|
||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
|
||||
time_limit = soft_time_limit * 1.5
|
||||
return sync_task.apply_async(
|
||||
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
|
||||
)
|
||||
|
||||
def sync_single(
|
||||
def sync_paginator(
|
||||
self,
|
||||
task: SystemTask,
|
||||
provider_pk: int,
|
||||
sync_objects: Callable[[int, int], list[str]],
|
||||
current_task: Task,
|
||||
provider: OutgoingSyncProvider,
|
||||
sync_objects: Actor[[str, int, int, bool], None],
|
||||
paginator: Paginator,
|
||||
object_type: type[User | Group],
|
||||
**options,
|
||||
):
|
||||
tasks = []
|
||||
for page in paginator.page_range:
|
||||
page_sync = sync_objects.message_with_options(
|
||||
args=(class_to_path(object_type), page, provider.pk),
|
||||
time_limit=PAGE_TIMEOUT_MS,
|
||||
# Assign tasks to the same schedule as the current one
|
||||
rel_obj=current_task.rel_obj,
|
||||
**options,
|
||||
)
|
||||
tasks.append(page_sync)
|
||||
return tasks
|
||||
|
||||
def sync(
|
||||
self,
|
||||
provider_pk: int,
|
||||
sync_objects: Actor[[str, int, int, bool], None],
|
||||
):
|
||||
task: Task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
)
|
||||
provider = self._provider_model.objects.filter(
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
task.warning("No provider found. Is it assigned to an application?")
|
||||
return
|
||||
task.set_uid(slugify(provider.name))
|
||||
messages = []
|
||||
messages.append(_("Starting full provider sync"))
|
||||
task.info("Starting full provider sync")
|
||||
self.logger.debug("Starting provider sync")
|
||||
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
|
||||
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
|
||||
with allow_join_result(), provider.sync_lock as lock_acquired:
|
||||
with provider.sync_lock as lock_acquired:
|
||||
if not lock_acquired:
|
||||
task.info("Synchronization is already running. Skipping.")
|
||||
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
|
||||
return
|
||||
try:
|
||||
messages.append(_("Syncing users"))
|
||||
user_results = (
|
||||
group(
|
||||
[
|
||||
sync_objects.signature(
|
||||
args=(class_to_path(User), page, provider_pk),
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
)
|
||||
for page in users_paginator.page_range
|
||||
]
|
||||
users_tasks = group(
|
||||
self.sync_paginator(
|
||||
current_task=task,
|
||||
provider=provider,
|
||||
sync_objects=sync_objects,
|
||||
paginator=provider.get_paginator(User),
|
||||
object_type=User,
|
||||
)
|
||||
.apply_async()
|
||||
.get()
|
||||
)
|
||||
for result in user_results:
|
||||
for msg in result:
|
||||
messages.append(LogEvent(**msg))
|
||||
messages.append(_("Syncing groups"))
|
||||
group_results = (
|
||||
group(
|
||||
[
|
||||
sync_objects.signature(
|
||||
args=(class_to_path(Group), page, provider_pk),
|
||||
time_limit=PAGE_TIMEOUT,
|
||||
soft_time_limit=PAGE_TIMEOUT,
|
||||
)
|
||||
for page in groups_paginator.page_range
|
||||
]
|
||||
group_tasks = group(
|
||||
self.sync_paginator(
|
||||
current_task=task,
|
||||
provider=provider,
|
||||
sync_objects=sync_objects,
|
||||
paginator=provider.get_paginator(Group),
|
||||
object_type=Group,
|
||||
)
|
||||
.apply_async()
|
||||
.get()
|
||||
)
|
||||
for result in group_results:
|
||||
for msg in result:
|
||||
messages.append(LogEvent(**msg))
|
||||
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User))
|
||||
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("transient sync exception", exc=exc)
|
||||
raise task.retry(exc=exc) from exc
|
||||
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
|
||||
raise Retry() from exc
|
||||
except StopSync as exc:
|
||||
task.set_error(exc)
|
||||
task.error(exc)
|
||||
return
|
||||
task.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
def sync_objects(
|
||||
self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter
|
||||
self,
|
||||
object_type: str,
|
||||
page: int,
|
||||
provider_pk: int,
|
||||
override_dry_run=False,
|
||||
**filter,
|
||||
):
|
||||
task: Task = CurrentTask.get_task()
|
||||
_object_type: type[Model] = path_to_class(object_type)
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
provider_pk=provider_pk,
|
||||
object_type=object_type,
|
||||
)
|
||||
messages = []
|
||||
provider = self._provider_model.objects.filter(pk=provider_pk).first()
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
return messages
|
||||
task.warning("No provider found. Is it assigned to an application?")
|
||||
return
|
||||
task.set_uid(slugify(provider.name))
|
||||
# Override dry run mode if requested, however don't save the provider
|
||||
# so that scheduled sync tasks still run in dry_run mode
|
||||
if override_dry_run:
|
||||
@ -147,25 +139,13 @@ class SyncTasks:
|
||||
try:
|
||||
client = provider.client_for_model(_object_type)
|
||||
except TransientSyncException:
|
||||
return messages
|
||||
return
|
||||
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
|
||||
if client.can_discover:
|
||||
self.logger.debug("starting discover")
|
||||
client.discover()
|
||||
self.logger.debug("starting sync for page", page=page)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
"Syncing page {page} of {object_type}".format(
|
||||
page=page, object_type=_object_type._meta.verbose_name_plural
|
||||
)
|
||||
),
|
||||
log_level="info",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
)
|
||||
)
|
||||
)
|
||||
task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}")
|
||||
for obj in paginator.page(page).object_list:
|
||||
obj: Model
|
||||
try:
|
||||
@ -174,89 +154,58 @@ class SyncTasks:
|
||||
self.logger.debug("skipping object due to SkipObject", obj=obj)
|
||||
continue
|
||||
except DryRunRejected as exc:
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_("Dropping mutating request due to dry run"),
|
||||
log_level="info",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={
|
||||
"obj": sanitize_item(obj),
|
||||
"method": exc.method,
|
||||
"url": exc.url,
|
||||
"body": exc.body,
|
||||
},
|
||||
)
|
||||
)
|
||||
task.info(
|
||||
"Dropping mutating request due to dry run",
|
||||
obj=sanitize_item(obj),
|
||||
method=exc.method,
|
||||
url=exc.url,
|
||||
body=exc.body,
|
||||
)
|
||||
except BadRequestSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, obj=obj)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
|
||||
)
|
||||
)
|
||||
task.warning(
|
||||
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
|
||||
arguments=exc.args[1:],
|
||||
obj=sanitize_item(obj),
|
||||
)
|
||||
except TransientSyncException as exc:
|
||||
self.logger.warning("failed to sync object", exc=exc, user=obj)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
(
|
||||
"Failed to sync {object_type} {object_name} "
|
||||
"due to transient error: {error}"
|
||||
).format_map(
|
||||
{
|
||||
"object_type": obj._meta.verbose_name,
|
||||
"object_name": str(obj),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={"obj": sanitize_item(obj)},
|
||||
)
|
||||
)
|
||||
task.warning(
|
||||
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to "
|
||||
"transient error: {str(exc)}",
|
||||
obj=sanitize_item(obj),
|
||||
)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc)
|
||||
messages.append(
|
||||
asdict(
|
||||
LogEvent(
|
||||
_(
|
||||
"Stopping sync due to error: {error}".format_map(
|
||||
{
|
||||
"error": exc.detail(),
|
||||
}
|
||||
)
|
||||
),
|
||||
log_level="warning",
|
||||
logger=f"{provider._meta.verbose_name}@{object_type}",
|
||||
attributes={"obj": sanitize_item(obj)},
|
||||
)
|
||||
)
|
||||
task.warning(
|
||||
f"Stopping sync due to error: {exc.detail()}",
|
||||
obj=sanitize_item(obj),
|
||||
)
|
||||
break
|
||||
return messages
|
||||
|
||||
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
|
||||
def sync_signal_direct_dispatch(
|
||||
self,
|
||||
task_sync_signal_direct: Actor[[str, str | int, int, str], None],
|
||||
model: str,
|
||||
pk: str | int,
|
||||
raw_op: str,
|
||||
):
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
task_sync_signal_direct.send_with_options(
|
||||
args=(model, pk, provider.pk, raw_op),
|
||||
rel_obj=provider,
|
||||
)
|
||||
|
||||
def sync_signal_direct(
|
||||
self,
|
||||
model: str,
|
||||
pk: str | int,
|
||||
provider_pk: int,
|
||||
raw_op: str,
|
||||
):
|
||||
task: Task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
@ -264,65 +213,108 @@ class SyncTasks:
|
||||
instance = model_class.objects.filter(pk=pk).first()
|
||||
if not instance:
|
||||
return
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
task.warning("No provider found. Is it assigned to an application?")
|
||||
return
|
||||
task.set_uid(slugify(provider.name))
|
||||
operation = Direction(raw_op)
|
||||
client = provider.client_for_model(instance.__class__)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset = provider.get_object_qs(instance.__class__)
|
||||
if not queryset:
|
||||
return
|
||||
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=instance.pk).exists():
|
||||
return
|
||||
|
||||
try:
|
||||
if operation == Direction.add:
|
||||
client.write(instance)
|
||||
if operation == Direction.remove:
|
||||
client.delete(instance)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
return
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
|
||||
def sync_signal_m2m_dispatch(
|
||||
self,
|
||||
task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
|
||||
instance_pk: str,
|
||||
action: str,
|
||||
pk_set: list[int],
|
||||
reverse: bool,
|
||||
):
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
client = provider.client_for_model(instance.__class__)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset = provider.get_object_qs(instance.__class__)
|
||||
if not queryset:
|
||||
continue
|
||||
# reverse: instance is a Group, pk_set is a list of user pks
|
||||
# non-reverse: instance is a User, pk_set is a list of groups
|
||||
if reverse:
|
||||
task_sync_signal_m2m.send_with_options(
|
||||
args=(instance_pk, provider.pk, action, list(pk_set)),
|
||||
rel_obj=provider,
|
||||
)
|
||||
else:
|
||||
for pk in pk_set:
|
||||
task_sync_signal_m2m.send_with_options(
|
||||
args=(pk, provider.pk, action, [instance_pk]),
|
||||
rel_obj=provider,
|
||||
)
|
||||
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=instance.pk).exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
if operation == Direction.add:
|
||||
client.write(instance)
|
||||
if operation == Direction.remove:
|
||||
client.delete(instance)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
continue
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
|
||||
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
|
||||
def sync_signal_m2m(
|
||||
self,
|
||||
group_pk: str,
|
||||
provider_pk: int,
|
||||
action: str,
|
||||
pk_set: list[int],
|
||||
):
|
||||
task: Task = CurrentTask.get_task()
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
group = Group.objects.filter(pk=group_pk).first()
|
||||
if not group:
|
||||
return
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet = provider.get_object_qs(Group)
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=group_pk).exists():
|
||||
continue
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
task.warning("No provider found. Is it assigned to an application?")
|
||||
return
|
||||
task.set_uid(slugify(provider.name))
|
||||
|
||||
client = provider.client_for_model(Group)
|
||||
try:
|
||||
operation = None
|
||||
if action == "post_add":
|
||||
operation = Direction.add
|
||||
if action == "post_remove":
|
||||
operation = Direction.remove
|
||||
client.update_group(group, operation, pk_set)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
continue
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet = provider.get_object_qs(Group)
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=group_pk).exists():
|
||||
return
|
||||
|
||||
client = provider.client_for_model(Group)
|
||||
try:
|
||||
operation = None
|
||||
if action == "post_add":
|
||||
operation = Direction.add
|
||||
if action == "post_remove":
|
||||
operation = Direction.remove
|
||||
client.update_group(group, operation, pk_set)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
return
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
|
@ -5,6 +5,8 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -60,3 +62,27 @@ class AuthentikOutpostConfig(ManagedAppConfig):
|
||||
outpost.save()
|
||||
else:
|
||||
Outpost.objects.filter(managed=MANAGED_OUTPOST).delete()
|
||||
|
||||
@property
|
||||
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.outposts.tasks import outpost_token_ensurer
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=outpost_token_ensurer,
|
||||
crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def global_schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.outposts.tasks import outpost_connection_discovery
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=outpost_connection_discovery,
|
||||
crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *",
|
||||
send_on_startup=True,
|
||||
paused=not CONFIG.get_bool("outposts.discover"),
|
||||
),
|
||||
]
|
||||
|
@ -101,7 +101,13 @@ class KubernetesController(BaseController):
|
||||
all_logs = []
|
||||
for reconcile_key in self.reconcile_order:
|
||||
if reconcile_key in self.outpost.config.kubernetes_disabled_components:
|
||||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||
all_logs.append(
|
||||
LogEvent(
|
||||
log_level="info",
|
||||
event=f"{reconcile_key.title()}: Disabled",
|
||||
logger=str(type(self)),
|
||||
)
|
||||
)
|
||||
continue
|
||||
with capture_logs() as logs:
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
@ -134,7 +140,13 @@ class KubernetesController(BaseController):
|
||||
all_logs = []
|
||||
for reconcile_key in self.reconcile_order:
|
||||
if reconcile_key in self.outpost.config.kubernetes_disabled_components:
|
||||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||
all_logs.append(
|
||||
LogEvent(
|
||||
log_level="info",
|
||||
event=f"{reconcile_key.title()}: Disabled",
|
||||
logger=str(type(self)),
|
||||
)
|
||||
)
|
||||
continue
|
||||
with capture_logs() as logs:
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
|
@ -36,7 +36,10 @@ from authentik.lib.config import CONFIG
|
||||
from authentik.lib.models import InheritanceForeignKey, SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
from authentik.outposts.controllers.k8s.utils import get_namespace
|
||||
from authentik.tasks.schedules.lib import ScheduleSpec
|
||||
from authentik.tasks.schedules.models import ScheduledModel
|
||||
|
||||
OUR_VERSION = parse(__version__)
|
||||
OUTPOST_HELLO_INTERVAL = 10
|
||||
@ -115,7 +118,7 @@ class OutpostServiceConnectionState:
|
||||
healthy: bool
|
||||
|
||||
|
||||
class OutpostServiceConnection(models.Model):
|
||||
class OutpostServiceConnection(ScheduledModel, models.Model):
|
||||
"""Connection details for an Outpost Controller, like Docker or Kubernetes"""
|
||||
|
||||
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
||||
@ -145,11 +148,11 @@ class OutpostServiceConnection(models.Model):
|
||||
@property
|
||||
def state(self) -> OutpostServiceConnectionState:
|
||||
"""Get state of service connection"""
|
||||
from authentik.outposts.tasks import outpost_service_connection_state
|
||||
from authentik.outposts.tasks import outpost_service_connection_monitor
|
||||
|
||||
state = cache.get(self.state_key, None)
|
||||
if not state:
|
||||
outpost_service_connection_state.delay(self.pk)
|
||||
outpost_service_connection_monitor.send_with_options(args=(self.pk), rel_obj=self)
|
||||
return OutpostServiceConnectionState("", False)
|
||||
return state
|
||||
|
||||
@ -160,6 +163,20 @@ class OutpostServiceConnection(models.Model):
|
||||
# since the response doesn't use the correct inheritance
|
||||
return ""
|
||||
|
||||
@property
|
||||
def schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.outposts.tasks import outpost_service_connection_monitor
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=outpost_service_connection_monitor,
|
||||
uid=self.pk,
|
||||
args=(self.pk,),
|
||||
crontab="3-59/15 * * * *",
|
||||
send_on_save=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class DockerServiceConnection(SerializerModel, OutpostServiceConnection):
|
||||
"""Service Connection to a Docker endpoint"""
|
||||
@ -244,7 +261,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection):
|
||||
return "ak-service-connection-kubernetes-form"
|
||||
|
||||
|
||||
class Outpost(SerializerModel, ManagedModel):
|
||||
class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
"""Outpost instance which manages a service user and token"""
|
||||
|
||||
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
|
||||
@ -298,6 +315,21 @@ class Outpost(SerializerModel, ManagedModel):
|
||||
"""Username for service user"""
|
||||
return f"ak-outpost-{self.uuid.hex}"
|
||||
|
||||
@property
|
||||
def schedule_specs(self) -> list[ScheduleSpec]:
|
||||
from authentik.outposts.tasks import outpost_controller
|
||||
|
||||
return [
|
||||
ScheduleSpec(
|
||||
actor=outpost_controller,
|
||||
uid=self.pk,
|
||||
args=(self.pk,),
|
||||
kwargs={"action": "up", "from_cache": False},
|
||||
crontab=f"{fqdn_rand('outpost_controller')} */4 * * *",
|
||||
send_on_save=True,
|
||||
),
|
||||
]
|
||||
|
||||
def build_user_permissions(self, user: User):
|
||||
"""Create per-object and global permissions for outpost service-account"""
|
||||
# To ensure the user only has the correct permissions, we delete all of them and re-add
|
||||
|
@ -1,28 +0,0 @@
|
||||
"""Outposts Settings"""
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from authentik.lib.utils.time import fqdn_rand
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"outposts_controller": {
|
||||
"task": "authentik.outposts.tasks.outpost_controller_all",
|
||||
"schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
"outposts_service_connection_check": {
|
||||
"task": "authentik.outposts.tasks.outpost_service_connection_monitor",
|
||||
"schedule": crontab(minute="3-59/15"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
"outpost_token_ensurer": {
|
||||
"task": "authentik.outposts.tasks.outpost_token_ensurer",
|
||||
"schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
"outpost_connection_discovery": {
|
||||
"task": "authentik.outposts.tasks.outpost_connection_discovery",
|
||||
"schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
"""authentik outpost signals"""
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
|
||||
from django.dispatch import receiver
|
||||
from structlog.stdlib import get_logger
|
||||
@ -9,27 +8,19 @@ from structlog.stdlib import get_logger
|
||||
from authentik.brands.models import Brand
|
||||
from authentik.core.models import AuthenticatedSession, Provider
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.outposts.models import Outpost, OutpostServiceConnection
|
||||
from authentik.outposts.models import Outpost, OutpostModel, OutpostServiceConnection
|
||||
from authentik.outposts.tasks import (
|
||||
CACHE_KEY_OUTPOST_DOWN,
|
||||
outpost_controller,
|
||||
outpost_post_save,
|
||||
outpost_send_update,
|
||||
outpost_session_end,
|
||||
)
|
||||
|
||||
LOGGER = get_logger()
|
||||
UPDATE_TRIGGERING_MODELS = (
|
||||
Outpost,
|
||||
OutpostServiceConnection,
|
||||
Provider,
|
||||
CertificateKeyPair,
|
||||
Brand,
|
||||
)
|
||||
|
||||
|
||||
@receiver(pre_save, sender=Outpost)
|
||||
def pre_save_outpost(sender, instance: Outpost, **_):
|
||||
def outpost_pre_save(sender, instance: Outpost, **_):
|
||||
"""Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes,
|
||||
we call down and then wait for the up after save"""
|
||||
old_instances = Outpost.objects.filter(pk=instance.pk)
|
||||
@ -44,43 +35,89 @@ def pre_save_outpost(sender, instance: Outpost, **_):
|
||||
if bool(dirty):
|
||||
LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
|
||||
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)
|
||||
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
|
||||
outpost_controller.send_with_options(
|
||||
args=(instance.pk.hex,),
|
||||
kwargs={"action": "down", "from_cache": True},
|
||||
rel_obj=instance,
|
||||
)
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=Outpost.providers.through)
|
||||
def m2m_changed_update(sender, instance: Model, action: str, **_):
|
||||
def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_):
|
||||
"""Update outpost on m2m change, when providers are added or removed"""
|
||||
if action in ["post_add", "post_remove", "post_clear"]:
|
||||
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
|
||||
if action not in ["post_add", "post_remove", "post_clear"]:
|
||||
return
|
||||
if isinstance(instance, Outpost):
|
||||
outpost_controller.send_with_options(
|
||||
args=(instance.pk,),
|
||||
rel_obj=instance.service_connection,
|
||||
)
|
||||
outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
|
||||
elif isinstance(instance, OutpostModel):
|
||||
for outpost in instance.outpost_set.all():
|
||||
outpost_controller.send_with_options(
|
||||
args=(instance.pk,),
|
||||
rel_obj=instance.service_connection,
|
||||
)
|
||||
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
|
||||
|
||||
|
||||
@receiver(post_save)
|
||||
def post_save_update(sender, instance: Model, created: bool, **_):
|
||||
"""If an Outpost is saved, Ensure that token is created/updated
|
||||
|
||||
If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved,
|
||||
we send a message down the relevant OutpostModels WS connection to trigger an update"""
|
||||
if instance.__module__ == "django.db.migrations.recorder":
|
||||
return
|
||||
if instance.__module__ == "__fake__":
|
||||
return
|
||||
if not isinstance(instance, UPDATE_TRIGGERING_MODELS):
|
||||
return
|
||||
if isinstance(instance, Outpost) and created:
|
||||
@receiver(post_save, sender=Outpost)
|
||||
def outpost_post_save(sender, instance: Outpost, created: bool, **_):
|
||||
if created:
|
||||
LOGGER.info("New outpost saved, ensuring initial token and user are created")
|
||||
_ = instance.token
|
||||
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
|
||||
outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection)
|
||||
outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
|
||||
|
||||
|
||||
def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_):
|
||||
for outpost in instance.outpost_set.all():
|
||||
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
|
||||
|
||||
|
||||
post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False)
|
||||
for subclass in OutpostModel.__subclasses__():
|
||||
post_save.connect(outpost_related_post_save, sender=subclass, weak=False)
|
||||
|
||||
|
||||
def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Brand, **_):
|
||||
for field in instance._meta.get_fields():
|
||||
# Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms)
|
||||
# are used, and if it has a value
|
||||
if not hasattr(field, "related_model"):
|
||||
continue
|
||||
if not field.related_model:
|
||||
continue
|
||||
if not issubclass(field.related_model, OutpostModel):
|
||||
continue
|
||||
|
||||
field_name = f"{field.name}_set"
|
||||
if not hasattr(instance, field_name):
|
||||
continue
|
||||
|
||||
LOGGER.debug("triggering outpost update from field", field=field.name)
|
||||
# Because the Outpost Model has an M2M to Provider,
|
||||
# we have to iterate over the entire QS
|
||||
for reverse in getattr(instance, field_name).all():
|
||||
if isinstance(reverse, OutpostModel):
|
||||
for outpost in reverse.outpost_set.all():
|
||||
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
|
||||
|
||||
|
||||
post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False)
|
||||
post_save.connect(outpost_reverse_related_post_save, sender=CertificateKeyPair, weak=False)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=Outpost)
|
||||
def pre_delete_cleanup(sender, instance: Outpost, **_):
|
||||
def outpost_pre_delete_cleanup(sender, instance: Outpost, **_):
|
||||
"""Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
|
||||
instance.user.delete()
|
||||
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
|
||||
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
|
||||
outpost_controller.send(instance.pk.hex, action="down", from_cache=True)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=AuthenticatedSession)
|
||||
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
|
||||
"""Catch logout by expiring sessions being deleted"""
|
||||
outpost_session_end.delay(instance.session.session_key)
|
||||
outpost_session_end.send(instance.session.session_key)
|
||||
|
@ -10,19 +10,17 @@ from urllib.parse import urlparse
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.core.cache import cache
|
||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
from django.db.models.base import Model
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_dramatiq_postgres.middleware import CurrentTask
|
||||
from docker.constants import DEFAULT_UNIX_SOCKET
|
||||
from dramatiq.actor import actor
|
||||
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
|
||||
from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
|
||||
from structlog.stdlib import get_logger
|
||||
from yaml import safe_load
|
||||
|
||||
from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask, prefill_task
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.reflection import path_to_class
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||
from authentik.outposts.controllers.base import BaseController, ControllerException
|
||||
from authentik.outposts.controllers.docker import DockerClient
|
||||
@ -31,7 +29,6 @@ from authentik.outposts.models import (
|
||||
DockerServiceConnection,
|
||||
KubernetesServiceConnection,
|
||||
Outpost,
|
||||
OutpostModel,
|
||||
OutpostServiceConnection,
|
||||
OutpostType,
|
||||
ServiceConnectionInvalid,
|
||||
@ -44,7 +41,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController
|
||||
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
|
||||
from authentik.providers.radius.controllers.docker import RadiusDockerController
|
||||
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
|
||||
@ -83,8 +80,8 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
|
||||
return None
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def outpost_service_connection_state(connection_pk: Any):
|
||||
@actor(description=_("Update cached state of service connection."))
|
||||
def outpost_service_connection_monitor(connection_pk: Any):
|
||||
"""Update cached state of a service connection"""
|
||||
connection: OutpostServiceConnection = (
|
||||
OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first()
|
||||
@ -108,37 +105,11 @@ def outpost_service_connection_state(connection_pk: Any):
|
||||
cache.set(connection.state_key, state, timeout=None)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=SystemTask,
|
||||
throws=(DatabaseError, ProgrammingError, InternalError),
|
||||
)
|
||||
@prefill_task
|
||||
def outpost_service_connection_monitor(self: SystemTask):
|
||||
"""Regularly check the state of Outpost Service Connections"""
|
||||
connections = OutpostServiceConnection.objects.all()
|
||||
for connection in connections.iterator():
|
||||
outpost_service_connection_state.delay(connection.pk)
|
||||
self.set_status(
|
||||
TaskStatus.SUCCESSFUL,
|
||||
f"Successfully updated {len(connections)} connections.",
|
||||
)
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
throws=(DatabaseError, ProgrammingError, InternalError),
|
||||
)
|
||||
def outpost_controller_all():
|
||||
"""Launch Controller for all Outposts which support it"""
|
||||
for outpost in Outpost.objects.exclude(service_connection=None):
|
||||
outpost_controller.delay(outpost.pk.hex, "up", from_cache=False)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
def outpost_controller(
|
||||
self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
|
||||
):
|
||||
@actor(description=_("Create/update/monitor/delete the deployment of an Outpost."))
|
||||
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
|
||||
"""Create/update/monitor/delete the deployment of an Outpost"""
|
||||
self: Task = CurrentTask.get_task()
|
||||
self.set_uid(outpost_pk)
|
||||
logs = []
|
||||
if from_cache:
|
||||
outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
||||
@ -159,125 +130,65 @@ def outpost_controller(
|
||||
logs = getattr(controller, f"{action}_with_logs")()
|
||||
LOGGER.debug("-----------------Outpost Controller logs end-------------------")
|
||||
except (ControllerException, ServiceConnectionInvalid) as exc:
|
||||
self.set_error(exc)
|
||||
self.error(exc)
|
||||
else:
|
||||
if from_cache:
|
||||
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *logs)
|
||||
self.logs(logs)
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=SystemTask)
|
||||
@prefill_task
|
||||
def outpost_token_ensurer(self: SystemTask):
|
||||
"""Periodically ensure that all Outposts have valid Service Accounts
|
||||
and Tokens"""
|
||||
@actor(description=_("Ensure that all Outposts have valid Service Accounts and Tokens."))
|
||||
def outpost_token_ensurer():
|
||||
"""
|
||||
Periodically ensure that all Outposts have valid Service Accounts and Tokens
|
||||
"""
|
||||
self: Task = CurrentTask.get_task()
|
||||
all_outposts = Outpost.objects.all()
|
||||
for outpost in all_outposts:
|
||||
_ = outpost.token
|
||||
outpost.build_user_permissions(outpost.user)
|
||||
self.set_status(
|
||||
TaskStatus.SUCCESSFUL,
|
||||
f"Successfully checked {len(all_outposts)} Outposts.",
|
||||
)
|
||||
self.info(f"Successfully checked {len(all_outposts)} Outposts.")
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
def outpost_post_save(model_class: str, model_pk: Any):
|
||||
"""If an Outpost is saved, Ensure that token is created/updated
|
||||
|
||||
If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved,
|
||||
we send a message down the relevant OutpostModels WS connection to trigger an update"""
|
||||
model: Model = path_to_class(model_class)
|
||||
try:
|
||||
instance = model.objects.get(pk=model_pk)
|
||||
except model.DoesNotExist:
|
||||
LOGGER.warning("Model does not exist", model=model, pk=model_pk)
|
||||
@actor(description=_("Send update to outpost"))
|
||||
def outpost_send_update(pk: Any):
|
||||
"""Update outpost instance"""
|
||||
outpost = Outpost.objects.filter(pk=pk).first()
|
||||
if not outpost:
|
||||
return
|
||||
|
||||
if isinstance(instance, Outpost):
|
||||
LOGGER.debug("Trigger reconcile for outpost", instance=instance)
|
||||
outpost_controller.delay(str(instance.pk))
|
||||
|
||||
if isinstance(instance, OutpostModel | Outpost):
|
||||
LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance)
|
||||
outpost_send_update(instance)
|
||||
|
||||
if isinstance(instance, OutpostServiceConnection):
|
||||
LOGGER.debug("triggering ServiceConnection state update", instance=instance)
|
||||
outpost_service_connection_state.delay(str(instance.pk))
|
||||
|
||||
for field in instance._meta.get_fields():
|
||||
# Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms)
|
||||
# are used, and if it has a value
|
||||
if not hasattr(field, "related_model"):
|
||||
continue
|
||||
if not field.related_model:
|
||||
continue
|
||||
if not issubclass(field.related_model, OutpostModel):
|
||||
continue
|
||||
|
||||
field_name = f"{field.name}_set"
|
||||
if not hasattr(instance, field_name):
|
||||
continue
|
||||
|
||||
LOGGER.debug("triggering outpost update from field", field=field.name)
|
||||
# Because the Outpost Model has an M2M to Provider,
|
||||
# we have to iterate over the entire QS
|
||||
for reverse in getattr(instance, field_name).all():
|
||||
outpost_send_update(reverse)
|
||||
|
||||
|
||||
def outpost_send_update(model_instance: Model):
|
||||
"""Send outpost update to all registered outposts, regardless to which authentik
|
||||
instance they are connected"""
|
||||
channel_layer = get_channel_layer()
|
||||
if isinstance(model_instance, OutpostModel):
|
||||
for outpost in model_instance.outpost_set.all():
|
||||
_outpost_single_update(outpost, channel_layer)
|
||||
elif isinstance(model_instance, Outpost):
|
||||
_outpost_single_update(model_instance, channel_layer)
|
||||
|
||||
|
||||
def _outpost_single_update(outpost: Outpost, layer=None):
|
||||
"""Update outpost instances connected to a single outpost"""
|
||||
# Ensure token again, because this function is called when anything related to an
|
||||
# OutpostModel is saved, so we can be sure permissions are right
|
||||
_ = outpost.token
|
||||
outpost.build_user_permissions(outpost.user)
|
||||
if not layer: # pragma: no cover
|
||||
layer = get_channel_layer()
|
||||
layer = get_channel_layer()
|
||||
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
base=SystemTask,
|
||||
bind=True,
|
||||
)
|
||||
def outpost_connection_discovery(self: SystemTask):
|
||||
@actor(description=_("Checks the local environment and create Service connections."))
|
||||
def outpost_connection_discovery():
|
||||
"""Checks the local environment and create Service connections."""
|
||||
messages = []
|
||||
self: Task = CurrentTask.get_task()
|
||||
if not CONFIG.get_bool("outposts.discover"):
|
||||
messages.append("Outpost integration discovery is disabled")
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
self.info("Outpost integration discovery is disabled")
|
||||
return
|
||||
# Explicitly check against token filename, as that's
|
||||
# only present when the integration is enabled
|
||||
if Path(SERVICE_TOKEN_FILENAME).exists():
|
||||
messages.append("Detected in-cluster Kubernetes Config")
|
||||
self.info("Detected in-cluster Kubernetes Config")
|
||||
if not KubernetesServiceConnection.objects.filter(local=True).exists():
|
||||
messages.append("Created Service Connection for in-cluster")
|
||||
self.info("Created Service Connection for in-cluster")
|
||||
KubernetesServiceConnection.objects.create(
|
||||
name="Local Kubernetes Cluster", local=True, kubeconfig={}
|
||||
)
|
||||
# For development, check for the existence of a kubeconfig file
|
||||
kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
|
||||
if kubeconfig_path.exists():
|
||||
messages.append("Detected kubeconfig")
|
||||
self.info("Detected kubeconfig")
|
||||
kubeconfig_local_name = f"k8s-{gethostname()}"
|
||||
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
|
||||
messages.append("Creating kubeconfig Service Connection")
|
||||
self.info("Creating kubeconfig Service Connection")
|
||||
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
|
||||
KubernetesServiceConnection.objects.create(
|
||||
name=kubeconfig_local_name,
|
||||
@ -286,20 +197,18 @@ def outpost_connection_discovery(self: SystemTask):
|
||||
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
|
||||
socket = Path(unix_socket_path)
|
||||
if socket.exists() and access(socket, R_OK):
|
||||
messages.append("Detected local docker socket")
|
||||
self.info("Detected local docker socket")
|
||||
if len(DockerServiceConnection.objects.filter(local=True)) == 0:
|
||||
messages.append("Created Service Connection for docker")
|
||||
self.info("Created Service Connection for docker")
|
||||
DockerServiceConnection.objects.create(
|
||||
name="Local Docker connection",
|
||||
local=True,
|
||||
url=unix_socket_path,
|
||||
)
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *messages)
|
||||
|
||||
|
||||
@CELERY_APP.task()
|
||||
@actor(description=_("Terminate session on all outposts."))
|
||||
def outpost_session_end(session_id: str):
|
||||
"""Update outpost instances connected to a single outpost"""
|
||||
layer = get_channel_layer()
|
||||
hashed_session_id = hash_session_key(session_id)
|
||||
for outpost in Outpost.objects.all():
|
||||
|
@ -37,6 +37,7 @@ class OutpostTests(TestCase):
|
||||
|
||||
# We add a provider, user should only have access to outpost and provider
|
||||
outpost.providers.add(provider)
|
||||
provider.refresh_from_db()
|
||||
permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by(
|
||||
"content_type__model"
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
|
||||
def proxy_set_defaults(self):
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
|
||||
# TODO: figure out if this can be in pre_save + post_save signals
|
||||
for provider in ProxyProvider.objects.all():
|
||||
provider.set_oauth_defaults()
|
||||
provider.save()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user