Compare commits

..

18 Commits

Author SHA1 Message Date
9075270b01 release: 2024.6.1 2024-07-11 21:45:54 +02:00
d17a39a431 website/docs: add 2024.6.1 release notes (cherry-pick #10456) (#10458)
website/docs: add 2024.6.1 release notes (#10456)

* website/docs: add 2024.6.1 release notes



* update



* fix version requirement for sfe



---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-11 19:11:28 +02:00
db1d091d2e core: revert backchannel only filtering (cherry-pick #10455) (#10457)
core: revert backchannel only filtering (#10455)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-11 16:58:29 +02:00
f98204e78e core: fix source flow_manager not resuming flow when linking (cherry-pick #10436) (#10438)
core: fix source flow_manager not resuming flow when linking (#10436)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-10 15:20:15 +02:00
3f663cab0f web/admin: fix access token list calling wrong API (cherry-pick #10434) (#10435)
web/admin: fix access token list calling wrong API (#10434)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-10 14:17:47 +02:00
3fe129e107 core: fix migrations missing using db_alias (cherry-pick #10409) (#10410)
core: fix migrations missing using db_alias (#10409)

Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2024-07-09 10:48:29 +02:00
f26d41aef9 web: bump API Client version (#10389)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: authentik-automation[bot] <135050075+authentik-automation[bot]@users.noreply.github.com>
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
# Conflicts:
#	web/package-lock.json
#	web/package.json
2024-07-05 20:49:31 +02:00
5d8b5998ae web/flows: Simplified flow executor (#10296)
* initial sfe

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* build sfe

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* downgrade bootstrap

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix path

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make IE compatible

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix query string missing

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add autosubmit stage

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add background image

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add code support

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add support for combo ident/password

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix logo rendering

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* only use for edge 18 and before

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix lint

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add docs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add webauthn support

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* migrate to TS for some creature comforts

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix ci

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* dedupe dependabot

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use API client...kinda

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add more docs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add more polyfills yay

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* turn powered by into span

prevent issues in restricted browsers where users might not be able to return

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* allow non-link footer entries

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix tsc errors

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* Apply suggestions from code review

Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
Signed-off-by: Jens L. <jens@beryju.org>

* auto switch for macos

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* reword

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* Update website/docs/flow/executors/if-flow.md

Signed-off-by: Jens L. <jens@beryju.org>

* format

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Signed-off-by: Jens L. <jens@beryju.org>
Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
# Conflicts:
#	.github/workflows/ci-web.yml
#	Dockerfile
#	website/developer-docs/api/flow-executor.md
2024-07-05 20:43:14 +02:00
7a5e136346 stages/authenticator_validate: fix friendly_name being required (cherry-pick #10382) (#10385)
stages/authenticator_validate: fix friendly_name being required (#10382)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-05 15:50:14 +02:00
bfbab6357a sources/oauth: fix link not being saved (cherry-pick #10374) (#10376)
sources/oauth: fix link not being saved (#10374)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-04 16:58:38 +02:00
5997b93f15 sources/saml: fix pickle error, add saml auth tests (cherry-pick #10348) (#10352)
sources/saml: fix pickle error, add saml auth tests (#10348)

* test with persistent nameid



* fix pickle



* user_write: dont attempt to write to read only property



* add test for enroll + auth



* unwrap lazy user



---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-03 18:34:22 +02:00
6cdae09dc0 providers/saml: fix metadata import error handling (cherry-pick #10349) (#10350)
Co-authored-by: Jens L <jens@goauthentik.io>
fix metadata import error handling (#10349)
2024-07-03 16:01:50 +00:00
ff0ef7a2b3 web: set noopener and noreferrer on all external links (#10304)
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-07-02 14:54:03 +02:00
3986104a20 provider/scim: Fix exception handling for missing ServiceProviderConfig (cherry-pick #10322) (#10335)
provider/scim: Fix exception handling for missing ServiceProviderConfig (#10322)

Co-authored-by: Michael Poutre <m1kep.my.mail@gmail.com>
2024-07-02 13:53:27 +02:00
1aa60e7864 core: remove transitionary old JS urls (cherry-pick #10317) (#10321)
core: remove transitionary old JS urls (#10317)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-01 21:00:05 +02:00
045578dd07 web/flows: remove background image link (cherry-pick #10318) (#10320)
web/flows: remove background image link (#10318)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-07-01 20:28:30 +02:00
f23d70dc75 stages/user_login: fix ?next parameter not carried through broken session binding (cherry-pick #10301) (#10302)
stages/user_login: fix ?next parameter not carried through broken session binding (#10301)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
2024-06-29 23:17:13 +02:00
496f3426d9 website/docs: update geoip and asn documentation following field changes (cherry-pick #10265) (#10266)
Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
2024-06-27 13:26:31 +00:00
2006 changed files with 25083 additions and 546148 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2024.8.1 current_version = 2024.6.1
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?

View File

@ -54,10 +54,9 @@ runs:
authentik: authentik:
outposts: outposts:
container_image_base: ghcr.io/goauthentik/dev-%(type)s:gh-%(build_hash)s container_image_base: ghcr.io/goauthentik/dev-%(type)s:gh-%(build_hash)s
global: image:
image: repository: ghcr.io/goauthentik/dev-server
repository: ghcr.io/goauthentik/dev-server tag: ${{ inputs.tag }}
tag: ${{ inputs.tag }}
``` ```
For arm64, use these values: For arm64, use these values:
@ -66,10 +65,9 @@ runs:
authentik: authentik:
outposts: outposts:
container_image_base: ghcr.io/goauthentik/dev-%(type)s:gh-%(build_hash)s container_image_base: ghcr.io/goauthentik/dev-%(type)s:gh-%(build_hash)s
global: image:
image: repository: ghcr.io/goauthentik/dev-server
repository: ghcr.io/goauthentik/dev-server tag: ${{ inputs.tag }}-arm64
tag: ${{ inputs.tag }}-arm64
``` ```
Afterwards, run the upgrade commands from the latest release notes. Afterwards, run the upgrade commands from the latest release notes.

View File

@ -29,15 +29,9 @@ outputs:
imageTags: imageTags:
description: "Docker image tags" description: "Docker image tags"
value: ${{ steps.ev.outputs.imageTags }} value: ${{ steps.ev.outputs.imageTags }}
attestImageNames:
description: "Docker image names used for attestation"
value: ${{ steps.ev.outputs.attestImageNames }}
imageMainTag: imageMainTag:
description: "Docker image main tag" description: "Docker image main tag"
value: ${{ steps.ev.outputs.imageMainTag }} value: ${{ steps.ev.outputs.imageMainTag }}
imageMainName:
description: "Docker image main name"
value: ${{ steps.ev.outputs.imageMainName }}
runs: runs:
using: "composite" using: "composite"

View File

@ -7,7 +7,7 @@ from time import time
parser = configparser.ConfigParser() parser = configparser.ConfigParser()
parser.read(".bumpversion.cfg") parser.read(".bumpversion.cfg")
should_build = str(len(os.environ.get("DOCKER_USERNAME", "")) > 0).lower() should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower()
branch_name = os.environ["GITHUB_REF"] branch_name = os.environ["GITHUB_REF"]
if os.environ.get("GITHUB_HEAD_REF", "") != "": if os.environ.get("GITHUB_HEAD_REF", "") != "":
@ -50,25 +50,13 @@ else:
f"{name}:gh-{safe_branch_name}-{int(time())}-{sha[:7]}{suffix}", # Use by FluxCD f"{name}:gh-{safe_branch_name}-{int(time())}-{sha[:7]}{suffix}", # Use by FluxCD
] ]
image_main_tag = image_tags[0].split(":")[-1] image_main_tag = image_tags[0]
image_tags_rendered = ",".join(image_tags)
def get_attest_image_names(image_with_tags: list[str]):
"""Attestation only for GHCR"""
image_tags = []
for image_name in set(name.split(":")[0] for name in image_with_tags):
if not image_name.startswith("ghcr.io"):
continue
image_tags.append(image_name)
return ",".join(set(image_tags))
with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output:
print(f"shouldBuild={should_build}", file=_output) print(f"shouldBuild={should_build}", file=_output)
print(f"sha={sha}", file=_output) print(f"sha={sha}", file=_output)
print(f"version={version}", file=_output) print(f"version={version}", file=_output)
print(f"prerelease={prerelease}", file=_output) print(f"prerelease={prerelease}", file=_output)
print(f"imageTags={','.join(image_tags)}", file=_output) print(f"imageTags={image_tags_rendered}", file=_output)
print(f"attestImageNames={get_attest_image_names(image_tags)}", file=_output)
print(f"imageMainTag={image_main_tag}", file=_output) print(f"imageMainTag={image_main_tag}", file=_output)
print(f"imageMainName={image_tags[0]}", file=_output)

View File

@ -58,10 +58,6 @@ updates:
patterns: patterns:
- "@rollup/*" - "@rollup/*"
- "rollup-*" - "rollup-*"
swc:
patterns:
- "@swc/*"
- "swc-*"
wdio: wdio:
patterns: patterns:
- "@wdio/*" - "@wdio/*"

View File

@ -35,12 +35,12 @@ jobs:
run: | run: |
export VERSION=`node -e 'console.log(require("../gen-ts-api/package.json").version)'` export VERSION=`node -e 'console.log(require("../gen-ts-api/package.json").version)'`
npm i @goauthentik/api@$VERSION npm i @goauthentik/api@$VERSION
- name: Upgrade /web/packages/sfe - name: Upgrade /web/sfe
working-directory: web/packages/sfe working-directory: web/sfe
run: | run: |
export VERSION=`node -e 'console.log(require("../gen-ts-api/package.json").version)'` export VERSION=`node -e 'console.log(require("../gen-ts-api/package.json").version)'`
npm i @goauthentik/api@$VERSION npm i @goauthentik/api@$VERSION
- uses: peter-evans/create-pull-request@v7 - uses: peter-evans/create-pull-request@v6
id: cpr id: cpr
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}

View File

@ -120,12 +120,6 @@ jobs:
with: with:
flags: unit flags: unit
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: unit
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
test-integration: test-integration:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 30 timeout-minutes: 30
@ -144,12 +138,6 @@ jobs:
with: with:
flags: integration flags: integration
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: integration
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
test-e2e: test-e2e:
name: test-e2e (${{ matrix.job.name }}) name: test-e2e (${{ matrix.job.name }})
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -202,12 +190,6 @@ jobs:
with: with:
flags: e2e flags: e2e
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: e2e
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
ci-core-mark: ci-core-mark:
needs: needs:
- lint - lint
@ -231,16 +213,13 @@ jobs:
permissions: permissions:
# Needed to upload contianer images to ghcr.io # Needed to upload contianer images to ghcr.io
packages: write packages: write
# Needed for attestation
id-token: write
attestations: write
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
@ -261,8 +240,7 @@ jobs:
- name: generate ts client - name: generate ts client
run: make gen-client-ts run: make gen-client-ts
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v6 uses: docker/build-push-action@v5
id: push
with: with:
context: . context: .
secrets: | secrets: |
@ -273,15 +251,8 @@ jobs:
build-args: | build-args: |
GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} GIT_BUILD_HASH=${{ steps.ev.outputs.sha }}
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache cache-from: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache
cache-to: ${{ steps.ev.outputs.shouldBuild == 'true' && 'type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache,mode=max' || '' }} cache-to: type=registry,ref=ghcr.io/goauthentik/dev-server:buildcache,mode=max
platforms: linux/${{ matrix.arch }} platforms: linux/${{ matrix.arch }}
- uses: actions/attest-build-provenance@v1
id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
pr-comment: pr-comment:
needs: needs:
- build - build
@ -303,7 +274,6 @@ jobs:
with: with:
image-name: ghcr.io/goauthentik/dev-server image-name: ghcr.io/goauthentik/dev-server
- name: Comment on PR - name: Comment on PR
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
uses: ./.github/actions/comment-pr-instructions uses: ./.github/actions/comment-pr-instructions
with: with:
tag: ${{ steps.ev.outputs.imageMainTag }} tag: gh-${{ steps.ev.outputs.imageMainTag }}

View File

@ -31,7 +31,7 @@ jobs:
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v6 uses: golangci/golangci-lint-action@v6
with: with:
version: latest version: v1.54.2
args: --timeout 5000s --verbose args: --timeout 5000s --verbose
skip-cache: true skip-cache: true
test-unittest: test-unittest:
@ -71,15 +71,12 @@ jobs:
permissions: permissions:
# Needed to upload contianer images to ghcr.io # Needed to upload contianer images to ghcr.io
packages: write packages: write
# Needed for attestation
id-token: write
attestations: write
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
@ -99,8 +96,7 @@ jobs:
- name: Generate API - name: Generate API
run: make gen-client-go run: make gen-client-go
- name: Build Docker Image - name: Build Docker Image
id: push uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with: with:
tags: ${{ steps.ev.outputs.imageTags }} tags: ${{ steps.ev.outputs.imageTags }}
file: ${{ matrix.type }}.Dockerfile file: ${{ matrix.type }}.Dockerfile
@ -110,14 +106,7 @@ jobs:
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
context: . context: .
cache-from: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache cache-from: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache
cache-to: ${{ steps.ev.outputs.shouldBuild == 'true' && format('type=registry,ref=ghcr.io/goauthentik/dev-{0}:buildcache,mode=max', matrix.type) || '' }} cache-to: type=registry,ref=ghcr.io/goauthentik/dev-${{ matrix.type }}:buildcache,mode=max
- uses: actions/attest-build-provenance@v1
id: attest
if: ${{ steps.ev.outputs.shouldBuild == 'true' }}
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-binary: build-binary:
timeout-minutes: 120 timeout-minutes: 120
needs: needs:

View File

@ -12,24 +12,19 @@ on:
- version-* - version-*
jobs: jobs:
lint: lint-eslint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
command:
- lint
- lint:lockfile
- tsc
- prettier-check
project: project:
- web - web
- tests/wdio - tests/wdio
include: include:
- command: tsc - command: tsc
project: web project: web
- command: lit-analyse extra_setup: |
project: web cd sfe/ && npm ci
exclude: exclude:
- command: lint:lockfile - command: lint:lockfile
project: tests/wdio project: tests/wdio
@ -43,17 +38,85 @@ jobs:
cache: "npm" cache: "npm"
cache-dependency-path: ${{ matrix.project }}/package-lock.json cache-dependency-path: ${{ matrix.project }}/package-lock.json
- working-directory: ${{ matrix.project }}/ - working-directory: ${{ matrix.project }}/
run: | run: npm ci
npm ci
${{ matrix.extra_setup }}
- name: Generate API - name: Generate API
run: make gen-client-ts run: make gen-client-ts
- name: Lint - name: Eslint
working-directory: ${{ matrix.project }}/ working-directory: ${{ matrix.project }}/
run: npm run ${{ matrix.command }} run: npm run lint
lint-lockfile:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- working-directory: web/
run: |
[ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ]
lint-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: web/package.json
cache: "npm"
cache-dependency-path: web/package-lock.json
- working-directory: web/
run: npm ci
- name: Generate API
run: make gen-client-ts
- name: TSC
working-directory: web/
run: npm run tsc
lint-prettier:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
project:
- web
- tests/wdio
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: ${{ matrix.project }}/package.json
cache: "npm"
cache-dependency-path: ${{ matrix.project }}/package-lock.json
- working-directory: ${{ matrix.project }}/
run: npm ci
- name: Generate API
run: make gen-client-ts
- name: prettier
working-directory: ${{ matrix.project }}/
run: npm run prettier-check
lint-lit-analyse:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: web/package.json
cache: "npm"
cache-dependency-path: web/package-lock.json
- working-directory: web/
run: |
npm ci
# lit-analyse doesn't understand path rewrites, so make it
# belive it's an actual module
cd node_modules/@goauthentik
ln -s ../../src/ web
- name: Generate API
run: make gen-client-ts
- name: lit-analyse
working-directory: web/
run: npm run lit-analyse
ci-web-mark: ci-web-mark:
needs: needs:
- lint - lint-lockfile
- lint-eslint
- lint-prettier
- lint-lit-analyse
- lint-build
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- run: echo mark - run: echo mark
@ -75,21 +138,3 @@ jobs:
- name: build - name: build
working-directory: web/ working-directory: web/
run: npm run build run: npm run build
test:
needs:
- ci-web-mark
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: web/package.json
cache: "npm"
cache-dependency-path: web/package-lock.json
- working-directory: web/
run: npm ci
- name: Generate API
run: make gen-client-ts
- name: test
working-directory: web/
run: npm run test || exit 0

View File

@ -12,21 +12,27 @@ on:
- version-* - version-*
jobs: jobs:
lint: lint-lockfile:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
command:
- lint:lockfile
- prettier-check
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- working-directory: website/
run: |
[ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ]
lint-prettier:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: website/package.json
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/
run: npm ci run: npm ci
- name: Lint - name: prettier
working-directory: website/ working-directory: website/
run: npm run ${{ matrix.command }} run: npm run prettier-check
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -63,7 +69,8 @@ jobs:
run: npm run ${{ matrix.job }} run: npm run ${{ matrix.job }}
ci-website-mark: ci-website-mark:
needs: needs:
- lint - lint-lockfile
- lint-prettier
- test - test
- build - build
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -24,7 +24,7 @@ jobs:
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
- run: poetry run ak update_webauthn_mds - run: poetry run ak update_webauthn_mds
- uses: peter-evans/create-pull-request@v7 - uses: peter-evans/create-pull-request@v6
id: cpr id: cpr
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}

View File

@ -42,7 +42,7 @@ jobs:
with: with:
githubToken: ${{ steps.generate_token.outputs.token }} githubToken: ${{ steps.generate_token.outputs.token }}
compressOnly: ${{ github.event_name != 'pull_request' }} compressOnly: ${{ github.event_name != 'pull_request' }}
- uses: peter-evans/create-pull-request@v7 - uses: peter-evans/create-pull-request@v6
if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}" if: "${{ github.event_name != 'pull_request' && steps.compress.outputs.markdown != '' }}"
id: cpr id: cpr
with: with:

View File

@ -11,13 +11,10 @@ jobs:
permissions: permissions:
# Needed to upload contianer images to ghcr.io # Needed to upload contianer images to ghcr.io
packages: write packages: write
# Needed for attestation
id-token: write
attestations: write
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
@ -43,32 +40,20 @@ jobs:
mkdir -p ./gen-ts-api mkdir -p ./gen-ts-api
mkdir -p ./gen-go-api mkdir -p ./gen-go-api
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v6 uses: docker/build-push-action@v5
id: push
with: with:
context: . context: .
push: true push: true
secrets: | secrets: |
GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }}
GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }}
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }} tags: ${{ steps.ev.outputs.imageTags }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
- uses: actions/attest-build-provenance@v1
id: attest
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-outpost: build-outpost:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
# Needed to upload contianer images to ghcr.io # Needed to upload contianer images to ghcr.io
packages: write packages: write
# Needed for attestation
id-token: write
attestations: write
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -83,7 +68,7 @@ jobs:
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3.2.0 uses: docker/setup-qemu-action@v3.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: prepare variables - name: prepare variables
@ -109,22 +94,13 @@ jobs:
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v6 uses: docker/build-push-action@v5
id: push
with: with:
push: true push: true
build-args: |
VERSION=${{ github.ref }}
tags: ${{ steps.ev.outputs.imageTags }} tags: ${{ steps.ev.outputs.imageTags }}
file: ${{ matrix.type }}.Dockerfile file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
context: . context: .
- uses: actions/attest-build-provenance@v1
id: attest
with:
subject-name: ${{ steps.ev.outputs.attestImageNames }}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
build-outpost-binary: build-outpost-binary:
timeout-minutes: 120 timeout-minutes: 120
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -179,8 +155,8 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Run test suite in final docker images - name: Run test suite in final docker images
run: | run: |
echo "PG_PASS=$(openssl rand 32 | base64 -w 0)" >> .env echo "PG_PASS=$(openssl rand 32 | base64)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64 -w 0)" >> .env echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64)" >> .env
docker compose pull -q docker compose pull -q
docker compose up --no-start docker compose up --no-start
docker compose start postgresql redis docker compose start postgresql redis
@ -202,8 +178,8 @@ jobs:
image-name: ghcr.io/goauthentik/server image-name: ghcr.io/goauthentik/server
- name: Get static files from docker image - name: Get static files from docker image
run: | run: |
docker pull ${{ steps.ev.outputs.imageMainName }} docker pull ${{ steps.ev.outputs.imageMainTag }}
container=$(docker container create ${{ steps.ev.outputs.imageMainName }}) container=$(docker container create ${{ steps.ev.outputs.imageMainTag }})
docker cp ${container}:web/ . docker cp ${container}:web/ .
- name: Create a Sentry.io release - name: Create a Sentry.io release
uses: getsentry/action-release@v1 uses: getsentry/action-release@v1

View File

@ -14,8 +14,8 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Pre-release test - name: Pre-release test
run: | run: |
echo "PG_PASS=$(openssl rand 32 | base64 -w 0)" >> .env echo "PG_PASS=$(openssl rand 32 | base64)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64 -w 0)" >> .env echo "AUTHENTIK_SECRET_KEY=$(openssl rand 32 | base64)" >> .env
docker buildx install docker buildx install
mkdir -p ./gen-ts-api mkdir -p ./gen-ts-api
docker build -t testing:latest . docker build -t testing:latest .

View File

@ -32,7 +32,7 @@ jobs:
poetry run ak compilemessages poetry run ak compilemessages
make web-check-compile make web-check-compile
- name: Create Pull Request - name: Create Pull Request
uses: peter-evans/create-pull-request@v7 uses: peter-evans/create-pull-request@v6
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
branch: extract-compile-backend-translation branch: extract-compile-backend-translation

View File

@ -16,6 +16,6 @@
"ms-python.black-formatter", "ms-python.black-formatter",
"redhat.vscode-yaml", "redhat.vscode-yaml",
"Tobermory.es6-string-html", "Tobermory.es6-string-html",
"unifiedjs.vscode-mdx" "unifiedjs.vscode-mdx",
] ]
} }

2
.vscode/launch.json vendored
View File

@ -22,6 +22,6 @@
}, },
"justMyCode": true, "justMyCode": true,
"django": true "django": true
} },
] ]
} }

21
.vscode/settings.json vendored
View File

@ -18,21 +18,20 @@
"sso", "sso",
"totp", "totp",
"traefik", "traefik",
"webauthn" "webauthn",
], ],
"todo-tree.tree.showCountsInTree": true, "todo-tree.tree.showCountsInTree": true,
"todo-tree.tree.showBadges": true, "todo-tree.tree.showBadges": true,
"yaml.customTags": [ "yaml.customTags": [
"!Condition sequence",
"!Context scalar",
"!Enumerate sequence",
"!Env scalar",
"!Find sequence", "!Find sequence",
"!Format sequence",
"!If sequence",
"!Index scalar",
"!KeyOf scalar", "!KeyOf scalar",
"!Value scalar" "!Context scalar",
"!Context sequence",
"!Format sequence",
"!Condition sequence",
"!Env sequence",
"!Env scalar",
"!If sequence"
], ],
"typescript.preferences.importModuleSpecifier": "non-relative", "typescript.preferences.importModuleSpecifier": "non-relative",
"typescript.preferences.importModuleSpecifierEnding": "index", "typescript.preferences.importModuleSpecifierEnding": "index",
@ -49,7 +48,9 @@
"ignoreCase": false "ignoreCase": false
} }
], ],
"go.testFlags": ["-count=1"], "go.testFlags": [
"-count=1"
],
"github-actions.workflows.pinned.workflows": [ "github-actions.workflows.pinned.workflows": [
".github/workflows/ci-main.yml" ".github/workflows/ci-main.yml"
] ]

