Compare commits

..

1 Commits

Author SHA1 Message Date
617d913ca2 events/batch: add event batching mechanism [AUTH-134] 2024-02-05 14:17:50 +01:00
263 changed files with 7452 additions and 19655 deletions

View File

@ -1,20 +1,12 @@
[bumpversion]
current_version = 2024.2.4
current_version = 2023.10.7
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
serialize =
{major}.{minor}.{patch}-{rc_t}{rc_n}
{major}.{minor}.{patch}
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch}
message = release: {new_version}
tag_name = version/{new_version}
[bumpversion:part:rc_t]
values =
rc
final
optional_value = final
[bumpversion:file:pyproject.toml]
[bumpversion:file:docker-compose.yml]

View File

@ -9,6 +9,9 @@ inputs:
runs:
using: "composite"
steps:
- name: Generate config
id: ev
uses: ./.github/actions/docker-push-variables
- name: Find Comment
uses: peter-evans/find-comment@v2
id: fc

View File

@ -1,47 +1,64 @@
---
name: "Prepare docker environment variables"
description: "Prepare docker environment variables"
inputs:
image-name:
required: true
description: "Docker image prefix"
image-arch:
required: false
description: "Docker image arch"
outputs:
shouldBuild:
description: "Whether to build image or not"
value: ${{ steps.ev.outputs.shouldBuild }}
branchName:
description: "Branch name"
value: ${{ steps.ev.outputs.branchName }}
branchNameContainer:
description: "Branch name (for containers)"
value: ${{ steps.ev.outputs.branchNameContainer }}
timestamp:
description: "Timestamp"
value: ${{ steps.ev.outputs.timestamp }}
sha:
description: "sha"
value: ${{ steps.ev.outputs.sha }}
shortHash:
description: "shortHash"
value: ${{ steps.ev.outputs.shortHash }}
version:
description: "Version"
description: "version"
value: ${{ steps.ev.outputs.version }}
prerelease:
description: "Prerelease"
value: ${{ steps.ev.outputs.prerelease }}
imageTags:
description: "Docker image tags"
value: ${{ steps.ev.outputs.imageTags }}
imageMainTag:
description: "Docker image main tag"
value: ${{ steps.ev.outputs.imageMainTag }}
versionFamily:
description: "versionFamily"
value: ${{ steps.ev.outputs.versionFamily }}
runs:
using: "composite"
steps:
- name: Generate config
id: ev
shell: bash
env:
IMAGE_NAME: ${{ inputs.image-name }}
IMAGE_ARCH: ${{ inputs.image-arch }}
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
shell: python
run: |
python3 ${{ github.action_path }}/push_vars.py
"""Helper script to get the actual branch name, docker safe"""
import configparser
import os
from time import time
parser = configparser.ConfigParser()
parser.read(".bumpversion.cfg")
branch_name = os.environ["GITHUB_REF"]
if os.environ.get("GITHUB_HEAD_REF", "") != "":
branch_name = os.environ["GITHUB_HEAD_REF"]
should_build = str(os.environ.get("DOCKER_USERNAME", "") != "").lower()
version = parser.get("bumpversion", "current_version")
version_family = ".".join(version.split(".")[:-1])
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
sha = os.environ["GITHUB_SHA"] if not "${{ github.event.pull_request.head.sha }}" else "${{ github.event.pull_request.head.sha }}"
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print("branchName=%s" % branch_name, file=_output)
print("branchNameContainer=%s" % safe_branch_name, file=_output)
print("timestamp=%s" % int(time()), file=_output)
print("sha=%s" % sha, file=_output)
print("shortHash=%s" % sha[:7], file=_output)
print("shouldBuild=%s" % should_build, file=_output)
print("version=%s" % version, file=_output)
print("versionFamily=%s" % version_family, file=_output)

View File

@ -1,62 +0,0 @@
"""Helper script to get the actual branch name, docker safe"""
import configparser
import os
from time import time
parser = configparser.ConfigParser()
parser.read(".bumpversion.cfg")
should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower()
branch_name = os.environ["GITHUB_REF"]
if os.environ.get("GITHUB_HEAD_REF", "") != "":
branch_name = os.environ["GITHUB_HEAD_REF"]
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
image_names = os.getenv("IMAGE_NAME").split(",")
image_arch = os.getenv("IMAGE_ARCH") or None
is_pull_request = bool(os.getenv("PR_HEAD_SHA"))
is_release = "dev" not in image_names[0]
sha = os.environ["GITHUB_SHA"] if not is_pull_request else os.getenv("PR_HEAD_SHA")
# 2042.1.0 or 2042.1.0-rc1
version = parser.get("bumpversion", "current_version")
# 2042.1
version_family = ".".join(version.split("-", 1)[0].split(".")[:-1])
prerelease = "-" in version
image_tags = []
if is_release:
for name in image_names:
image_tags += [
f"{name}:{version}",
]
if not prerelease:
image_tags += [
f"{name}:latest",
f"{name}:{version_family}",
]
else:
suffix = ""
if image_arch and image_arch != "amd64":
suffix = f"-{image_arch}"
for name in image_names:
image_tags += [
f"{name}:gh-{sha}{suffix}", # Used for ArgoCD and PR comments
f"{name}:gh-{safe_branch_name}{suffix}", # For convenience
f"{name}:gh-{safe_branch_name}-{int(time())}-{sha[:7]}{suffix}", # Use by FluxCD
]
image_main_tag = image_tags[0]
image_tags_rendered = ",".join(image_tags)
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print("shouldBuild=%s" % should_build, file=_output)
print("sha=%s" % sha, file=_output)
print("version=%s" % version, file=_output)
print("prerelease=%s" % prerelease, file=_output)
print("imageTags=%s" % image_tags_rendered, file=_output)
print("imageMainTag=%s" % image_main_tag, file=_output)

View File

@ -1,7 +0,0 @@
#!/bin/bash -x
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
GITHUB_OUTPUT=/dev/stdout \
GITHUB_REF=ref \
GITHUB_SHA=sha \
IMAGE_NAME=ghcr.io/goauthentik/server,beryju/authentik \
python $SCRIPT_DIR/push_vars.py

View File

@ -3,4 +3,3 @@ keypairs
hass
warmup
ontext
singed

View File

@ -27,6 +27,7 @@ If an API change has been made
If changes to the frontend have been made
- [ ] The code has been formatted (`make web`)
- [ ] The translation files have been updated (`make i18n-extract`)
If applicable

View File