62
.vscode/tasks.json vendored
View File

@ -2,67 +2,85 @@
"version": "2.0.0", "version": "2.0.0",
"tasks": [ "tasks": [
{ {
"label": "authentik/core: make", "label": "authentik[core]: format & test",
"command": "poetry", "command": "poetry",
"args": ["run", "make", "lint-fix", "lint"], "args": [
"presentation": { "run",
"panel": "new" "make"
}, ],
"group": "test" "group": "build",
}, },
{ {
"label": "authentik/core: run", "label": "authentik[core]: run",
"command": "poetry", "command": "poetry",
"args": ["run", "ak", "server"], "args": [
"run",
"make",
"run",
],
"group": "build", "group": "build",
"presentation": { "presentation": {
"panel": "dedicated", "panel": "dedicated",
"group": "running" "group": "running"
} },
}, },
{ {
"label": "authentik/web: make", "label": "authentik[web]: format",
"command": "make", "command": "make",
"args": ["web"], "args": ["web"],
"group": "build" "group": "build",
}, },
{ {
"label": "authentik/web: watch", "label": "authentik[web]: watch",
"command": "make", "command": "make",
"args": ["web-watch"], "args": ["web-watch"],
"group": "build", "group": "build",
"presentation": { "presentation": {
"panel": "dedicated", "panel": "dedicated",
"group": "running" "group": "running"
} },
}, },
{ {
"label": "authentik: install", "label": "authentik: install",
"command": "make", "command": "make",
"args": ["install", "-j4"], "args": ["install"],
"group": "build" "group": "build",
}, },
{ {
"label": "authentik/website: make", "label": "authentik: i18n-extract",
"command": "poetry",
"args": [
"run",
"make",
"i18n-extract"
],
"group": "build",
},
{
"label": "authentik[website]: format",
"command": "make", "command": "make",
"args": ["website"], "args": ["website"],
"group": "build" "group": "build",
}, },
{ {
"label": "authentik/website: watch", "label": "authentik[website]: watch",
"command": "make", "command": "make",
"args": ["website-watch"], "args": ["website-watch"],
"group": "build", "group": "build",
"presentation": { "presentation": {
"panel": "dedicated", "panel": "dedicated",
"group": "running" "group": "running"
} },
}, },
{ {
"label": "authentik/api: generate", "label": "authentik[api]: generate",
"command": "poetry", "command": "poetry",
"args": ["run", "make", "gen"], "args": [
"run",
"make",
"gen"
],
"group": "build" "group": "build"
} },
] ]
} }

View File