@ -1,4 +1,3 @@
---
name: authentik-ci-main
on:
@ -8,7 +7,7 @@ on:
- next
- version-*
paths-ignore:
- website/**
- website
pull_request:
branches:
- main
@ -30,7 +29,7 @@ jobs:
- codespell
- isort
- pending-migrations
# - pylint
- pylint
- pyright
- ruff
runs-on: ubuntu-latest
@ -70,7 +69,7 @@ jobs:
cp authentik/lib/default.yml local.env.yml
cp -R .github ..
cp -R scripts ..
git checkout $(git tag --sort=version:refname | grep '^version/' | grep -vE -- '-rc[0-9]+$' | tail -n1)
git checkout version/$(python -c "from authentik import __version__; print(__version__)")
rm -rf .github/ scripts/
mv ../.github ../scripts .
- name: Setup authentik env (stable)
@ -135,7 +134,7 @@ jobs:
- name: Setup authentik env
uses: ./.github/actions/setup
- name: Create k8s Kind Cluster
uses: helm/kind-action@v1.9.0
uses: helm/kind-action@v1.8.0
- name: run integration
run: |
poetry run coverage run manage.py test tests/integration
@ -207,12 +206,6 @@ jobs:
steps:
- run: echo mark
build:
strategy:
fail-fast: false
matrix:
arch:
- amd64
- arm64
needs: ci-core-mark
runs-on: ubuntu-latest
permissions:
@ -232,12 +225,9 @@ jobs:
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
with:
image-name: ghcr.io/goauthentik/dev-server
image-arch: ${{ matrix.arch }}
- name: Login to Container Registry
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
uses: docker/login-action@v3
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
@ -251,16 +241,69 @@ jobs:
secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
tags: ${{ steps.ev.outputs.imageTags }}
push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
tags: |
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.sha }}
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-arm64:
needs: ci-core-mark
runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
timeout-minutes: 120
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
- name: Login to Container Registry
uses: docker/login-action@v3
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: generate ts client
run: make gen-client-ts
- name: Build Docker Image
uses: docker/build-push-action@v5
with:
context: .
secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
tags: |
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-arm64
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.sha }}-arm64
ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}-arm64
build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
platforms: linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/${{ matrix.arch }}
pr-comment:
needs:
- build
- build-arm64
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request' }}
permissions:
@ -276,9 +319,7 @@ jobs:
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
with:
image-name: ghcr.io/goauthentik/dev-server
- name: Comment on PR
uses: ./.github/actions/comment-pr-instructions
with:
tag: gh-${{ steps.ev.outputs.imageMainTag }}
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}

View File

@ -1,4 +1,3 @@
---
name: authentik-ci-outpost
on:
@ -29,7 +28,7 @@ jobs:
- name: Generate API
run: make gen-client-go
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
uses: golangci/golangci-lint-action@v3
with:
version: v1.54.2
args: --timeout 5000s --verbose
@ -84,11 +83,9 @@ jobs:
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
with:
image-name: ghcr.io/goauthentik/dev-${{ matrix.type }}
- name: Login to Container Registry
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
uses: docker/login-action@v3
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
@ -98,11 +95,15 @@ jobs:
- name: Build Docker Image
uses: docker/build-push-action@v5
with:
tags: ${{ steps.ev.outputs.imageTags }}
file: ${{ matrix.type }}.Dockerfile
push: ${{ steps.ev.outputs.shouldBuild == 'true' }}
tags: |
ghcr.io/goauthentik/dev-${{ matrix.type }}:gh-${{ steps.ev.outputs.branchNameContainer }}
ghcr.io/goauthentik/dev-${{ matrix.type }}:gh-${{ steps.ev.outputs.sha }}
file: ${{ matrix.type }}.Dockerfile
build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
platforms: linux/amd64,linux/arm64
context: .
cache-from: type=gha

View File

@ -1,4 +1,3 @@
---
name: authentik-on-release
on:
@ -20,10 +19,6 @@ jobs:
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
with:
image-name: ghcr.io/goauthentik/server,beryju/authentik
- name: Docker Login Registry
uses: docker/login-action@v3
with:
@ -43,12 +38,21 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
push: true
push: ${{ github.event_name == 'release' }}
secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
tags: ${{ steps.ev.outputs.imageTags }}
tags: |
beryju/authentik:${{ steps.ev.outputs.version }},
beryju/authentik:${{ steps.ev.outputs.versionFamily }},
beryju/authentik:latest,
ghcr.io/goauthentik/server:${{ steps.ev.outputs.version }},
ghcr.io/goauthentik/server:${{ steps.ev.outputs.versionFamily }},
ghcr.io/goauthentik/server:latest
platforms: linux/amd64,linux/arm64
build-args: |
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost:
runs-on: ubuntu-latest
permissions:
@ -74,10 +78,6 @@ jobs:
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
with:
image-name: ghcr.io/goauthentik/${{ matrix.type }},beryju/authentik-${{ matrix.type }}
- name: make empty clients
run: |
mkdir -p ./gen-ts-api
@ -96,11 +96,20 @@ jobs:
- name: Build Docker Image
uses: docker/build-push-action@v5
with:
push: true
tags: ${{ steps.ev.outputs.imageTags }}
push: ${{ github.event_name == 'release' }}
tags: |
beryju/authentik-${{ matrix.type }}:${{ steps.ev.outputs.version }},
beryju/authentik-${{ matrix.type }}:${{ steps.ev.outputs.versionFamily }},
beryju/authentik-${{ matrix.type }}:latest,
ghcr.io/goauthentik/${{ matrix.type }}:${{ steps.ev.outputs.version }},
ghcr.io/goauthentik/${{ matrix.type }}:${{ steps.ev.outputs.versionFamily }},
ghcr.io/goauthentik/${{ matrix.type }}:latest
file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64
context: .
build-args: |
VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost-binary:
timeout-minutes: 120
runs-on: ubuntu-latest
@ -172,18 +181,15 @@ jobs:
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
with:
image-name: ghcr.io/goauthentik/server
- name: Get static files from docker image
run: |
docker pull ${{ steps.ev.outputs.imageMainTag }}
container=$(docker container create ${{ steps.ev.outputs.imageMainTag }})
docker pull ghcr.io/goauthentik/server:latest
container=$(docker container create ghcr.io/goauthentik/server:latest)
docker cp ${container}:web/ .
- name: Create a Sentry.io release
uses: getsentry/action-release@v1
continue-on-error: true
if: ${{ github.event_name == 'release' }}
env:
SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }}
SENTRY_ORG: authentik-security-inc

View File

@ -1,4 +1,3 @@
---
name: authentik-on-tag
on:
@ -29,13 +28,13 @@ jobs:
with:
app_id: ${{ secrets.GH_APP_ID }}
private_key: ${{ secrets.GH_APP_PRIVATE_KEY }}
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
- name: Extract version number
id: get_version
uses: actions/github-script@v7
with:
image-name: ghcr.io/goauthentik/server
github-token: ${{ steps.generate_token.outputs.token }}
script: |
return context.payload.ref.replace(/\/refs\/tags\/version\//, '');
- name: Create Release
id: create_release
uses: actions/create-release@v1.1.4
@ -43,6 +42,6 @@ jobs:
GITHUB_TOKEN: ${{ steps.generate_token.outputs.token }}
with:
tag_name: ${{ github.ref }}
release_name: Release ${{ steps.ev.outputs.version }}
release_name: Release ${{ steps.get_version.outputs.result }}
draft: true
prerelease: ${{ steps.ev.outputs.prerelease == 'true' }}
prerelease: false

View File

@ -1,8 +1,9 @@
---
name: authentik-backend-translate-extract-compile
name: authentik-backend-translate-compile
on:
schedule:
- cron: "0 0 * * *" # every day at midnight
push:
branches: [main]
paths:
- "locale/**"
workflow_dispatch:
env:
@ -24,20 +25,16 @@ jobs:
token: ${{ steps.generate_token.outputs.token }}
- name: Setup authentik env
uses: ./.github/actions/setup
- name: run extract
run: |
poetry run make i18n-extract
- name: run compile
run: |
poetry run ak compilemessages
make web-check-compile
run: poetry run ak compilemessages
- name: Create Pull Request
uses: peter-evans/create-pull-request@v6
id: cpr
with:
token: ${{ steps.generate_token.outputs.token }}
branch: extract-compile-backend-translation
commit-message: "core, web: update translations"
title: "core, web: update translations"
body: "core, web: update translations"
branch: compile-backend-translation
commit-message: "core: compile backend translations"
title: "core: compile backend translations"
body: "core: compile backend translations"
delete-branch: true
signoff: true

View File

@ -37,7 +37,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build
# Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.0-bookworm AS go-builder
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.21.6-bookworm AS go-builder
ARG TARGETOS
ARG TARGETARCH
@ -83,7 +83,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 5: Python dependencies
FROM docker.io/python:3.12.2-slim-bookworm AS python-deps
FROM docker.io/python:3.12.1-slim-bookworm AS python-deps
WORKDIR /ak-root/poetry
@ -103,13 +103,12 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
--mount=type=cache,target=/root/.cache/pip \
--mount=type=cache,target=/root/.cache/pypoetry \
python -m venv /ak-root/venv/ && \
bash -c "source ${VENV_PATH}/bin/activate && \
pip3 install --upgrade pip && \
pip3 install poetry && \
poetry install --only=main --no-ansi --no-interaction --no-root"
pip3 install --upgrade pip && \
pip3 install poetry && \
poetry install --only=main --no-ansi --no-interaction
# Stage 6: Run
FROM docker.io/python:3.12.2-slim-bookworm AS final-image
FROM docker.io/python:3.12.1-slim-bookworm AS final-image
ARG GIT_BUILD_HASH
ARG VERSION

View File

@ -5,12 +5,9 @@ PWD = $(shell pwd)
UID = $(shell id -u)
GID = $(shell id -g)
NPM_VERSION = $(shell python -m scripts.npm_version)
PY_SOURCES = authentik tests scripts lifecycle .github
PY_SOURCES = authentik tests scripts lifecycle
DOCKER_IMAGE ?= "authentik:test"
GEN_API_TS = "gen-ts-api"
GEN_API_GO = "gen-go-api"
pg_user := $(shell python -m authentik.lib.config postgresql.user 2>/dev/null)
pg_host := $(shell python -m authentik.lib.config postgresql.host 2>/dev/null)
pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null)
@ -79,15 +76,7 @@ migrate: ## Run the Authentik Django server's migrations
i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that require translation into files to send to a translation service
core-i18n-extract:
ak makemessages \
--add-location file \
--no-obsolete \
--ignore web \
--ignore internal \
--ignore ${GEN_API_TS} \
--ignore ${GEN_API_GO} \
--ignore website \
-l en
ak makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en
install: web-install website-install core-install ## Install all requires dependencies for `web`, `website` and `core`
@ -125,7 +114,7 @@ gen-diff: ## (Release) generate the changelog diff between the current schema a
docker run \
--rm -v ${PWD}:/local \
--user ${UID}:${GID} \
docker.io/openapitools/openapi-diff:2.1.0-beta.8 \
docker.io/openapitools/openapi-diff:2.1.0-beta.6 \
--markdown /local/diff.md \
/local/old_schema.yml /local/schema.yml
rm old_schema.yml
@ -134,11 +123,11 @@ gen-diff: ## (Release) generate the changelog diff between the current schema a
npx prettier --write diff.md
gen-clean-ts: ## Remove generated API client for Typescript
rm -rf ./${GEN_API_TS}/
rm -rf ./web/node_modules/@goauthentik/api/
rm -rf gen-ts-api/
rm -rf web/node_modules/@goauthentik/api/
gen-clean-go: ## Remove generated API client for Go
rm -rf ./${GEN_API_GO}/
rm -rf gen-go-api/
gen-clean: gen-clean-ts gen-clean-go ## Remove generated API clients
@ -149,31 +138,31 @@ gen-client-ts: gen-clean-ts ## Build and install the authentik API for Typescri
docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \
-i /local/schema.yml \
-g typescript-fetch \
-o /local/${GEN_API_TS} \
-o /local/gen-ts-api \
-c /local/scripts/api-ts-config.yaml \
--additional-properties=npmVersion=${NPM_VERSION} \
--git-repo-id authentik \
--git-user-id goauthentik
mkdir -p web/node_modules/@goauthentik/api
cd ./${GEN_API_TS} && npm i
\cp -rf ./${GEN_API_TS}/* web/node_modules/@goauthentik/api
cd gen-ts-api && npm i
\cp -rfv gen-ts-api/* web/node_modules/@goauthentik/api
gen-client-go: gen-clean-go ## Build and install the authentik API for Golang
mkdir -p ./${GEN_API_GO} ./${GEN_API_GO}/templates
wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./${GEN_API_GO}/config.yaml
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./${GEN_API_GO}/templates/README.mustache
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O ./${GEN_API_GO}/templates/go.mod.mustache
cp schema.yml ./${GEN_API_GO}/
mkdir -p ./gen-go-api ./gen-go-api/templates
wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./gen-go-api/config.yaml
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./gen-go-api/templates/README.mustache
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O ./gen-go-api/templates/go.mod.mustache
cp schema.yml ./gen-go-api/
docker run \
--rm -v ${PWD}/${GEN_API_GO}:/local \
--rm -v ${PWD}/gen-go-api:/local \
--user ${UID}:${GID} \
docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \
-i /local/schema.yml \
-g go \
-o /local/ \
-c /local/config.yaml
go mod edit -replace goauthentik.io/api/v3=./${GEN_API_GO}
rm -rf ./${GEN_API_GO}/config.yaml ./${GEN_API_GO}/templates/
go mod edit -replace goauthentik.io/api/v3=./gen-go-api
rm -rf ./gen-go-api/config.yaml ./gen-go-api/templates/
gen-dev-config: ## Generate a local development config file
python -m scripts.generate_config
@ -187,7 +176,7 @@ gen: gen-build gen-client-ts
web-build: web-install ## Build the Authentik UI
cd web && npm run build
web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
web: web-lint-fix web-lint web-check-compile web-i18n-extract ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
web-install: ## Install the necessary libraries to build the Authentik UI
cd web && npm ci

View File

@ -3,7 +3,7 @@
from os import environ
from typing import Optional
__version__ = "2024.2.4"
__version__ = "2023.10.7"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -15,3 +15,7 @@ class AuthentikAdminConfig(ManagedAppConfig):
label = "authentik_admin"
verbose_name = "authentik Admin"
default = True
def reconcile_global_load_admin_signals(self):
"""Load admin signals"""
self.import_module("authentik.admin.signals")

View File

@ -14,23 +14,18 @@ LOGGER = get_logger()
def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[list[str]] = None):
"""Check permissions for a single custom action"""
def _check_obj_perm(self: ModelViewSet, request: Request):
# Check obj_perm both globally and on the specific object
# Having the global permission has higher priority
if request.user.has_perm(obj_perm):
return
obj = self.get_object()
if not request.user.has_perm(obj_perm, obj):
LOGGER.debug("denying access for object", user=request.user, perm=obj_perm, obj=obj)
self.permission_denied(request)
def wrapper_outer(func: Callable):
def wrapper_outter(func: Callable):
"""Check permissions for a single custom action"""
@wraps(func)
def wrapper(self: ModelViewSet, request: Request, *args, **kwargs) -> Response:
if obj_perm:
_check_obj_perm(self, request)
obj = self.get_object()
if not request.user.has_perm(obj_perm, obj):
LOGGER.debug(
"denying access for object", user=request.user, perm=obj_perm, obj=obj
)
return self.permission_denied(request)
if global_perms:
for other_perm in global_perms:
if not request.user.has_perm(other_perm):
@ -40,4 +35,4 @@ def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[l
return wrapper
return wrapper_outer
return wrapper_outter

View File

@ -0,0 +1,35 @@
"""test decorators api"""
from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase
from authentik.core.models import Application, User
from authentik.lib.generators import generate_id
class TestAPIDecorators(APITestCase):
"""test decorators api"""
def setUp(self) -> None:
super().setUp()
self.user = User.objects.create(username="test-user")
def test_obj_perm_denied(self):
"""Test object perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name=generate_id(), slug=generate_id())
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 403)
def test_other_perm_denied(self):
"""Test other perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name=generate_id(), slug=generate_id())
assign_perm("authentik_core.view_application", self.user, app)
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 403)

View File

@ -68,11 +68,7 @@ class ConfigView(APIView):
"""Get all capabilities this server instance supports"""
caps = []
deb_test = settings.DEBUG or settings.TEST
if (
CONFIG.get("storage.media.backend", "file") == "s3"
or Path(settings.STORAGES["default"]["OPTIONS"]["location"]).is_mount()
or deb_test
):
if Path(settings.MEDIA_ROOT).is_mount() or deb_test:
caps.append(Capabilities.CAN_SAVE_MEDIA)
for processor in get_context_processors():
if cap := processor.capability():

View File

@ -10,13 +10,13 @@ from rest_framework.response import Response
from rest_framework.serializers import ListSerializer, ModelSerializer
from rest_framework.viewsets import ModelViewSet
from authentik.api.decorators import permission_required
from authentik.blueprints.models import BlueprintInstance
from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.oci import OCI_PREFIX
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField, PassiveSerializer
from authentik.rbac.decorators import permission_required
class ManagedSerializer:

View File

@ -21,27 +21,10 @@ class ManagedAppConfig(AppConfig):
self.logger = get_logger().bind(app_name=app_name)
def ready(self) -> None:
self.import_related()
self.reconcile_global()
self.reconcile_tenant()
return super().ready()
def import_related(self):
"""Automatically import related modules which rely on just being imported
to register themselves (mainly django signals and celery tasks)"""
def import_relative(rel_module: str):
try:
module_name = f"{self.name}.{rel_module}"
import_module(module_name)
self.logger.info("Imported related module", module=module_name)
except ModuleNotFoundError:
pass
import_relative("checks")
import_relative("tasks")
import_relative("signals")
def import_module(self, path: str):
"""Load module"""
import_module(path)

View File

@ -74,7 +74,7 @@ class Exporter:
class FlowExporter(Exporter):
"""Exporter customized to only return objects related to `flow`"""
"""Exporter customised to only return objects related to `flow`"""
flow: Flow
with_policies: bool

View File

@ -39,8 +39,7 @@ from authentik.core.models import (
Source,
UserSourceConnection,
)
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsage
from authentik.enterprise.models import LicenseKey, LicenseUsage
from authentik.enterprise.providers.rac.models import ConnectionToken
from authentik.events.models import SystemTask
from authentik.events.utils import cleanse_dict

View File

@ -3,7 +3,6 @@
from dataclasses import asdict, dataclass, field
from hashlib import sha512
from pathlib import Path
from sys import platform
from typing import Optional
from dacite.core import from_dict
@ -61,12 +60,7 @@ def start_blueprint_watcher():
if _file_watcher_started:
return
observer = Observer()
kwargs = {}
if platform.startswith("linux"):
kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent)
observer.schedule(
BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs
)
observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True)
observer.start()
_file_watcher_started = True
@ -74,36 +68,26 @@ def start_blueprint_watcher():
class BlueprintEventHandler(FileSystemEventHandler):
"""Event handler for blueprint events"""
# We only ever get creation and modification events.
# See the creation of the Observer instance above for the event filtering.
# Even though we filter to only get file events, we might still get
# directory events as some implementations such as inotify do not support
# filtering on file/directory.
def dispatch(self, event: FileSystemEvent) -> None:
"""Call specific event handler method. Ignores directory changes."""
def on_any_event(self, event: FileSystemEvent):
if not isinstance(event, (FileCreatedEvent, FileModifiedEvent)):
return
if event.is_directory:
return None
return super().dispatch(event)
def on_created(self, event: FileSystemEvent):
"""Process file creation"""
LOGGER.debug("new blueprint file created, starting discovery")
for tenant in Tenant.objects.filter(ready=True):
with tenant:
blueprints_discovery.delay()
def on_modified(self, event: FileSystemEvent):
"""Process file modification"""
path = Path(event.src_path)
return
root = Path(CONFIG.get("blueprints_dir")).absolute()
path = Path(event.src_path).absolute()
rel_path = str(path.relative_to(root))
for tenant in Tenant.objects.filter(ready=True):
with tenant:
for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True):
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
apply_blueprint.delay(instance.pk.hex)
root = Path(CONFIG.get("blueprints_dir")).absolute()
path = Path(event.src_path).absolute()
rel_path = str(path.relative_to(root))
if isinstance(event, FileCreatedEvent):
LOGGER.debug("new blueprint file created, starting discovery", path=rel_path)
blueprints_discovery.delay(rel_path)
if isinstance(event, FileModifiedEvent):
for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True):
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
apply_blueprint.delay(instance.pk.hex)
@CELERY_APP.task(

View File

@ -9,7 +9,6 @@ from sentry_sdk.hub import Hub
from authentik import get_full_version
from authentik.brands.models import Brand
from authentik.tenants.models import Tenant
_q_default = Q(default=True)
DEFAULT_BRAND = Brand(domain="fallback")
@ -31,14 +30,13 @@ def get_brand_for_request(request: HttpRequest) -> Brand:
def context_processor(request: HttpRequest) -> dict[str, Any]:
"""Context Processor that injects brand object into every template"""
brand = getattr(request, "brand", DEFAULT_BRAND)
tenant = getattr(request, "tenant", Tenant())
trace = ""
span = Hub.current.scope.span
if span:
trace = span.to_traceparent()
return {
"brand": brand,
"footer_links": tenant.footer_links,
"footer_links": request.tenant.footer_links,
"sentry_trace": trace,
"version": get_full_version(),
}

View File

@ -2,7 +2,7 @@
from copy import copy
from datetime import timedelta
from typing import Iterator, Optional
from typing import Optional
from django.core.cache import cache
from django.db.models import QuerySet
@ -23,6 +23,7 @@ from structlog.stdlib import get_logger
from structlog.testing import capture_logs
from authentik.admin.api.metrics import CoordinateSerializer
from authentik.api.decorators import permission_required
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
@ -38,7 +39,6 @@ from authentik.lib.utils.file import (
from authentik.policies.api.exec import PolicyTestResultSerializer
from authentik.policies.engine import PolicyEngine
from authentik.policies.types import PolicyResult
from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger()
@ -131,14 +131,14 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
return queryset
def _get_allowed_applications(
self, pagined_apps: Iterator[Application], user: Optional[User] = None
self, queryset: QuerySet, user: Optional[User] = None
) -> list[Application]:
applications = []
request = self.request._request
if user:
request = copy(request)
request.user = user
for application in pagined_apps:
for application in queryset:
engine = PolicyEngine(application, request.user, request)
engine.build()
if engine.passing:
@ -215,7 +215,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
return super().list(request)
queryset = self._filter_queryset_for_list(self.get_queryset())
pagined_apps = self.paginate_queryset(queryset)
self.paginate_queryset(queryset)
if "for_user" in request.query_params:
try:
@ -229,18 +229,18 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
raise ValidationError({"for_user": "User not found"})
except ValueError as exc:
raise ValidationError from exc
allowed_applications = self._get_allowed_applications(pagined_apps, user=for_user)
allowed_applications = self._get_allowed_applications(queryset, user=for_user)
serializer = self.get_serializer(allowed_applications, many=True)
return self.get_paginated_response(serializer.data)
allowed_applications = []
if not should_cache:
allowed_applications = self._get_allowed_applications(pagined_apps)
allowed_applications = self._get_allowed_applications(queryset)
if should_cache:
allowed_applications = cache.get(user_app_cache_key(self.request.user.pk))
if not allowed_applications:
LOGGER.debug("Caching allowed application list")
allowed_applications = self._get_allowed_applications(pagined_apps)
allowed_applications = self._get_allowed_applications(queryset)
cache.set(
user_app_cache_key(self.request.user.pk),
allowed_applications,

View File

@ -15,11 +15,11 @@ from rest_framework.response import Response
from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError
from rest_framework.viewsets import ModelViewSet
from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField, PassiveSerializer
from authentik.core.models import Group, User
from authentik.rbac.api.roles import RoleSerializer
from authentik.rbac.decorators import permission_required
class GroupMemberSerializer(ModelSerializer):

View File

@ -14,6 +14,7 @@ from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer, SerializerMethodField
from rest_framework.viewsets import GenericViewSet
from authentik.api.decorators import permission_required
from authentik.blueprints.api import ManagedSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
@ -22,7 +23,6 @@ from authentik.core.models import PropertyMapping
from authentik.events.utils import sanitize_item
from authentik.lib.utils.reflection import all_subclasses
from authentik.policies.api.exec import PolicyTestSerializer
from authentik.rbac.decorators import permission_required
class PropertyMappingTestResultSerializer(PassiveSerializer):
@ -118,11 +118,7 @@ class PropertyMappingViewSet(
@action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"])
def test(self, request: Request, pk: str) -> Response:
"""Test Property Mapping"""
_mapping: PropertyMapping = self.get_object()
# Use `get_subclass` to get correct class and correct `.evaluate` implementation
mapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk)
# FIXME: when we separate policy mappings between ones for sources
# and ones for providers, we need to make the user field optional for the source mapping
mapping: PropertyMapping = self.get_object()
test_params = PolicyTestSerializer(data=request.data)
if not test_params.is_valid():
return Response(test_params.errors, status=400)

View File

@ -16,6 +16,7 @@ from rest_framework.viewsets import GenericViewSet
from structlog.stdlib import get_logger
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
from authentik.api.decorators import permission_required
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
@ -29,7 +30,6 @@ from authentik.lib.utils.file import (
)
from authentik.lib.utils.reflection import all_subclasses
from authentik.policies.engine import PolicyEngine
from authentik.rbac.decorators import permission_required
LOGGER = get_logger()

View File

@ -15,15 +15,15 @@ from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet
from authentik.api.authorization import OwnerSuperuserPermissions
from authentik.api.decorators import permission_required
from authentik.blueprints.api import ManagedSerializer
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.users import UserSerializer
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents
from authentik.events.models import Event, EventAction
from authentik.events.utils import model_to_dict
from authentik.rbac.decorators import permission_required
class TokenSerializer(ManagedSerializer, ModelSerializer):
@ -36,13 +36,6 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
self.fields["key"] = CharField(required=False)
def validate_user(self, user: User):
"""Ensure user of token cannot be changed"""
if self.instance and self.instance.user_id:
if user.pk != self.instance.user_id:
raise ValidationError("User cannot be changed")
return user
def validate(self, attrs: dict[Any, str]) -> dict[Any, str]:
"""Ensure only API or App password tokens are created."""
request: Request = self.context.get("request")

View File

@ -49,6 +49,7 @@ from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.admin.api.metrics import CoordinateSerializer
from authentik.api.decorators import permission_required
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.brands.models import Brand
from authentik.core.api.used_by import UsedByMixin
@ -73,7 +74,6 @@ from authentik.flows.models import FlowToken
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
from authentik.flows.views.executor import QS_KEY_TOKEN
from authentik.lib.avatars import get_avatar
from authentik.rbac.decorators import permission_required
from authentik.stages.email.models import EmailStage
from authentik.stages.email.tasks import send_mails
from authentik.stages.email.utils import TemplateEmailMessage
@ -154,7 +154,7 @@ class UserSerializer(ModelSerializer):
def get_avatar(self, user: User) -> str:
"""User's avatar, either a http/https URL or a data URI"""
return get_avatar(user, self.context.get("request"))
return get_avatar(user, self.context["request"])
def validate_path(self, path: str) -> str:
"""Validate path"""
@ -218,7 +218,7 @@ class UserSelfSerializer(ModelSerializer):
def get_avatar(self, user: User) -> str:
"""User's avatar, either a http/https URL or a data URI"""
return get_avatar(user, self.context.get("request"))
return get_avatar(user, self.context["request"])
@extend_schema_field(
ListSerializer(
@ -533,7 +533,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
400: OpenApiResponse(description="Bad request"),
},
)
@action(detail=True, methods=["POST"], permission_classes=[])
@action(detail=True, methods=["POST"])
def set_password(self, request: Request, pk: int) -> Response:
"""Set password for user"""
user: User = self.get_object()
@ -611,7 +611,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
email_stage: EmailStage = stages.first()
message = TemplateEmailMessage(
subject=_(email_stage.subject),
to=[(for_user.name, for_user.email)],
to=[for_user.email],
template_name=email_stage.template,
language=for_user.locale(request),
template_context={
@ -631,7 +631,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
"401": OpenApiResponse(description="Access denied"),
},
)
@action(detail=True, methods=["POST"], permission_classes=[])
@action(detail=True, methods=["POST"])
def impersonate(self, request: Request, pk: int) -> Response:
"""Impersonate a user"""
if not request.tenant.impersonation:

View File

@ -14,6 +14,10 @@ class AuthentikCoreConfig(ManagedAppConfig):
mountpoint = ""
default = True
def reconcile_global_load_core_signals(self):
"""Load core signals"""
self.import_module("authentik.core.signals")
def reconcile_global_debug_worker_hook(self):
"""Dispatch startup tasks inline when debugging"""
if settings.DEBUG:

View File

@ -43,9 +43,7 @@ class TokenBackend(InbuiltBackend):
self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any
) -> Optional[User]:
try:
# pylint: disable=no-member
user = User._default_manager.get_by_natural_key(username)
# pylint: disable=no-member
except User.DoesNotExist:
# Run the default password hasher once to reduce the timing
# difference between an existing and a nonexistent user (#20760).

View File

@ -37,7 +37,6 @@ def clean_expired_models(self: SystemTask):
messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}")
# Special case
amount = 0
# pylint: disable=no-member
for session in AuthenticatedSession.objects.all():
cache_key = f"{KEY_PREFIX}{session.session_key}"
value = None
@ -50,7 +49,6 @@ def clean_expired_models(self: SystemTask):
session.delete()
amount += 1
LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
# pylint: disable=no-member
messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")
self.set_status(TaskStatus.SUCCESSFUL, *messages)

View File

@ -7,8 +7,8 @@ from guardian.shortcuts import get_anonymous_user
from rest_framework.test import APITestCase
from authentik.core.api.tokens import TokenSerializer
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents
from authentik.core.tests.utils import create_test_admin_user, create_test_user
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User
from authentik.core.tests.utils import create_test_admin_user
from authentik.lib.generators import generate_id
@ -17,7 +17,7 @@ class TestTokenAPI(APITestCase):
def setUp(self) -> None:
super().setUp()
self.user = create_test_user()
self.user = User.objects.create(username="testuser")
self.admin = create_test_admin_user()
self.client.force_login(self.user)
@ -76,24 +76,6 @@ class TestTokenAPI(APITestCase):
self.assertEqual(token.intent, TokenIntents.INTENT_API)
self.assertEqual(token.expiring, False)
def test_token_change_user(self):
"""Test creating a token and then changing the user"""
ident = generate_id()
response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident})
self.assertEqual(response.status_code, 201)
token = Token.objects.get(identifier=ident)
self.assertEqual(token.user, self.user)
self.assertEqual(token.intent, TokenIntents.INTENT_API)
self.assertEqual(token.expiring, True)
self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token))
response = self.client.put(
reverse("authentik_api:token-detail", kwargs={"identifier": ident}),
data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk},
)
self.assertEqual(response.status_code, 400)
token.refresh_from_db()
self.assertEqual(token.user, self.user)
def test_list(self):
"""Test Token List (Test normal authentication)"""
Token.objects.all().delete()

View File

@ -24,13 +24,13 @@ from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.authorization import SecretKeyFilter
from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.crypto.apps import MANAGED_KEY
from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event, EventAction
from authentik.rbac.decorators import permission_required
LOGGER = get_logger()

View File

@ -1,6 +1,6 @@
"""authentik crypto app config"""
from datetime import datetime, timezone
from datetime import datetime
from typing import Optional
from authentik.blueprints.apps import ManagedAppConfig
@ -17,6 +17,10 @@ class AuthentikCryptoConfig(ManagedAppConfig):
verbose_name = "authentik Crypto"
default = True
def reconcile_global_load_crypto_tasks(self):
"""Load crypto tasks"""
self.import_module("authentik.crypto.tasks")
def _create_update_cert(self):
from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair
@ -43,9 +47,9 @@ class AuthentikCryptoConfig(ManagedAppConfig):
cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter(
managed=MANAGED_KEY
).first()
now = datetime.now(tz=timezone.utc)
now = datetime.now()
if not cert or (
now < cert.certificate.not_valid_after_utc or now > cert.certificate.not_valid_after_utc
now < cert.certificate.not_valid_before or now > cert.certificate.not_valid_after
):
self._create_update_cert()

View File

@ -1,7 +1,6 @@
"""Enterprise API Views"""
from dataclasses import asdict
from datetime import timedelta
from datetime import datetime, timedelta
from django.utils.timezone import now
from django.utils.translation import gettext as _
@ -9,29 +8,29 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, IntegerField
from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet
from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
from authentik.enterprise.models import License
from authentik.rbac.decorators import permission_required
from authentik.enterprise.models import License, LicenseKey
from authentik.root.install_id import get_install_id
class EnterpriseRequiredMixin:
"""Mixin to validate that a valid enterprise license
exists before allowing to save the object"""
exists before allowing to safe the object"""
def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists"""
if not LicenseKey.cached_summary().has_license:
total = LicenseKey.get_total()
if not total.is_valid():
raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs)
@ -62,6 +61,19 @@ class LicenseSerializer(ModelSerializer):
}
class LicenseSummary(PassiveSerializer):
"""Serializer for license status"""
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField()
has_license = BooleanField()
class LicenseForecastSerializer(PassiveSerializer):
"""Serializer for license forecast"""
@ -99,13 +111,31 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
@extend_schema(
request=OpenApiTypes.NONE,
responses={
200: LicenseSummarySerializer(),
200: LicenseSummary(),
},
)
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
def summary(self, request: Request) -> Response:
"""Get the total license status"""
response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
total = LicenseKey.get_total()
last_valid = LicenseKey.last_valid_date()
# TODO: move this to a different place?
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(total.exp)
response = LicenseSummary(
data={
"internal_users": total.internal_users,
"external_users": total.external_users,
"valid": total.is_valid(),
"show_admin_warning": show_admin_warning,
"show_user_warning": show_user_warning,
"read_only": read_only,
"latest_valid": latest_valid,
"has_license": License.objects.all().count() > 0,
}
)
response.is_valid(raise_exception=True)
return Response(response.data)

View File

@ -17,12 +17,16 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
verbose_name = "authentik Enterprise"
default = True
def reconcile_global_load_enterprise_signals(self):
"""Load enterprise signals"""
self.import_module("authentik.enterprise.signals")
def enabled(self):
"""Return true if enterprise is enabled and valid"""
return self.check_enabled() or settings.TEST
def check_enabled(self):
"""Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseKey
return LicenseKey.cached_summary().valid
return LicenseKey.get_total().is_valid()

View File

@ -11,6 +11,7 @@ from django.db.models.expressions import BaseExpression, Combinable
from django.db.models.signals import post_init
from django.http import HttpRequest
from authentik.core.models import User
from authentik.events.middleware import AuditMiddleware, should_log_model
from authentik.events.utils import cleanse_dict, sanitize_item
@ -18,19 +19,26 @@ from authentik.events.utils import cleanse_dict, sanitize_item
class EnterpriseAuditMiddleware(AuditMiddleware):
"""Enterprise audit middleware"""
_enabled = None
@property
def enabled(self):
"""Check if audit logging is enabled"""
return apps.get_app_config("authentik_enterprise").enabled()
"""Lazy check if audit logging is enabled"""
if self._enabled is None:
self._enabled = apps.get_app_config("authentik_enterprise").enabled()
return self._enabled
def connect(self, request: HttpRequest):
super().connect(request)
if not self.enabled:
return
user = getattr(request, "user", self.anonymous_user)
if not user.is_authenticated:
user = self.anonymous_user
if not hasattr(request, "request_id"):
return
post_init.connect(
partial(self.post_init_handler, request=request),
partial(self.post_init_handler, user=user, request=request),
dispatch_uid=request.request_id,
weak=False,
)
@ -72,7 +80,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
diff[key] = {"previous_value": value, "new_value": after.get(key)}
return sanitize_item(diff)
def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
def post_init_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_):
"""post_init django model handler"""
if not should_log_model(instance):
return
@ -87,6 +95,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
# pylint: disable=too-many-arguments
def post_save_handler(
self,
user: User,
request: HttpRequest,
sender,
instance: Model,
@ -108,4 +117,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
for field_set in ignored_field_sets:
if set(diff.keys()) == set(field_set):
return None
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
return super().post_save_handler(
user, request, sender, instance, created, thread_kwargs, **_
)

View File

@ -1,214 +0,0 @@
"""Enterprise license"""
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, DateTimeField, IntegerField
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseUsage
from authentik.root.install_id import get_install_id
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60 # 2 Hours
@lru_cache()
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(_key.read())
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_install_id()}"
class LicenseFlags(Enum):
"""License flags"""
@dataclass
class LicenseSummary:
"""Internal representation of a license summary"""
internal_users: int
external_users: int
valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
latest_valid: datetime
has_license: bool
class LicenseSummarySerializer(PassiveSerializer):
"""Serializer for license status"""
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField()
has_license = BooleanField()
@dataclass
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license")
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
try:
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
our_cert.verify_directly_issued_by(intermediate)
intermediate.verify_directly_issued_by(get_licensing_key())
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license")
try:
body = from_dict(
LicenseKey,
decode(
jwt,
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
),
)
except PyJWTError:
raise ValidationError("Unable to verify license")
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
if exp_ts <= total.exp:
total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@staticmethod
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
LicenseKey.base_user_qs()
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
.count()
)
def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
)
summary = asdict(self.summary())
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
def summary(self) -> LicenseSummary:
"""Summary of license status"""
has_license = License.objects.all().count() > 0
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary(
show_admin_warning=show_admin_warning and has_license,
show_user_warning=show_user_warning and has_license,
read_only=read_only and has_license,
latest_valid=latest_valid,
internal_users=self.internal_users,
external_users=self.external_users,
valid=self.is_valid(),
has_license=has_license,
)
@staticmethod
def cached_summary() -> LicenseSummary:
"""Helper method which looks up the last summary"""
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary:
return LicenseKey.get_total().summary()
return from_dict(LicenseSummary, summary)

View File

@ -1,64 +0,0 @@
"""Enterprise middleware"""
from collections.abc import Callable
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.urls import resolve
from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path
class EnterpriseMiddleware:
"""Enterprise middleware"""
get_response: Callable[[HttpRequest], HttpResponse]
logger: BoundLogger
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
self.logger = get_logger().bind()
def __call__(self, request: HttpRequest) -> HttpResponse:
resolver_match = resolve(request.path_info)
request.resolver_match = resolver_match
if not self.is_request_allowed(request):
self.logger.warning("Refusing request due to expired/invalid license")
return JsonResponse(
{
"detail": "Request denied due to expired/invalid license.",
"code": "denied_license",
},
status=400,
)
return self.get_response(request)
def is_request_allowed(self, request: HttpRequest) -> bool:
"""Check if a specific request is allowed"""
if self.is_request_always_allowed(request):
return True
cached_status = LicenseKey.cached_summary()
if not cached_status:
return True
if cached_status.read_only:
return False
return True
def is_request_always_allowed(self, request: HttpRequest):
"""Check if a request is always allowed"""
# Always allow "safe" methods
if request.method.lower() in ["get", "head", "options", "trace"]:
return True
# Always allow requests to manage licenses
if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
return True
# Flow executor is mounted as an API path but explicitly allowed
if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
return True
# Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names:
return True
return False

View File

@ -1,20 +1,159 @@
"""Enterprise models"""
from datetime import timedelta
from typing import TYPE_CHECKING
from base64 import b64decode
from binascii import Error
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from uuid import uuid4
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from django.contrib.postgres.indexes import HashIndex
from django.db import models
from django.db.models.query import QuerySet
from django.utils.timezone import now
from django.utils.translation import gettext as _
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer
from authentik.core.models import ExpiringModel
from authentik.core.models import ExpiringModel, User, UserTypes
from authentik.lib.models import SerializerModel
from authentik.root.install_id import get_install_id
if TYPE_CHECKING:
from authentik.enterprise.license import LicenseKey
@lru_cache()
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(_key.read())
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"enterprise.goauthentik.io/license/{get_install_id()}"
class LicenseFlags(Enum):
"""License flags"""
@dataclass
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license")
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
try:
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
our_cert.verify_directly_issued_by(intermediate)
intermediate.verify_directly_issued_by(get_licensing_key())
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license")
try:
body = from_dict(
LicenseKey,
decode(
jwt,
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
),
)
except PyJWTError:
raise ValidationError("Unable to verify license")
return body
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
if exp_ts <= total.exp:
total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total
@staticmethod
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@staticmethod
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
LicenseKey.base_user_qs()
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
.count()
)
def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if LicenseUsage.objects.filter(record_date__gte=threshold).exists():
return
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
)
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
class License(SerializerModel):
@ -35,10 +174,8 @@ class License(SerializerModel):
return LicenseSerializer
@property
def status(self) -> "LicenseKey":
def status(self) -> LicenseKey:
"""Get parsed license status"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key)
class Meta:

View File

@ -5,7 +5,7 @@ from typing import Optional
from django.utils.translation import gettext_lazy as _
from authentik.core.models import User, UserTypes
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseKey
from authentik.policies.types import PolicyRequest, PolicyResult
from authentik.policies.views import PolicyAccessView

View File

@ -1,53 +0,0 @@
"""RAC Provider API Views"""
from django_filters.rest_framework.backends import DjangoFilterBackend
from rest_framework import mixins
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
from authentik.core.api.groups import GroupMemberSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
from authentik.enterprise.providers.rac.models import ConnectionToken
class ConnectionTokenSerializer(EnterpriseRequiredMixin, ModelSerializer):
"""ConnectionToken Serializer"""
provider_obj = RACProviderSerializer(source="provider", read_only=True)
endpoint_obj = EndpointSerializer(source="endpoint", read_only=True)
user = GroupMemberSerializer(source="session.user", read_only=True)
class Meta:
model = ConnectionToken
fields = [
"pk",
"provider",
"provider_obj",
"endpoint",
"endpoint_obj",
"user",
]
class ConnectionTokenViewSet(
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""ConnectionToken Viewset"""
queryset = ConnectionToken.objects.all().select_related("session", "endpoint")
serializer_class = ConnectionTokenSerializer
filterset_fields = ["endpoint", "session__user", "provider"]
search_fields = ["endpoint__name", "provider__name"]
ordering = ["endpoint__name", "provider__name"]
permission_classes = [OwnerSuperuserPermissions]
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]

View File

@ -16,12 +16,7 @@ class RACProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
class Meta:
model = RACProvider
fields = ProviderSerializer.Meta.fields + [
"settings",
"outpost_set",
"connection_expiry",
"delete_token_on_disconnect",
]
fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs

View File

@ -12,3 +12,7 @@ class AuthentikEnterpriseProviderRAC(EnterpriseConfig):
default = True
mountpoint = ""
ws_mountpoint = "authentik.enterprise.providers.rac.urls"
def reconcile_global_load_rac_signals(self):
"""Load rac signals"""
self.import_module("authentik.enterprise.providers.rac.signals")

View File

@ -43,7 +43,6 @@ class RACClientConsumer(AsyncWebsocketConsumer):
logger: BoundLogger
async def connect(self):
self.logger = get_logger()
await self.accept("guacamole")
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
await self.channel_layer.group_add(
@ -65,11 +64,9 @@ class RACClientConsumer(AsyncWebsocketConsumer):
@database_sync_to_async
def init_outpost_connection(self):
"""Initialize guac connection settings"""
self.token = (
ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"])
.select_related("endpoint", "provider", "session", "session__user")
.first()
)
self.token = ConnectionToken.filter_not_expired(
token=self.scope["url_route"]["kwargs"]["token"]
).first()
if not self.token:
raise DenyConnection()
self.provider = self.token.provider
@ -110,9 +107,6 @@ class RACClientConsumer(AsyncWebsocketConsumer):
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
msg,
)
if self.provider and self.provider.delete_token_on_disconnect:
self.logger.info("Deleting connection token to prevent reconnect", token=self.token)
self.token.delete()
async def receive(self, text_data=None, bytes_data=None):
"""Mirror data received from client to the dest_channel_id

View File

@ -1,181 +0,0 @@
# Generated by Django 5.0.1 on 2024-02-11 19:04
import uuid
import django.db.models.deletion
from django.db import migrations, models
import authentik.core.models
import authentik.lib.utils.time
class Migration(migrations.Migration):
replaces = [
("authentik_providers_rac", "0001_initial"),
("authentik_providers_rac", "0002_endpoint_maximum_connections"),
("authentik_providers_rac", "0003_alter_connectiontoken_options_and_more"),
]
initial = True
dependencies = [
("authentik_core", "0032_group_roles"),
("authentik_policies", "0011_policybinding_failure_result_and_more"),
]
operations = [
migrations.CreateModel(
name="RACPropertyMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
("static_settings", models.JSONField(default=dict)),
],
options={
"verbose_name": "RAC Property Mapping",
"verbose_name_plural": "RAC Property Mappings",
},
bases=("authentik_core.propertymapping",),
),
migrations.CreateModel(
name="RACProvider",
fields=[
(
"provider_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.provider",
),
),
("settings", models.JSONField(default=dict)),
(
"auth_mode",
models.TextField(
choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt"
),
),
(
"connection_expiry",
models.TextField(
default="hours=8",
help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)",
validators=[authentik.lib.utils.time.timedelta_string_validator],
),
),
(
"delete_token_on_disconnect",
models.BooleanField(
default=False,
help_text="When set to true, connection tokens will be deleted upon disconnect.",
),
),
],
options={
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
},
bases=("authentik_core.provider",),
),
migrations.CreateModel(
name="Endpoint",
fields=[
(
"policybindingmodel_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_policies.policybindingmodel",
),
),
("name", models.TextField()),
("host", models.TextField()),
(
"protocol",
models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]),
),
("settings", models.JSONField(default=dict)),
(
"auth_mode",
models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]),
),
(
"property_mappings",
models.ManyToManyField(
blank=True, default=None, to="authentik_core.propertymapping"
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.racprovider",
),
),
("maximum_connections", models.IntegerField(default=1)),
],
options={
"verbose_name": "RAC Endpoint",
"verbose_name_plural": "RAC Endpoints",
},
bases=("authentik_policies.policybindingmodel", models.Model),
),
migrations.CreateModel(
name="ConnectionToken",
fields=[
(
"expires",
models.DateTimeField(default=authentik.core.models.default_token_duration),
),
("expiring", models.BooleanField(default=True)),
(
"connection_token_uuid",
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
),
("token", models.TextField(default=authentik.core.models.default_token_key)),
("settings", models.JSONField(default=dict)),
(
"endpoint",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.endpoint",
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.racprovider",
),
),
(
"session",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_core.authenticatedsession",
),
),
],
options={
"abstract": False,
"verbose_name": "RAC Connection token",
"verbose_name_plural": "RAC Connection tokens",
},
),
]

View File

@ -1,28 +0,0 @@
# Generated by Django 5.0.1 on 2024-02-11 19:04
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_rac", "0002_endpoint_maximum_connections"),
]
operations = [
migrations.AlterModelOptions(
name="connectiontoken",
options={
"verbose_name": "RAC Connection token",
"verbose_name_plural": "RAC Connection tokens",
},
),
migrations.AddField(
model_name="racprovider",
name="delete_token_on_disconnect",
field=models.BooleanField(
default=False,
help_text="When set to true, connection tokens will be deleted upon disconnect.",
),
),
]

View File

@ -1,18 +1,17 @@
"""RAC Models"""
from typing import Any, Optional
from typing import Optional
from uuid import uuid4
from deepmerge import always_merger
from django.db import models
from django.db.models import QuerySet
from django.http import HttpRequest
from django.utils.translation import gettext as _
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User, default_token_key
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key
from authentik.events.models import Event, EventAction
from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_string_validator
@ -52,10 +51,6 @@ class RACProvider(Provider):
"(Format: hours=-1;minutes=-2;seconds=-3)"
),
)
delete_token_on_disconnect = models.BooleanField(
default=False,
help_text=_("When set to true, connection tokens will be deleted upon disconnect."),
)
@property
def launch_url(self) -> Optional[str]:
@ -112,12 +107,6 @@ class RACPropertyMapping(PropertyMapping):
static_settings = models.JSONField(default=dict)
def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any:
"""Evaluate `self.expression` using `**kwargs` as Context."""
if len(self.static_settings) > 0:
return self.static_settings
return super().evaluate(user, request, **kwargs)
@property
def component(self) -> str:
return "ak-property-mapping-rac-form"
@ -166,6 +155,9 @@ class ConnectionToken(ExpiringModel):
def mapping_evaluator(mappings: QuerySet):
for mapping in mappings:
mapping: RACPropertyMapping
if len(mapping.static_settings) > 0:
always_merger.merge(settings, mapping.static_settings)
continue
try:
mapping_settings = mapping.evaluate(
self.session.user, None, endpoint=self.endpoint, provider=self.provider
@ -199,13 +191,3 @@ class ConnectionToken(ExpiringModel):
continue
settings[key] = str(value)
return settings
def __str__(self):
return (
f"RAC Connection token {self.session.user} to "
f"{self.endpoint.provider.name}/{self.endpoint.name}"
)
class Meta:
verbose_name = _("RAC Connection token")
verbose_name_plural = _("RAC Connection tokens")

View File

@ -45,8 +45,8 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **
@receiver(post_save, sender=Endpoint)
def post_save_endpoint(sender: type[Model], instance, created: bool, **_):
"""Clear user's endpoint cache upon endpoint creation"""
def post_save_application(sender: type[Model], instance, created: bool, **_):
"""Clear user's application cache upon application creation"""
if not created: # pragma: no cover
return

View File

@ -70,7 +70,6 @@ class TestEndpointsAPI(APITestCase):
"authorization_flow": None,
"property_mappings": [],
"connection_expiry": "hours=8",
"delete_token_on_disconnect": False,
"component": "ak-provider-rac-form",
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
@ -125,7 +124,6 @@ class TestEndpointsAPI(APITestCase):
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"connection_expiry": "hours=8",
"delete_token_on_disconnect": False,
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",
@ -154,7 +152,6 @@ class TestEndpointsAPI(APITestCase):
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"connection_expiry": "hours=8",
"delete_token_on_disconnect": False,
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",

View File

@ -11,8 +11,7 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.models import License, LicenseKey
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
from authentik.lib.generators import generate_id
from authentik.policies.denied import AccessDeniedResponse
@ -40,7 +39,7 @@ class TestRACViews(APITestCase):
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
@ -71,7 +70,7 @@ class TestRACViews(APITestCase):
self.assertEqual(final_response.status_code, 200)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
@ -100,7 +99,7 @@ class TestRACViews(APITestCase):
self.assertIsInstance(response, AccessDeniedResponse)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",

View File

@ -6,7 +6,6 @@ from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from authentik.core.channels import TokenOutpostMiddleware
from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet
from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet
@ -46,5 +45,4 @@ api_urlpatterns = [
("providers/rac", RACProviderViewSet),
("propertymappings/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet),
("rac/connection_tokens", ConnectionTokenViewSet),
]

View File

@ -104,15 +104,14 @@ class RACFinalStage(RedirectStage):
# Check if we're already at the maximum connection limit
all_tokens = ConnectionToken.filter_not_expired(
endpoint=self.endpoint,
)
if self.endpoint.maximum_connections > -1:
if all_tokens.count() >= self.endpoint.maximum_connections:
msg = [_("Maximum connection limit reached.")]
# Check if any other tokens exist for the current user, and inform them
# they are already connected
if all_tokens.filter(session__user=self.request.user).exists():
msg.append(_("(You are already connected in another tab/window)"))
return self.executor.stage_invalid(" ".join(msg))
).exclude(endpoint__maximum_connections__lte=-1)
if all_tokens.count() >= self.endpoint.maximum_connections:
msg = [_("Maximum connection limit reached.")]
# Check if any other tokens exist for the current user, and inform them
# they are already connected
if all_tokens.filter(session__user=self.request.user).exists():
msg.append(_("(You are already connected in another tab/window)"))
return self.executor.stage_invalid(" ".join(msg))
return super().dispatch(request, *args, **kwargs)
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:

View File

@ -5,9 +5,9 @@ from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"enterprise_update_usage": {
"task": "authentik.enterprise.tasks.enterprise_update_usage",
"schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"),
"enterprise_calculate_license": {
"task": "authentik.enterprise.tasks.calculate_license",
"schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/2"),
"options": {"queue": "authentik_scheduled"},
}
}
@ -16,5 +16,3 @@ TENANT_APPS = [
"authentik.enterprise.audit",
"authentik.enterprise.providers.rac",
]
MIDDLEWARE = ["authentik.enterprise.middleware.EnterpriseMiddleware"]

View File

@ -2,14 +2,11 @@
from datetime import datetime
from django.core.cache import cache
from django.db.models.signals import post_save, pre_save
from django.db.models.signals import pre_save
from django.dispatch import receiver
from django.utils.timezone import get_current_timezone
from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE
from authentik.enterprise.models import License
from authentik.enterprise.tasks import enterprise_update_usage
@receiver(pre_save, sender=License)
@ -20,10 +17,3 @@ def pre_save_license(sender: type[License], instance: License, **_):
instance.internal_users = status.internal_users
instance.external_users = status.external_users
instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone())
@receiver(post_save, sender=License)
def post_save_license(sender: type[License], instance: License, **_):
"""Trigger license usage calculation when license is saved"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
enterprise_update_usage.delay()

View File

@ -1,14 +1,10 @@
"""Enterprise tasks"""
from authentik.enterprise.license import LicenseKey
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.enterprise.models import LicenseKey
from authentik.root.celery import CELERY_APP
@CELERY_APP.task(bind=True, base=SystemTask)
@prefill_task
def enterprise_update_usage(self: SystemTask):
"""Update enterprise license status"""
@CELERY_APP.task()
def calculate_license():
"""Calculate licensing status"""
LicenseKey.get_total().record_usage()
self.set_status(TaskStatus.SUCCESSFUL)

View File

@ -8,8 +8,7 @@ from django.test import TestCase
from django.utils.timezone import now
from rest_framework.exceptions import ValidationError
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.models import License, LicenseKey
from authentik.lib.generators import generate_id
_exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
@ -19,7 +18,7 @@ class TestEnterpriseLicense(TestCase):
"""Enterprise license tests"""
@patch(
"authentik.enterprise.license.LicenseKey.validate",
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
@ -42,7 +41,7 @@ class TestEnterpriseLicense(TestCase):
License.objects.create(key=generate_id())
@patch(
"authentik.enterprise.license.LicenseKey.validate",
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",

View File

@ -38,6 +38,7 @@ class EventSerializer(ModelSerializer):
"created",
"expires",
"brand",
"batch_id",
]

View File

@ -12,6 +12,7 @@ from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet
from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer
from authentik.events.models import (
@ -23,7 +24,6 @@ from authentik.events.models import (
TransportMode,
)
from authentik.events.utils import get_user
from authentik.rbac.decorators import permission_required
class NotificationTransportSerializer(ModelSerializer):
@ -53,6 +53,9 @@ class NotificationTransportSerializer(ModelSerializer):
"webhook_url",
"webhook_mapping",
"send_once",
"enable_batching",
"batch_timeout",
"max_batch_size",
]

View File

@ -1,5 +1,6 @@
"""Tasks API"""
from datetime import datetime, timezone
from importlib import import_module
from django.contrib import messages
@ -7,22 +8,15 @@ from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import (
CharField,
ChoiceField,
DateTimeField,
FloatField,
ListField,
SerializerMethodField,
)
from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ReadOnlyModelViewSet
from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required
from authentik.events.models import SystemTask, TaskStatus
from authentik.rbac.decorators import permission_required
LOGGER = get_logger()
@ -34,9 +28,9 @@ class SystemTaskSerializer(ModelSerializer):
full_name = SerializerMethodField()
uid = CharField(required=False)
description = CharField()
start_timestamp = DateTimeField(read_only=True)
finish_timestamp = DateTimeField(read_only=True)
duration = FloatField(read_only=True)
start_timestamp = SerializerMethodField()
finish_timestamp = SerializerMethodField()
duration = SerializerMethodField()
status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus])
messages = ListField(child=CharField())
@ -47,6 +41,18 @@ class SystemTaskSerializer(ModelSerializer):
return f"{instance.name}:{instance.uid}"
return instance.name
def get_start_timestamp(self, instance: SystemTask) -> datetime:
"""Timestamp when the task started"""
return datetime.fromtimestamp(instance.start_timestamp, tz=timezone.utc)
def get_finish_timestamp(self, instance: SystemTask) -> datetime:
"""Timestamp when the task finished"""
return datetime.fromtimestamp(instance.finish_timestamp, tz=timezone.utc)
def get_duration(self, instance: SystemTask) -> float:
"""Get the duration a task took to run"""
return max(instance.finish_timestamp - instance.start_timestamp, 0)
class Meta:
model = SystemTask
fields = [
@ -81,7 +87,7 @@ class SystemTaskViewSet(ReadOnlyModelViewSet):
500: OpenApiResponse(description="Failed to retry task"),
},
)
@action(detail=True, methods=["POST"], permission_classes=[])
@action(detail=True, methods=["post"])
def run(self, request: Request, pk=None) -> Response:
"""Run task"""
task: SystemTask = self.get_object()

View File

@ -1,12 +1,9 @@
"""authentik events app"""
from celery.schedules import crontab
from prometheus_client import Gauge, Histogram
from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.config import CONFIG, ENV_PREFIX
from authentik.lib.utils.reflection import path_to_class
from authentik.root.celery import CELERY_APP
# TODO: Deprecated metric - remove in 2024.2 or later
GAUGE_TASKS = Gauge(
@ -18,7 +15,7 @@ GAUGE_TASKS = Gauge(
SYSTEM_TASK_TIME = Histogram(
"authentik_system_tasks_time_seconds",
"Runtime of system tasks",
["tenant", "task_name", "task_uid"],
["tenant"],
)
SYSTEM_TASK_STATUS = Gauge(
"authentik_system_tasks_status",
@ -35,6 +32,10 @@ class AuthentikEventsConfig(ManagedAppConfig):
verbose_name = "authentik Events"
default = True
def reconcile_global_load_events_signals(self):
"""Load events signals"""
self.import_module("authentik.events.signals")
def reconcile_global_check_deprecations(self):
"""Check for config deprecations"""
from authentik.events.models import Event, EventAction
@ -56,7 +57,7 @@ class AuthentikEventsConfig(ManagedAppConfig):
message=msg,
).save()
def reconcile_tenant_prefill_tasks(self):
def reconcile_prefill_tasks(self):
"""Prefill tasks"""
from authentik.events.models import SystemTask
from authentik.events.system_tasks import _prefill_tasks
@ -66,28 +67,3 @@ class AuthentikEventsConfig(ManagedAppConfig):
continue
task.save()
self.logger.debug("prefilled task", task_name=task.name)
def reconcile_tenant_run_scheduled_tasks(self):
"""Run schedule tasks which are behind schedule (only applies
to tasks of which we keep metrics)"""
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask as CelerySystemTask
for task in CELERY_APP.conf["beat_schedule"].values():
schedule = task["schedule"]
if not isinstance(schedule, crontab):
continue
task_class: CelerySystemTask = path_to_class(task["task"])
if not isinstance(task_class, CelerySystemTask):
continue
db_task = task_class.db()
if not db_task:
continue
due, _ = schedule.is_due(db_task.finish_timestamp)
if due or db_task.status == TaskStatus.UNKNOWN:
self.logger.debug("Running past-due scheduled task", task=task["task"])
task_class.apply_async(
args=task.get("args", None),
kwargs=task.get("kwargs", None),
**task.get("options", {}),
)

View File

@ -82,29 +82,26 @@ class AuditMiddleware:
self.anonymous_user = get_anonymous_user()
def get_user(self, request: HttpRequest) -> User:
user = getattr(request, "user", self.anonymous_user)
if not user.is_authenticated:
return self.anonymous_user
return user
def connect(self, request: HttpRequest):
"""Connect signal for automatic logging"""
self._ensure_fallback_user()
user = getattr(request, "user", self.anonymous_user)
if not user.is_authenticated:
user = self.anonymous_user
if not hasattr(request, "request_id"):
return
post_save.connect(
partial(self.post_save_handler, request=request),
partial(self.post_save_handler, user=user, request=request),
dispatch_uid=request.request_id,
weak=False,
)
pre_delete.connect(
partial(self.pre_delete_handler, request=request),
partial(self.pre_delete_handler, user=user, request=request),
dispatch_uid=request.request_id,
weak=False,
)
m2m_changed.connect(
partial(self.m2m_changed_handler, request=request),
partial(self.m2m_changed_handler, user=user, request=request),
dispatch_uid=request.request_id,
weak=False,
)
@ -150,6 +147,7 @@ class AuditMiddleware:
# pylint: disable=too-many-arguments
def post_save_handler(
self,
user: User,
request: HttpRequest,
sender,
instance: Model,
@ -160,18 +158,16 @@ class AuditMiddleware:
"""Signal handler for all object's post_save"""
if not should_log_model(instance):
return
user = self.get_user(request)
action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED
thread = EventNewThread(action, request, user=user, model=model_to_dict(instance))
thread.kwargs.update(thread_kwargs or {})
thread.run()
def pre_delete_handler(self, request: HttpRequest, sender, instance: Model, **_):
def pre_delete_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_):
"""Signal handler for all object's pre_delete"""
if not should_log_model(instance): # pragma: no cover
return
user = self.get_user(request)
EventNewThread(
EventAction.MODEL_DELETED,
@ -180,13 +176,14 @@ class AuditMiddleware:
model=model_to_dict(instance),
).run()
def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_):
def m2m_changed_handler(
self, user: User, request: HttpRequest, sender, instance: Model, action: str, **_
):
"""Signal handler for all object's m2m_changed"""
if action not in ["pre_add", "pre_remove", "post_clear"]:
return
if not should_log_m2m(instance):
return
user = self.get_user(request)
EventNewThread(
EventAction.MODEL_UPDATED,

View File

@ -1,68 +0,0 @@
# Generated by Django 5.0.1 on 2024-02-07 15:42
import uuid
import django.utils.timezone
from django.db import migrations, models
import authentik.core.models
class Migration(migrations.Migration):
replaces = [
("authentik_events", "0004_systemtask"),
("authentik_events", "0005_remove_systemtask_finish_timestamp_and_more"),
]
dependencies = [
("authentik_events", "0003_rename_tenant_event_brand"),
]
operations = [
migrations.CreateModel(
name="SystemTask",
fields=[
(
"expires",
models.DateTimeField(default=authentik.core.models.default_token_duration),
),
("expiring", models.BooleanField(default=True)),
(
"uuid",
models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
("name", models.TextField()),
("uid", models.TextField(null=True)),
(
"status",
models.TextField(
choices=[
("unknown", "Unknown"),
("successful", "Successful"),
("warning", "Warning"),
("error", "Error"),
]
),
),
("description", models.TextField(null=True)),
("messages", models.JSONField()),
("task_call_module", models.TextField()),
("task_call_func", models.TextField()),
("task_call_args", models.JSONField(default=list)),
("task_call_kwargs", models.JSONField(default=dict)),
("duration", models.FloatField(default=0)),
("finish_timestamp", models.DateTimeField(default=django.utils.timezone.now)),
("start_timestamp", models.DateTimeField(default=django.utils.timezone.now)),
],
options={
"verbose_name": "System Task",
"verbose_name_plural": "System Tasks",
"permissions": [("run_task", "Run task")],
"default_permissions": ["view"],
"unique_together": {("name", "uid")},
},
),
]

View File

@ -1,37 +0,0 @@
# Generated by Django 5.0.1 on 2024-02-06 18:02
import django.utils.timezone
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0004_systemtask"),
]
operations = [
migrations.RemoveField(
model_name="systemtask",
name="finish_timestamp",
),
migrations.RemoveField(
model_name="systemtask",
name="start_timestamp",
),
migrations.AddField(
model_name="systemtask",
name="duration",
field=models.FloatField(default=0),
),
migrations.AddField(
model_name="systemtask",
name="finish_timestamp",
field=models.DateTimeField(default=django.utils.timezone.now),
),
migrations.AddField(
model_name="systemtask",
name="start_timestamp",
field=models.DateTimeField(default=django.utils.timezone.now),
),
]

View File

@ -172,6 +172,63 @@ class EventManager(Manager):
return self.get_queryset().get_events_per(time_since, extract, data_points)
class EventBatch(ExpiringModel):
"""Model to store information about batches of events."""
batch_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
event_type = models.CharField(max_length=255)
event_app = models.CharField(max_length=255)
event_user = models.CharField(max_length=255)
start_time = models.DateTimeField(auto_now_add=True)
end_time = models.DateTimeField(null=True, blank=True)
event_count = models.IntegerField(default=0)
last_updated = models.DateTimeField(auto_now=True)
max_batch_size = models.IntegerField(default=10)
batch_timeout = models.IntegerField(default=60) # Timeout in seconds
sent = models.BooleanField(default=False)
def add_event_to_batch(self, event):
"""Add an event to the batch and check if it's ready to send."""
self.add_event(event)
if self.check_batch_limits():
self.process_batch()
@staticmethod
def get_or_create_batch(action, app, user):
"""Get or create a batch for a given action."""
return EventBatch.objects.filter(
event_type=action, event_app=app, event_user=user, end_time__isnull=True
).first() or EventBatch.objects.create(event_type=action, event_app=app, event_user=user)
def check_batch_limits(self):
"""Check if the batch has reached its size or timeout limits."""
time_elapsed = now() - self.start_time
return self.event_count >= self.max_batch_size or time_elapsed >= timedelta(
seconds=self.batch_timeout
)
def add_event(self, event):
"""Add an event to the batch."""
self.event_count += 1
self.save()
def create_batch_summary(self):
"""Create a summary message for the batch."""
return f"Batched Event Summary: {self.event_type} action \
on {self.event_app} app by {self.event_user} user \
occurred {self.event_count} times between {self.start_time} and {now()}"
def process_batch(self):
"""Process the batch and check if it's ready to send."""
summary_message = self.create_batch_summary()
return summary_message
def send_notification(self):
"""Send notification for this batch."""
# Implement the logic to send notification
pass
class Event(SerializerModel, ExpiringModel):
"""An individual Audit/Metrics/Notification/Error Event"""
@ -187,6 +244,8 @@ class Event(SerializerModel, ExpiringModel):
# Shadow the expires attribute from ExpiringModel to override the default duration
expires = models.DateTimeField(default=default_event_duration)
batch_id = models.UUIDField(null=True, blank=True)
objects = EventManager()
@staticmethod
@ -214,6 +273,7 @@ class Event(SerializerModel, ExpiringModel):
# Also ensure that closest django app has the correct prefix
if len(django_apps) > 0 and django_apps[0].startswith(app):
app = django_apps[0]
cleaned_kwargs = cleanse_dict(sanitize_dict(kwargs))
event = Event(action=action, app=app, context=cleaned_kwargs)
return event
@ -275,6 +335,9 @@ class Event(SerializerModel, ExpiringModel):
return self
def save(self, *args, **kwargs):
# Creating a batch for this event in the save method
batch = EventBatch.get_or_create_batch(self.action, self.user, self.app)
self.batch_id = batch.batch_id
if self._state.adding:
LOGGER.info(
"Created Event",
@ -334,7 +397,17 @@ class NotificationTransport(SerializerModel):
),
)
enable_batching = models.BooleanField(default=False)
batch_timeout = models.IntegerField(default=60) # Timeout in seconds
max_batch_size = models.IntegerField(default=10)
def send(self, notification: "Notification") -> list[str]:
"""Send a batched notification or a single notification"""
if self.enable_batching:
return self.process_batch(notification)
return self.send_notification(notification)
def send_notification(self, notification: "Notification") -> list[str]:
"""Send notification to user, called from async task"""
if self.mode == TransportMode.LOCAL:
return self.send_local(notification)
@ -451,13 +524,6 @@ class NotificationTransport(SerializerModel):
def send_email(self, notification: "Notification") -> list[str]:
"""Send notification via global email configuration"""
if notification.user.email.strip() == "":
LOGGER.info(
"Discarding notification as user has no email address",
user=notification.user,
notification=notification,
)
return None
subject_prefix = "authentik Notification: "
context = {
"key_value": {
@ -487,7 +553,7 @@ class NotificationTransport(SerializerModel):
}
mail = TemplateEmailMessage(
subject=subject_prefix + context["title"],
to=[(notification.user.name, notification.user.email)],
to=[f"{notification.user.name} <{notification.user.email}>"],
language=notification.user.locale(),
template_name="email/event_notification.html",
template_context=context,
@ -627,9 +693,8 @@ class SystemTask(SerializerModel, ExpiringModel):
name = models.TextField()
uid = models.TextField(null=True)
start_timestamp = models.DateTimeField(default=now)
finish_timestamp = models.DateTimeField(default=now)
duration = models.FloatField(default=0)
start_timestamp = models.FloatField()
finish_timestamp = models.FloatField()
status = models.TextField(choices=TaskStatus.choices)
@ -649,18 +714,17 @@ class SystemTask(SerializerModel, ExpiringModel):
def update_metrics(self):
"""Update prometheus metrics"""
duration = max(self.finish_timestamp - self.start_timestamp, 0)
# TODO: Deprecated metric - remove in 2024.2 or later
GAUGE_TASKS.labels(
tenant=connection.schema_name,
task_name=self.name,
task_uid=self.uid or "",
status=self.status.lower(),
).set(self.duration)
).set(duration)
SYSTEM_TASK_TIME.labels(
tenant=connection.schema_name,
task_name=self.name,
task_uid=self.uid or "",
).observe(self.duration)
).observe(duration)
SYSTEM_TASK_STATUS.labels(
tenant=connection.schema_name,
task_name=self.name,

View File

@ -1,7 +1,7 @@
"""Monitored tasks"""
from datetime import datetime, timedelta
from time import perf_counter
from datetime import timedelta
from timeit import default_timer
from typing import Any, Optional
from django.utils.timezone import now
@ -24,17 +24,14 @@ class SystemTask(TenantTask):
# For tasks that should only be listed if they failed, set this to False
save_on_success: bool
_status: TaskStatus
_status: Optional[TaskStatus]
_messages: list[str]
_uid: Optional[str]
# Precise start time from perf_counter
_start_precise: Optional[float] = None
_start: Optional[datetime] = None
_start: Optional[float] = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._status = TaskStatus.SUCCESSFUL
self.save_on_success = True
self._uid = None
self._status = None
@ -56,17 +53,9 @@ class SystemTask(TenantTask):
self._messages = [exception_to_string(exception)]
def before_start(self, task_id, args, kwargs):
self._start_precise = perf_counter()
self._start = now()
self._start = default_timer()
return super().before_start(task_id, args, kwargs)
def db(self) -> Optional[DBSystemTask]:
"""Get DB object for latest task"""
return DBSystemTask.objects.filter(
name=self.__name__,
uid=self._uid,
).first()
# pylint: disable=too-many-arguments
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
@ -83,13 +72,12 @@ class SystemTask(TenantTask):
uid=self._uid,
defaults={
"description": self.__doc__,
"start_timestamp": self._start or now(),
"finish_timestamp": now(),
"duration": max(perf_counter() - self._start_precise, 0),
"start_timestamp": self._start or default_timer(),
"finish_timestamp": default_timer(),
"task_call_module": self.__module__,
"task_call_func": self.__name__,
"task_call_args": sanitize_item(args),
"task_call_kwargs": sanitize_item(kwargs),
"task_call_args": args,
"task_call_kwargs": kwargs,
"status": self._status,
"messages": sanitize_item(self._messages),
"expires": now() + timedelta(hours=self.result_timeout_hours),
@ -108,13 +96,12 @@ class SystemTask(TenantTask):
uid=self._uid,
defaults={
"description": self.__doc__,
"start_timestamp": self._start or now(),
"finish_timestamp": now(),
"duration": max(perf_counter() - self._start_precise, 0),
"start_timestamp": self._start or default_timer(),
"finish_timestamp": default_timer(),
"task_call_module": self.__module__,
"task_call_func": self.__name__,
"task_call_args": sanitize_item(args),
"task_call_kwargs": sanitize_item(kwargs),
"task_call_args": args,
"task_call_kwargs": kwargs,
"status": self._status,
"messages": sanitize_item(self._messages),
"expires": now() + timedelta(hours=self.result_timeout_hours),
@ -136,14 +123,11 @@ def prefill_task(func):
DBSystemTask(
name=func.__name__,
description=func.__doc__,
start_timestamp=now(),
finish_timestamp=now(),
status=TaskStatus.UNKNOWN,
messages=sanitize_item([_("Task has not been run yet.")]),
task_call_module=func.__module__,
task_call_func=func.__name__,
expiring=False,
duration=0,
)
)
return func

View File

@ -3,11 +3,13 @@
from typing import Optional
from django.db.models.query_utils import Q
from django.utils import timezone
from guardian.shortcuts import get_anonymous_user
from structlog.stdlib import get_logger
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import User
from authentik.events.models import EventBatch # Importing the EventBatch model
from authentik.events.models import (
Event,
Notification,
@ -17,6 +19,13 @@ from authentik.events.models import (
TaskStatus,
)
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
TaskResultStatus,
shared_task,
)
>>>>>>> 5255207c5 (events/batch: add event batching mechanism [AUTH-134])
from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.root.celery import CELERY_APP
@ -107,20 +116,44 @@ def notification_transport(
event = Event.objects.filter(pk=event_pk).first()
if not event:
return
user = User.objects.filter(pk=user_pk).first()
if not user:
return
trigger = NotificationRule.objects.filter(pk=trigger_pk).first()
if not trigger:
return
notification = Notification(
severity=trigger.severity, body=event.summary, event=event, user=user
)
# Check if batching is enabled and process accordingly
transport = NotificationTransport.objects.filter(pk=transport_pk).first()
<<<<<<< HEAD
if not transport:
return
transport.send(notification)
self.set_status(TaskStatus.SUCCESSFUL)
=======
if transport and transport.enable_batching:
# Process the event for batching
batch = EventBatch.get_or_create_batch(event.action, event.app, event.user)
batch.add_event_to_batch(event)
# Check if the batch has reached its limits
if not batch.check_batch_limits():
return
batch_summary = batch.process_batch()
batch.delete()
notification = Notification(
severity=trigger.severity, body=batch_summary, event=event, user=user
)
else:
notification = Notification(
severity=trigger.severity, body=event.summary, event=event, user=user
)
transport.send_notification(notification)
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
>>>>>>> 5255207c5 (events/batch: add event batching mechanism [AUTH-134])
except (NotificationTransportError, PropertyMappingExpressionException) as exc:
self.set_error(exc)
raise exc
@ -143,4 +176,21 @@ def notification_cleanup(self: SystemTask):
for notification in notifications:
notification.delete()
LOGGER.debug("Expired notifications", amount=amount)
<<<<<<< HEAD
self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications")
=======
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, [f"Expired {amount} Notifications"]))
# Scheduled task to check and send pending batches
@CELERY_APP.task(base=MonitoredTask)
@shared_task
def check_and_send_pending_batches():
"""Check for pending batches that haven't been sent and have been idle for a specified time."""
idle_time = timezone.now() - timedelta(minutes=10) # Example idle time
pending_batches = EventBatch.objects.filter(sent=False, last_updated__lt=idle_time)
for batch in pending_batches:
batch.send_notification()
batch.sent = True
batch.save()
>>>>>>> 5255207c5 (events/batch: add event batching mechanism [AUTH-134])

View File

@ -3,10 +3,9 @@
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Application, Token, TokenIntents
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
class TestEventsMiddleware(APITestCase):
@ -48,30 +47,3 @@ class TestEventsMiddleware(APITestCase):
context__model__name="test-delete",
).exists()
)
def test_create_with_api(self):
"""Test model creation event (with API token auth)"""
self.client.logout()
token = Token.objects.create(user=self.user, intent=TokenIntents.INTENT_API, expiring=False)
uid = generate_id()
self.client.post(
reverse("authentik_api:application-list"),
data={"name": uid, "slug": uid},
HTTP_AUTHORIZATION=f"Bearer {token.key}",
)
self.assertTrue(Application.objects.filter(name=uid).exists())
event = Event.objects.filter(
action=EventAction.MODEL_CREATED,
context__model__model_name="application",
context__model__app="authentik_core",
context__model__name=uid,
).first()
self.assertIsNotNone(event)
self.assertEqual(
event.user,
{
"pk": self.user.pk,
"email": self.user.email,
"username": self.user.username,
},
)

View File

@ -15,6 +15,7 @@ from rest_framework.serializers import ModelSerializer, SerializerMethodField
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required
from authentik.blueprints.v1.exporter import FlowExporter
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer
from authentik.core.api.used_by import UsedByMixin
@ -32,7 +33,6 @@ from authentik.lib.utils.file import (
set_file_url,
)
from authentik.lib.views import bad_request_message
from authentik.rbac.decorators import permission_required
LOGGER = get_logger()

View File

@ -31,6 +31,10 @@ class AuthentikFlowsConfig(ManagedAppConfig):
verbose_name = "authentik Flows"
default = True
def reconcile_global_load_flows_signals(self):
"""Load flows signals"""
self.import_module("authentik.flows.signals")
def reconcile_global_load_stages(self):
"""Ensure all stages are loaded"""
from authentik.flows.models import Stage

View File

@ -1,7 +1,6 @@
"""flow views tests"""
from unittest.mock import MagicMock, PropertyMock, patch
from urllib.parse import urlencode
from django.http import HttpRequest, HttpResponse
from django.test.client import RequestFactory
@ -19,12 +18,7 @@ from authentik.flows.models import (
from authentik.flows.planner import FlowPlan, FlowPlanner
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import (
NEXT_ARG_NAME,
QS_QUERY,
SESSION_KEY_PLAN,
FlowExecutorView,
)
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
@ -127,73 +121,16 @@ class TestFlowExecutor(FlowTestCase):
TO_STAGE_RESPONSE_MOCK,
)
def test_invalid_flow_redirect(self):
"""Test invalid flow with valid redirect destination"""
"""Tests that an invalid flow still redirects"""
flow = create_test_flow(
FlowDesignation.AUTHENTICATION,
)
dest = "/unique-string"
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}")
response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}")
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/unique-string")
@patch(
"authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK,
)
def test_invalid_flow_invalid_redirect(self):
"""Test invalid flow redirect with an invalid URL"""
flow = create_test_flow(
FlowDesignation.AUTHENTICATION,
)
dest = "http://something.example.com/unique-string"
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}")
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow,
component="ak-stage-access-denied",
error_message="Invalid next URL",
)
@patch(
"authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK,
)
def test_valid_flow_redirect(self):
"""Test valid flow with valid redirect destination"""
flow = create_test_flow()
dest = "/unique-string"
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}")
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/unique-string")
@patch(
"authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK,
)
def test_valid_flow_invalid_redirect(self):
"""Test valid flow redirect with an invalid URL"""
flow = create_test_flow()
dest = "http://something.example.com/unique-string"
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}")
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow,
component="ak-stage-access-denied",
error_message="Invalid next URL",
)
self.assertEqual(response.url, reverse("authentik_core:root-redirect"))
@patch(
"authentik.flows.views.executor.to_stage_response",

View File

@ -12,7 +12,6 @@ from django.shortcuts import get_object_or_404, redirect
from django.template.response import TemplateResponse
from django.urls import reverse
from django.utils.decorators import method_decorator
from django.utils.translation import gettext as _
from django.views.decorators.clickjacking import xframe_options_sameorigin
from django.views.generic import View
from drf_spectacular.types import OpenApiTypes
@ -179,8 +178,6 @@ class FlowExecutorView(APIView):
self.cancel()
self._logger.debug("f(exec): Continuing existing plan")
# Initial flow request, check if we have an upstream query string passed in
request.session[SESSION_KEY_GET] = get_params
# Don't check session again as we've either already loaded the plan or we need to plan
if not self.plan:
request.session[SESSION_KEY_HISTORY] = []
@ -195,6 +192,8 @@ class FlowExecutorView(APIView):
# To match behaviour with loading an empty flow plan from cache,
# we don't show an error message here, but rather call _flow_done()
return self._flow_done()
# Initial flow request, check if we have an upstream query string passed in
request.session[SESSION_KEY_GET] = get_params
# We don't save the Plan after getting the next stage
# as it hasn't been successfully passed yet
try:
@ -393,11 +392,7 @@ class FlowExecutorView(APIView):
NEXT_ARG_NAME, "authentik_core:root-redirect"
)
self.cancel()
if next_param and not is_url_absolute(next_param):
return to_stage_response(self.request, redirect_with_qs(next_param))
return to_stage_response(
self.request, self.stage_invalid(error_message=_("Invalid next URL"))
)
return to_stage_response(self.request, redirect_with_qs(next_param))
def stage_ok(self) -> HttpResponse:
"""Callback called by stages upon successful completion.

View File

@ -30,6 +30,10 @@ class AuthentikOutpostConfig(ManagedAppConfig):
verbose_name = "authentik Outpost"
default = True
def reconcile_global_load_outposts_signals(self):
"""Load outposts signals"""
self.import_module("authentik.outposts.signals")
def reconcile_tenant_embedded_outpost(self):
"""Ensure embedded outpost"""
from authentik.outposts.models import (

View File

@ -13,6 +13,7 @@ from rest_framework.viewsets import GenericViewSet
from structlog.stdlib import get_logger
from structlog.testing import capture_logs
from authentik.api.decorators import permission_required
from authentik.core.api.applications import user_app_cache_key
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer
@ -22,7 +23,6 @@ from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSe
from authentik.policies.models import Policy, PolicyBinding
from authentik.policies.process import PolicyProcess
from authentik.policies.types import CACHE_PREFIX, PolicyRequest
from authentik.rbac.decorators import permission_required
LOGGER = get_logger()

View File

@ -35,3 +35,7 @@ class AuthentikPoliciesConfig(ManagedAppConfig):
label = "authentik_policies"
verbose_name = "authentik Policies"
default = True
def reconcile_global_load_policies_signals(self):
"""Load policies signals"""
self.import_module("authentik.policies.signals")

View File

@ -2,7 +2,7 @@
from multiprocessing import Pipe, current_process
from multiprocessing.connection import Connection
from time import perf_counter
from timeit import default_timer
from typing import Iterator, Optional
from django.core.cache import cache
@ -84,10 +84,10 @@ class PolicyEngine:
def _check_cache(self, binding: PolicyBinding):
if not self.use_cache:
return False
before = perf_counter()
before = default_timer()
key = cache_key(binding, self.request)
cached_policy = cache.get(key, None)
duration = max(perf_counter() - before, 0)
duration = max(default_timer() - before, 0)
if not cached_policy:
return False
self.logger.debug(

View File

@ -2,8 +2,6 @@
from authentik.blueprints.apps import ManagedAppConfig
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
class AuthentikPolicyReputationConfig(ManagedAppConfig):
"""Authentik reputation app config"""
@ -12,3 +10,11 @@ class AuthentikPolicyReputationConfig(ManagedAppConfig):
label = "authentik_policies_reputation"
verbose_name = "authentik Policies.Reputation"
default = True
def reconcile_global_load_policies_reputation_signals(self):
"""Load policies.reputation signals"""
self.import_module("authentik.policies.reputation.signals")
def reconcile_global_load_policies_reputation_tasks(self):
"""Load policies.reputation tasks"""
self.import_module("authentik.policies.reputation.tasks")

View File

@ -19,6 +19,7 @@ from authentik.policies.types import PolicyRequest, PolicyResult
from authentik.root.middleware import ClientIPMiddleware
LOGGER = get_logger()
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
def reputation_expiry():

View File

@ -8,7 +8,7 @@ from structlog.stdlib import get_logger
from authentik.core.signals import login_failed
from authentik.lib.config import CONFIG
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
from authentik.policies.reputation.models import CACHE_KEY_PREFIX
from authentik.policies.reputation.tasks import save_reputation
from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.identification.signals import identification_failed

View File

@ -7,8 +7,8 @@ from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
from authentik.policies.reputation.models import Reputation
from authentik.policies.reputation.signals import CACHE_KEY_PREFIX
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()

View File

@ -6,8 +6,7 @@ from django.test import RequestFactory, TestCase
from authentik.core.models import User
from authentik.lib.generators import generate_id
from authentik.policies.reputation.api import ReputationPolicySerializer
from authentik.policies.reputation.apps import CACHE_KEY_PREFIX
from authentik.policies.reputation.models import Reputation, ReputationPolicy
from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy
from authentik.policies.reputation.tasks import save_reputation
from authentik.policies.types import PolicyRequest
from authentik.stages.password import BACKEND_INBUILT

View File

@ -15,13 +15,13 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from authentik.api.decorators import permission_required
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
from authentik.core.models import Provider
from authentik.providers.oauth2.id_token import IDToken
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, ScopeMapping
from authentik.rbac.decorators import permission_required
class OAuth2ProviderSerializer(ProviderSerializer):

View File

@ -36,21 +36,8 @@ class TestAuthorize(OAuthTestCase):
def test_invalid_grant_type(self):
"""Test with invalid grant type"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid/Foo",
)
with self.assertRaises(AuthorizeError):
request = self.factory.get(
"/",
data={
"response_type": "invalid",
"client_id": "test",
"redirect_uri": "http://local.invalid/Foo",
},
)
request = self.factory.get("/", data={"response_type": "invalid"})
OAuthAuthorizationParams.from_request(request)
def test_invalid_client_id(self):
@ -357,12 +344,7 @@ class TestAuthorize(OAuthTestCase):
]
)
)
provider.property_mappings.add(
ScopeMapping.objects.create(
name=generate_id(), scope_name="test", expression="""return {"sub": "foo"}"""
)
)
Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
@ -383,7 +365,7 @@ class TestAuthorize(OAuthTestCase):
"response_type": "id_token",
"client_id": "test",
"state": state,
"scope": "openid test",
"scope": "openid",
"redirect_uri": "http://localhost",
"nonce": generate_id(),
},
@ -408,7 +390,6 @@ class TestAuthorize(OAuthTestCase):
)
jwt = self.validate_jwt(token, provider)
self.assertEqual(jwt["amr"], ["pwd"])
self.assertEqual(jwt["sub"], "foo")
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,

View File

@ -4,10 +4,9 @@ from urllib.parse import urlencode
from django.urls import reverse
from authentik.core.models import Application, Group
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
@ -78,23 +77,3 @@ class TesOAuth2DeviceInit(OAuthTestCase):
+ "?"
+ urlencode({QS_KEY_CODE: token.user_code}),
)
def test_device_init_denied(self):
"""Test device init"""
group = Group.objects.create(name="foo")
PolicyBinding.objects.create(
group=group,
target=self.application,
order=0,
)
token = DeviceToken.objects.create(
user_code="foo",
provider=self.provider,
)
res = self.client.get(
reverse("authentik_providers_oauth2_root:device-login")
+ "?"
+ urlencode({QS_KEY_CODE: token.user_code})
)
self.assertEqual(res.status_code, 200)
self.assertIn(b"Permission denied", res.content)

View File

@ -121,18 +121,44 @@ class OAuthAuthorizationParams:
redirect_uri = query_dict.get("redirect_uri", "")
response_type = query_dict.get("response_type", "")
grant_type = None
# Determine which flow to use.
if response_type in [ResponseTypes.CODE]:
grant_type = GrantTypes.AUTHORIZATION_CODE
elif response_type in [
ResponseTypes.ID_TOKEN,
ResponseTypes.ID_TOKEN_TOKEN,
]:
grant_type = GrantTypes.IMPLICIT
elif response_type in [
ResponseTypes.CODE_TOKEN,
ResponseTypes.CODE_ID_TOKEN,
ResponseTypes.CODE_ID_TOKEN_TOKEN,
]:
grant_type = GrantTypes.HYBRID
# Grant type validation.
if not grant_type:
LOGGER.warning("Invalid response type", type=response_type)
raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state)
# Validate and check the response_mode against the predefined dict
# Set to Query or Fragment if not defined in request
response_mode = query_dict.get("response_mode", False)
if response_mode not in ResponseMode.values:
response_mode = ResponseMode.QUERY
if grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]:
response_mode = ResponseMode.FRAGMENT
max_age = query_dict.get("max_age")
return OAuthAuthorizationParams(
client_id=query_dict.get("client_id", ""),
redirect_uri=redirect_uri,
response_type=response_type,
response_mode=response_mode,
grant_type="",
grant_type=grant_type,
scope=set(query_dict.get("scope", "").split()),
state=state,
nonce=query_dict.get("nonce"),
@ -152,7 +178,6 @@ class OAuthAuthorizationParams:
LOGGER.warning("Invalid client identifier", client_id=self.client_id)
raise ClientIdError(client_id=self.client_id)
self.check_redirect_uri()
self.check_grant()
self.check_scope(github_compat)
self.check_nonce()
self.check_code_challenge()
@ -161,34 +186,6 @@ class OAuthAuthorizationParams:
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
def check_grant(self):
"""Check grant"""
# Determine which flow to use.
if self.response_type in [ResponseTypes.CODE]:
self.grant_type = GrantTypes.AUTHORIZATION_CODE
elif self.response_type in [
ResponseTypes.ID_TOKEN,
ResponseTypes.ID_TOKEN_TOKEN,
]:
self.grant_type = GrantTypes.IMPLICIT
elif self.response_type in [
ResponseTypes.CODE_TOKEN,
ResponseTypes.CODE_ID_TOKEN,
ResponseTypes.CODE_ID_TOKEN_TOKEN,
]:
self.grant_type = GrantTypes.HYBRID
# Grant type validation.
if not self.grant_type:
LOGGER.warning("Invalid response type", type=self.response_type)
raise AuthorizeError(self.redirect_uri, "unsupported_response_type", "", self.state)
if self.response_mode not in ResponseMode.values:
self.response_mode = ResponseMode.QUERY
if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]:
self.response_mode = ResponseMode.FRAGMENT
def check_redirect_uri(self):
"""Redirect URI validation."""
allowed_redirect_urls = self.provider.redirect_uris.split()
@ -260,9 +257,9 @@ class OAuthAuthorizationParams:
if SCOPE_OFFLINE_ACCESS in self.scope:
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
if PROMPT_CONSENT not in self.prompt:
# Instead of ignoring the `offline_access` scope when `prompt`
# isn't set to `consent`, we set override it ourselves
self.prompt.add(PROMPT_CONSENT)
raise AuthorizeError(
self.redirect_uri, "consent_required", self.grant_type, self.state
)
if self.response_type not in [
ResponseTypes.CODE,
ResponseTypes.CODE_TOKEN,

View File

@ -12,11 +12,10 @@ from django.views.decorators.csrf import csrf_exempt
from rest_framework.throttling import AnonRateThrottle
from structlog.stdlib import get_logger
from authentik.core.models import Application
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application
LOGGER = get_logger()
@ -39,9 +38,7 @@ class DeviceView(View):
).first()
if not provider:
return HttpResponseBadRequest()
try:
_ = provider.application
except Application.DoesNotExist:
if not get_application(provider):
return HttpResponseBadRequest()
self.provider = provider
self.client_id = client_id

View File

@ -1,10 +1,11 @@
"""Device flow views"""
from typing import Any, Optional
from typing import Optional
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _
from rest_framework.exceptions import ValidationError
from django.views import View
from rest_framework.exceptions import ErrorDetail
from rest_framework.fields import CharField, IntegerField
from structlog.stdlib import get_logger
@ -17,7 +18,6 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO,
from authentik.flows.stage import ChallengeStageView
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.utils.urls import redirect_with_qs
from authentik.policies.views import PolicyAccessView
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
from authentik.providers.oauth2.views.device_finish import (
PLAN_CONTEXT_DEVICE,
@ -44,52 +44,48 @@ def get_application(provider: OAuth2Provider) -> Optional[Application]:
return None
class CodeValidatorView(PolicyAccessView):
"""Helper to validate frontside token"""
def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]:
"""Validate user token"""
token = DeviceToken.objects.filter(
user_code=code,
).first()
if not token:
return None
def __init__(self, code: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.code = code
app = get_application(token.provider)
if not app:
return None
def resolve_provider_application(self):
self.token = DeviceToken.objects.filter(user_code=self.code).first()
if not self.token:
raise Application.DoesNotExist
self.provider = self.token.provider
self.application = self.token.provider.application
def get(self, request: HttpRequest, *args, **kwargs):
scope_descriptions = UserInfoView().get_scope_descriptions(self.token.scope, self.provider)
planner = FlowPlanner(self.provider.authorization_flow)
planner.allow_empty_flows = True
planner.use_cache = False
try:
plan = planner.plan(
request,
{
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_APPLICATION: self.application,
# OAuth2 related params
PLAN_CONTEXT_DEVICE: self.token,
# Consent related params
PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
% {"application": self.application.name},
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
},
)
except FlowNonApplicableException:
LOGGER.warning("Flow not applicable to user")
return None
plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
request.GET,
flow_slug=self.token.provider.authorization_flow.slug,
scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider)
planner = FlowPlanner(token.provider.authorization_flow)
planner.allow_empty_flows = True
try:
plan = planner.plan(
request,
{
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_APPLICATION: app,
# OAuth2 related params
PLAN_CONTEXT_DEVICE: token,
# Consent related params
PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
% {"application": app.name},
PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
},
)
except FlowNonApplicableException:
LOGGER.warning("Flow not applicable to user")
return None
plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
request.GET,
flow_slug=token.provider.authorization_flow.slug,
)
class DeviceEntryView(PolicyAccessView):
class DeviceEntryView(View):
"""View used to initiate the device-code flow, url entered by endusers"""
def dispatch(self, request: HttpRequest) -> HttpResponse:
@ -99,9 +95,7 @@ class DeviceEntryView(PolicyAccessView):
LOGGER.info("Brand has no device code flow configured", brand=brand)
return HttpResponse(status=404)
if QS_KEY_CODE in request.GET:
validation = CodeValidatorView(request.GET[QS_KEY_CODE], request=request).dispatch(
request
)
validation = validate_code(request.GET[QS_KEY_CODE], request)
if validation:
return validation
LOGGER.info("Got code from query parameter but no matching token found")
@ -136,13 +130,6 @@ class OAuthDeviceCodeChallengeResponse(ChallengeResponse):
code = IntegerField()
component = CharField(default="ak-provider-oauth2-device-code")
def validate_code(self, code: int) -> HttpResponse | None:
"""Validate code and save the returned http response"""
response = CodeValidatorView(code, request=self.stage.request).dispatch(self.stage.request)
if not response:
raise ValidationError(_("Invalid code"), "invalid")
return response
class OAuthDeviceCodeStage(ChallengeStageView):
"""Flow challenge for users to enter device codes"""
@ -158,4 +145,12 @@ class OAuthDeviceCodeStage(ChallengeStageView):
)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
return response.validated_data["code"]
code = response.validated_data["code"]
validation = validate_code(code, self.request)
if not validation:
response._errors.setdefault("code", [])
response._errors["code"].append(ErrorDetail(_("Invalid code"), "invalid"))
return self.challenge_invalid(response)
# Run cancel to cleanup the current flow
self.executor.cancel()
return validation

View File

@ -101,8 +101,8 @@ class UserInfoView(View):
value=value,
)
continue
always_merger.merge(final_claims, value)
LOGGER.debug("updated scope", scope=scope)
always_merger.merge(final_claims, value)
return final_claims
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
@ -121,9 +121,8 @@ class UserInfoView(View):
"""Handle GET Requests for UserInfo"""
if not self.token:
return HttpResponseBadRequest()
claims = {}
claims.setdefault("sub", self.token.id_token.sub)
claims.update(self.get_claims(self.token.provider, self.token))
claims = self.get_claims(self.token.provider, self.token)
claims["sub"] = self.token.id_token.sub
if self.token.id_token.nonce:
claims["nonce"] = self.token.id_token.nonce
response = TokenResponse(claims)

View File

@ -10,3 +10,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
label = "authentik_providers_proxy"
verbose_name = "authentik Providers.Proxy"
default = True
def reconcile_global_load_providers_proxy_signals(self):
"""Load proxy signals"""
self.import_module("authentik.providers.proxy.signals")

View File

@ -22,6 +22,7 @@ from rest_framework.serializers import PrimaryKeyRelatedField, ValidationError
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer
@ -32,7 +33,6 @@ from authentik.providers.saml.processors.assertion import AssertionProcessor
from authentik.providers.saml.processors.authn_request_parser import AuthNRequest
from authentik.providers.saml.processors.metadata import MetadataProcessor
from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser
from authentik.rbac.decorators import permission_required
from authentik.sources.saml.processors.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT
LOGGER = get_logger()

View File

@ -10,3 +10,7 @@ class AuthentikProviderSCIMConfig(ManagedAppConfig):
label = "authentik_providers_scim"
verbose_name = "authentik Providers.SCIM"
default = True
def reconcile_global_load_signals(self):
"""Load signals"""
self.import_module("authentik.providers.scim.signals")

View File

@ -15,10 +15,10 @@ from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from authentik.api.decorators import permission_required
from authentik.core.api.utils import PassiveSerializer
from authentik.policies.event_matcher.models import model_choices
from authentik.rbac.api.rbac import PermissionAssignSerializer
from authentik.rbac.decorators import permission_required
from authentik.rbac.models import Role

View File

@ -16,11 +16,11 @@ from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet
from authentik.api.decorators import permission_required
from authentik.core.api.groups import GroupMemberSerializer
from authentik.core.models import User, UserTypes
from authentik.policies.event_matcher.models import model_choices
from authentik.rbac.api.rbac import PermissionAssignSerializer
from authentik.rbac.decorators import permission_required
class UserObjectPermissionSerializer(ModelSerializer):

View File

@ -10,3 +10,7 @@ class AuthentikRBACConfig(ManagedAppConfig):
label = "authentik_rbac"
verbose_name = "authentik RBAC"
default = True
def reconcile_global_load_rbac_signals(self):
"""Load rbac signals"""
self.import_module("authentik.rbac.signals")

View File

@ -1,58 +0,0 @@
"""test decorators api"""
from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_user
from authentik.lib.generators import generate_id
class TestAPIDecorators(APITestCase):
"""test decorators api"""
def setUp(self) -> None:
super().setUp()
self.user = create_test_user()
def test_obj_perm_denied(self):
"""Test object perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name=generate_id(), slug=generate_id())
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 403)
def test_obj_perm_global(self):
"""Test object perm successful (global)"""
assign_perm("authentik_core.view_application", self.user)
assign_perm("authentik_events.view_event", self.user)
self.client.force_login(self.user)
app = Application.objects.create(name=generate_id(), slug=generate_id())
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 200)
def test_obj_perm_scoped(self):
"""Test object perm successful (scoped)"""
assign_perm("authentik_events.view_event", self.user)
app = Application.objects.create(name=generate_id(), slug=generate_id())
assign_perm("authentik_core.view_application", self.user, app)
self.client.force_login(self.user)
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 200)
def test_other_perm_denied(self):
"""Test other perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name=generate_id(), slug=generate_id())
assign_perm("authentik_core.view_application", self.user, app)
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 403)

View File

@ -91,10 +91,13 @@ def _get_startup_tasks_default_tenant() -> list[Callable]:
def _get_startup_tasks_all_tenants() -> list[Callable]:
"""Get all tasks to be run on startup for all tenants"""
from authentik.admin.tasks import clear_update_notifications
from authentik.outposts.tasks import outpost_connection_discovery, outpost_controller_all
from authentik.providers.proxy.tasks import proxy_set_defaults
return [
clear_update_notifications,
outpost_connection_discovery,
outpost_controller_all,
proxy_set_defaults,
]

View File

@ -7,8 +7,6 @@ from psycopg import connect
from authentik.lib.config import CONFIG
QUERY = """SELECT id FROM public.authentik_install_id ORDER BY id LIMIT 1;"""
@lru_cache
def get_install_id() -> str:
@ -20,7 +18,7 @@ def get_install_id() -> str:
if settings.TEST:
return str(uuid4())
with connection.cursor() as cursor:
cursor.execute(QUERY)
cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;")
return cursor.fetchone()[0]
@ -40,5 +38,5 @@ def get_install_id_raw():
sslkey=CONFIG.get("postgresql.sslkey"),
)
cursor = conn.cursor()
cursor.execute(QUERY)
cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;")
return cursor.fetchone()[0]

Some files were not shown because too many files have changed in this diff Show More