@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1 # syntax=docker/dockerfile:1
# Stage 1: Build website # Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS website-builder FROM --platform=${BUILDPLATFORM} docker.io/node:22 as website-builder
ENV NODE_ENV=production ENV NODE_ENV=production
@ -20,7 +20,7 @@ COPY ./SECURITY.md /work/
RUN npm run build-bundled RUN npm run build-bundled
# Stage 2: Build webui # Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/library/node:22 AS web-builder FROM --platform=${BUILDPLATFORM} docker.io/node:22 as web-builder
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
@ -30,9 +30,12 @@ WORKDIR /work/web
RUN --mount=type=bind,target=/work/web/package.json,src=./web/package.json \ RUN --mount=type=bind,target=/work/web/package.json,src=./web/package.json \
--mount=type=bind,target=/work/web/package-lock.json,src=./web/package-lock.json \ --mount=type=bind,target=/work/web/package-lock.json,src=./web/package-lock.json \
--mount=type=bind,target=/work/web/packages/sfe/package.json,src=./web/packages/sfe/package.json \ --mount=type=bind,target=/work/web/sfe/package.json,src=./web/sfe/package.json \
--mount=type=bind,target=/work/web/sfe/package-lock.json,src=./web/sfe/package-lock.json \
--mount=type=bind,target=/work/web/scripts,src=./web/scripts \ --mount=type=bind,target=/work/web/scripts,src=./web/scripts \
--mount=type=cache,id=npm-web,sharing=shared,target=/root/.npm \ --mount=type=cache,id=npm-web,sharing=shared,target=/root/.npm \
npm ci --include=dev && \
cd sfe && \
npm ci --include=dev npm ci --include=dev
COPY ./package.json /work COPY ./package.json /work
@ -40,10 +43,12 @@ COPY ./web /work/web/
COPY ./website /work/website/ COPY ./website /work/website/
COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build RUN npm run build && \
cd sfe && \
npm run build
# Stage 3: Build go proxy # Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.23-fips-bookworm AS go-builder FROM --platform=${BUILDPLATFORM} mcr.microsoft.com/oss/go/microsoft/golang:1.22-fips-bookworm AS go-builder
ARG TARGETOS ARG TARGETOS
ARG TARGETARCH ARG TARGETARCH
@ -80,7 +85,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
go build -o /go/authentik ./cmd/server go build -o /go/authentik ./cmd/server
# Stage 4: MaxMind GeoIP # Stage 4: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 AS geoip FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.0.1 as geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN" ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
ENV GEOIPUPDATE_VERBOSE="1" ENV GEOIPUPDATE_VERBOSE="1"
@ -94,10 +99,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 5: Python dependencies # Stage 5: Python dependencies
FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS python-deps FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS python-deps
ARG TARGETARCH
ARG TARGETVARIANT
WORKDIR /ak-root/poetry WORKDIR /ak-root/poetry
@ -124,17 +126,17 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \
pip install --force-reinstall /wheels/*" pip install --force-reinstall /wheels/*"
# Stage 6: Run # Stage 6: Run
FROM ghcr.io/goauthentik/fips-python:3.12.6-slim-bookworm-fips-full AS final-image FROM ghcr.io/goauthentik/fips-python:3.12.3-slim-bookworm-fips-full AS final-image
ARG VERSION
ARG GIT_BUILD_HASH ARG GIT_BUILD_HASH
ARG VERSION
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
LABEL org.opencontainers.image.url=https://goauthentik.io LABEL org.opencontainers.image.url https://goauthentik.io
LABEL org.opencontainers.image.description="goauthentik.io Main server image, see https://goauthentik.io for more info." LABEL org.opencontainers.image.description goauthentik.io Main server image, see https://goauthentik.io for more info.
LABEL org.opencontainers.image.source=https://github.com/goauthentik/authentik LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version=${VERSION} LABEL org.opencontainers.image.version ${VERSION}
LABEL org.opencontainers.image.revision=${GIT_BUILD_HASH} LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
WORKDIR / WORKDIR /

View File

@ -43,12 +43,12 @@ help: ## Show this help
sort sort
@echo "" @echo ""
go-test: test-go:
go test -timeout 0 -v -race -cover ./... go test -timeout 0 -v -race -cover ./...
test-docker: ## Run all tests in a docker-compose test-docker: ## Run all tests in a docker-compose
echo "PG_PASS=$(shell openssl rand 32 | base64 -w 0)" >> .env echo "PG_PASS=$(shell openssl rand 32 | base64)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(shell openssl rand 32 | base64 -w 0)" >> .env echo "AUTHENTIK_SECRET_KEY=$(shell openssl rand 32 | base64)" >> .env
docker compose pull -q docker compose pull -q
docker compose up --no-start docker compose up --no-start
docker compose start postgresql redis docker compose start postgresql redis
@ -60,11 +60,9 @@ test: ## Run the server tests and produce a coverage report (locally)
coverage html coverage html
coverage report coverage report
lint-fix: lint-codespell ## Lint and automatically fix errors in the python source code. Reports spelling errors. lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
black $(PY_SOURCES) black $(PY_SOURCES)
ruff check --fix $(PY_SOURCES) ruff check --fix $(PY_SOURCES)
lint-codespell: ## Reports spelling errors.
codespell -w $(CODESPELL_ARGS) codespell -w $(CODESPELL_ARGS)
lint: ## Lint the python and golang sources lint: ## Lint the python and golang sources
@ -210,9 +208,6 @@ web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting is
web-install: ## Install the necessary libraries to build the Authentik UI web-install: ## Install the necessary libraries to build the Authentik UI
cd web && npm ci cd web && npm ci
web-test: ## Run tests for the Authentik UI
cd web && npm run test
web-watch: ## Build and watch the Authentik UI for changes, updating automatically web-watch: ## Build and watch the Authentik UI for changes, updating automatically
rm -rf web/dist/ rm -rf web/dist/
mkdir web/dist/ mkdir web/dist/
@ -244,7 +239,7 @@ website: website-lint-fix website-build ## Automatically fix formatting issues
website-install: website-install:
cd website && npm ci cd website && npm ci
website-lint-fix: lint-codespell website-lint-fix:
cd website && npm run prettier cd website && npm run prettier
website-build: website-build:

View File

@ -15,9 +15,7 @@
## What is authentik? ## What is authentik?
authentik is an open-source Identity Provider that emphasizes flexibility and versatility, with support for a wide set of protocols. authentik is an open-source Identity Provider that emphasizes flexibility and versatility. It can be seamlessly integrated into existing environments to support new protocols. authentik is also a great solution for implementing sign-up, recovery, and other similar features in your application, saving you the hassle of dealing with them.
Our [enterprise offer](https://goauthentik.io/pricing) can also be used as a self-hosted replacement for large-scale deployments of Okta/Auth0, Entra ID, Ping Identity, or other legacy IdPs for employees and B2B2C use.
## Installation ## Installation

View File

@ -2,7 +2,7 @@
from os import environ from os import environ
__version__ = "2024.8.1" __version__ = "2024.6.1"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -73,7 +73,7 @@ class SystemInfoSerializer(PassiveSerializer):
"authentik_version": get_full_version(), "authentik_version": get_full_version(),
"environment": get_env(), "environment": get_env(),
"openssl_fips_enabled": ( "openssl_fips_enabled": (
backend._fips_enabled if LicenseKey.get_total().status().is_valid else None backend._fips_enabled if LicenseKey.get_total().is_valid() else None
), ),
"openssl_version": OPENSSL_VERSION, "openssl_version": OPENSSL_VERSION,
"platform": platform.platform(), "platform": platform.platform(),

View File

@ -12,7 +12,6 @@ from rest_framework.views import APIView
from authentik import __version__, get_build_hash from authentik import __version__, get_build_hash
from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version from authentik.admin.tasks import VERSION_CACHE_KEY, VERSION_NULL, update_latest_version
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.outposts.models import Outpost
class VersionSerializer(PassiveSerializer): class VersionSerializer(PassiveSerializer):
@ -23,7 +22,6 @@ class VersionSerializer(PassiveSerializer):
version_latest_valid = SerializerMethodField() version_latest_valid = SerializerMethodField()
build_hash = SerializerMethodField() build_hash = SerializerMethodField()
outdated = SerializerMethodField() outdated = SerializerMethodField()
outpost_outdated = SerializerMethodField()
def get_build_hash(self, _) -> str: def get_build_hash(self, _) -> str:
"""Get build hash, if version is not latest or released""" """Get build hash, if version is not latest or released"""
@ -49,15 +47,6 @@ class VersionSerializer(PassiveSerializer):
"""Check if we're running the latest version""" """Check if we're running the latest version"""
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance)) return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
def get_outpost_outdated(self, _) -> bool:
"""Check if any outpost is outdated/has a version mismatch"""
any_outdated = False
for outpost in Outpost.objects.all():
for state in outpost.state:
if state.version_outdated:
any_outdated = True
return any_outdated
class VersionView(APIView): class VersionView(APIView):
"""Get running and latest version.""" """Get running and latest version."""

View File

@ -1,8 +1,10 @@
"""authentik admin tasks""" """authentik admin tasks"""
import re
from django.core.cache import cache from django.core.cache import cache
from django.core.validators import URLValidator
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.translation import gettext_lazy as _
from packaging.version import parse from packaging.version import parse
from requests import RequestException from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -19,6 +21,8 @@ LOGGER = get_logger()
VERSION_NULL = "0.0.0" VERSION_NULL = "0.0.0"
VERSION_CACHE_KEY = "authentik_latest_version" VERSION_CACHE_KEY = "authentik_latest_version"
VERSION_CACHE_TIMEOUT = 8 * 60 * 60 # 8 hours VERSION_CACHE_TIMEOUT = 8 * 60 * 60 # 8 hours
# Chop of the first ^ because we want to search the entire string
URL_FINDER = URLValidator.regex.pattern[1:]
LOCAL_VERSION = parse(__version__) LOCAL_VERSION = parse(__version__)
@ -74,16 +78,10 @@ def update_latest_version(self: SystemTask):
context__new_version=upstream_version, context__new_version=upstream_version,
).exists(): ).exists():
return return
Event.new( event_dict = {"new_version": upstream_version}
EventAction.UPDATE_AVAILABLE, if match := re.search(URL_FINDER, data.get("stable", {}).get("changelog", "")):
message=_( event_dict["message"] = f"Changelog: {match.group()}"
"New version {version} available!".format( Event.new(EventAction.UPDATE_AVAILABLE, **event_dict).save()
version=upstream_version,
)
),
new_version=upstream_version,
changelog=data.get("stable", {}).get("changelog_url"),
).save()
except (RequestException, IndexError) as exc: except (RequestException, IndexError) as exc:
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.set_error(exc) self.set_error(exc)

View File

@ -17,7 +17,6 @@ RESPONSE_VALID = {
"stable": { "stable": {
"version": "99999999.9999999", "version": "99999999.9999999",
"changelog": "See https://goauthentik.io/test", "changelog": "See https://goauthentik.io/test",
"changelog_url": "https://goauthentik.io/test",
"reason": "bugfix", "reason": "bugfix",
}, },
} }
@ -36,7 +35,7 @@ class TestAdminTasks(TestCase):
Event.objects.filter( Event.objects.filter(
action=EventAction.UPDATE_AVAILABLE, action=EventAction.UPDATE_AVAILABLE,
context__new_version="99999999.9999999", context__new_version="99999999.9999999",
context__message="New version 99999999.9999999 available!", context__message="Changelog: https://goauthentik.io/test",
).exists() ).exists()
) )
# test that a consecutive check doesn't create a duplicate event # test that a consecutive check doesn't create a duplicate event
@ -46,7 +45,7 @@ class TestAdminTasks(TestCase):
Event.objects.filter( Event.objects.filter(
action=EventAction.UPDATE_AVAILABLE, action=EventAction.UPDATE_AVAILABLE,
context__new_version="99999999.9999999", context__new_version="99999999.9999999",
context__message="New version 99999999.9999999 available!", context__message="Changelog: https://goauthentik.io/test",
) )
), ),
1, 1,

View File

@ -23,11 +23,9 @@ class Command(BaseCommand):
for blueprint_path in options.get("blueprints", []): for blueprint_path in options.get("blueprints", []):
content = BlueprintInstance(path=blueprint_path).retrieve() content = BlueprintInstance(path=blueprint_path).retrieve()
importer = Importer.from_string(content) importer = Importer.from_string(content)
valid, logs = importer.validate() valid, _ = importer.validate()
if not valid: if not valid:
self.stderr.write("Blueprint invalid") self.stderr.write("blueprint invalid")
for log in logs:
self.stderr.write(f"\t{log.logger}: {log.event}: {log.attributes}")
sys_exit(1) sys_exit(1)
importer.apply() importer.apply()

View File

@ -113,19 +113,16 @@ class Command(BaseCommand):
) )
model_path = f"{model._meta.app_label}.{model._meta.model_name}" model_path = f"{model._meta.app_label}.{model._meta.model_name}"
self.schema["properties"]["entries"]["items"]["oneOf"].append( self.schema["properties"]["entries"]["items"]["oneOf"].append(
self.template_entry(model_path, model, serializer) self.template_entry(model_path, serializer)
) )
def template_entry(self, model_path: str, model: type[Model], serializer: Serializer) -> dict: def template_entry(self, model_path: str, serializer: Serializer) -> dict:
"""Template entry for a single model""" """Template entry for a single model"""
model_schema = self.to_jsonschema(serializer) model_schema = self.to_jsonschema(serializer)
model_schema["required"] = [] model_schema["required"] = []
def_name = f"model_{model_path}" def_name = f"model_{model_path}"
def_path = f"#/$defs/{def_name}" def_path = f"#/$defs/{def_name}"
self.schema["$defs"][def_name] = model_schema self.schema["$defs"][def_name] = model_schema
def_name_perm = f"model_{model_path}_permissions"
def_path_perm = f"#/$defs/{def_name_perm}"
self.schema["$defs"][def_name_perm] = self.model_permissions(model)
return { return {
"type": "object", "type": "object",
"required": ["model", "identifiers"], "required": ["model", "identifiers"],
@ -138,7 +135,6 @@ class Command(BaseCommand):
"default": "present", "default": "present",
}, },
"conditions": {"type": "array", "items": {"type": "boolean"}}, "conditions": {"type": "array", "items": {"type": "boolean"}},
"permissions": {"$ref": def_path_perm},
"attrs": {"$ref": def_path}, "attrs": {"$ref": def_path},
"identifiers": {"$ref": def_path}, "identifiers": {"$ref": def_path},
}, },
@ -189,20 +185,3 @@ class Command(BaseCommand):
if required: if required:
result["required"] = required result["required"] = required
return result return result
def model_permissions(self, model: type[Model]) -> dict:
perms = [x[0] for x in model._meta.permissions]
for action in model._meta.default_permissions:
perms.append(f"{action}_{model._meta.model_name}")
return {
"type": "array",
"items": {
"type": "object",
"required": ["permission"],
"properties": {
"permission": {"type": "string", "enum": perms},
"user": {"type": "integer"},
"role": {"type": "string"},
},
},
}

View File

@ -1,24 +0,0 @@
version: 1
entries:
- model: authentik_core.user
id: user
identifiers:
username: "%(id)s"
attrs:
name: "%(id)s"
- model: authentik_rbac.role
id: role
identifiers:
name: "%(id)s"
- model: authentik_flows.flow
identifiers:
slug: "%(id)s"
attrs:
designation: authentication
name: foo
title: foo
permissions:
- permission: view_flow
user: !KeyOf user
- permission: view_flow
role: !KeyOf role

View File

@ -1,8 +0,0 @@
version: 1
entries:
- model: authentik_rbac.role
identifiers:
name: "%(id)s"
attrs:
permissions:
- authentik_blueprints.view_blueprintinstance

View File

@ -1,9 +0,0 @@
version: 1
entries:
- model: authentik_core.user
identifiers:
username: "%(id)s"
attrs:
name: "%(id)s"
permissions:
- authentik_blueprints.view_blueprintinstance

View File

@ -1,57 +0,0 @@
"""Test blueprints v1"""
from django.test import TransactionTestCase
from guardian.shortcuts import get_perms
from authentik.blueprints.v1.importer import Importer
from authentik.core.models import User
from authentik.flows.models import Flow
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import load_fixture
from authentik.rbac.models import Role
class TestBlueprintsV1RBAC(TransactionTestCase):
"""Test Blueprints rbac attribute"""
def test_user_permission(self):
"""Test permissions"""
uid = generate_id()
import_yaml = load_fixture("fixtures/rbac_user.yaml", id=uid)
importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply())
user = User.objects.filter(username=uid).first()
self.assertIsNotNone(user)
self.assertTrue(user.has_perms(["authentik_blueprints.view_blueprintinstance"]))
def test_role_permission(self):
"""Test permissions"""
uid = generate_id()
import_yaml = load_fixture("fixtures/rbac_role.yaml", id=uid)
importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply())
role = Role.objects.filter(name=uid).first()
self.assertIsNotNone(role)
self.assertEqual(
list(role.group.permissions.all().values_list("codename", flat=True)),
["view_blueprintinstance"],
)
def test_object_permission(self):
"""Test permissions"""
uid = generate_id()
import_yaml = load_fixture("fixtures/rbac_object.yaml", id=uid)
importer = Importer.from_string(import_yaml)
self.assertTrue(importer.validate()[0])
self.assertTrue(importer.apply())
flow = Flow.objects.filter(slug=uid).first()
user = User.objects.filter(username=uid).first()
role = Role.objects.filter(name=uid).first()
self.assertIsNotNone(flow)
self.assertEqual(get_perms(user, flow), ["view_flow"])
self.assertEqual(get_perms(role.group, flow), ["view_flow"])

View File

@ -1,7 +1,7 @@
"""transfer common classes""" """transfer common classes"""
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Generator, Iterable, Mapping from collections.abc import Iterable, Mapping
from copy import copy from copy import copy
from dataclasses import asdict, dataclass, field, is_dataclass from dataclasses import asdict, dataclass, field, is_dataclass
from enum import Enum from enum import Enum
@ -58,15 +58,6 @@ class BlueprintEntryDesiredState(Enum):
MUST_CREATED = "must_created" MUST_CREATED = "must_created"
@dataclass
class BlueprintEntryPermission:
"""Describe object-level permissions"""
permission: Union[str, "YAMLTag"]
user: Union[int, "YAMLTag", None] = field(default=None)
role: Union[str, "YAMLTag", None] = field(default=None)
@dataclass @dataclass
class BlueprintEntry: class BlueprintEntry:
"""Single entry of a blueprint""" """Single entry of a blueprint"""
@ -78,7 +69,6 @@ class BlueprintEntry:
conditions: list[Any] = field(default_factory=list) conditions: list[Any] = field(default_factory=list)
identifiers: dict[str, Any] = field(default_factory=dict) identifiers: dict[str, Any] = field(default_factory=dict)
attrs: dict[str, Any] | None = field(default_factory=dict) attrs: dict[str, Any] | None = field(default_factory=dict)
permissions: list[BlueprintEntryPermission] = field(default_factory=list)
id: str | None = None id: str | None = None
@ -160,17 +150,6 @@ class BlueprintEntry:
"""Get the blueprint model, with yaml tags resolved if present""" """Get the blueprint model, with yaml tags resolved if present"""
return str(self.tag_resolver(self.model, blueprint)) return str(self.tag_resolver(self.model, blueprint))
def get_permissions(
self, blueprint: "Blueprint"
) -> Generator[BlueprintEntryPermission, None, None]:
"""Get permissions of this entry, with all yaml tags resolved"""
for perm in self.permissions:
yield BlueprintEntryPermission(
permission=self.tag_resolver(perm.permission, blueprint),
user=self.tag_resolver(perm.user, blueprint),
role=self.tag_resolver(perm.role, blueprint),
)
def check_all_conditions_match(self, blueprint: "Blueprint") -> bool: def check_all_conditions_match(self, blueprint: "Blueprint") -> bool:
"""Check all conditions of this entry match (evaluate to True)""" """Check all conditions of this entry match (evaluate to True)"""
return all(self.tag_resolver(self.conditions, blueprint)) return all(self.tag_resolver(self.conditions, blueprint))
@ -328,10 +307,7 @@ class Find(YAMLTag):
else: else:
model_name = self.model_name model_name = self.model_name
try: model_class = apps.get_model(*model_name.split("."))
model_class = apps.get_model(*model_name.split("."))
except LookupError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
query = Q() query = Q()
for cond in self.conditions: for cond in self.conditions:

View File

@ -16,7 +16,6 @@ from django.db.models.query_utils import Q
from django.db.transaction import atomic from django.db.transaction import atomic
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from guardian.models import UserObjectPermission from guardian.models import UserObjectPermission
from guardian.shortcuts import assign_perm
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer, Serializer from rest_framework.serializers import BaseSerializer, Serializer
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
@ -33,11 +32,9 @@ from authentik.blueprints.v1.common import (
from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry
from authentik.core.models import ( from authentik.core.models import (
AuthenticatedSession, AuthenticatedSession,
GroupSourceConnection,
PropertyMapping, PropertyMapping,
Provider, Provider,
Source, Source,
User,
UserSourceConnection, UserSourceConnection,
) )
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
@ -57,13 +54,11 @@ from authentik.events.utils import cleanse_dict
from authentik.flows.models import FlowToken, Stage from authentik.flows.models import FlowToken, Stage
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.reflection import get_apps
from authentik.outposts.models import OutpostServiceConnection from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel from authentik.policies.models import Policy, PolicyBindingModel
from authentik.policies.reputation.models import Reputation from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser
from authentik.rbac.models import Role
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -92,7 +87,6 @@ def excluded_models() -> list[type[Model]]:
Source, Source,
PropertyMapping, PropertyMapping,
UserSourceConnection, UserSourceConnection,
GroupSourceConnection,
Stage, Stage,
OutpostServiceConnection, OutpostServiceConnection,
Policy, Policy,
@ -142,16 +136,6 @@ def transaction_rollback():
pass pass
def rbac_models() -> dict:
models = {}
for app in get_apps():
for model in app.get_models():
if not is_model_allowed(model):
continue
models[model._meta.model_name] = app.label
return models
class Importer: class Importer:
"""Import Blueprint from raw dict or YAML/JSON""" """Import Blueprint from raw dict or YAML/JSON"""
@ -170,10 +154,7 @@ class Importer:
def default_context(self): def default_context(self):
"""Default context""" """Default context"""
return { return {"goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid()}
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid,
"goauthentik.io/rbac/models": rbac_models(),
}
@staticmethod @staticmethod
def from_string(yaml_input: str, context: dict | None = None) -> "Importer": def from_string(yaml_input: str, context: dict | None = None) -> "Importer":
@ -233,17 +214,14 @@ class Importer:
return main_query | sub_query return main_query | sub_query
def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer | None: # noqa: PLR0915 def _validate_single(self, entry: BlueprintEntry) -> BaseSerializer | None:
"""Validate a single entry""" """Validate a single entry"""
if not entry.check_all_conditions_match(self._import): if not entry.check_all_conditions_match(self._import):
self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") self.logger.debug("One or more conditions of this entry are not fulfilled, skipping")
return None return None
model_app_label, model_name = entry.get_model(self._import).split(".") model_app_label, model_name = entry.get_model(self._import).split(".")
try: model: type[SerializerModel] = registry.get_model(model_app_label, model_name)
model: type[SerializerModel] = registry.get_model(model_app_label, model_name)
except LookupError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
# Don't use isinstance since we don't want to check for inheritance # Don't use isinstance since we don't want to check for inheritance
if not is_model_allowed(model): if not is_model_allowed(model):
raise EntryInvalidError.from_entry(f"Model {model} not allowed", entry) raise EntryInvalidError.from_entry(f"Model {model} not allowed", entry)
@ -318,7 +296,10 @@ class Importer:
try: try:
full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import)) full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import))
except ValueError as exc: except ValueError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc raise EntryInvalidError.from_entry(
exc,
entry,
) from exc
always_merger.merge(full_data, updated_identifiers) always_merger.merge(full_data, updated_identifiers)
serializer_kwargs["data"] = full_data serializer_kwargs["data"] = full_data
@ -339,15 +320,6 @@ class Importer:
) from exc ) from exc
return serializer return serializer
def _apply_permissions(self, instance: Model, entry: BlueprintEntry):
"""Apply object-level permissions for an entry"""
for perm in entry.get_permissions(self._import):
if perm.user is not None:
assign_perm(perm.permission, User.objects.get(pk=perm.user), instance)
if perm.role is not None:
role = Role.objects.get(pk=perm.role)
role.assign_permission(perm.permission, obj=instance)
def apply(self) -> bool: def apply(self) -> bool:
"""Apply (create/update) models yaml, in database transaction""" """Apply (create/update) models yaml, in database transaction"""
try: try:
@ -412,7 +384,6 @@ class Importer:
if "pk" in entry.identifiers: if "pk" in entry.identifiers:
self.__pk_map[entry.identifiers["pk"]] = instance.pk self.__pk_map[entry.identifiers["pk"]] = instance.pk
entry._state = BlueprintEntryState(instance) entry._state = BlueprintEntryState(instance)
self._apply_permissions(instance, entry)
elif state == BlueprintEntryDesiredState.ABSENT: elif state == BlueprintEntryDesiredState.ABSENT:
instance: Model | None = serializer.instance instance: Model | None = serializer.instance
if instance.pk: if instance.pk:

View File

@ -55,7 +55,6 @@ class BrandSerializer(ModelSerializer):
"flow_unenrollment", "flow_unenrollment",
"flow_user_settings", "flow_user_settings",
"flow_device_code", "flow_device_code",
"default_application",
"web_certificate", "web_certificate",
"attributes", "attributes",
] ]

View File

@ -9,6 +9,3 @@ class AuthentikBrandsConfig(AppConfig):
name = "authentik.brands" name = "authentik.brands"
label = "authentik_brands" label = "authentik_brands"
verbose_name = "authentik Brands" verbose_name = "authentik Brands"
mountpoints = {
"authentik.brands.urls_root": "",
}

View File

@ -1,26 +0,0 @@
# Generated by Django 5.0.6 on 2024-07-04 20:32
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_brands", "0006_brand_authentik_b_domain_b9b24a_idx_and_more"),
("authentik_core", "0035_alter_group_options_and_more"),
]
operations = [
migrations.AddField(
model_name="brand",
name="default_application",
field=models.ForeignKey(
default=None,
help_text="When set, external users will be redirected to this application after authenticating.",
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.application",
),
),
]

View File

@ -3,7 +3,6 @@
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.db import models
from django.http import HttpRequest
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -52,16 +51,6 @@ class Brand(SerializerModel):
Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code" Flow, null=True, on_delete=models.SET_NULL, related_name="brand_device_code"
) )
default_application = models.ForeignKey(
"authentik_core.Application",
null=True,
default=None,
on_delete=models.SET_DEFAULT,
help_text=_(
"When set, external users will be redirected to this application after authenticating."
),
)
web_certificate = models.ForeignKey( web_certificate = models.ForeignKey(
CertificateKeyPair, CertificateKeyPair,
null=True, null=True,
@ -99,13 +88,3 @@ class Brand(SerializerModel):
models.Index(fields=["domain"]), models.Index(fields=["domain"]),
models.Index(fields=["default"]), models.Index(fields=["default"]),
] ]
class WebfingerProvider(models.Model):
"""Provider which supports webfinger discovery"""
class Meta:
abstract = True
def webfinger(self, resource: str, request: HttpRequest) -> dict:
raise NotImplementedError()

View File

@ -5,11 +5,7 @@ from rest_framework.test import APITestCase
from authentik.brands.api import Themes from authentik.brands.api import Themes
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_brand from authentik.core.tests.utils import create_test_admin_user, create_test_brand
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.saml.models import SAMLProvider
class TestBrands(APITestCase): class TestBrands(APITestCase):
@ -79,45 +75,3 @@ class TestBrands(APITestCase):
reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True} reverse("authentik_api:brand-list"), data={"domain": "bar", "default": True}
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
def test_webfinger_no_app(self):
"""Test Webfinger"""
create_test_brand()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
)
def test_webfinger_not_supported(self):
"""Test Webfinger"""
brand = create_test_brand()
provider = SAMLProvider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
brand.default_application = app
brand.save()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(), {}
)
def test_webfinger_oidc(self):
"""Test Webfinger"""
brand = create_test_brand()
provider = OAuth2Provider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
brand.default_application = app
brand.save()
self.assertJSONEqual(
self.client.get(reverse("authentik_brands:webfinger")).content.decode(),
{
"links": [
{
"href": f"http://testserver/application/o/{app.slug}/",
"rel": "http://openid.net/specs/connect/1.0/issuer",
}
],
"subject": None,
},
)

View File

@ -1,9 +0,0 @@
"""authentik brand root URLs"""
from django.urls import path
from authentik.brands.views.webfinger import WebFingerView
urlpatterns = [
path(".well-known/webfinger", WebFingerView.as_view(), name="webfinger"),
]

View File

@ -5,7 +5,7 @@ from typing import Any
from django.db.models import F, Q from django.db.models import F, Q
from django.db.models import Value as V from django.db.models import Value as V
from django.http.request import HttpRequest from django.http.request import HttpRequest
from sentry_sdk import get_current_span from sentry_sdk.hub import Hub
from authentik import get_full_version from authentik import get_full_version
from authentik.brands.models import Brand from authentik.brands.models import Brand
@ -33,7 +33,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
brand = getattr(request, "brand", DEFAULT_BRAND) brand = getattr(request, "brand", DEFAULT_BRAND)
tenant = getattr(request, "tenant", Tenant()) tenant = getattr(request, "tenant", Tenant())
trace = "" trace = ""
span = get_current_span() span = Hub.current.scope.span
if span: if span:
trace = span.to_traceparent() trace = span.to_traceparent()
return { return {

View File

@ -1,29 +0,0 @@
from typing import Any
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.views import View
from authentik.brands.models import Brand, WebfingerProvider
from authentik.core.models import Application
class WebFingerView(View):
"""Webfinger endpoint"""
def get(self, request: HttpRequest) -> HttpResponse:
brand: Brand = request.brand
if not brand.default_application:
return JsonResponse({})
application: Application = brand.default_application
provider = application.get_provider()
if not provider or not isinstance(provider, WebfingerProvider):
return JsonResponse({})
webfinger_data = provider.webfinger(request.GET.get("resource"), request)
return JsonResponse(webfinger_data)
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
response = super().dispatch(request, *args, **kwargs)
# RFC7033 spec
response["Access-Control-Allow-Origin"] = "*"
response["Content-Type"] = "application/jrd+json"
return response

View File

@ -103,12 +103,7 @@ class ApplicationSerializer(ModelSerializer):
class ApplicationViewSet(UsedByMixin, ModelViewSet): class ApplicationViewSet(UsedByMixin, ModelViewSet):
"""Application Viewset""" """Application Viewset"""
queryset = ( queryset = Application.objects.all().prefetch_related("provider")
Application.objects.all()
.with_provider()
.prefetch_related("policies")
.prefetch_related("backchannel_providers")
)
serializer_class = ApplicationSerializer serializer_class = ApplicationSerializer
search_fields = [ search_fields = [
"name", "name",
@ -152,15 +147,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
applications.append(application) applications.append(application)
return applications return applications
def _filter_applications_with_launch_url(
self, pagined_apps: Iterator[Application]
) -> list[Application]:
applications = []
for app in pagined_apps:
if app.get_launch_url():
applications.append(app)
return applications
@extend_schema( @extend_schema(
parameters=[ parameters=[
OpenApiParameter( OpenApiParameter(
@ -218,11 +204,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
location=OpenApiParameter.QUERY, location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT, type=OpenApiTypes.INT,
), ),
OpenApiParameter(
name="only_with_launch_url",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.BOOL,
),
] ]
) )
def list(self, request: Request) -> Response: def list(self, request: Request) -> Response:
@ -235,10 +216,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
if superuser_full_list and request.user.is_superuser: if superuser_full_list and request.user.is_superuser:
return super().list(request) return super().list(request)
only_with_launch_url = str(
request.query_params.get("only_with_launch_url", "false")
).lower()
queryset = self._filter_queryset_for_list(self.get_queryset()) queryset = self._filter_queryset_for_list(self.get_queryset())
paginator: Pagination = self.paginator paginator: Pagination = self.paginator
paginated_apps = paginator.paginate_queryset(queryset, request) paginated_apps = paginator.paginate_queryset(queryset, request)
@ -274,10 +251,6 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
allowed_applications, allowed_applications,
timeout=86400, timeout=86400,
) )
if only_with_launch_url == "true":
allowed_applications = self._filter_applications_with_launch_url(allowed_applications)
serializer = self.get_serializer(allowed_applications, many=True) serializer = self.get_serializer(allowed_applications, many=True)
return self.get_paginated_response(serializer.data) return self.get_paginated_response(serializer.data)

View File

@ -2,13 +2,7 @@
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework.fields import ( from rest_framework.fields import BooleanField, CharField, IntegerField, SerializerMethodField
BooleanField,
CharField,
DateTimeField,
IntegerField,
SerializerMethodField,
)
from rest_framework.permissions import IsAdminUser, IsAuthenticated from rest_framework.permissions import IsAdminUser, IsAuthenticated
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
@ -26,9 +20,6 @@ class DeviceSerializer(MetaNameSerializer):
name = CharField() name = CharField()
type = SerializerMethodField() type = SerializerMethodField()
confirmed = BooleanField() confirmed = BooleanField()
created = DateTimeField(read_only=True)
last_updated = DateTimeField(read_only=True)
last_used = DateTimeField(read_only=True, allow_null=True)
def get_type(self, instance: Device) -> str: def get_type(self, instance: Device) -> str:
"""Get type of device""" """Get type of device"""

View File

@ -2,15 +2,8 @@
from json import dumps from json import dumps
from django_filters.filters import AllValuesMultipleFilter, BooleanFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import ( from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
OpenApiParameter,
OpenApiResponse,
extend_schema,
extend_schema_field,
)
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework import mixins from rest_framework import mixins
from rest_framework.decorators import action from rest_framework.decorators import action
@ -30,10 +23,8 @@ from authentik.core.api.utils import (
PassiveSerializer, PassiveSerializer,
) )
from authentik.core.expression.evaluator import PropertyMappingEvaluator from authentik.core.expression.evaluator import PropertyMappingEvaluator
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group, PropertyMapping, User from authentik.core.models import Group, PropertyMapping, User
from authentik.events.utils import sanitize_item from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
from authentik.policies.api.exec import PolicyTestSerializer from authentik.policies.api.exec import PolicyTestSerializer
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
@ -76,18 +67,6 @@ class PropertyMappingSerializer(ManagedSerializer, ModelSerializer, MetaNameSeri
] ]
class PropertyMappingFilterSet(FilterSet):
"""Filter for PropertyMapping"""
managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))
managed__isnull = BooleanFilter(field_name="managed", lookup_expr="isnull")
class Meta:
model = PropertyMapping
fields = ["name", "managed"]
class PropertyMappingViewSet( class PropertyMappingViewSet(
TypesMixin, TypesMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
@ -108,9 +87,11 @@ class PropertyMappingViewSet(
queryset = PropertyMapping.objects.select_subclasses() queryset = PropertyMapping.objects.select_subclasses()
serializer_class = PropertyMappingSerializer serializer_class = PropertyMappingSerializer
filterset_class = PropertyMappingFilterSet search_fields = [
"name",
]
filterset_fields = {"managed": ["isnull"]}
ordering = ["name"] ordering = ["name"]
search_fields = ["name"]
@permission_required("authentik_core.view_propertymapping") @permission_required("authentik_core.view_propertymapping")
@extend_schema( @extend_schema(
@ -164,15 +145,12 @@ class PropertyMappingViewSet(
response_data = {"successful": True, "result": ""} response_data = {"successful": True, "result": ""}
try: try:
result = mapping.evaluate(dry_run=True, **context) result = mapping.evaluate(**context)
response_data["result"] = dumps( response_data["result"] = dumps(
sanitize_item(result), indent=(4 if format_result else None) sanitize_item(result), indent=(4 if format_result else None)
) )
except PropertyMappingExpressionException as exc:
response_data["result"] = exception_to_string(exc.exc)
response_data["successful"] = False
except Exception as exc: except Exception as exc:
response_data["result"] = exception_to_string(exc) response_data["result"] = str(exc)
response_data["successful"] = False response_data["successful"] = False
response = PropertyMappingTestResultSerializer(response_data) response = PropertyMappingTestResultSerializer(response_data)
return Response(response.data) return Response(response.data)

View File

@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
from authentik.core.api.object_types import TypesMixin from authentik.core.api.object_types import TypesMixin
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import MetaNameSerializer, ModelSerializer from authentik.core.api.utils import MetaNameSerializer, ModelSerializer
from authentik.core.models import GroupSourceConnection, Source, UserSourceConnection from authentik.core.models import Source, UserSourceConnection
from authentik.core.types import UserSettingSerializer from authentik.core.types import UserSettingSerializer
from authentik.lib.utils.file import ( from authentik.lib.utils.file import (
FilePathSerializer, FilePathSerializer,
@ -60,8 +60,6 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
"enabled", "enabled",
"authentication_flow", "authentication_flow",
"enrollment_flow", "enrollment_flow",
"user_property_mappings",
"group_property_mappings",
"component", "component",
"verbose_name", "verbose_name",
"verbose_name_plural", "verbose_name_plural",
@ -190,47 +188,6 @@ class UserSourceConnectionViewSet(
queryset = UserSourceConnection.objects.all() queryset = UserSourceConnection.objects.all()
serializer_class = UserSourceConnectionSerializer serializer_class = UserSourceConnectionSerializer
permission_classes = [OwnerSuperuserPermissions] permission_classes = [OwnerSuperuserPermissions]
filterset_fields = ["user", "source__slug"] filterset_fields = ["user"]
search_fields = ["source__slug"]
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
ordering = ["source__slug", "pk"] ordering = ["pk"]
class GroupSourceConnectionSerializer(SourceSerializer):
"""Group Source Connection Serializer"""
source = SourceSerializer(read_only=True)
class Meta:
model = GroupSourceConnection
fields = [
"pk",
"group",
"source",
"identifier",
"created",
]
extra_kwargs = {
"group": {"read_only": True},
"identifier": {"read_only": True},
"created": {"read_only": True},
}
class GroupSourceConnectionViewSet(
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
UsedByMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""Group-source connection Viewset"""
queryset = GroupSourceConnection.objects.all()
serializer_class = GroupSourceConnectionSerializer
permission_classes = [OwnerSuperuserPermissions]
filterset_fields = ["group", "source__slug"]
search_fields = ["source__slug"]
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
ordering = ["source__slug", "pk"]

View File

@ -14,7 +14,6 @@ from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.rbac.filters import ObjectFilter
class DeleteAction(Enum): class DeleteAction(Enum):
@ -54,7 +53,7 @@ class UsedByMixin:
@extend_schema( @extend_schema(
responses={200: UsedBySerializer(many=True)}, responses={200: UsedBySerializer(many=True)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def used_by(self, request: Request, *args, **kwargs) -> Response: def used_by(self, request: Request, *args, **kwargs) -> Response:
"""Get a list of all objects that use this object""" """Get a list of all objects that use this object"""
model: Model = self.get_object() model: Model = self.get_object()

View File

@ -5,7 +5,6 @@ from json import loads
from typing import Any from typing import Any
from django.contrib.auth import update_session_auth_hash from django.contrib.auth import update_session_auth_hash
from django.contrib.auth.models import Permission
from django.contrib.sessions.backends.cache import KEY_PREFIX from django.contrib.sessions.backends.cache import KEY_PREFIX
from django.core.cache import cache from django.core.cache import cache
from django.db.models.functions import ExtractHour from django.db.models.functions import ExtractHour
@ -34,21 +33,15 @@ from drf_spectacular.utils import (
) )
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
from rest_framework.fields import (
BooleanField,
CharField,
ChoiceField,
DateTimeField,
IntegerField,
ListField,
SerializerMethodField,
)
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ( from rest_framework.serializers import (
BooleanField,
DateTimeField,
ListSerializer, ListSerializer,
PrimaryKeyRelatedField, PrimaryKeyRelatedField,
ValidationError,
) )
from rest_framework.validators import UniqueValidator from rest_framework.validators import UniqueValidator
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
@ -85,7 +78,6 @@ from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner
from authentik.flows.views.executor import QS_KEY_TOKEN from authentik.flows.views.executor import QS_KEY_TOKEN
from authentik.lib.avatars import get_avatar from authentik.lib.avatars import get_avatar
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
from authentik.rbac.models import get_permission_choices
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
from authentik.stages.email.tasks import send_mails from authentik.stages.email.tasks import send_mails
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
@ -149,19 +141,12 @@ class UserSerializer(ModelSerializer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if SERIALIZER_CONTEXT_BLUEPRINT in self.context: if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
self.fields["password"] = CharField(required=False, allow_null=True) self.fields["password"] = CharField(required=False, allow_null=True)
self.fields["permissions"] = ListField(
required=False, child=ChoiceField(choices=get_permission_choices())
)
def create(self, validated_data: dict) -> User: def create(self, validated_data: dict) -> User:
"""If this serializer is used in the blueprint context, we allow for """If this serializer is used in the blueprint context, we allow for
directly setting a password. However should be done via the `set_password` directly setting a password. However should be done via the `set_password`
method instead of directly setting it like rest_framework.""" method instead of directly setting it like rest_framework."""
password = validated_data.pop("password", None) password = validated_data.pop("password", None)
permissions = Permission.objects.filter(
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
)
validated_data["user_permissions"] = permissions
instance: User = super().create(validated_data) instance: User = super().create(validated_data)
self._set_password(instance, password) self._set_password(instance, password)
return instance return instance
@ -170,10 +155,6 @@ class UserSerializer(ModelSerializer):
"""Same as `create` above, set the password directly if we're in a blueprint """Same as `create` above, set the password directly if we're in a blueprint
context""" context"""
password = validated_data.pop("password", None) password = validated_data.pop("password", None)
permissions = Permission.objects.filter(
codename__in=[x.split(".")[1] for x in validated_data.pop("permissions", [])]
)
validated_data["user_permissions"] = permissions
instance = super().update(instance, validated_data) instance = super().update(instance, validated_data)
self._set_password(instance, password) self._set_password(instance, password)
return instance return instance
@ -678,10 +659,10 @@ class UserViewSet(UsedByMixin, ModelViewSet):
if not request.tenant.impersonation: if not request.tenant.impersonation:
LOGGER.debug("User attempted to impersonate", user=request.user) LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401) return Response(status=401)
user_to_be = self.get_object() if not request.user.has_perm("impersonate"):
if not request.user.has_perm("impersonate", user_to_be):
LOGGER.debug("User attempted to impersonate without permissions", user=request.user) LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return Response(status=401) return Response(status=401)
user_to_be = self.get_object()
if user_to_be.pk == self.request.user.pk: if user_to_be.pk == self.request.user.pk:
LOGGER.debug("User attempted to impersonate themselves", user=request.user) LOGGER.debug("User attempted to impersonate themselves", user=request.user)
return Response(status=401) return Response(status=401)

View File

@ -1,32 +0,0 @@
"""Change user type"""
from authentik.core.models import User, UserTypes
from authentik.tenants.management import TenantCommand
class Command(TenantCommand):
"""Change user type"""
def add_arguments(self, parser):
parser.add_argument("--type", type=str, required=True)
parser.add_argument("--all", action="store_true", default=False)
parser.add_argument("usernames", nargs="*", type=str)
def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"])
qs = (
User.objects.exclude_anonymous()
.exclude(type=UserTypes.SERVICE_ACCOUNT)
.exclude(type=UserTypes.INTERNAL_SERVICE_ACCOUNT)
)
if options["usernames"] and options["all"]:
self.stderr.write("--all and usernames specified, only one can be specified")
return
if not options["usernames"] and not options["all"]:
self.stderr.write("--all or usernames must be specified")
return
if options["usernames"] and not options["all"]:
qs = qs.filter(username__in=options["usernames"])
updated = qs.update(type=new_type)
self.stdout.write(f"Updated {updated} users.")

View File

@ -1,43 +0,0 @@
# Generated by Django 5.0.2 on 2024-02-29 11:05
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0035_alter_group_options_and_more"),
]
operations = [
migrations.AddField(
model_name="source",
name="group_property_mappings",
field=models.ManyToManyField(
blank=True,
default=None,
related_name="source_grouppropertymappings_set",
to="authentik_core.propertymapping",
),
),
migrations.AddField(
model_name="source",
name="user_property_mappings",
field=models.ManyToManyField(
blank=True,
default=None,
related_name="source_userpropertymappings_set",
to="authentik_core.propertymapping",
),
),
migrations.AlterField(
model_name="source",
name="property_mappings",
field=models.ManyToManyField(
blank=True,
default=None,
related_name="source_set",
to="authentik_core.propertymapping",
),
),
]

View File

@ -1,18 +0,0 @@
# Generated by Django 5.0.2 on 2024-02-29 11:21
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_sources_ldap", "0005_remove_ldappropertymapping_object_field_and_more"),
("authentik_core", "0036_source_group_property_mappings_and_more"),
]
operations = [
migrations.RemoveField(
model_name="source",
name="property_mappings",
),
]

View File

@ -1,19 +0,0 @@
# Generated by Django 5.0.7 on 2024-07-22 13:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0037_remove_source_property_mappings"),
("authentik_flows", "0027_auto_20231028_1424"),
("authentik_policies", "0011_policybinding_failure_result_and_more"),
]
operations = [
migrations.AddIndex(
model_name="source",
index=models.Index(fields=["enabled"], name="authentik_c_enabled_d72365_idx"),
),
]

View File

@ -1,67 +0,0 @@
# Generated by Django 5.0.7 on 2024-08-01 18:52
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0038_source_authentik_c_enabled_d72365_idx"),
]
operations = [
migrations.AddField(
model_name="source",
name="group_matching_mode",
field=models.TextField(
choices=[
("identifier", "Use the source-specific identifier"),
(
"name_link",
"Link to a group with identical name. Can have security implications when a group name is used with another source.",
),
(
"name_deny",
"Use the group name, but deny enrollment when the name already exists.",
),
],
default="identifier",
help_text="How the source determines if an existing group should be used or a new group created.",
),
),
migrations.AlterField(
model_name="group",
name="name",
field=models.TextField(verbose_name="name"),
),
migrations.CreateModel(
name="GroupSourceConnection",
fields=[
(
"id",
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
),
),
("created", models.DateTimeField(auto_now_add=True)),
("last_updated", models.DateTimeField(auto_now=True)),
("identifier", models.TextField()),
(
"group",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_core.group"
),
),
(
"source",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_core.source"
),
),
],
options={
"unique_together": {("group", "source")},
},
),
]

View File

@ -11,7 +11,6 @@ from django.contrib.auth.models import AbstractUser
from django.contrib.auth.models import UserManager as DjangoUserManager from django.contrib.auth.models import UserManager as DjangoUserManager
from django.db import models from django.db import models
from django.db.models import Q, QuerySet, options from django.db.models import Q, QuerySet, options
from django.db.models.constants import LOOKUP_SEP
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now from django.utils.timezone import now
@ -29,7 +28,6 @@ from authentik.core.types import UILoginButton, UserSettingSerializer
from authentik.lib.avatars import get_avatar from authentik.lib.avatars import get_avatar
from authentik.lib.expression.exceptions import ControlFlowException from authentik.lib.expression.exceptions import ControlFlowException
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.merge import MERGE_LIST_UNIQUE
from authentik.lib.models import ( from authentik.lib.models import (
CreatedUpdatedModel, CreatedUpdatedModel,
DomainlessFormattedURLValidator, DomainlessFormattedURLValidator,
@ -102,38 +100,6 @@ class UserTypes(models.TextChoices):
INTERNAL_SERVICE_ACCOUNT = "internal_service_account" INTERNAL_SERVICE_ACCOUNT = "internal_service_account"
class AttributesMixin(models.Model):
"""Adds an attributes property to a model"""
attributes = models.JSONField(default=dict, blank=True)
class Meta:
abstract = True
def update_attributes(self, properties: dict[str, Any]):
"""Update fields and attributes, but correctly by merging dicts"""
for key, value in properties.items():
if key == "attributes":
continue
setattr(self, key, value)
final_attributes = {}
MERGE_LIST_UNIQUE.merge(final_attributes, self.attributes)
MERGE_LIST_UNIQUE.merge(final_attributes, properties.get("attributes", {}))
self.attributes = final_attributes
self.save()
@classmethod
def update_or_create_attributes(
cls, query: dict[str, Any], properties: dict[str, Any]
) -> tuple[models.Model, bool]:
"""Same as django's update_or_create but correctly updates attributes by merging dicts"""
instance = cls.objects.filter(**query).first()
if not instance:
return cls.objects.create(**properties), True
instance.update_attributes(properties)
return instance, False
class GroupQuerySet(CTEQuerySet): class GroupQuerySet(CTEQuerySet):
def with_children_recursive(self): def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents """Recursively get all groups that have the current queryset as parents
@ -168,12 +134,12 @@ class GroupQuerySet(CTEQuerySet):
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte) return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
class Group(SerializerModel, AttributesMixin): class Group(SerializerModel):
"""Group model which supports a basic hierarchy and has attributes""" """Group model which supports a basic hierarchy and has attributes"""
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
name = models.TextField(_("name")) name = models.CharField(_("name"), max_length=80)
is_superuser = models.BooleanField( is_superuser = models.BooleanField(
default=False, help_text=_("Users added to this group will be superusers.") default=False, help_text=_("Users added to this group will be superusers.")
) )
@ -188,27 +154,10 @@ class Group(SerializerModel, AttributesMixin):
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
related_name="children", related_name="children",
) )
attributes = models.JSONField(default=dict, blank=True)
objects = GroupQuerySet.as_manager() objects = GroupQuerySet.as_manager()
class Meta:
unique_together = (
(
"name",
"parent",
),
)
indexes = [models.Index(fields=["name"])]
verbose_name = _("Group")
verbose_name_plural = _("Groups")
permissions = [
("add_user_to_group", _("Add user to group")),
("remove_user_from_group", _("Remove user from group")),
]
def __str__(self):
return f"Group {self.name}"
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:
from authentik.core.api.groups import GroupSerializer from authentik.core.api.groups import GroupSerializer
@ -233,6 +182,24 @@ class Group(SerializerModel, AttributesMixin):
qs = Group.objects.filter(group_uuid=self.group_uuid) qs = Group.objects.filter(group_uuid=self.group_uuid)
return qs.with_children_recursive() return qs.with_children_recursive()
def __str__(self):
return f"Group {self.name}"
class Meta:
unique_together = (
(
"name",
"parent",
),
)
indexes = [models.Index(fields=["name"])]
verbose_name = _("Group")
verbose_name_plural = _("Groups")
permissions = [
("add_user_to_group", _("Add user to group")),
("remove_user_from_group", _("Remove user from group")),
]
class UserQuerySet(models.QuerySet): class UserQuerySet(models.QuerySet):
"""User queryset""" """User queryset"""
@ -258,7 +225,7 @@ class UserManager(DjangoUserManager):
return self.get_queryset().exclude_anonymous() return self.get_queryset().exclude_anonymous()
class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser): class User(SerializerModel, GuardianUserMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model.""" """authentik User model, based on django's contrib auth user model."""
uuid = models.UUIDField(default=uuid4, editable=False, unique=True) uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
@ -270,30 +237,10 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
ak_groups = models.ManyToManyField("Group", related_name="users") ak_groups = models.ManyToManyField("Group", related_name="users")
password_change_date = models.DateTimeField(auto_now_add=True) password_change_date = models.DateTimeField(auto_now_add=True)
attributes = models.JSONField(default=dict, blank=True)
objects = UserManager() objects = UserManager()
class Meta:
verbose_name = _("User")
verbose_name_plural = _("Users")
permissions = [
("reset_user_password", _("Reset Password")),
("impersonate", _("Can impersonate other users")),
("assign_user_permissions", _("Can assign permissions to users")),
("unassign_user_permissions", _("Can unassign permissions from users")),
("preview_user", _("Can preview user data sent to providers")),
("view_user_applications", _("View applications the user has access to")),
]
indexes = [
models.Index(fields=["last_login"]),
models.Index(fields=["password_change_date"]),
models.Index(fields=["uuid"]),
models.Index(fields=["path"]),
models.Index(fields=["type"]),
]
def __str__(self):
return self.username
@staticmethod @staticmethod
def default_path() -> str: def default_path() -> str:
"""Get the default user path""" """Get the default user path"""
@ -375,6 +322,25 @@ class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
"""Get avatar, depending on authentik.avatar setting""" """Get avatar, depending on authentik.avatar setting"""
return get_avatar(self) return get_avatar(self)
class Meta:
verbose_name = _("User")
verbose_name_plural = _("Users")
permissions = [
("reset_user_password", _("Reset Password")),
("impersonate", _("Can impersonate other users")),
("assign_user_permissions", _("Can assign permissions to users")),
("unassign_user_permissions", _("Can unassign permissions from users")),
("preview_user", _("Can preview user data sent to providers")),
("view_user_applications", _("View applications the user has access to")),
]
indexes = [
models.Index(fields=["last_login"]),
models.Index(fields=["password_change_date"]),
models.Index(fields=["uuid"]),
models.Index(fields=["path"]),
models.Index(fields=["type"]),
]
class Provider(SerializerModel): class Provider(SerializerModel):
"""Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application""" """Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application"""
@ -462,14 +428,6 @@ class BackchannelProvider(Provider):
abstract = True abstract = True
class ApplicationQuerySet(QuerySet):
def with_provider(self) -> "QuerySet[Application]":
qs = self.select_related("provider")
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
qs = qs.select_related(f"provider__{subclass}")
return qs
class Application(SerializerModel, PolicyBindingModel): class Application(SerializerModel, PolicyBindingModel):
"""Every Application which uses authentik for authentication/identification/authorization """Every Application which uses authentik for authentication/identification/authorization
needs an Application record. Other authentication types can subclass this Model to needs an Application record. Other authentication types can subclass this Model to
@ -501,8 +459,6 @@ class Application(SerializerModel, PolicyBindingModel):
meta_description = models.TextField(default="", blank=True) meta_description = models.TextField(default="", blank=True)
meta_publisher = models.TextField(default="", blank=True) meta_publisher = models.TextField(default="", blank=True)
objects = ApplicationQuerySet.as_manager()
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:
from authentik.core.api.applications import ApplicationSerializer from authentik.core.api.applications import ApplicationSerializer
@ -539,28 +495,16 @@ class Application(SerializerModel, PolicyBindingModel):
return url return url
def get_provider(self) -> Provider | None: def get_provider(self) -> Provider | None:
"""Get casted provider instance. Needs Application queryset with_provider""" """Get casted provider instance"""
if not self.provider: if not self.provider:
return None return None
# if the Application class has been cache, self.provider is set
candidates = [] # but doing a direct query lookup will fail.
base_class = Provider # In that case, just return None
for subclass in base_class.objects.get_queryset()._get_subclasses_recurse(base_class): try:
parent = self.provider return Provider.objects.get_subclass(pk=self.provider.pk)
for level in subclass.split(LOOKUP_SEP): except Provider.DoesNotExist:
try:
parent = getattr(parent, level)
except AttributeError:
break
if parent in candidates:
continue
idx = subclass.count(LOOKUP_SEP)
if type(parent) is not base_class:
idx += 1
candidates.insert(idx, parent)
if not candidates:
return None return None
return candidates[-1]
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)
@ -590,19 +534,6 @@ class SourceUserMatchingModes(models.TextChoices):
) )
class SourceGroupMatchingModes(models.TextChoices):
"""Different modes a source can handle new/returning groups"""
IDENTIFIER = "identifier", _("Use the source-specific identifier")
NAME_LINK = "name_link", _(
"Link to a group with identical name. Can have security implications "
"when a group name is used with another source."
)
NAME_DENY = "name_deny", _(
"Use the group name, but deny enrollment when the name already exists."
)
class Source(ManagedModel, SerializerModel, PolicyBindingModel): class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
@ -612,12 +543,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
user_path_template = models.TextField(default="goauthentik.io/sources/%(slug)s") user_path_template = models.TextField(default="goauthentik.io/sources/%(slug)s")
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
user_property_mappings = models.ManyToManyField( property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
"PropertyMapping", default=None, blank=True, related_name="source_userpropertymappings_set"
)
group_property_mappings = models.ManyToManyField(
"PropertyMapping", default=None, blank=True, related_name="source_grouppropertymappings_set"
)
icon = models.FileField( icon = models.FileField(
upload_to="source-icons/", upload_to="source-icons/",
default=None, default=None,
@ -652,14 +578,6 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"a new user enrolled." "a new user enrolled."
), ),
) )
group_matching_mode = models.TextField(
choices=SourceGroupMatchingModes.choices,
default=SourceGroupMatchingModes.IDENTIFIER,
help_text=_(
"How the source determines if an existing group should be used or "
"a new group created."
),
)
objects = InheritanceManager() objects = InheritanceManager()
@ -689,11 +607,6 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Return component used to edit this object""" """Return component used to edit this object"""
raise NotImplementedError raise NotImplementedError
@property
def property_mapping_type(self) -> "type[PropertyMapping]":
"""Return property mapping type used by this object"""
raise NotImplementedError
def ui_login_button(self, request: HttpRequest) -> UILoginButton | None: def ui_login_button(self, request: HttpRequest) -> UILoginButton | None:
"""If source uses a http-based flow, return UI Information about the login """If source uses a http-based flow, return UI Information about the login
button. If source doesn't use http-based flow, return None.""" button. If source doesn't use http-based flow, return None."""
@ -704,14 +617,6 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
user settings are available, or UserSettingSerializer.""" user settings are available, or UserSettingSerializer."""
return None return None
def get_base_user_properties(self, **kwargs) -> dict[str, Any | dict[str, Any]]:
"""Get base properties for a user to build final properties upon."""
raise NotImplementedError
def get_base_group_properties(self, **kwargs) -> dict[str, Any | dict[str, Any]]:
"""Get base properties for a group to build final properties upon."""
raise NotImplementedError
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)
@ -727,11 +632,6 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"name", "name",
] ]
), ),
models.Index(
fields=[
"enabled",
]
),
] ]
@ -755,27 +655,6 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel):
unique_together = (("user", "source"),) unique_together = (("user", "source"),)
class GroupSourceConnection(SerializerModel, CreatedUpdatedModel):
"""Connection between Group and Source."""
group = models.ForeignKey(Group, on_delete=models.CASCADE)
source = models.ForeignKey(Source, on_delete=models.CASCADE)
identifier = models.TextField()
objects = InheritanceManager()
@property
def serializer(self) -> type[Serializer]:
"""Get serializer for this model"""
raise NotImplementedError
def __str__(self) -> str:
return f"Group-source connection (group={self.group_id}, source={self.source_id})"
class Meta:
unique_together = (("group", "source"),)
class ExpiringModel(models.Model): class ExpiringModel(models.Model):
"""Base Model which can expire, and is automatically cleaned up.""" """Base Model which can expire, and is automatically cleaned up."""
@ -908,7 +787,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
except ControlFlowException as exc: except ControlFlowException as exc:
raise exc raise exc
except Exception as exc: except Exception as exc:
raise PropertyMappingExpressionException(exc, self) from exc raise PropertyMappingExpressionException(self, exc) from exc
def __str__(self): def __str__(self):
return f"Property Mapping {self.name}" return f"Property Mapping {self.name}"

View File

@ -52,8 +52,6 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
@receiver(user_logged_out) @receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_): def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Delete AuthenticatedSession if it exists""" """Delete AuthenticatedSession if it exists"""
if not request.session or not request.session.session_key:
return
AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete() AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()

View File

@ -4,7 +4,7 @@ from enum import Enum
from typing import Any from typing import Any
from django.contrib import messages from django.contrib import messages
from django.db import IntegrityError, transaction from django.db import IntegrityError
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.shortcuts import redirect from django.shortcuts import redirect
@ -12,20 +12,8 @@ from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.models import ( from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
Group, from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage
GroupSourceConnection,
Source,
SourceGroupMatchingModes,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
from authentik.core.sources.mapper import SourceMapper
from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION,
PostSourceStage,
)
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.flows.exceptions import FlowNonApplicableException from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage
@ -48,10 +36,7 @@ from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
from authentik.stages.user_write.stage import PLAN_CONTEXT_USER_PATH from authentik.stages.user_write.stage import PLAN_CONTEXT_USER_PATH
LOGGER = get_logger()
SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token" # nosec SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token" # nosec
PLAN_CONTEXT_SOURCE_GROUPS = "source_groups"
class Action(Enum): class Action(Enum):
@ -85,69 +70,48 @@ class SourceFlowManager:
or deny the request.""" or deny the request."""
source: Source source: Source
mapper: SourceMapper
request: HttpRequest request: HttpRequest
identifier: str identifier: str
user_connection_type: type[UserSourceConnection] = UserSourceConnection connection_type: type[UserSourceConnection] = UserSourceConnection
group_connection_type: type[GroupSourceConnection] = GroupSourceConnection
user_info: dict[str, Any] enroll_info: dict[str, Any]
policy_context: dict[str, Any] policy_context: dict[str, Any]
user_properties: dict[str, Any | dict[str, Any]]
groups_properties: dict[str, dict[str, Any | dict[str, Any]]]
def __init__( def __init__(
self, self,
source: Source, source: Source,
request: HttpRequest, request: HttpRequest,
identifier: str, identifier: str,
user_info: dict[str, Any], enroll_info: dict[str, Any],
policy_context: dict[str, Any],
) -> None: ) -> None:
self.source = source self.source = source
self.mapper = SourceMapper(self.source)
self.request = request self.request = request
self.identifier = identifier self.identifier = identifier
self.user_info = user_info self.enroll_info = enroll_info
self._logger = get_logger().bind(source=source, identifier=identifier) self._logger = get_logger().bind(source=source, identifier=identifier)
self.policy_context = policy_context self.policy_context = {}
self.user_properties = self.mapper.build_object_properties(
object_type=User, request=request, user=None, **self.user_info
)
self.groups_properties = {
group_id: self.mapper.build_object_properties(
object_type=Group,
request=request,
user=None,
group_id=group_id,
**self.user_info,
)
for group_id in self.user_properties.setdefault("groups", [])
}
del self.user_properties["groups"]
def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911 def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911
"""decide which action should be taken""" """decide which action should be taken"""
new_connection = self.user_connection_type(source=self.source, identifier=self.identifier) new_connection = self.connection_type(source=self.source, identifier=self.identifier)
# When request is authenticated, always link # When request is authenticated, always link
if self.request.user.is_authenticated: if self.request.user.is_authenticated:
new_connection.user = self.request.user new_connection.user = self.request.user
new_connection = self.update_user_connection(new_connection, **kwargs) new_connection = self.update_connection(new_connection, **kwargs)
return Action.LINK, new_connection return Action.LINK, new_connection
existing_connections = self.user_connection_type.objects.filter( existing_connections = self.connection_type.objects.filter(
source=self.source, identifier=self.identifier source=self.source, identifier=self.identifier
) )
if existing_connections.exists(): if existing_connections.exists():
connection = existing_connections.first() connection = existing_connections.first()
return Action.AUTH, self.update_user_connection(connection, **kwargs) return Action.AUTH, self.update_connection(connection, **kwargs)
# No connection exists, but we match on identifier, so enroll # No connection exists, but we match on identifier, so enroll
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER: if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet # We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) return Action.ENROLL, self.update_connection(new_connection, **kwargs)
# Check for existing users with matching attributes # Check for existing users with matching attributes
query = Q() query = Q()
@ -156,24 +120,24 @@ class SourceFlowManager:
SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.EMAIL_DENY, SourceUserMatchingModes.EMAIL_DENY,
]: ]:
if not self.user_properties.get("email", None): if not self.enroll_info.get("email", None):
self._logger.warning("Refusing to use none email") self._logger.warning("Refusing to use none email", source=self.source)
return Action.DENY, None return Action.DENY, None
query = Q(email__exact=self.user_properties.get("email", None)) query = Q(email__exact=self.enroll_info.get("email", None))
if self.source.user_matching_mode in [ if self.source.user_matching_mode in [
SourceUserMatchingModes.USERNAME_LINK, SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY, SourceUserMatchingModes.USERNAME_DENY,
]: ]:
if not self.user_properties.get("username", None): if not self.enroll_info.get("username", None):
self._logger.warning("Refusing to use none username") self._logger.warning("Refusing to use none username", source=self.source)
return Action.DENY, None return Action.DENY, None
query = Q(username__exact=self.user_properties.get("username", None)) query = Q(username__exact=self.enroll_info.get("username", None))
self._logger.debug("trying to link with existing user", query=query) self._logger.debug("trying to link with existing user", query=query)
matching_users = User.objects.filter(query) matching_users = User.objects.filter(query)
# No matching users, always enroll # No matching users, always enroll
if not matching_users.exists(): if not matching_users.exists():
self._logger.debug("no matching users found, enrolling") self._logger.debug("no matching users found, enrolling")
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) return Action.ENROLL, self.update_connection(new_connection, **kwargs)
user = matching_users.first() user = matching_users.first()
if self.source.user_matching_mode in [ if self.source.user_matching_mode in [
@ -181,7 +145,7 @@ class SourceFlowManager:
SourceUserMatchingModes.USERNAME_LINK, SourceUserMatchingModes.USERNAME_LINK,
]: ]:
new_connection.user = user new_connection.user = user
new_connection = self.update_user_connection(new_connection, **kwargs) new_connection = self.update_connection(new_connection, **kwargs)
return Action.LINK, new_connection return Action.LINK, new_connection
if self.source.user_matching_mode in [ if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_DENY, SourceUserMatchingModes.EMAIL_DENY,
@ -192,10 +156,10 @@ class SourceFlowManager:
# Should never get here as default enroll case is returned above. # Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover return Action.DENY, None # pragma: no cover
def update_user_connection( def update_connection(
self, connection: UserSourceConnection, **kwargs self, connection: UserSourceConnection, **kwargs
) -> UserSourceConnection: # pragma: no cover ) -> UserSourceConnection: # pragma: no cover
"""Optionally make changes to the user connection after it is looked up/created.""" """Optionally make changes to the connection after it is looked up/created."""
return connection return connection
def get_flow(self, **kwargs) -> HttpResponse: def get_flow(self, **kwargs) -> HttpResponse:
@ -251,31 +215,25 @@ class SourceFlowManager:
flow: Flow | None, flow: Flow | None,
connection: UserSourceConnection, connection: UserSourceConnection,
stages: list[StageView] | None = None, stages: list[StageView] | None = None,
**flow_context, **kwargs,
) -> HttpResponse: ) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor""" """Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to kwargs.update(
# authorize application
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-user"
)
flow_context.update(
{ {
# Since we authenticate the user by their token, they have no backend set # Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: BACKEND_INBUILT, PLAN_CONTEXT_AUTHENTICATION_BACKEND: BACKEND_INBUILT,
PLAN_CONTEXT_SSO: True, PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: self.source, PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_SOURCES_CONNECTION: connection, PLAN_CONTEXT_SOURCES_CONNECTION: connection,
PLAN_CONTEXT_SOURCE_GROUPS: self.groups_properties,
} }
) )
flow_context.update(self.policy_context) kwargs.update(self.policy_context)
if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session: if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session:
token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
self._logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) self._logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
plan = token.plan plan = token.plan
plan.context[PLAN_CONTEXT_IS_RESTORED] = token plan.context[PLAN_CONTEXT_IS_RESTORED] = token
plan.context.update(flow_context) plan.context.update(kwargs)
for stage in self.get_stages_to_append(flow): for stage in self.get_stages_to_append(flow):
plan.append_stage(stage) plan.append_stage(stage)
if stages: if stages:
@ -294,8 +252,8 @@ class SourceFlowManager:
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get( final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-user" NEXT_ARG_NAME, "authentik_core:if-user"
) )
if PLAN_CONTEXT_REDIRECT not in flow_context: if PLAN_CONTEXT_REDIRECT not in kwargs:
flow_context[PLAN_CONTEXT_REDIRECT] = final_redirect kwargs[PLAN_CONTEXT_REDIRECT] = final_redirect
if not flow: if not flow:
return bad_request_message( return bad_request_message(
@ -307,12 +265,9 @@ class SourceFlowManager:
# We append some stages so the initial flow we get might be empty # We append some stages so the initial flow we get might be empty
planner.allow_empty_flows = True planner.allow_empty_flows = True
planner.use_cache = False planner.use_cache = False
plan = planner.plan(self.request, flow_context) plan = planner.plan(self.request, kwargs)
for stage in self.get_stages_to_append(flow): for stage in self.get_stages_to_append(flow):
plan.append_stage(stage) plan.append_stage(stage)
plan.append_stage(
in_memory_stage(GroupUpdateStage, group_connection_type=self.group_connection_type)
)
if stages: if stages:
for stage in stages: for stage in stages:
plan.append_stage(stage) plan.append_stage(stage)
@ -399,123 +354,7 @@ class SourceFlowManager:
) )
], ],
**{ **{
PLAN_CONTEXT_PROMPT: delete_none_values(self.user_properties), PLAN_CONTEXT_PROMPT: delete_none_values(self.enroll_info),
PLAN_CONTEXT_USER_PATH: self.source.get_user_path(), PLAN_CONTEXT_USER_PATH: self.source.get_user_path(),
}, },
) )
class GroupUpdateStage(StageView):
"""Dynamically injected stage which updates the user after enrollment/authentication."""
def get_action(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, GroupSourceConnection | None]:
"""decide which action should be taken"""
new_connection = self.group_connection_type(source=self.source, identifier=group_id)
existing_connections = self.group_connection_type.objects.filter(
source=self.source, identifier=group_id
)
if existing_connections.exists():
return Action.LINK, existing_connections.first()
# No connection exists, but we match on identifier, so enroll
if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, new_connection
# Check for existing groups with matching attributes
query = Q()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
SourceGroupMatchingModes.NAME_DENY,
]:
if not group_properties.get("name", None):
LOGGER.warning(
"Refusing to use none group name", source=self.source, group_id=group_id
)
return Action.DENY, None
query = Q(name__exact=group_properties.get("name"))
LOGGER.debug(
"trying to link with existing group", source=self.source, query=query, group_id=group_id
)
matching_groups = Group.objects.filter(query)
# No matching groups, always enroll
if not matching_groups.exists():
LOGGER.debug(
"no matching groups found, enrolling", source=self.source, group_id=group_id
)
return Action.ENROLL, new_connection
group = matching_groups.first()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
]:
new_connection.group = group
return Action.LINK, new_connection
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_DENY,
]:
LOGGER.info(
"denying source because group exists",
source=self.source,
group=group,
group_id=group_id,
)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def handle_group(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> Group | None:
action, connection = self.get_action(group_id, group_properties)
if action == Action.ENROLL:
group = Group.objects.create(**group_properties)
connection.group = group
connection.save()
return group
elif action == Action.LINK:
group = connection.group
group.update_attributes(group_properties)
connection.save()
return group
return None
def handle_groups(self) -> bool:
self.source: Source = self.executor.plan.context[PLAN_CONTEXT_SOURCE]
self.user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
self.group_connection_type: GroupSourceConnection = (
self.executor.current_stage.group_connection_type
)
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
PLAN_CONTEXT_SOURCE_GROUPS
]
groups: list[Group] = []
for group_id, group_properties in raw_groups.items():
group = self.handle_group(group_id, group_properties)
if not group:
return False
groups.append(group)
with transaction.atomic():
self.user.ak_groups.remove(
*self.user.ak_groups.filter(groupsourceconnection__source=self.source)
)
self.user.ak_groups.add(*groups)
return True
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Stage used after the user has been enrolled to sync their groups from source data"""
if self.handle_groups():
return self.executor.stage_ok()
else:
return self.executor.stage_invalid("Failed to update groups. Please try again later.")
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View File

@ -1,103 +0,0 @@
from typing import Any
from django.http import HttpRequest
from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import Group, PropertyMapping, Source, User
from authentik.events.models import Event, EventAction
from authentik.lib.merge import MERGE_LIST_UNIQUE
from authentik.lib.sync.mapper import PropertyMappingManager
from authentik.policies.utils import delete_none_values
LOGGER = get_logger()
class SourceMapper:
def __init__(self, source: Source):
self.source = source
def get_manager(
self, object_type: type[User | Group], context_keys: list[str]
) -> PropertyMappingManager:
"""Get property mapping manager for this source."""
qs = PropertyMapping.objects.none()
if object_type == User:
qs = self.source.user_property_mappings.all().select_subclasses()
elif object_type == Group:
qs = self.source.group_property_mappings.all().select_subclasses()
qs = qs.order_by("name")
return PropertyMappingManager(
qs,
self.source.property_mapping_type,
["source", "properties"] + context_keys,
)
def get_base_properties(
self, object_type: type[User | Group], **kwargs
) -> dict[str, Any | dict[str, Any]]:
"""Get base properties for a user or a group to build final properties upon."""
if object_type == User:
properties = self.source.get_base_user_properties(**kwargs)
properties.setdefault("path", self.source.get_user_path())
return properties
if object_type == Group:
return self.source.get_base_group_properties(**kwargs)
return {}
def build_object_properties(
self,
object_type: type[User | Group],
manager: "PropertyMappingManager | None" = None,
user: User | None = None,
request: HttpRequest | None = None,
**kwargs,
) -> dict[str, Any | dict[str, Any]]:
"""Build a user or group properties from the source configured property mappings."""
properties = self.get_base_properties(object_type, **kwargs)
if "attributes" not in properties:
properties["attributes"] = {}
if not manager:
manager = self.get_manager(object_type, list(kwargs.keys()))
evaluations = manager.iter_eval(
user=user,
request=request,
return_mapping=True,
source=self.source,
properties=properties,
**kwargs,
)
while True:
try:
value, mapping = next(evaluations)
except StopIteration:
break
except PropertyMappingExpressionException as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property mapping: '{exc.mapping.name}'",
source=self,
mapping=exc.mapping,
).save()
LOGGER.warning(
"Mapping failed to evaluate",
exc=exc,
source=self,
mapping=exc.mapping,
)
raise exc
if not value or not isinstance(value, dict):
LOGGER.debug(
"Mapping evaluated to None or is not a dict. Skipping",
source=self,
mapping=mapping,
)
continue
MERGE_LIST_UNIQUE.merge(properties, value)
return delete_none_values(properties)

View File

@ -4,7 +4,7 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1"> <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">

View File

@ -9,12 +9,9 @@ from rest_framework.test import APITestCase
from authentik.core.models import Application from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.models import OAuth2Provider from authentik.providers.oauth2.models import OAuth2Provider
from authentik.providers.proxy.models import ProxyProvider
from authentik.providers.saml.models import SAMLProvider
class TestApplicationsAPI(APITestCase): class TestApplicationsAPI(APITestCase):
@ -225,31 +222,3 @@ class TestApplicationsAPI(APITestCase):
], ],
}, },
) )
def test_get_provider(self):
"""Ensure that proxy providers (at the time of writing that is the only provider
that inherits from another proxy type (OAuth) instead of inheriting from the root
provider class) is correctly looked up and selected from the database"""
slug = generate_id()
provider = ProxyProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)
slug = generate_id()
provider = SAMLProvider.objects.create(name=generate_id())
Application.objects.create(
name=generate_id(),
slug=slug,
provider=provider,
)
self.assertEqual(Application.objects.get(slug=slug).get_provider(), provider)
self.assertEqual(
Application.objects.with_provider().get(slug=slug).get_provider(), provider
)

View File

@ -3,10 +3,10 @@
from json import loads from json import loads
from django.urls import reverse from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user, create_test_user from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
@ -15,7 +15,7 @@ class TestImpersonation(APITestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.other_user = create_test_user() self.other_user = User.objects.create(username="to-impersonate")
self.user = create_test_admin_user() self.user = create_test_admin_user()
def test_impersonate_simple(self): def test_impersonate_simple(self):
@ -44,26 +44,6 @@ class TestImpersonation(APITestCase):
self.assertEqual(response_body["user"]["username"], self.user.username) self.assertEqual(response_body["user"]["username"], self.user.username)
self.assertNotIn("original", response_body) self.assertNotIn("original", response_body)
def test_impersonate_scoped(self):
"""Test impersonation with scoped permissions"""
new_user = create_test_user()
assign_perm("authentik_core.impersonate", new_user, self.other_user)
assign_perm("authentik_core.view_user", new_user, self.other_user)
self.client.force_login(new_user)
response = self.client.post(
reverse(
"authentik_api:user-impersonate",
kwargs={"pk": self.other_user.pk},
)
)
self.assertEqual(response.status_code, 201)
response = self.client.get(reverse("authentik_api:user-me"))
response_body = loads(response.content.decode())
self.assertEqual(response_body["user"]["username"], self.other_user.username)
self.assertEqual(response_body["original"]["username"], new_user.username)
def test_impersonate_denied(self): def test_impersonate_denied(self):
"""test impersonation without permissions""" """test impersonation without permissions"""
self.client.force_login(self.other_user) self.client.force_login(self.other_user)

View File

@ -38,9 +38,7 @@ class TestSourceFlowManager(TestCase):
def test_unauthenticated_enroll(self): def test_unauthenticated_enroll(self):
"""Test un-authenticated user enrolling""" """Test un-authenticated user enrolling"""
request = get_request("/", user=AnonymousUser()) request = get_request("/", user=AnonymousUser())
flow_manager = OAuthSourceFlowManager( flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
self.source, request, self.identifier, {"info": {}}, {}
)
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL) self.assertEqual(action, Action.ENROLL)
response = flow_manager.get_flow() response = flow_manager.get_flow()
@ -54,9 +52,7 @@ class TestSourceFlowManager(TestCase):
user=get_anonymous_user(), source=self.source, identifier=self.identifier user=get_anonymous_user(), source=self.source, identifier=self.identifier
) )
request = get_request("/", user=AnonymousUser()) request = get_request("/", user=AnonymousUser())
flow_manager = OAuthSourceFlowManager( flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
self.source, request, self.identifier, {"info": {}}, {}
)
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.AUTH) self.assertEqual(action, Action.AUTH)
response = flow_manager.get_flow() response = flow_manager.get_flow()
@ -68,9 +64,7 @@ class TestSourceFlowManager(TestCase):
"""Test authenticated user linking""" """Test authenticated user linking"""
user = User.objects.create(username="foo", email="foo@bar.baz") user = User.objects.create(username="foo", email="foo@bar.baz")
request = get_request("/", user=user) request = get_request("/", user=user)
flow_manager = OAuthSourceFlowManager( flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
self.source, request, self.identifier, {"info": {}}, {}
)
action, connection = flow_manager.get_action() action, connection = flow_manager.get_action()
self.assertEqual(action, Action.LINK) self.assertEqual(action, Action.LINK)
self.assertIsNone(connection.pk) self.assertIsNone(connection.pk)
@ -83,9 +77,7 @@ class TestSourceFlowManager(TestCase):
def test_unauthenticated_link(self): def test_unauthenticated_link(self):
"""Test un-authenticated user linking""" """Test un-authenticated user linking"""
flow_manager = OAuthSourceFlowManager( flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {})
self.source, get_request("/"), self.identifier, {"info": {}}, {}
)
action, connection = flow_manager.get_action() action, connection = flow_manager.get_action()
self.assertEqual(action, Action.LINK) self.assertEqual(action, Action.LINK)
self.assertIsNone(connection.pk) self.assertIsNone(connection.pk)
@ -98,7 +90,7 @@ class TestSourceFlowManager(TestCase):
# Without email, deny # Without email, deny
flow_manager = OAuthSourceFlowManager( flow_manager = OAuthSourceFlowManager(
self.source, get_request("/", user=AnonymousUser()), self.identifier, {"info": {}}, {} self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.DENY) self.assertEqual(action, Action.DENY)
@ -108,12 +100,7 @@ class TestSourceFlowManager(TestCase):
self.source, self.source,
get_request("/", user=AnonymousUser()), get_request("/", user=AnonymousUser()),
self.identifier, self.identifier,
{ {"email": "foo@bar.baz"},
"info": {
"email": "foo@bar.baz",
},
},
{},
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.LINK) self.assertEqual(action, Action.LINK)
@ -126,7 +113,7 @@ class TestSourceFlowManager(TestCase):
# Without username, deny # Without username, deny
flow_manager = OAuthSourceFlowManager( flow_manager = OAuthSourceFlowManager(
self.source, get_request("/", user=AnonymousUser()), self.identifier, {"info": {}}, {} self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.DENY) self.assertEqual(action, Action.DENY)
@ -136,10 +123,7 @@ class TestSourceFlowManager(TestCase):
self.source, self.source,
get_request("/", user=AnonymousUser()), get_request("/", user=AnonymousUser()),
self.identifier, self.identifier,
{ {"username": "foo"},
"info": {"username": "foo"},
},
{},
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.LINK) self.assertEqual(action, Action.LINK)
@ -156,11 +140,8 @@ class TestSourceFlowManager(TestCase):
get_request("/", user=AnonymousUser()), get_request("/", user=AnonymousUser()),
self.identifier, self.identifier,
{ {
"info": { "username": "bar",
"username": "bar",
},
}, },
{},
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL) self.assertEqual(action, Action.ENROLL)
@ -170,10 +151,7 @@ class TestSourceFlowManager(TestCase):
self.source, self.source,
get_request("/", user=AnonymousUser()), get_request("/", user=AnonymousUser()),
self.identifier, self.identifier,
{ {"username": "foo"},
"info": {"username": "foo"},
},
{},
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.DENY) self.assertEqual(action, Action.DENY)
@ -187,10 +165,7 @@ class TestSourceFlowManager(TestCase):
self.source, self.source,
get_request("/", user=AnonymousUser()), get_request("/", user=AnonymousUser()),
self.identifier, self.identifier,
{ {"username": "foo"},
"info": {"username": "foo"},
},
{},
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL) self.assertEqual(action, Action.ENROLL)
@ -216,10 +191,7 @@ class TestSourceFlowManager(TestCase):
self.source, self.source,
get_request("/", user=AnonymousUser()), get_request("/", user=AnonymousUser()),
self.identifier, self.identifier,
{ {"username": "foo"},
"info": {"username": "foo"},
},
{},
) )
action, _ = flow_manager.get_action() action, _ = flow_manager.get_action()
self.assertEqual(action, Action.ENROLL) self.assertEqual(action, Action.ENROLL)

View File

@ -1,237 +0,0 @@
"""Test Source flow_manager group update stage"""
from django.test import RequestFactory
from authentik.core.models import Group, SourceGroupMatchingModes
from authentik.core.sources.flow_manager import PLAN_CONTEXT_SOURCE_GROUPS, GroupUpdateStage
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import in_memory_stage
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import GroupOAuthSourceConnection, OAuthSource
class TestSourceFlowManager(FlowTestCase):
"""Test Source flow_manager group update stage"""
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.authentication_flow = create_test_flow()
self.enrollment_flow = create_test_flow()
self.source: OAuthSource = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
authentication_flow=self.authentication_flow,
enrollment_flow=self.enrollment_flow,
)
self.identifier = generate_id()
self.user = create_test_admin_user()
def test_nonexistant_group(self):
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(
group=Group.objects.get(name="group 1"), source=self.source
).exists()
)
def test_nonexistant_group_name_link(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK
self.source.save()
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(
group=Group.objects.get(name="group 1"), source=self.source
).exists()
)
def test_existant_group_name_link(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK
self.source.save()
group = Group.objects.create(name="group 1")
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(group=group, source=self.source).exists()
)
def test_nonexistant_group_name_deny(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_DENY
self.source.save()
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertTrue(Group.objects.filter(name="group 1").exists())
self.assertTrue(self.user.ak_groups.filter(name="group 1").exists())
self.assertTrue(
GroupOAuthSourceConnection.objects.filter(
group=Group.objects.get(name="group 1"), source=self.source
).exists()
)
def test_existant_group_name_deny(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_DENY
self.source.save()
group = Group.objects.create(name="group 1")
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"group 1": {
"name": "group 1",
},
},
},
),
),
request=request,
)
self.assertFalse(stage.handle_groups())
self.assertFalse(self.user.ak_groups.filter(name="group 1").exists())
self.assertFalse(
GroupOAuthSourceConnection.objects.filter(group=group, source=self.source).exists()
)
def test_group_updates(self):
self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK
self.source.save()
other_group = Group.objects.create(name="other group")
old_group = Group.objects.create(name="old group")
new_group = Group.objects.create(name="new group")
self.user.ak_groups.set([other_group, old_group])
GroupOAuthSourceConnection.objects.create(
group=old_group, source=self.source, identifier=old_group.name
)
GroupOAuthSourceConnection.objects.create(
group=new_group, source=self.source, identifier=new_group.name
)
request = self.factory.get("/")
stage = GroupUpdateStage(
FlowExecutorView(
current_stage=in_memory_stage(
GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection
),
plan=FlowPlan(
flow_pk=generate_id(),
context={
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_PENDING_USER: self.user,
PLAN_CONTEXT_SOURCE_GROUPS: {
"new group": {
"name": "new group",
},
},
},
),
),
request=request,
)
self.assertTrue(stage.handle_groups())
self.assertFalse(self.user.ak_groups.filter(name="old group").exists())
self.assertTrue(self.user.ak_groups.filter(name="other group").exists())
self.assertTrue(self.user.ak_groups.filter(name="new group").exists())
self.assertEqual(self.user.ak_groups.count(), 2)

View File

@ -1,72 +0,0 @@
"""Test Source Property mappings"""
from django.test import TestCase
from authentik.core.models import Group, PropertyMapping, Source, User
from authentik.core.sources.mapper import SourceMapper
from authentik.lib.generators import generate_id
class ProxySource(Source):
@property
def property_mapping_type(self):
return PropertyMapping
def get_base_user_properties(self, **kwargs):
return {
"username": kwargs.get("username", None),
"email": kwargs.get("email", "default@authentik"),
}
def get_base_group_properties(self, **kwargs):
return {"name": kwargs.get("name", None)}
class Meta:
proxy = True
class TestSourcePropertyMappings(TestCase):
"""Test Source PropertyMappings"""
def test_base_properties(self):
source = ProxySource.objects.create(name=generate_id(), slug=generate_id(), enabled=True)
mapper = SourceMapper(source)
user_base_properties = mapper.get_base_properties(User, username="test1")
self.assertEqual(
user_base_properties,
{
"username": "test1",
"email": "default@authentik",
"path": f"goauthentik.io/sources/{source.slug}",
},
)
group_base_properties = mapper.get_base_properties(Group)
self.assertEqual(group_base_properties, {"name": None})
def test_build_properties(self):
source = ProxySource.objects.create(name=generate_id(), slug=generate_id(), enabled=True)
mapper = SourceMapper(source)
source.user_property_mappings.add(
PropertyMapping.objects.create(
name=generate_id(),
expression="""
return {"username": data.get("username", None), "email": None}
""",
)
)
properties = mapper.build_object_properties(
object_type=User, user=None, request=None, username="test1", data={"username": "test2"}
)
self.assertEqual(
properties,
{
"username": "test2",
"path": f"goauthentik.io/sources/{source.slug}",
"attributes": {},
},
)

View File

@ -6,6 +6,7 @@ from django.conf import settings
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.urls import path from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import RedirectView
from authentik.core.api.applications import ApplicationViewSet from authentik.core.api.applications import ApplicationViewSet
from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet
@ -17,13 +18,9 @@ from authentik.core.api.sources import SourceViewSet, UserSourceConnectionViewSe
from authentik.core.api.tokens import TokenViewSet from authentik.core.api.tokens import TokenViewSet
from authentik.core.api.transactional_applications import TransactionalApplicationView from authentik.core.api.transactional_applications import TransactionalApplicationView
from authentik.core.api.users import UserViewSet from authentik.core.api.users import UserViewSet
from authentik.core.views.apps import RedirectToAppLaunch from authentik.core.views import apps
from authentik.core.views.debug import AccessDeniedView from authentik.core.views.debug import AccessDeniedView
from authentik.core.views.interface import ( from authentik.core.views.interface import InterfaceView
BrandDefaultRedirectView,
InterfaceView,
RootRedirectView,
)
from authentik.core.views.session import EndSessionView from authentik.core.views.session import EndSessionView
from authentik.flows.views.interface import FlowInterfaceView from authentik.flows.views.interface import FlowInterfaceView
from authentik.root.asgi_middleware import SessionMiddleware from authentik.root.asgi_middleware import SessionMiddleware
@ -33,24 +30,26 @@ from authentik.root.middleware import ChannelsLoggingMiddleware
urlpatterns = [ urlpatterns = [
path( path(
"", "",
login_required(RootRedirectView.as_view()), login_required(
RedirectView.as_view(pattern_name="authentik_core:if-user", query_string=True)
),
name="root-redirect", name="root-redirect",
), ),
path( path(
# We have to use this format since everything else uses application/o or application/saml # We have to use this format since everything else uses applications/o or applications/saml
"application/launch/<slug:application_slug>/", "application/launch/<slug:application_slug>/",
RedirectToAppLaunch.as_view(), apps.RedirectToAppLaunch.as_view(),
name="application-launch", name="application-launch",
), ),
# Interfaces # Interfaces
path( path(
"if/admin/", "if/admin/",
ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/admin.html")), ensure_csrf_cookie(InterfaceView.as_view(template_name="if/admin.html")),
name="if-admin", name="if-admin",
), ),
path( path(
"if/user/", "if/user/",
ensure_csrf_cookie(BrandDefaultRedirectView.as_view(template_name="if/user.html")), ensure_csrf_cookie(InterfaceView.as_view(template_name="if/user.html")),
name="if-user", name="if-user",
), ),
path( path(

View File

@ -8,6 +8,7 @@ from django.views import View
from authentik.core.models import Application from authentik.core.models import Application
from authentik.flows.challenge import ( from authentik.flows.challenge import (
ChallengeResponse, ChallengeResponse,
ChallengeTypes,
HttpChallengeResponse, HttpChallengeResponse,
RedirectChallenge, RedirectChallenge,
) )
@ -73,6 +74,7 @@ class RedirectToAppStage(ChallengeStageView):
raise Http404 raise Http404
return RedirectChallenge( return RedirectChallenge(
instance={ instance={
"type": ChallengeTypes.REDIRECT.value,
"to": launch, "to": launch,
} }
) )

View File

@ -3,42 +3,13 @@
from json import dumps from json import dumps
from typing import Any from typing import Any
from django.http import HttpRequest from django.views.generic.base import TemplateView
from django.http.response import HttpResponse
from django.shortcuts import redirect
from django.utils.translation import gettext as _
from django.views.generic.base import RedirectView, TemplateView
from rest_framework.request import Request from rest_framework.request import Request
from authentik import get_build_hash from authentik import get_build_hash
from authentik.admin.tasks import LOCAL_VERSION from authentik.admin.tasks import LOCAL_VERSION
from authentik.api.v3.config import ConfigView from authentik.api.v3.config import ConfigView
from authentik.brands.api import CurrentBrandSerializer from authentik.brands.api import CurrentBrandSerializer
from authentik.brands.models import Brand
from authentik.core.models import UserTypes
from authentik.policies.denied import AccessDeniedResponse
class RootRedirectView(RedirectView):
"""Root redirect view, redirect to brand's default application if set"""
pattern_name = "authentik_core:if-user"
query_string = True
def redirect_to_app(self, request: HttpRequest):
if request.user.is_authenticated and request.user.type == UserTypes.EXTERNAL:
brand: Brand = request.brand
if brand.default_application:
return redirect(
"authentik_core:application-launch",
application_slug=brand.default_application.slug,
)
return None
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
if redirect_response := RootRedirectView().redirect_to_app(request):
return redirect_response
return super().dispatch(request, *args, **kwargs)
class InterfaceView(TemplateView): class InterfaceView(TemplateView):
@ -52,20 +23,3 @@ class InterfaceView(TemplateView):
kwargs["build"] = get_build_hash() kwargs["build"] = get_build_hash()
kwargs["url_kwargs"] = self.kwargs kwargs["url_kwargs"] = self.kwargs
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
class BrandDefaultRedirectView(InterfaceView):
"""By default redirect to default app"""
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
if request.user.is_authenticated and request.user.type == UserTypes.EXTERNAL:
brand: Brand = request.brand
if brand.default_application:
return redirect(
"authentik_core:application-launch",
application_slug=brand.default_application.slug,
)
response = AccessDeniedResponse(self.request)
response.error_message = _("Interface can only be accessed by internal users.")
return response
return super().dispatch(request, *args, **kwargs)

View File

@ -35,7 +35,6 @@ from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger() LOGGER = get_logger()
@ -266,7 +265,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
], ],
responses={200: CertificateDataSerializer(many=False)}, responses={200: CertificateDataSerializer(many=False)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def view_certificate(self, request: Request, pk: str) -> Response: def view_certificate(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs certificate and log access""" """Return certificate-key pairs certificate and log access"""
certificate: CertificateKeyPair = self.get_object() certificate: CertificateKeyPair = self.get_object()
@ -296,7 +295,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
], ],
responses={200: CertificateDataSerializer(many=False)}, responses={200: CertificateDataSerializer(many=False)},
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def view_private_key(self, request: Request, pk: str) -> Response: def view_private_key(self, request: Request, pk: str) -> Response:
"""Return certificate-key pairs private key and log access""" """Return certificate-key pairs private key and log access"""
certificate: CertificateKeyPair = self.get_object() certificate: CertificateKeyPair = self.get_object()

View File

@ -76,7 +76,7 @@ class CertificateBuilder:
.subject_name( .subject_name(
x509.Name( x509.Name(
[ [
x509.NameAttribute(NameOID.COMMON_NAME, self.common_name[:64]), x509.NameAttribute(NameOID.COMMON_NAME, self.common_name),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"), x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"),
] ]

View File

@ -214,46 +214,6 @@ class TestCrypto(APITestCase):
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertIn("Content-Disposition", response) self.assertIn("Content-Disposition", response)
def test_certificate_download_denied(self):
"""Test certificate export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-certificate",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_private_key_download_denied(self):
"""Test private_key export (download)"""
self.client.logout()
keypair = create_test_cert()
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-view-private-key",
kwargs={"pk": keypair.pk},
),
data={"download": True},
)
self.assertEqual(403, response.status_code)
def test_used_by(self): def test_used_by(self):
"""Test used_by endpoint""" """Test used_by endpoint"""
self.client.force_login(create_test_admin_user()) self.client.force_login(create_test_admin_user())
@ -286,26 +246,6 @@ class TestCrypto(APITestCase):
], ],
) )
def test_used_by_denied(self):
"""Test used_by endpoint"""
self.client.logout()
keypair = create_test_cert()
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://localhost",
signing_key=keypair,
)
response = self.client.get(
reverse(
"authentik_api:certificatekeypair-used-by",
kwargs={"pk": keypair.pk},
)
)
self.assertEqual(403, response.status_code)
def test_discovery(self): def test_discovery(self):
"""Test certificate discovery""" """Test certificate discovery"""
name = generate_id() name = generate_id()

View File

@ -1,11 +1,12 @@
"""Enterprise API Views""" """Enterprise API Views"""
from dataclasses import asdict
from datetime import timedelta from datetime import timedelta
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, inline_serializer from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import CharField, IntegerField from rest_framework.fields import CharField, IntegerField
@ -29,7 +30,7 @@ class EnterpriseRequiredMixin:
def validate(self, attrs: dict) -> dict: def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists""" """Check that a valid license exists"""
if not LicenseKey.cached_summary().status.is_valid: if not LicenseKey.cached_summary().has_license:
raise ValidationError(_("Enterprise is required to create/update this object.")) raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs) return super().validate(attrs)
@ -86,7 +87,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
}, },
) )
@action(detail=False, methods=["GET"]) @action(detail=False, methods=["GET"])
def install_id(self, request: Request) -> Response: def get_install_id(self, request: Request) -> Response:
"""Get install_id""" """Get install_id"""
return Response( return Response(
data={ data={
@ -99,22 +100,12 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
responses={ responses={
200: LicenseSummarySerializer(), 200: LicenseSummarySerializer(),
}, },
parameters=[
OpenApiParameter(
name="cached",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.BOOL,
default=True,
)
],
) )
@action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated])
def summary(self, request: Request) -> Response: def summary(self, request: Request) -> Response:
"""Get the total license status""" """Get the total license status"""
summary = LicenseKey.cached_summary() response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary()))
if request.query_params.get("cached", "true").lower() == "false": response.is_valid(raise_exception=True)
summary = LicenseKey.get_total().summary()
response = LicenseSummarySerializer(instance=summary)
return Response(response.data) return Response(response.data)
@permission_required(None, ["authentik_enterprise.view_license"]) @permission_required(None, ["authentik_enterprise.view_license"])
@ -137,7 +128,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
forecast_for_months = 12 forecast_for_months = 12
response = LicenseForecastSerializer( response = LicenseForecastSerializer(
data={ data={
"internal_users": LicenseKey.get_internal_user_count(), "internal_users": LicenseKey.get_default_user_count(),
"external_users": LicenseKey.get_external_user_count(), "external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months), "forecasted_internal_users": (internal_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months), "forecasted_external_users": (external_in_last_month * forecast_for_months),

View File

@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
"""Actual enterprise check, cached""" """Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.cached_summary().status.is_valid return LicenseKey.cached_summary().valid

View File

@ -3,37 +3,24 @@
from base64 import b64decode from base64 import b64decode
from binascii import Error from binascii import Error
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from time import mktime from time import mktime
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import DaciteError, from_dict from dacite import from_dict
from django.core.cache import cache from django.core.cache import cache
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import ( from rest_framework.fields import BooleanField, DateTimeField, IntegerField
ChoiceField,
DateTimeField,
IntegerField,
ListField,
)
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes from authentik.core.models import User, UserTypes
from authentik.enterprise.models import ( from authentik.enterprise.models import License, LicenseUsage
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_EXPIRY_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.tenants.utils import get_unique_identifier from authentik.tenants.utils import get_unique_identifier
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
@ -55,9 +42,6 @@ def get_license_aud() -> str:
class LicenseFlags(Enum): class LicenseFlags(Enum):
"""License flags""" """License flags"""
TRIAL = "trial"
NON_PRODUCTION = "non_production"
@dataclass @dataclass
class LicenseSummary: class LicenseSummary:
@ -65,9 +49,12 @@ class LicenseSummary:
internal_users: int internal_users: int
external_users: int external_users: int
status: LicenseUsageStatus valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
latest_valid: datetime latest_valid: datetime
license_flags: list[LicenseFlags] has_license: bool
class LicenseSummarySerializer(PassiveSerializer): class LicenseSummarySerializer(PassiveSerializer):
@ -75,9 +62,12 @@ class LicenseSummarySerializer(PassiveSerializer):
internal_users = IntegerField(required=True) internal_users = IntegerField(required=True)
external_users = IntegerField(required=True) external_users = IntegerField(required=True)
status = ChoiceField(choices=LicenseUsageStatus.choices) valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
latest_valid = DateTimeField() latest_valid = DateTimeField()
license_flags = ListField(child=ChoiceField(choices=tuple(x.value for x in LicenseFlags))) has_license = BooleanField()
@dataclass @dataclass
@ -90,10 +80,10 @@ class LicenseKey:
name: str name: str
internal_users: int = 0 internal_users: int = 0
external_users: int = 0 external_users: int = 0
license_flags: list[LicenseFlags] = field(default_factory=list) flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod @staticmethod
def validate(jwt: str, check_expiry=True) -> "LicenseKey": def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT""" """Validate the license from a given JWT"""
try: try:
headers = get_unverified_header(jwt) headers = get_unverified_header(jwt)
@ -117,28 +107,26 @@ class LicenseKey:
our_cert.public_key(), our_cert.public_key(),
algorithms=["ES512"], algorithms=["ES512"],
audience=get_license_aud(), audience=get_license_aud(),
options={"verify_exp": check_expiry, "verify_signature": check_expiry},
), ),
) )
except PyJWTError: except PyJWTError:
unverified = decode(jwt, options={"verify_signature": False})
if unverified["aud"] != get_license_aud():
raise ValidationError("Invalid Install ID in license") from None
raise ValidationError("Unable to verify license") from None raise ValidationError("Unable to verify license") from None
return body return body
@staticmethod @staticmethod
def get_total() -> "LicenseKey": def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses""" """Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in License.objects.all(): for lic in active_licenses:
total.internal_users += lic.internal_users total.internal_users += lic.internal_users
total.external_users += lic.external_users total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple())) exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0: if total.exp == 0:
total.exp = exp_ts total.exp = exp_ts
total.exp = max(total.exp, exp_ts) if exp_ts <= total.exp:
total.license_flags.extend(lic.status.license_flags) total.exp = exp_ts
total.flags.extend(lic.status.flags)
return total return total
@staticmethod @staticmethod
@ -147,7 +135,7 @@ class LicenseKey:
return User.objects.all().exclude_anonymous().exclude(is_active=False) return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod @staticmethod
def get_internal_user_count(): def get_default_user_count():
"""Get current default user count""" """Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@ -156,73 +144,59 @@ class LicenseKey:
"""Get current external user count""" """Get current external user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count() return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()
def _last_valid_date(self): def is_valid(self) -> bool:
last_valid_date = ( """Check if the given license body covers all users
LicenseUsage.objects.order_by("-record_date")
.filter(status=LicenseUsageStatus.VALID)
.first()
)
if not last_valid_date:
return datetime.fromtimestamp(0, UTC)
return last_valid_date.record_date
def status(self) -> LicenseUsageStatus: Only checks the current count, no historical data is checked"""
"""Check if the given license body covers all users, and is valid.""" default_users = self.get_default_user_count()
last_valid = self._last_valid_date() if default_users > self.internal_users:
if self.exp == 0 and not License.objects.exists(): return False
return LicenseUsageStatus.UNLICENSED active_users = self.get_external_user_count()
_now = now() if active_users > self.external_users:
# Check limit-exceeded based status return False
internal_users = self.get_internal_user_count() return True
external_users = self.get_external_user_count()
if internal_users > self.internal_users or external_users > self.external_users:
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
return LicenseUsageStatus.READ_ONLY
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
# Check expiry based status
if datetime.fromtimestamp(self.exp, UTC) < _now:
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
weeks=THRESHOLD_READ_ONLY_WEEKS
):
return LicenseUsageStatus.READ_ONLY
return LicenseUsageStatus.EXPIRED
# Expiry warning
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
):
return LicenseUsageStatus.EXPIRY_SOON
return LicenseUsageStatus.VALID
def record_usage(self): def record_usage(self):
"""Capture the current validity status and metrics and save them""" """Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8) threshold = now() - timedelta(hours=8)
usage = ( if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first() LicenseUsage.objects.create(
) user_count=self.get_default_user_count(),
if not usage:
usage = LicenseUsage.objects.create(
internal_user_count=self.get_internal_user_count(),
external_user_count=self.get_external_user_count(), external_user_count=self.get_external_user_count(),
status=self.status(), within_limits=self.is_valid(),
) )
summary = asdict(self.summary()) summary = asdict(self.summary())
# Also cache the latest summary for the middleware # Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE) cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return usage return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
def summary(self) -> LicenseSummary: def summary(self) -> LicenseSummary:
"""Summary of license status""" """Summary of license status"""
status = self.status() has_license = License.objects.all().count() > 0
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
latest_valid = datetime.fromtimestamp(self.exp) latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary( return LicenseSummary(
show_admin_warning=show_admin_warning and has_license,
show_user_warning=show_user_warning and has_license,
read_only=read_only and has_license,
latest_valid=latest_valid, latest_valid=latest_valid,
internal_users=self.internal_users, internal_users=self.internal_users,
external_users=self.external_users, external_users=self.external_users,
status=status, valid=self.is_valid(),
license_flags=self.license_flags, has_license=has_license,
) )
@staticmethod @staticmethod
@ -231,8 +205,4 @@ class LicenseKey:
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE) summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary: if not summary:
return LicenseKey.get_total().summary() return LicenseKey.get_total().summary()
try: return from_dict(LicenseSummary, summary)
return from_dict(LicenseSummary, summary)
except DaciteError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()

View File

@ -8,7 +8,6 @@ from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsageStatus
from authentik.flows.views.executor import FlowExecutorView from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
@ -44,7 +43,7 @@ class EnterpriseMiddleware:
cached_status = LicenseKey.cached_summary() cached_status = LicenseKey.cached_summary()
if not cached_status: if not cached_status:
return True return True
if cached_status.status == LicenseUsageStatus.READ_ONLY: if cached_status.read_only:
return False return False
return True return True
@ -54,10 +53,10 @@ class EnterpriseMiddleware:
if request.method.lower() in ["get", "head", "options", "trace"]: if request.method.lower() in ["get", "head", "options", "trace"]:
return True return True
# Always allow requests to manage licenses # Always allow requests to manage licenses
if request.resolver_match._func_path == class_to_path(LicenseViewSet): if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
return True return True
# Flow executor is mounted as an API path but explicitly allowed # Flow executor is mounted as an API path but explicitly allowed
if request.resolver_match._func_path == class_to_path(FlowExecutorView): if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
return True return True
# Only apply these restrictions to the API # Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names: if "authentik_api" not in request.resolver_match.app_names:

View File

@ -1,68 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-08 14:15
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage")
db_alias = schema_editor.connection.alias
for usage in LicenseUsage.objects.using(db_alias).all():
usage.status = "valid" if usage.within_limits else "limit_exceeded_admin"
usage.save()
class Migration(migrations.Migration):
dependencies = [
("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"),
]
operations = [
migrations.AddField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
default=None,
null=True,
),
preserve_default=False,
),
migrations.RunPython(migrate_license_usage),
migrations.RemoveField(
model_name="licenseusage",
name="within_limits",
),
migrations.AlterField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
),
preserve_default=False,
),
migrations.RenameField(
model_name="licenseusage",
old_name="user_count",
new_name="internal_user_count",
),
]

View File

@ -17,17 +17,6 @@ if TYPE_CHECKING:
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
THRESHOLD_WARNING_ADMIN_WEEKS = 2
THRESHOLD_WARNING_USER_WEEKS = 4
THRESHOLD_WARNING_EXPIRY_WEEKS = 2
THRESHOLD_READ_ONLY_WEEKS = 6
class License(SerializerModel): class License(SerializerModel):
"""An authentik enterprise license""" """An authentik enterprise license"""
@ -50,7 +39,7 @@ class License(SerializerModel):
"""Get parsed license status""" """Get parsed license status"""
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key, check_expiry=False) return LicenseKey.validate(self.key)
class Meta: class Meta:
indexes = (HashIndex(fields=("key",)),) indexes = (HashIndex(fields=("key",)),)
@ -58,23 +47,9 @@ class License(SerializerModel):
verbose_name_plural = _("Licenses") verbose_name_plural = _("Licenses")
class LicenseUsageStatus(models.TextChoices): def usage_expiry():
"""License states an instance/tenant can be in""" """Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
UNLICENSED = "unlicensed"
VALID = "valid"
EXPIRED = "expired"
EXPIRY_SOON = "expiry_soon"
# User limit exceeded, 2 week threshold, show message in admin interface
LIMIT_EXCEEDED_ADMIN = "limit_exceeded_admin"
# User limit exceeded, 4 week threshold, show message in user interface
LIMIT_EXCEEDED_USER = "limit_exceeded_user"
READ_ONLY = "read_only"
@property
def is_valid(self) -> bool:
"""Quickly check if a license is valid"""
return self in [LicenseUsageStatus.VALID, LicenseUsageStatus.EXPIRY_SOON]
class LicenseUsage(ExpiringModel): class LicenseUsage(ExpiringModel):
@ -84,9 +59,9 @@ class LicenseUsage(ExpiringModel):
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
internal_user_count = models.BigIntegerField() user_count = models.BigIntegerField()
external_user_count = models.BigIntegerField() external_user_count = models.BigIntegerField()
status = models.TextField(choices=LicenseUsageStatus.choices) within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True) record_date = models.DateTimeField(auto_now_add=True)

View File

@ -13,7 +13,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
def check_license(self): def check_license(self):
"""Check license""" """Check license"""
if not LicenseKey.get_total().status().is_valid: if not LicenseKey.get_total().is_valid():
return PolicyResult(False, _("Enterprise required to access this feature.")) return PolicyResult(False, _("Enterprise required to access this feature."))
if self.request.user.type != UserTypes.INTERNAL: if self.request.user.type != UserTypes.INTERNAL:
return PolicyResult(False, _("Feature only accessible for internal users.")) return PolicyResult(False, _("Feature only accessible for internal users."))

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import ( from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
google_workspace_sync,
google_workspace_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -55,4 +52,3 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_single_task = google_workspace_sync sync_single_task = google_workspace_sync
sync_objects_task = google_workspace_sync_objects

View File

@ -181,7 +181,7 @@ class GoogleWorkspaceProviderMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-google-workspace-form" return "ak-property-mapping-google-workspace-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -6,10 +6,7 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.api import EnterpriseRequiredMixin from authentik.enterprise.api import EnterpriseRequiredMixin
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import ( from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
microsoft_entra_sync,
microsoft_entra_sync_objects,
)
from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin from authentik.lib.sync.outgoing.api import OutgoingSyncProviderStatusMixin
@ -53,4 +50,3 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_single_task = microsoft_entra_sync sync_single_task = microsoft_entra_sync
sync_objects_task = microsoft_entra_sync_objects

View File

@ -170,7 +170,7 @@ class MicrosoftEntraProviderMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-microsoft-entra-form" return "ak-property-mapping-microsoft-entra-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -34,12 +34,6 @@ class ConnectionTokenSerializer(EnterpriseRequiredMixin, ModelSerializer):
] ]
class ConnectionTokenOwnerFilter(OwnerFilter):
"""Owner filter for connection tokens (checks session's user)"""
owner_key = "session__user"
class ConnectionTokenViewSet( class ConnectionTokenViewSet(
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.UpdateModelMixin, mixins.UpdateModelMixin,
@ -56,9 +50,4 @@ class ConnectionTokenViewSet(
search_fields = ["endpoint__name", "provider__name"] search_fields = ["endpoint__name", "provider__name"]
ordering = ["endpoint__name", "provider__name"] ordering = ["endpoint__name", "provider__name"]
permission_classes = [OwnerSuperuserPermissions] permission_classes = [OwnerSuperuserPermissions]
filter_backends = [ filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
ConnectionTokenOwnerFilter,
DjangoFilterBackend,
OrderingFilter,
SearchFilter,
]

View File

@ -1,20 +0,0 @@
# Generated by Django 5.0.8 on 2024-08-12 12:54
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_rac", "0004_alter_connectiontoken_expires"),
]
operations = [
migrations.AlterModelOptions(
name="racpropertymapping",
options={
"verbose_name": "RAC Provider Property Mapping",
"verbose_name_plural": "RAC Provider Property Mappings",
},
),
]

View File

@ -125,7 +125,7 @@ class RACPropertyMapping(PropertyMapping):
@property @property
def component(self) -> str: def component(self) -> str:
return "ak-property-mapping-provider-rac-form" return "ak-property-mapping-rac-form"
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
@ -136,8 +136,8 @@ class RACPropertyMapping(PropertyMapping):
return RACPropertyMappingSerializer return RACPropertyMappingSerializer
class Meta: class Meta:
verbose_name = _("RAC Provider Property Mapping") verbose_name = _("RAC Property Mapping")
verbose_name_plural = _("RAC Provider Property Mappings") verbose_name_plural = _("RAC Property Mappings")
class ConnectionToken(ExpiringModel): class ConnectionToken(ExpiringModel):

View File

@ -21,8 +21,6 @@ from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint
@receiver(user_logged_out) @receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_): def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Disconnect any open RAC connections""" """Disconnect any open RAC connections"""
if not request.session or not request.session.session_key:
return
layer = get_channel_layer() layer = get_channel_layer()
async_to_sync(layer.group_send)( async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_SESSION RAC_CLIENT_GROUP_SESSION

View File

@ -5,6 +5,7 @@ from channels.sessions import CookieMiddleware
from django.urls import path from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from authentik.core.channels import TokenOutpostMiddleware
from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet
@ -12,7 +13,6 @@ from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet
from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer
from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer
from authentik.enterprise.providers.rac.views import RACInterface, RACStartView from authentik.enterprise.providers.rac.views import RACInterface, RACStartView
from authentik.outposts.channels import TokenOutpostMiddleware
from authentik.root.asgi_middleware import SessionMiddleware from authentik.root.asgi_middleware import SessionMiddleware
from authentik.root.middleware import ChannelsLoggingMiddleware from authentik.root.middleware import ChannelsLoggingMiddleware
@ -44,7 +44,7 @@ websocket_urlpatterns = [
api_urlpatterns = [ api_urlpatterns = [
("providers/rac", RACProviderViewSet), ("providers/rac", RACProviderViewSet),
("propertymappings/provider/rac", RACPropertyMappingViewSet), ("propertymappings/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet), ("rac/endpoints", EndpointViewSet),
("rac/connection_tokens", ConnectionTokenViewSet), ("rac/connection_tokens", ConnectionTokenViewSet),
] ]

View File

@ -3,7 +3,7 @@
from datetime import datetime from datetime import datetime
from django.core.cache import cache from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_save from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.timezone import get_current_timezone from django.utils.timezone import get_current_timezone
@ -27,9 +27,3 @@ def post_save_license(sender: type[License], instance: License, **_):
"""Trigger license usage calculation when license is saved""" """Trigger license usage calculation when license is saved"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
enterprise_update_usage.delay() enterprise_update_usage.delay()
@receiver(post_delete, sender=License)
def post_delete_license(sender: type[License], instance: License, **_):
"""Clear license cache when license is deleted"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)

View File

@ -9,26 +9,10 @@ from django.utils.timezone import now
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import ( from authentik.enterprise.models import License
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
# Valid license expiry _exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry, expires soon
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
# Invalid license expiry, recently expired
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
# Invalid license expiry, expired longer ago
expiry_expired_read_only = int(
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
)
class TestEnterpriseLicense(TestCase): class TestEnterpriseLicense(TestCase):
@ -39,7 +23,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
exp=expiry_valid, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, internal_users=100,
external_users=100, external_users=100,
@ -49,7 +33,7 @@ class TestEnterpriseLicense(TestCase):
def test_valid(self): def test_valid(self):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid) self.assertTrue(lic.status.is_valid())
self.assertEqual(lic.internal_users, 100) self.assertEqual(lic.internal_users, 100)
def test_invalid(self): def test_invalid(self):
@ -62,7 +46,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock( MagicMock(
return_value=LicenseKey( return_value=LicenseKey(
aud="", aud="",
exp=expiry_valid, exp=_exp,
name=generate_id(), name=generate_id(),
internal_users=100, internal_users=100,
external_users=100, external_users=100,
@ -72,186 +56,11 @@ class TestEnterpriseLicense(TestCase):
def test_valid_multiple(self): def test_valid_multiple(self):
"""Check license verification""" """Check license verification"""
lic = License.objects.create(key=generate_id()) lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.status().is_valid) self.assertTrue(lic.status.is_valid())
lic2 = License.objects.create(key=generate_id()) lic2 = License.objects.create(key=generate_id())
self.assertTrue(lic2.status.status().is_valid) self.assertTrue(lic2.status.is_valid())
total = LicenseKey.get_total() total = LicenseKey.get_total()
self.assertEqual(total.internal_users, 200) self.assertEqual(total.internal_users, 200)
self.assertEqual(total.external_users, 200) self.assertEqual(total.external_users, 200)
self.assertEqual(total.exp, expiry_valid) self.assertEqual(total.exp, _exp)
self.assertTrue(total.status().is_valid) self.assertTrue(total.is_valid())
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_user_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_admin_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired_read_only,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_expired(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_soon,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_soon(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON)

View File

@ -1,217 +0,0 @@
"""read only tests"""
from datetime import timedelta
from unittest.mock import MagicMock, patch
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.flows.models import (
FlowDesignation,
FlowStageBinding,
)
from authentik.flows.tests import FlowTestCase
from authentik.lib.generators import generate_id
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.password import BACKEND_INBUILT
from authentik.stages.password.models import PasswordStage
from authentik.stages.user_login.models import UserLoginStage
class TestReadOnly(FlowTestCase):
"""Test read_only"""
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_login(self):
"""Test flow, ensure login is still possible with read only mode"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
flow = create_test_flow(
FlowDesignation.AUTHENTICATION,
)
ident_stage = IdentificationStage.objects.create(
name=generate_id(),
user_fields=[UserFields.E_MAIL],
pretend_user_exists=False,
)
FlowStageBinding.objects.create(
target=flow,
stage=ident_stage,
order=0,
)
password_stage = PasswordStage.objects.create(
name=generate_id(), backends=[BACKEND_INBUILT]
)
FlowStageBinding.objects.create(
target=flow,
stage=password_stage,
order=1,
)
login_stage = UserLoginStage.objects.create(
name=generate_id(),
)
FlowStageBinding.objects.create(
target=flow,
stage=login_stage,
order=2,
)
user = create_test_user()
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(exec_url)
self.assertStageResponse(
response,
flow,
component="ak-stage-identification",
password_fields=False,
primary_action="Log in",
sources=[],
show_source_labels=False,
user_fields=[UserFields.E_MAIL],
)
response = self.client.post(exec_url, {"uid_field": user.email}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-password")
response = self.client.post(exec_url, {"password": user.username}, follow=True)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_licenses(self):
"""Test that managing licenses is still possible"""
license = License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Reading is always allowed
response = self.client.get(reverse("authentik_api:license-list"))
self.assertEqual(response.status_code, 200)
# Writing should also be allowed
response = self.client.patch(
reverse("authentik_api:license-detail", kwargs={"pk": license.pk})
)
self.assertEqual(response.status_code, 200)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_flows(self):
"""Test flow"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Read only is still allowed
response = self.client.get(reverse("authentik_api:flow-list"))
self.assertEqual(response.status_code, 200)
flow = create_test_flow()
# Writing is not
response = self.client.patch(
reverse("authentik_api:flow-detail", kwargs={"slug": flow.slug})
)
self.assertJSONEqual(
response.content,
{"detail": "Request denied due to expired/invalid license.", "code": "denied_license"},
)
self.assertEqual(response.status_code, 400)

View File

@ -69,5 +69,8 @@ class NotificationViewSet(
@action(detail=False, methods=["post"]) @action(detail=False, methods=["post"])
def mark_all_seen(self, request: Request) -> Response: def mark_all_seen(self, request: Request) -> Response:
"""Mark all the user's notifications as seen""" """Mark all the user's notifications as seen"""
Notification.objects.filter(user=request.user, seen=False).update(seen=True) notifications = Notification.objects.filter(user=request.user)
for notification in notifications:
notification.seen = True
Notification.objects.bulk_update(notifications, ["seen"])
return Response({}, status=204) return Response({}, status=204)

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, TypedDict
from django.http import HttpRequest from django.http import HttpRequest
from geoip2.errors import GeoIP2Error from geoip2.errors import GeoIP2Error
from geoip2.models import ASN from geoip2.models import ASN
from sentry_sdk import start_span from sentry_sdk import Hub
from authentik.events.context_processors.mmdb import MMDBContextProcessor from authentik.events.context_processors.mmdb import MMDBContextProcessor
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -48,7 +48,7 @@ class ASNContextProcessor(MMDBContextProcessor):
def asn(self, ip_address: str) -> ASN | None: def asn(self, ip_address: str) -> ASN | None:
"""Wrapper for Reader.asn""" """Wrapper for Reader.asn"""
with start_span( with Hub.current.start_span(
op="authentik.events.asn.asn", op="authentik.events.asn.asn",
description=ip_address, description=ip_address,
): ):

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, TypedDict
from django.http import HttpRequest from django.http import HttpRequest
from geoip2.errors import GeoIP2Error from geoip2.errors import GeoIP2Error
from geoip2.models import City from geoip2.models import City
from sentry_sdk import start_span from sentry_sdk.hub import Hub
from authentik.events.context_processors.mmdb import MMDBContextProcessor from authentik.events.context_processors.mmdb import MMDBContextProcessor
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -49,7 +49,7 @@ class GeoIPContextProcessor(MMDBContextProcessor):
def city(self, ip_address: str) -> City | None: def city(self, ip_address: str) -> City | None:
"""Wrapper for Reader.city""" """Wrapper for Reader.city"""
with start_span( with Hub.current.start_span(
op="authentik.events.geo.city", op="authentik.events.geo.city",
description=ip_address, description=ip_address,
): ):

View File

@ -35,7 +35,6 @@ IGNORED_MODELS = tuple(
_CTX_OVERWRITE_USER = ContextVar[User | None]("authentik_events_log_overwrite_user", default=None) _CTX_OVERWRITE_USER = ContextVar[User | None]("authentik_events_log_overwrite_user", default=None)
_CTX_IGNORE = ContextVar[bool]("authentik_events_log_ignore", default=False) _CTX_IGNORE = ContextVar[bool]("authentik_events_log_ignore", default=False)
_CTX_REQUEST = ContextVar[HttpRequest | None]("authentik_events_log_request", default=None)
def should_log_model(model: Model) -> bool: def should_log_model(model: Model) -> bool:
@ -150,13 +149,11 @@ class AuditMiddleware:
m2m_changed.disconnect(dispatch_uid=request.request_id) m2m_changed.disconnect(dispatch_uid=request.request_id)
def __call__(self, request: HttpRequest) -> HttpResponse: def __call__(self, request: HttpRequest) -> HttpResponse:
_CTX_REQUEST.set(request)
self.connect(request) self.connect(request)
response = self.get_response(request) response = self.get_response(request)
self.disconnect(request) self.disconnect(request)
_CTX_REQUEST.set(None)
return response return response
def process_exception(self, request: HttpRequest, exception: Exception): def process_exception(self, request: HttpRequest, exception: Exception):
@ -170,7 +167,7 @@ class AuditMiddleware:
thread = EventNewThread( thread = EventNewThread(
EventAction.SUSPICIOUS_REQUEST, EventAction.SUSPICIOUS_REQUEST,
request, request,
message=exception_to_string(exception), message=str(exception),
) )
thread.run() thread.run()
elif before_send({}, {"exc_info": (None, exception, None)}) is not None: elif before_send({}, {"exc_info": (None, exception, None)}) is not None:
@ -195,8 +192,6 @@ class AuditMiddleware:
return return
if _CTX_IGNORE.get(): if _CTX_IGNORE.get():
return return
if request.request_id != _CTX_REQUEST.get().request_id:
return
user = self.get_user(request) user = self.get_user(request)
action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED
@ -210,8 +205,6 @@ class AuditMiddleware:
return return
if _CTX_IGNORE.get(): if _CTX_IGNORE.get():
return return
if request.request_id != _CTX_REQUEST.get().request_id:
return
user = self.get_user(request) user = self.get_user(request)
EventNewThread( EventNewThread(
@ -237,8 +230,6 @@ class AuditMiddleware:
return return
if _CTX_IGNORE.get(): if _CTX_IGNORE.get():
return return
if request.request_id != _CTX_REQUEST.get().request_id:
return
user = self.get_user(request) user = self.get_user(request)
EventNewThread( EventNewThread(

View File

@ -238,8 +238,6 @@ class Event(SerializerModel, ExpiringModel):
"args": cleanse_dict(QueryDict(request.META.get("QUERY_STRING", ""))), "args": cleanse_dict(QueryDict(request.META.get("QUERY_STRING", ""))),
"user_agent": request.META.get("HTTP_USER_AGENT", ""), "user_agent": request.META.get("HTTP_USER_AGENT", ""),
} }
if hasattr(request, "request_id"):
self.context["http_request"]["request_id"] = request.request_id
# Special case for events created during flow execution # Special case for events created during flow execution
# since they keep the http query within a wrapped query # since they keep the http query within a wrapped query
if QS_QUERY in self.context["http_request"]["args"]: if QS_QUERY in self.context["http_request"]["args"]:

View File

@ -13,7 +13,7 @@ from authentik.events.apps import SYSTEM_TASK_STATUS
from authentik.events.models import Event, EventAction, SystemTask from authentik.events.models import Event, EventAction, SystemTask
from authentik.events.tasks import event_notification_handler, gdpr_cleanup from authentik.events.tasks import event_notification_handler, gdpr_cleanup
from authentik.flows.models import Stage from authentik.flows.models import Stage
from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.root.monitoring import monitoring_set from authentik.root.monitoring import monitoring_set
from authentik.stages.invitation.models import Invitation from authentik.stages.invitation.models import Invitation
@ -38,9 +38,6 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
# Save the login method used # Save the login method used
kwargs[PLAN_CONTEXT_METHOD] = flow_plan.context[PLAN_CONTEXT_METHOD] kwargs[PLAN_CONTEXT_METHOD] = flow_plan.context[PLAN_CONTEXT_METHOD]
kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {})
if PLAN_CONTEXT_OUTPOST in flow_plan.context:
# Save outpost context
kwargs[PLAN_CONTEXT_OUTPOST] = flow_plan.context[PLAN_CONTEXT_OUTPOST]
event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user) event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user)
request.session[SESSION_LOGIN_EVENT] = event request.session[SESSION_LOGIN_EVENT] = event
@ -78,10 +75,7 @@ def on_login_failed(
**kwargs, **kwargs,
): ):
"""Failed Login, authentik custom event""" """Failed Login, authentik custom event"""
user = User.objects.filter(username=credentials.get("username")).first() Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **kwargs).from_http(request)
Event.new(EventAction.LOGIN_FAILED, **credentials, stage=stage, **kwargs).from_http(
request, user
)
@receiver(invitation_used) @receiver(invitation_used)

View File

@ -2,8 +2,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from django.urls import reverse from django.test import TestCase
from rest_framework.test import APITestCase
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.models import ( from authentik.events.models import (
@ -11,7 +10,6 @@ from authentik.events.models import (
EventAction, EventAction,
Notification, Notification,
NotificationRule, NotificationRule,
NotificationSeverity,
NotificationTransport, NotificationTransport,
NotificationWebhookMapping, NotificationWebhookMapping,
TransportMode, TransportMode,
@ -22,7 +20,7 @@ from authentik.policies.exceptions import PolicyException
from authentik.policies.models import PolicyBinding from authentik.policies.models import PolicyBinding
class TestEventsNotifications(APITestCase): class TestEventsNotifications(TestCase):
"""Test Event Notifications""" """Test Event Notifications"""
def setUp(self) -> None: def setUp(self) -> None:
@ -133,15 +131,3 @@ class TestEventsNotifications(APITestCase):
Notification.objects.all().delete() Notification.objects.all().delete()
Event.new(EventAction.CUSTOM_PREFIX).save() Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(Notification.objects.first().body, "foo") self.assertEqual(Notification.objects.first().body, "foo")
def test_api_mark_all_seen(self):
"""Test mark_all_seen"""
self.client.force_login(self.user)
Notification.objects.create(
severity=NotificationSeverity.NOTICE, body="foo", user=self.user, seen=False
)
response = self.client.post(reverse("authentik_api:notification-mark-all-seen"))
self.assertEqual(response.status_code, 204)
self.assertFalse(Notification.objects.filter(body="foo", seen=False).exists())

View File

@ -37,7 +37,6 @@ from authentik.lib.utils.file import (
) )
from authentik.lib.views import bad_request_message from authentik.lib.views import bad_request_message
from authentik.rbac.decorators import permission_required from authentik.rbac.decorators import permission_required
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger() LOGGER = get_logger()
@ -282,7 +281,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
400: OpenApiResponse(description="Flow not applicable"), 400: OpenApiResponse(description="Flow not applicable"),
}, },
) )
@action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) @action(detail=True, pagination_class=None, filter_backends=[])
def execute(self, request: Request, slug: str): def execute(self, request: Request, slug: str):
"""Execute flow for current user""" """Execute flow for current user"""
# Because we pre-plan the flow here, and not in the planner, we need to manually clear # Because we pre-plan the flow here, and not in the planner, we need to manually clear

View File

@ -32,6 +32,14 @@ class FlowLayout(models.TextChoices):
SIDEBAR_RIGHT = "sidebar_right" SIDEBAR_RIGHT = "sidebar_right"
class ChallengeTypes(Enum):
"""Currently defined challenge types"""
NATIVE = "native"
SHELL = "shell"
REDIRECT = "redirect"
class ErrorDetailSerializer(PassiveSerializer): class ErrorDetailSerializer(PassiveSerializer):
"""Serializer for rest_framework's error messages""" """Serializer for rest_framework's error messages"""
@ -52,6 +60,9 @@ class Challenge(PassiveSerializer):
"""Challenge that gets sent to the client based on which stage """Challenge that gets sent to the client based on which stage
is currently active""" is currently active"""
type = ChoiceField(
choices=[(x.value, x.name) for x in ChallengeTypes],
)
flow_info = ContextualFlowInfo(required=False) flow_info = ContextualFlowInfo(required=False)
component = CharField(default="") component = CharField(default="")
@ -85,6 +96,7 @@ class FlowErrorChallenge(Challenge):
"""Challenge class when an unhandled error occurs during a stage. Normal users """Challenge class when an unhandled error occurs during a stage. Normal users
are shown an error message, superusers are shown a full stacktrace.""" are shown an error message, superusers are shown a full stacktrace."""
type = CharField(default=ChallengeTypes.NATIVE.value)
component = CharField(default="ak-stage-flow-error") component = CharField(default="ak-stage-flow-error")
request_id = CharField() request_id = CharField()

Some files were not shown because too many files have changed in this diff Show More